Back to Repositories

Testing Bilingual Retrieval QA System in ColossalAI

This test suite validates the multilingual retrieval QA functionality in ColossalAI, testing both English and Chinese language question-answering capabilities using the UniversalRetrievalConversation system.

Test Coverage Overview

The test suite provides comprehensive coverage of the retrieval QA system’s bilingual capabilities.

  • Tests English language query processing and response generation
  • Validates Chinese language query handling and answers
  • Verifies proper environment configuration and file path handling
  • Tests model loading for both language modes

Implementation Analysis

The implementation follows a structured approach using environment variables for configuration management and the UniversalRetrievalConversation class for QA processing.

  • Separate test functions for English and Chinese queries
  • Consistent configuration setup across both language modes
  • Environment-based path configuration for flexibility

Technical Details

  • Uses os.environ for configuration management
  • Implements UniversalRetrievalConversation class for QA processing
  • Configures separate model paths for English and Chinese
  • Handles file separators and data paths for both languages
  • SQL file integration for data retrieval

Best Practices Demonstrated

The test suite exemplifies several testing best practices for multilingual AI systems.

  • Clear separation of concerns between language-specific tests
  • Consistent configuration patterns
  • Environment variable usage for flexible deployment
  • Structured test organization with specific language handling

hpcaitech/colossalai

applications/ColossalQA/tests/test_retrieval_qa.py

            
import os

from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation


def test_en_retrievalQA():
    data_path_en = os.environ.get("TEST_DATA_PATH_EN")
    data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
    en_model_path = os.environ.get("EN_MODEL_PATH")
    zh_model_path = os.environ.get("ZH_MODEL_PATH")
    zh_model_name = os.environ.get("ZH_MODEL_NAME")
    en_model_name = os.environ.get("EN_MODEL_NAME")
    sql_file_path = os.environ.get("SQL_FILE_PATH")
    qa_session = UniversalRetrievalConversation(
        files_en=[{"data_path": data_path_en, "name": "company information", "separator": "
"}],
        files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "
"}],
        zh_model_path=zh_model_path,
        en_model_path=en_model_path,
        zh_model_name=zh_model_name,
        en_model_name=en_model_name,
        sql_file_path=sql_file_path,
    )
    ans = qa_session.run("which company runs business in hotel industry?", which_language="en")
    print(ans)


def test_zh_retrievalQA():
    data_path_en = os.environ.get("TEST_DATA_PATH_EN")
    data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
    en_model_path = os.environ.get("EN_MODEL_PATH")
    zh_model_path = os.environ.get("ZH_MODEL_PATH")
    zh_model_name = os.environ.get("ZH_MODEL_NAME")
    en_model_name = os.environ.get("EN_MODEL_NAME")
    sql_file_path = os.environ.get("SQL_FILE_PATH")
    qa_session = UniversalRetrievalConversation(
        files_en=[{"data_path": data_path_en, "name": "company information", "separator": "
"}],
        files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "
"}],
        zh_model_path=zh_model_path,
        en_model_path=en_model_path,
        zh_model_name=zh_model_name,
        en_model_name=en_model_name,
        sql_file_path=sql_file_path,
    )
    ans = qa_session.run("哪家公司在经营酒店业务?", which_language="zh")
    print(ans)


if __name__ == "__main__":
    test_en_retrievalQA()
    test_zh_retrievalQA()