Back to Repositories

Testing ReActAgent Chat Implementation in llama_index

This test suite validates the ReActAgent implementation in the llama_index framework, focusing on chat functionality, streaming responses, and tool integration. It verifies both synchronous and asynchronous agent interactions using mock LLM implementations.

Test Coverage Overview

The test suite provides comprehensive coverage of the ReActAgent’s core functionality:
  • Basic chat interactions with synchronous and asynchronous methods
  • Streaming chat responses with token-by-token processing
  • Tool integration testing using a mock addition function
  • Chat history validation and message role verification
  • Error handling for malformed agent responses

Implementation Analysis

The testing approach utilizes pytest fixtures and mock implementations to isolate and verify agent behavior. It employs both MockLLM and MockStreamChatLLM classes to simulate language model responses, with careful attention to message formatting and response streaming mechanics.

The tests follow a structured pattern of setting up mock responses, initializing the agent with tools, and validating both the final output and intermediate states.

Technical Details

Key technical components include:
  • pytest framework with async support via pytest.mark.asyncio
  • Custom mock implementations of LLM interfaces
  • FunctionTool wrapper for testing tool integration
  • ChatMessage and ChatResponse type validation
  • Stream processing using re.split for token simulation

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Isolation of external dependencies through comprehensive mocking
  • Explicit type checking and assertions
  • Separate test cases for sync/async implementations
  • Thorough validation of both success paths and error conditions
  • Clear separation of test setup, execution, and verification phases

run-llama/llama_index

llama-index-core/tests/agent/react/test_react_agent.py

            
import re
from typing import Any, List, Sequence

import pytest
from llama_index.core.agent.react.base import ReActAgent
from llama_index.core.agent.react.types import ObservationReasoningStep
from llama_index.core.agent.types import Task
from llama_index.core.base.llms.types import (
    ChatMessage,
    ChatResponse,
    ChatResponseGen,
    MessageRole,
)
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.chat_engine.types import (
    AgentChatResponse,
    StreamingAgentChatResponse,
)
from llama_index.core.llms.mock import MockLLM
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.types import BaseTool


@pytest.fixture()
def add_tool() -> FunctionTool:
    def add(a: int, b: int) -> int:
        """Add two integers and returns the result integer."""
        return a + b

    return FunctionTool.from_defaults(fn=add)


class MockChatLLM(MockLLM):
    _i: int = PrivateAttr()
    _responses: List[ChatMessage] = PrivateAttr()

    def __init__(self, responses: List[ChatMessage]) -> None:
        super().__init__()
        self._i = 0  # call counter, determines which response to return
        self._responses = responses  # list of responses to return

    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
        del messages  # unused
        response = ChatResponse(
            message=self._responses[self._i],
        )
        self._i += 1
        return response


MOCK_ACTION_RESPONSE = """\
Thought: I need to use a tool to help me answer the question.
Action: add
Action Input: {"a": 1, "b": 1}
"""

MOCK_FINAL_RESPONSE = """\
Thought: I have enough information to answer the question without using any more tools.
Answer: 2
"""


