Testing Memory Management Implementation in ColossalAI QA System
This test suite validates the memory management functionality in ColossalAI’s QA system, focusing on conversation buffer handling and document retrieval capabilities. The tests verify both long-term and short-term memory behaviors with document integration.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
applications/ColossalQA/tests/test_memory.py
import os
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.prompt.prompt import PROMPT_RETRIEVAL_QA_ZH
from colossalqa.retriever import CustomRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
def test_memory_long():
model_path = os.environ.get("EN_MODEL_PATH")
data_path = os.environ.get("TEST_DATA_PATH_EN")
model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get("SQL_FILE_PATH")
if not os.path.exists(sql_file_path):
os.makedirs(sql_file_path)
colossal_api = ColossalAPI.get_api(model_name, model_path)
llm = ColossalLLM(n=4, api=colossal_api)
memory = ConversationBufferWithSummary(
llm=llm, max_tokens=600, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
)
retriever_data = DocumentLoader([[data_path, "company information"]]).all_data
# Split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
splits = text_splitter.split_documents(retriever_data)
embedding = HuggingFaceEmbeddings(
model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
)
# Create retriever
information_retriever = CustomRetriever(k=3, sql_file_path=sql_file_path)
information_retriever.add_documents(docs=splits, cleanup="incremental", mode="by_source", embedding=embedding)
memory.initiate_document_retrieval_chain(
llm,
PROMPT_RETRIEVAL_QA_ZH,
information_retriever,
chain_type_kwargs={
"chat_history": "",
},
)
# This keep the prompt length excluding dialogues the same
docs = information_retriever.get_relevant_documents("this is a test input.")
prompt_length = memory.chain.prompt_length(docs, **{"question": "this is a test input.", "chat_history": ""})
remain = 600 - prompt_length
have_summarization_flag = False
for i in range(40):
chat_history = memory.load_memory_variables({"question": "this is a test input.", "input_documents": docs})[
"chat_history"
]
assert memory.get_conversation_length() <= remain
memory.save_context({"question": "this is a test input."}, {"output": "this is a test output."})
if "A summarization of historical conversation:" in chat_history:
have_summarization_flag = True
assert have_summarization_flag == True
def test_memory_short():
model_path = os.environ.get("EN_MODEL_PATH")
data_path = os.environ.get("TEST_DATA_PATH_EN")
model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get("SQL_FILE_PATH")
if not os.path.exists(sql_file_path):
os.makedirs(sql_file_path)
colossal_api = ColossalAPI.get_api(model_name, model_path)
llm = ColossalLLM(n=4, api=colossal_api)
memory = ConversationBufferWithSummary(
llm=llm, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
)
retriever_data = DocumentLoader([[data_path, "company information"]]).all_data
# Split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
splits = text_splitter.split_documents(retriever_data)
embedding = HuggingFaceEmbeddings(
model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
)
# create retriever
information_retriever = CustomRetriever(k=3, sql_file_path=sql_file_path)
information_retriever.add_documents(docs=splits, cleanup="incremental", mode="by_source", embedding=embedding)
memory.initiate_document_retrieval_chain(
llm,
PROMPT_RETRIEVAL_QA_ZH,
information_retriever,
chain_type_kwargs={
"chat_history": "",
},
)
# This keep the prompt length excluding dialogues the same
docs = information_retriever.get_relevant_documents("this is a test input.", return_scores=True)
for i in range(4):
chat_history = memory.load_memory_variables({"question": "this is a test input.", "input_documents": docs})[
"chat_history"
]
assert chat_history.count("Assistant: this is a test output.") == i
assert chat_history.count("Human: this is a test input.") == i
memory.save_context({"question": "this is a test input."}, {"output": "this is a test output."})
if __name__ == "__main__":
test_memory_short()
test_memory_long()