Back to Repositories

Testing Memory Management Implementation in ColossalAI QA System

This test suite validates the memory management functionality in ColossalAI’s QA system, focusing on conversation buffer handling and document retrieval capabilities. The tests verify both long-term and short-term memory behaviors with document integration.

Test Coverage Overview

The test suite provides comprehensive coverage of memory management features in ColossalQA.

Key areas tested include:
  • Long-term memory handling with token limits
  • Short-term conversation buffer management
  • Document retrieval integration
  • Conversation history summarization
  • Memory context preservation

Implementation Analysis

The testing approach implements two main test cases – test_memory_long() and test_memory_short() – to validate different memory scenarios. The implementation utilizes ColossalLLM, HuggingFace embeddings, and custom retriever components with specific configuration parameters for document processing and conversation management.

Technical patterns include:
  • Environment-based configuration
  • Token-aware memory management
  • Document chunking and embedding
  • Conversation context validation

Technical Details

Testing tools and configurations:
  • RecursiveCharacterTextSplitter for document processing
  • HuggingFaceEmbeddings with m3e-base model
  • CustomRetriever for document management
  • ConversationBufferWithSummary for memory handling
  • Environment variables for model and data paths
  • SQL-based document storage

Best Practices Demonstrated

The test suite exemplifies robust testing practices for conversational AI systems.

Notable practices include:
  • Systematic validation of memory constraints
  • Comprehensive assertion checks
  • Environment-independent setup
  • Modular test organization
  • Clear separation of long and short-term memory tests

hpcaitech/colossalai

applications/ColossalQA/tests/test_memory.py

            
import os

from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.prompt.prompt import PROMPT_RETRIEVAL_QA_ZH
from colossalqa.retriever import CustomRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter


def test_memory_long():
    model_path = os.environ.get("EN_MODEL_PATH")
    data_path = os.environ.get("TEST_DATA_PATH_EN")
    model_name = os.environ.get("EN_MODEL_NAME")
    sql_file_path = os.environ.get("SQL_FILE_PATH")

    if not os.path.exists(sql_file_path):
        os.makedirs(sql_file_path)

    colossal_api = ColossalAPI.get_api(model_name, model_path)
    llm = ColossalLLM(n=4, api=colossal_api)
    memory = ConversationBufferWithSummary(
        llm=llm, max_tokens=600, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
    )
    retriever_data = DocumentLoader([[data_path, "company information"]]).all_data

    # Split
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
    splits = text_splitter.split_documents(retriever_data)

    embedding = HuggingFaceEmbeddings(
        model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
    )

    # Create retriever
    information_retriever = CustomRetriever(k=3, sql_file_path=sql_file_path)
    information_retriever.add_documents(docs=splits, cleanup="incremental", mode="by_source", embedding=embedding)

    memory.initiate_document_retrieval_chain(
        llm,
        PROMPT_RETRIEVAL_QA_ZH,
        information_retriever,
        chain_type_kwargs={
            "chat_history": "",
        },
    )

    # This keep the prompt length excluding dialogues the same
    docs = information_retriever.get_relevant_documents("this is a test input.")
    prompt_length = memory.chain.prompt_length(docs, **{"question": "this is a test input.", "chat_history": ""})
    remain = 600 - prompt_length
    have_summarization_flag = False
    for i in range(40):
        chat_history = memory.load_memory_variables({"question": "this is a test input.", "input_documents": docs})[
            "chat_history"
        ]

        assert memory.get_conversation_length() <= remain
        memory.save_context({"question": "this is a test input."}, {"output": "this is a test output."})
        if "A summarization of historical conversation:" in chat_history:
            have_summarization_flag = True
    assert have_summarization_flag == True


def test_memory_short():
    model_path = os.environ.get("EN_MODEL_PATH")
    data_path = os.environ.get("TEST_DATA_PATH_EN")
    model_name = os.environ.get("EN_MODEL_NAME")
    sql_file_path = os.environ.get("SQL_FILE_PATH")

    if not os.path.exists(sql_file_path):
        os.makedirs(sql_file_path)

    colossal_api = ColossalAPI.get_api(model_name, model_path)
    llm = ColossalLLM(n=4, api=colossal_api)
    memory = ConversationBufferWithSummary(
        llm=llm, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
    )
    retriever_data = DocumentLoader([[data_path, "company information"]]).all_data

    # Split
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
    splits = text_splitter.split_documents(retriever_data)

    embedding = HuggingFaceEmbeddings(
        model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
    )

    # create retriever
    information_retriever = CustomRetriever(k=3, sql_file_path=sql_file_path)
    information_retriever.add_documents(docs=splits, cleanup="incremental", mode="by_source", embedding=embedding)

    memory.initiate_document_retrieval_chain(
        llm,
        PROMPT_RETRIEVAL_QA_ZH,
        information_retriever,
        chain_type_kwargs={
            "chat_history": "",
        },
    )

    # This keep the prompt length excluding dialogues the same
    docs = information_retriever.get_relevant_documents("this is a test input.", return_scores=True)

    for i in range(4):
        chat_history = memory.load_memory_variables({"question": "this is a test input.", "input_documents": docs})[
            "chat_history"
        ]
        assert chat_history.count("Assistant: this is a test output.") == i
        assert chat_history.count("Human: this is a test input.") == i
        memory.save_context({"question": "this is a test input."}, {"output": "this is a test output."})


if __name__ == "__main__":
    test_memory_short()
    test_memory_long()