def test_chat_basic(
    add_tool: FunctionTool,
) -> None:
    mock_llm = MockChatLLM(
        responses=[
            ChatMessage(
                content=MOCK_ACTION_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
            ChatMessage(
                content=MOCK_FINAL_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
        ]
    )

    agent = ReActAgent.from_tools(
        tools=[add_tool],
        llm=mock_llm,
    )
    response = agent.chat("What is 1 + 1?")
    assert isinstance(response, AgentChatResponse)
    assert response.response == "2"

    chat_history = agent.chat_history
    assert chat_history == [
        ChatMessage(
            content="What is 1 + 1?",
            role=MessageRole.USER,
        ),
        ChatMessage(
            content="2",
            role=MessageRole.ASSISTANT,
        ),
    ]


@pytest.mark.asyncio()
async def test_achat_basic(
    add_tool: FunctionTool,
) -> None:
    mock_llm = MockChatLLM(
        responses=[
            ChatMessage(
                content=MOCK_ACTION_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
            ChatMessage(
                content=MOCK_FINAL_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
        ]
    )

    agent = ReActAgent.from_tools(
        tools=[add_tool],
        llm=mock_llm,
    )
    response = await agent.achat("What is 1 + 1?")
    assert isinstance(response, AgentChatResponse)
    assert response.response == "2"

    chat_history = agent.chat_history
    assert chat_history == [
        ChatMessage(
            content="What is 1 + 1?",
            role=MessageRole.USER,
        ),
        ChatMessage(
            content="2",
            role=MessageRole.ASSISTANT,
        ),
    ]


class MockStreamChatLLM(MockLLM):
    _i: int = PrivateAttr()
    _responses: List[ChatMessage] = PrivateAttr()

    def __init__(self, responses: List[ChatMessage]) -> None:
        super().__init__()
        self._i = 0  # call counter, determines which response to return
        self._responses = responses  # list of responses to return

    def stream_chat(
        self, messages: Sequence[ChatMessage], **kwargs: Any
    ) -> ChatResponseGen:
        del messages  # unused
        full_message = self._responses[self._i]
        self._i += 1

        role = full_message.role
        full_text = full_message.content or ""

        text_so_far = ""
        # create mock stream
        mock_stream = re.split(r"(\s+)", full_text)
        for token in mock_stream:
            text_so_far += token
            message = ChatMessage(
                content=text_so_far,
                role=role,
            )
            yield ChatResponse(
                message=message,
                delta=token,
            )


MOCK_STREAM_FINAL_RESPONSE = """\
Thought: I have enough information to answer the question without using any more tools.
Answer: 2 is the final answer.
"""


def test_stream_chat_basic(
    add_tool: FunctionTool,
) -> None:
    mock_llm = MockStreamChatLLM(
        responses=[
            ChatMessage(
                content=MOCK_ACTION_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
            ChatMessage(
                content=MOCK_STREAM_FINAL_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
        ]
    )

    agent = ReActAgent.from_tools(
        tools=[add_tool],
        llm=mock_llm,
    )
    response = agent.stream_chat("What is 1 + 1?")
    assert isinstance(response, StreamingAgentChatResponse)

    # exhaust stream
    for delta in response.response_gen:
        continue
    expected_answer = MOCK_STREAM_FINAL_RESPONSE.split("Answer: ")[-1].strip()
    assert response.response == expected_answer

    assert agent.chat_history == [
        ChatMessage(
            content="What is 1 + 1?",
            role=MessageRole.USER,
        ),
        ChatMessage(
            content="2 is the final answer.",
            role=MessageRole.ASSISTANT,
        ),
    ]


@pytest.mark.asyncio()
async def test_astream_chat_basic(
    add_tool: FunctionTool,
) -> None:
    mock_llm = MockStreamChatLLM(
        responses=[
            ChatMessage(
                content=MOCK_ACTION_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
            ChatMessage(
                content=MOCK_STREAM_FINAL_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
        ]
    )

    agent = ReActAgent.from_tools(
        tools=[add_tool],
        llm=mock_llm,
    )
    response = await agent.astream_chat("What is 1 + 1?")
    assert isinstance(response, StreamingAgentChatResponse)

    # exhaust stream
    async for delta in response.async_response_gen():
        continue
    expected_answer = MOCK_STREAM_FINAL_RESPONSE.split("Answer: ")[-1].strip()
    assert response.response == expected_answer

    assert agent.chat_history == [
        ChatMessage(
            content="What is 1 + 1?",
            role=MessageRole.USER,
        ),
        ChatMessage(
            content="2 is the final answer.",
            role=MessageRole.ASSISTANT,
        ),
    ]


def _get_agent(
    tools: List[BaseTool],
    streaming: bool = False,
) -> ReActAgent:
    if streaming:
        mock_llm = MockStreamChatLLM(
            responses=[
                ChatMessage(
                    content=MOCK_ACTION_RESPONSE,
                    role=MessageRole.ASSISTANT,
                ),
                ChatMessage(
                    content=MOCK_STREAM_FINAL_RESPONSE,
                    role=MessageRole.ASSISTANT,
                ),
            ]
        )
    else:
        mock_llm = MockChatLLM(
            responses=[
                ChatMessage(
                    content=MOCK_ACTION_RESPONSE,
                    role=MessageRole.ASSISTANT,
                ),
                ChatMessage(
                    content=MOCK_FINAL_RESPONSE,
                    role=MessageRole.ASSISTANT,
                ),
            ]
        )
    return ReActAgent.from_tools(
        tools=tools,
        llm=mock_llm,
    )


def _get_observations(task: Task) -> List[str]:
    obs_steps = [
        s
        for s in task.extra_state["current_reasoning"]
        if isinstance(s, ObservationReasoningStep)
    ]
    return [s.observation for s in obs_steps]


def test_complaint_when_no_reasoning_step():
    runner = ReActAgent.from_tools(
        tools=[],
        llm=MockLLM(),
    )
    task = runner.create_task("lorem")
    chat_response = ChatResponse(
        message=ChatMessage(
            content="Thought: ipsum
Action: dolor", role=MessageRole.ASSISTANT
        )
    )
    current_reasoning, is_done = runner.agent_worker._process_actions(
        task, tools=[], output=chat_response
    )
    assert (
        current_reasoning[0].get_content()
        == "Observation: Error: Could not parse output. Please follow the thought-action-input format. Try again."
    )


def test_add_step(
    add_tool: FunctionTool,
) -> None:
    # sync
    agent = _get_agent([add_tool])
    task = agent.create_task("What is 1 + 1?")
    # first step
    _step_output = agent.run_step(task.task_id)
    # add human input (not used but should be in memory)
    _step_output = agent.run_step(task.task_id, input="tmp")
    observations = _get_observations(task)
    assert "tmp" in observations

    # stream_step
    agent = _get_agent([add_tool], streaming=True)
    task = agent.create_task("What is 1 + 1?")
    # first step
    _step_output = agent.stream_step(task.task_id)
    # add human input (not used but should be in memory)
    _step_output = agent.stream_step(task.task_id, input="tmp")
    observations = _get_observations(task)
    assert "tmp" in observations


@pytest.mark.asyncio()
async def test_async_add_step(
    add_tool: FunctionTool,
) -> None:
    # async
    agent = _get_agent([add_tool])
    task = agent.create_task("What is 1 + 1?")
    # first step
    _step_output = await agent.arun_step(task.task_id)
    # add human input (not used but should be in memory)
    _step_output = await agent.arun_step(task.task_id, input="tmp")
    observations = _get_observations(task)
    assert "tmp" in observations

    # async stream step
    agent = _get_agent([add_tool], streaming=True)
    task = agent.create_task("What is 1 + 1?")
    # first step
    _step_output = await agent.astream_step(task.task_id)
    # add human input (not used but should be in memory)
    _step_output = await agent.astream_step(task.task_id, input="tmp")
    observations = _get_observations(task)
    assert "tmp" in observations