Back to Repositories

Testing AI Model Integration Workflows in cover-agent

This test suite validates the AICaller class functionality in the cover-agent repository, focusing on model interactions, error handling, and integration with various AI model APIs. The tests ensure robust handling of API calls, streaming responses, and logging capabilities.

Test Coverage Overview

The test suite provides comprehensive coverage of the AICaller class functionality.

Key areas tested include:
  • Basic model calling with simplified and complex scenarios
  • Error handling for API calls and streaming responses
  • Weights & Biases (W&B) logging integration
  • Support for different model types (OpenAI, O1-preview)
  • Edge cases like missing prompt keys and logging failures

Implementation Analysis

The testing approach employs pytest fixtures and extensive mocking to isolate components and simulate various scenarios.

Notable patterns include:
  • Use of @pytest.fixture for consistent test setup
  • Mock patching for external dependencies (litellm, W&B)
  • Environment variable manipulation for configuration testing
  • Structured error validation using pytest.raises

Technical Details

Testing infrastructure includes:
  • pytest as the primary testing framework
  • unittest.mock for dependency mocking
  • litellm for AI model interactions
  • Environment variable configuration for W&B integration
  • Mock objects for simulating API responses and streaming data

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Isolated test cases with clear boundaries
  • Comprehensive error scenario coverage
  • Proper fixture usage for setup/teardown
  • Thorough assertion checking
  • Mock usage for external dependencies
  • Clear test naming conventions

codium-ai/cover-agent

tests/test_AICaller.py

            
import os

import pytest
from unittest.mock import patch, Mock
from cover_agent.AICaller import AICaller


