Back to Repositories

Testing LLM Prompt Style Implementations in private-gpt

This test suite validates prompt styling and formatting functionality in Private-GPT, covering various LLM prompt styles including Llama2, Mistral, ChatML, and custom tag formats. The tests ensure correct message formatting and system prompt handling across different LLM implementations.

Test Coverage Overview

The test suite provides comprehensive coverage of prompt style implementations and message formatting.

  • Tests prompt style selection and initialization
  • Validates message formatting for multiple LLM types
  • Covers system prompt handling and integration
  • Includes error handling for unknown prompt styles

Implementation Analysis

The testing approach uses pytest’s parametrize feature for efficient test case management and clear separation of concerns.

Key patterns include:
  • Fixture-based test organization
  • Systematic validation of prompt formatting
  • Explicit expected vs actual output comparison
  • Comprehensive error case handling

Technical Details

Testing infrastructure includes:

  • pytest framework for test execution
  • ChatMessage and MessageRole from llama_index.core.llms
  • Custom prompt style classes for different LLM implementations
  • Assertion-based validation mechanisms

Best Practices Demonstrated

The test suite exemplifies high-quality testing practices through structured test organization and thorough validation.

  • Modular test case design
  • Comprehensive edge case coverage
  • Clear test naming conventions
  • Effective use of pytest features

zylon-ai/private-gpt

tests/test_prompt_helper.py

            
import pytest
from llama_index.core.llms import ChatMessage, MessageRole

from private_gpt.components.llm.prompt_helper import (
    ChatMLPromptStyle,
    DefaultPromptStyle,
    Llama2PromptStyle,
    Llama3PromptStyle,
    MistralPromptStyle,
    TagPromptStyle,
    get_prompt_style,
)


@pytest.mark.parametrize(
    ("prompt_style", "expected_prompt_style"),
    [
        ("default", DefaultPromptStyle),
        ("llama2", Llama2PromptStyle),
        ("tag", TagPromptStyle),
        ("mistral", MistralPromptStyle),
        ("chatml", ChatMLPromptStyle),
    ],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
    assert isinstance(get_prompt_style(prompt_style), expected_prompt_style)


def test_get_prompt_style_failure():
    prompt_style = "unknown"
    with pytest.raises(ValueError) as exc_info:
        get_prompt_style(prompt_style)
    assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'"


def test_tag_prompt_style_format():
    prompt_style = TagPromptStyle()
    messages = [
        ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
        ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
    ]

    expected_prompt = (
        "<|system|>: You are an AI assistant.
"
        "<|user|>: Hello, how are you doing?
"
        "<|assistant|>: "
    )

    assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_tag_prompt_style_format_with_system_prompt():
    prompt_style = TagPromptStyle()
    messages = [
        ChatMessage(
            content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
        ),
        ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
    ]

    expected_prompt = (
        "<|system|>: FOO BAR Custom sys prompt from messages.
"
        "<|user|>: Hello, how are you doing?
"
        "<|assistant|>: "
    )

    assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_mistral_prompt_style_format():
    prompt_style = MistralPromptStyle()
    messages = [
        ChatMessage(content="A", role=MessageRole.SYSTEM),
        ChatMessage(content="B", role=MessageRole.USER),
    ]
    expected_prompt = "<s>[INST] A
B [/INST]"
    assert prompt_style.messages_to_prompt(messages) == expected_prompt

    messages2 = [
        ChatMessage(content="A", role=MessageRole.SYSTEM),
        ChatMessage(content="B", role=MessageRole.USER),
        ChatMessage(content="C", role=MessageRole.ASSISTANT),
        ChatMessage(content="D", role=MessageRole.USER),
    ]
    expected_prompt2 = "<s>[INST] A
B [/INST] C</s><s>[INST] D [/INST]"
    assert prompt_style.messages_to_prompt(messages2) == expected_prompt2


def test_chatml_prompt_style_format():
    prompt_style = ChatMLPromptStyle()
    messages = [
        ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
        ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
    ]

    expected_prompt = (
        "<|im_start|>system
"
        "You are an AI assistant.<|im_end|>
"
        "<|im_start|>user
"
        "Hello, how are you doing?<|im_end|>
"
        "<|im_start|>assistant
"
    )

    assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_llama2_prompt_style_format():
    prompt_style = Llama2PromptStyle()
    messages = [
        ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
        ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
    ]

    expected_prompt = (
        "<s> [INST] <<SYS>>
"
        " You are an AI assistant. 
"
        "<</SYS>>
"
        "
"
        " Hello, how are you doing? [/INST]"
    )

    assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_llama2_prompt_style_with_system_prompt():
    prompt_style = Llama2PromptStyle()
    messages = [
        ChatMessage(
            content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
        ),
        ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
    ]

    expected_prompt = (
        "<s> [INST] <<SYS>>
"
        " FOO BAR Custom sys prompt from messages. 
"
        "<</SYS>>
"
        "
"
        " Hello, how are you doing? [/INST]"
    )

    assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_llama3_prompt_style_format():
    prompt_style = Llama3PromptStyle()
    messages = [
        ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
        ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
    ]

    expected_prompt = (
        "<|start_header_id|>system<|end_header_id|>

"
        "You are a helpful assistant<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>

"
        "Hello, how are you doing?<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>

"
    )

    assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_llama3_prompt_style_with_default_system():
    prompt_style = Llama3PromptStyle()
    messages = [
        ChatMessage(content="Hello!", role=MessageRole.USER),
    ]
    expected = (
        "<|start_header_id|>system<|end_header_id|>

"
        f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>

Hello!<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>

"
    )
    assert prompt_style._messages_to_prompt(messages) == expected


def test_llama3_prompt_style_with_assistant_response():
    prompt_style = Llama3PromptStyle()
    messages = [
        ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
        ChatMessage(content="What is the capital of France?", role=MessageRole.USER),
        ChatMessage(
            content="The capital of France is Paris.", role=MessageRole.ASSISTANT
        ),
    ]

    expected_prompt = (
        "<|start_header_id|>system<|end_header_id|>

"
        "You are a helpful assistant<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>

"
        "What is the capital of France?<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>

"
        "The capital of France is Paris.<|eot_id|>"
    )

    assert prompt_style.messages_to_prompt(messages) == expected_prompt