Back to Repositories

Testing Knowledge Graph Retrieval Implementation in LlamaIndex

This test suite validates the Knowledge Graph retriever functionality in LlamaIndex, focusing on graph-based information retrieval and embedding-based similarity search capabilities.

Test Coverage Overview

The test suite provides comprehensive coverage of the KGTableRetriever component, including:
  • Basic retrieval functionality with text inclusion
  • Retrieval without text inclusion
  • Similarity-based retrieval using embeddings
  • Custom query keyword extraction
  • Graph store integration testing

Implementation Analysis

The testing approach utilizes mock objects and patching to isolate the knowledge graph functionality:
  • MockEmbedding class for controlled embedding generation
  • Mocked triplet extraction for predictable graph construction
  • Structured validation of retrieval results and query formatting

Technical Details

Key technical components include:
  • SimpleGraphStore for graph data storage
  • StorageContext for managing test storage configuration
  • QueryBundle for structured query handling
  • Custom embedding dimensions and mock vector representations

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Isolation of external dependencies through mocking
  • Comprehensive edge case coverage
  • Clear test case organization
  • Explicit assertion statements
  • Proper setup and teardown patterns

run-llama/llama_index

llama-index-core/tests/indices/knowledge_graph/test_retrievers.py

            
from typing import Any, List, Tuple
from unittest.mock import patch

from llama_index.core.graph_stores import SimpleGraphStore
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.knowledge_graph.base import KnowledgeGraphIndex
from llama_index.core.indices.knowledge_graph.retrievers import KGTableRetriever
from llama_index.core.schema import Document, QueryBundle
from llama_index.core.storage.storage_context import StorageContext
from tests.mock_utils.mock_prompts import MOCK_QUERY_KEYWORD_EXTRACT_PROMPT