class TestAICaller:
    @pytest.fixture
    def ai_caller(self):
        return AICaller("test-model", "test-api", enable_retry=False)

    @patch("cover_agent.AICaller.AICaller.call_model")
    def test_call_model_simplified(self, mock_call_model):
        # Set up the mock to return a predefined response
        mock_call_model.return_value = ("Hello world!", 2, 10)
        prompt = {"system": "", "user": "Hello, world!"}

        ai_caller = AICaller("test-model", "test-api", enable_retry=False)
        # Explicitly provide the default value of max_tokens
        response, prompt_tokens, response_tokens = ai_caller.call_model(
            prompt, max_tokens=4096
        )

        # Assertions to check if the returned values are as expected
        assert response == "Hello world!"
        assert prompt_tokens == 2
        assert response_tokens == 10

        # Check if call_model was called correctly
        mock_call_model.assert_called_once_with(prompt, max_tokens=4096)

    @patch("cover_agent.AICaller.litellm.completion")
    def test_call_model_with_error(self, mock_completion, ai_caller):
        # Set up mock to raise an exception
        mock_completion.side_effect = Exception("Test exception")
        prompt = {"system": "", "user": "Hello, world!"}
        # Call the method and handle the exception
        with pytest.raises(Exception) as exc_info:
            ai_caller.call_model(prompt)

        assert str(exc_info.value) == "Test exception"

    @patch("cover_agent.AICaller.litellm.completion")
    def test_call_model_error_streaming(self, mock_completion, ai_caller):
        # Set up mock to raise an exception
        mock_completion.side_effect = ["results"]
        prompt = {"system": "", "user": "Hello, world!"}
        # Call the method and handle the exception
        with pytest.raises(Exception) as exc_info:
            ai_caller.call_model(prompt)

        # assert str(exc_info.value) == "list index out of range"
        assert str(exc_info.value) == "'NoneType' object is not subscriptable" # this error message might change for different versions of litellm

    @patch("cover_agent.AICaller.litellm.completion")
    @patch.dict(os.environ, {"WANDB_API_KEY": "test_key"})
    @patch("cover_agent.AICaller.Trace.log")
    def test_call_model_wandb_logging(self, mock_log, mock_completion, ai_caller):
        mock_completion.return_value = [
            {"choices": [{"delta": {"content": "response"}}]}
        ]
        prompt = {"system": "", "user": "Hello, world!"}
        with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder:
            mock_builder.return_value = {
                "choices": [{"message": {"content": "response"}}],
                "usage": {"prompt_tokens": 2, "completion_tokens": 10},
            }
            response, prompt_tokens, response_tokens = ai_caller.call_model(prompt)
            assert response == "response"
            assert prompt_tokens == 2
            assert response_tokens == 10
            mock_log.assert_called_once()

    @patch("cover_agent.AICaller.litellm.completion")
    def test_call_model_api_base(self, mock_completion, ai_caller):
        mock_completion.return_value = [
            {"choices": [{"delta": {"content": "response"}}]}
        ]
        ai_caller.model = "openai/test-model"
        prompt = {"system": "", "user": "Hello, world!"}
        with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder:
            mock_builder.return_value = {
                "choices": [{"message": {"content": "response"}}],
                "usage": {"prompt_tokens": 2, "completion_tokens": 10},
            }
            response, prompt_tokens, response_tokens = ai_caller.call_model(prompt)
            assert response == "response"
            assert prompt_tokens == 2
            assert response_tokens == 10

    @patch("cover_agent.AICaller.litellm.completion")
    def test_call_model_with_system_key(self, mock_completion, ai_caller):
        mock_completion.return_value = [
            {"choices": [{"delta": {"content": "response"}}]}
        ]
        prompt = {"system": "System message", "user": "Hello, world!"}
        with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder:
            mock_builder.return_value = {
                "choices": [{"message": {"content": "response"}}],
                "usage": {"prompt_tokens": 2, "completion_tokens": 10},
            }
            response, prompt_tokens, response_tokens = ai_caller.call_model(prompt)
            assert response == "response"
            assert prompt_tokens == 2
            assert response_tokens == 10

    def test_call_model_missing_keys(self, ai_caller):
        prompt = {"user": "Hello, world!"}
        with pytest.raises(KeyError) as exc_info:
            ai_caller.call_model(prompt)
        assert (
            str(exc_info.value)
            == "\"The prompt dictionary must contain 'system' and 'user' keys.\""
        )

    @patch("cover_agent.AICaller.litellm.completion")
    def test_call_model_o1_preview(self, mock_completion, ai_caller):
        ai_caller.model = "o1-preview"
        prompt = {"system": "System message", "user": "Hello, world!"}
        # Mock the response
        mock_response = Mock()
        mock_response.choices = [Mock(message=Mock(content="response"))]
        mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10)
        mock_completion.return_value = mock_response
        # Call the method
        response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=False)
        assert response == "response"
        assert prompt_tokens == 2
        assert response_tokens == 10

    @patch("cover_agent.AICaller.litellm.completion")
    def test_call_model_streaming_response(self, mock_completion, ai_caller):
        prompt = {"system": "", "user": "Hello, world!"}
        # Mock the response to be an iterable of chunks
        mock_chunk = Mock()
        mock_chunk.choices = [Mock(delta=Mock(content="response part"))]
        mock_completion.return_value = [mock_chunk]
        with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder:
            mock_builder.return_value = {
                "choices": [{"message": {"content": "response"}}],
                "usage": {"prompt_tokens": 2, "completion_tokens": 10},
            }
            response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=True)
            assert response == "response"
            assert prompt_tokens == 2

    @patch("cover_agent.AICaller.litellm.completion")
    @patch.dict(os.environ, {"WANDB_API_KEY": "test_key"})
    @patch("cover_agent.AICaller.Trace.log")
    def test_call_model_wandb_logging_exception(self, mock_log, mock_completion, ai_caller):
        mock_completion.return_value = [
            {"choices": [{"delta": {"content": "response"}}]}
        ]
        mock_log.side_effect = Exception("Logging error")
        prompt = {"system": "", "user": "Hello, world!"}
        with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder:
            mock_builder.return_value = {
                "choices": [{"message": {"content": "response"}}],
                "usage": {"prompt_tokens": 2, "completion_tokens": 10},
            }
            with patch("builtins.print") as mock_print:
                response, prompt_tokens, response_tokens = ai_caller.call_model(prompt)
                assert response == "response"
                assert prompt_tokens == 2
                assert response_tokens == 10
                mock_print.assert_any_call("Error logging to W&B: Logging error")