Testing ContextChatEngine Chat Operations in LlamaIndex
This test suite validates the ContextChatEngine functionality in LlamaIndex, covering both synchronous and asynchronous chat operations with mock components. The tests ensure proper handling of chat history, system prompts, and streaming responses.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
run-llama/llama_index
llama-index-core/tests/chat_engine/test_context.py
import pytest
from llama_index.core import MockEmbedding
from llama_index.core.chat_engine.context import (
ContextChatEngine,
)
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.llms.mock import MockLLM
from llama_index.core.schema import Document
SYSTEM_PROMPT = "Talk like a pirate."
@pytest.fixture()
def chat_engine() -> ContextChatEngine:
index = VectorStoreIndex.from_documents(
[Document.example()], embed_model=MockEmbedding(embed_dim=3)
)
retriever = index.as_retriever()
return ContextChatEngine.from_defaults(
retriever, llm=MockLLM(), system_prompt=SYSTEM_PROMPT
)
def test_chat(chat_engine: ContextChatEngine):
response = chat_engine.chat("Hello World!")
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert len(chat_engine.chat_history) == 2
response = chat_engine.chat("What is the capital of the moon?")
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert "What is the capital of the moon?" in str(response)
assert len(chat_engine.chat_history) == 4
def test_chat_stream(chat_engine: ContextChatEngine):
response = chat_engine.stream_chat("Hello World!")
num_iters = 0
for _ in response.response_gen:
num_iters += 1
assert num_iters > 10
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert len(chat_engine.chat_history) == 2
response = chat_engine.stream_chat("What is the capital of the moon?")
num_iters = 0
for _ in response.response_gen:
num_iters += 1
assert num_iters > 10
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert "What is the capital of the moon?" in str(response)
assert len(chat_engine.chat_history) == 4
@pytest.mark.asyncio()
async def test_achat(chat_engine: ContextChatEngine):
response = await chat_engine.achat("Hello World!")
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert len(chat_engine.chat_history) == 2
response = await chat_engine.achat("What is the capital of the moon?")
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert "What is the capital of the moon?" in str(response)
assert len(chat_engine.chat_history) == 4
@pytest.mark.asyncio()
async def test_chat_astream(chat_engine: ContextChatEngine):
response = await chat_engine.astream_chat("Hello World!")
num_iters = 0
async for _ in response.async_response_gen():
num_iters += 1
assert num_iters > 10
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert len(chat_engine.chat_history) == 2
response = await chat_engine.astream_chat("What is the capital of the moon?")
num_iters = 0
async for _ in response.async_response_gen():
num_iters += 1
assert num_iters > 10
assert SYSTEM_PROMPT in str(response)
assert "Hello World!" in str(response)
assert "What is the capital of the moon?" in str(response)
assert len(chat_engine.chat_history) == 4