Back to Repositories

Testing SummaryIndex Retrieval Modes in llama_index

This test suite validates the retrieval functionality in the SummaryIndex class of llama_index, focusing on different retrieval modes including default, embedding-based, and LLM-based retrieval methods.

Test Coverage Overview

The test suite provides comprehensive coverage of the SummaryIndex retriever functionality across multiple retrieval modes:

  • Default retrieval testing with document content verification
  • Embedding-based retrieval with similarity matching
  • LLM-based retrieval with different batch sizes
  • Edge cases for each retrieval mode

Implementation Analysis

The testing approach utilizes mock objects and patching to isolate retrieval functionality:

  • Mock embeddings implementation for controlled similarity testing
  • Patched LLM predictor for deterministic responses
  • Structured test cases for each retrieval mode

Technical Details

Key technical components include:

  • unittest.mock for dependency isolation
  • Custom embedding simulation function
  • MockLLM implementation
  • Token text splitter configuration
  • Document and Node data structures

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Isolated test cases with clear boundaries
  • Consistent assertion patterns
  • Proper mock object usage
  • Controlled test environment setup
  • Comprehensive retrieval mode coverage

run-llama/llama_index

llama-index-core/tests/indices/list/test_retrievers.py

            
from typing import Any, Dict, List, Tuple
from unittest.mock import patch

from llama_index.core.indices.list.base import SummaryIndex
from llama_index.core.indices.list.retrievers import SummaryIndexEmbeddingRetriever
from llama_index.core.llms.mock import MockLLM
from llama_index.core.prompts import BasePromptTemplate
from llama_index.core.schema import BaseNode, Document


def _get_embeddings(
    query_str: str, nodes: List[BaseNode]
) -> Tuple[List[float], List[List[float]]]:
    """Get node text embedding similarity."""
    text_embed_map: Dict[str, List[float]] = {
        "Hello world.": [1.0, 0.0, 0.0, 0.0, 0.0],
        "This is a test.": [0.0, 1.0, 0.0, 0.0, 0.0],
        "This is another test.": [0.0, 0.0, 1.0, 0.0, 0.0],
        "This is a test v2.": [0.0, 0.0, 0.0, 1.0, 0.0],
    }
    node_embeddings = []
    for node in nodes:
        node_embeddings.append(text_embed_map[node.get_content()])

    return [1.0, 0, 0, 0, 0], node_embeddings


def test_retrieve_default(documents: List[Document], patch_token_text_splitter) -> None:
    """Test list query."""
    index = SummaryIndex.from_documents(documents)

    query_str = "What is?"
    retriever = index.as_retriever(retriever_mode="default")
    nodes = retriever.retrieve(query_str)

    for node_with_score, line in zip(nodes, documents[0].get_content().split("
")):
        assert node_with_score.node.get_content() == line


@patch.object(
    SummaryIndexEmbeddingRetriever,
    "_get_embeddings",
    side_effect=_get_embeddings,
)
def test_embedding_query(
    _patch_get_embeddings: Any, documents: List[Document], patch_token_text_splitter
) -> None:
    """Test embedding query."""
    index = SummaryIndex.from_documents(documents)

    # test embedding query
    query_str = "What is?"
    retriever = index.as_retriever(retriever_mode="embedding", similarity_top_k=1)
    nodes = retriever.retrieve(query_str)
    assert len(nodes) == 1

    assert nodes[0].node.get_content() == "Hello world."


def mock_llmpredictor_predict(
    self: Any, prompt: BasePromptTemplate, **prompt_args: Any
) -> str:
    """Patch llm predictor predict."""
    return "Doc: 2, Relevance: 5"


@patch.object(
    MockLLM,
    "predict",
    mock_llmpredictor_predict,
)
def test_llm_query(documents: List[Document], patch_token_text_splitter) -> None:
    """Test llm query."""
    index = SummaryIndex.from_documents(documents)

    # test llm query (batch size 10)
    query_str = "What is?"
    retriever = index.as_retriever(retriever_mode="llm")
    nodes = retriever.retrieve(query_str)
    assert len(nodes) == 1

    assert nodes[0].node.get_content() == "This is a test."

    # test llm query (batch size 2)
    query_str = "What is?"
    retriever = index.as_retriever(retriever_mode="llm", choice_batch_size=2)
    nodes = retriever.retrieve(query_str)
    assert len(nodes) == 2

    assert nodes[0].node.get_content() == "This is a test."
    assert nodes[1].node.get_content() == "This is a test v2."