class MockEmbedding(BaseEmbedding):
    @classmethod
    def class_name(cls) -> str:
        return "MockEmbedding"

    async def _aget_query_embedding(self, query: str) -> List[float]:
        del query
        return [0, 0, 1, 0, 0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        # assume dimensions are 4
        if text == "('foo', 'is', 'bar')":
            return [1, 0, 0, 0]
        elif text == "('hello', 'is not', 'world')":
            return [0, 1, 0, 0]
        elif text == "('Jane', 'is mother of', 'Bob')":
            return [0, 0, 1, 0]
        elif text == "foo":
            return [0, 0, 0, 1]
        else:
            raise ValueError("Invalid text for `mock_get_text_embedding`.")

    def _get_text_embedding(self, text: str) -> List[float]:
        """Mock get text embedding."""
        # assume dimensions are 4
        if text == "('foo', 'is', 'bar')":
            return [1, 0, 0, 0]
        elif text == "('hello', 'is not', 'world')":
            return [0, 1, 0, 0]
        elif text == "('Jane', 'is mother of', 'Bob')":
            return [0, 0, 1, 0]
        elif text == "foo":
            return [0, 0, 0, 1]
        else:
            raise ValueError("Invalid text for `mock_get_text_embedding`.")

    def _get_query_embedding(self, query: str) -> List[float]:
        """Mock get query embedding."""
        del query
        return [0, 0, 1, 0, 0]


def mock_extract_triplets(text: str) -> List[Tuple[str, str, str]]:
    """Mock extract triplets."""
    lines = text.split("
")
    triplets: List[Tuple[str, str, str]] = []
    for line in lines:
        tokens = line[1:-1].split(",")
        tokens = [t.strip() for t in tokens]

        subj, pred, obj = tokens
        triplets.append((subj, pred, obj))
    return triplets


@patch.object(
    KnowledgeGraphIndex, "_extract_triplets", side_effect=mock_extract_triplets
)
def test_as_retriever(_patch_extract_triplets: Any, documents: List[Document]) -> None:
    """Test query."""
    graph_store = SimpleGraphStore()
    storage_context = StorageContext.from_defaults(graph_store=graph_store)
    index = KnowledgeGraphIndex.from_documents(
        documents, storage_context=storage_context
    )
    retriever: KGTableRetriever = index.as_retriever()  # type: ignore
    nodes = retriever.retrieve(QueryBundle("foo"))
    # when include_text is True, the first node is the raw text
    # the second node is the query
    rel_initial_text = (
        f"The following are knowledge sequence in max depth"
        f" {retriever.graph_store_query_depth} "
        f"in the form of directed graph like:
"
        f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
        f" object_next_hop ...`"
    )

    raw_text = "['foo', 'is', 'bar']"
    query = rel_initial_text + "
" + raw_text
    assert len(nodes) == 2
    assert nodes[1].node.get_content() == query


@patch.object(
    KnowledgeGraphIndex, "_extract_triplets", side_effect=mock_extract_triplets
)
def test_retrievers(_patch_extract_triplets: Any, documents: List[Document]) -> None:
    # test specific retriever class
    graph_store = SimpleGraphStore()
    storage_context = StorageContext.from_defaults(graph_store=graph_store)

    index = KnowledgeGraphIndex.from_documents(
        documents, storage_context=storage_context
    )
    retriever = KGTableRetriever(
        index,
        query_keyword_extract_template=MOCK_QUERY_KEYWORD_EXTRACT_PROMPT,
        graph_store=graph_store,
    )
    query_bundle = QueryBundle(query_str="foo", custom_embedding_strs=["foo"])
    nodes = retriever.retrieve(query_bundle)
    assert (
        nodes[1].node.get_content()
        == "The following are knowledge sequence in max depth 2"
        " in the form of directed graph like:
"
        "`subject -[predicate]->, object, <-[predicate_next_hop]-,"
        " object_next_hop ...`"
        "
['foo', 'is', 'bar']"
    )


@patch.object(
    KnowledgeGraphIndex, "_extract_triplets", side_effect=mock_extract_triplets
)
def test_retriever_no_text(
    _patch_extract_triplets: Any, documents: List[Document]
) -> None:
    # test specific retriever class
    graph_store = SimpleGraphStore()
    storage_context = StorageContext.from_defaults(graph_store=graph_store)

    index = KnowledgeGraphIndex.from_documents(
        documents, storage_context=storage_context
    )
    retriever = KGTableRetriever(
        index,
        query_keyword_extract_template=MOCK_QUERY_KEYWORD_EXTRACT_PROMPT,
        include_text=False,
        graph_store=graph_store,
    )
    query_bundle = QueryBundle(query_str="foo", custom_embedding_strs=["foo"])
    nodes = retriever.retrieve(query_bundle)
    assert (
        nodes[0].node.get_content()
        == "The following are knowledge sequence in max depth 2"
        " in the form of directed graph like:
"
        "`subject -[predicate]->, object, <-[predicate_next_hop]-,"
        " object_next_hop ...`"
        "
['foo', 'is', 'bar']"
    )


@patch.object(
    KnowledgeGraphIndex, "_extract_triplets", side_effect=mock_extract_triplets
)
def test_retrieve_similarity(
    _patch_extract_triplets: Any, documents: List[Document]
) -> None:
    """Test query."""
    graph_store = SimpleGraphStore()
    storage_context = StorageContext.from_defaults(graph_store=graph_store)

    index = KnowledgeGraphIndex.from_documents(
        documents,
        include_embeddings=True,
        storage_context=storage_context,
        embed_model=MockEmbedding(),
    )
    retriever = KGTableRetriever(index, similarity_top_k=2, graph_store=graph_store)

    # returns only two rel texts to use for generating response
    # uses hyrbid query by default
    nodes = retriever.retrieve(QueryBundle("foo"))
    assert (
        nodes[1].node.get_content()
        == "The following are knowledge sequence in max depth 2"
        " in the form of directed graph like:
"
        "`subject -[predicate]->, object, <-[predicate_next_hop]-,"
        " object_next_hop ...`"
        "
['foo', 'is', 'bar']"
    )