Back to Repositories

Testing Protocol Agent Operations in AutoGPT

This test suite validates the ProtocolAgent class functionality in AutoGPT, focusing on task management, step execution, and artifact handling using pytest and FastAPI. The tests ensure proper implementation of core agent protocol operations through comprehensive async testing.

Test Coverage Overview

The test suite provides extensive coverage of the ProtocolAgent’s core functionalities:
  • Task creation, retrieval, and listing operations
  • Step execution and management
  • Artifact handling including creation, storage, and retrieval
  • Edge cases with xfail markers for unimplemented features
  • Integration with FastAPI’s UploadFile and async operations

Implementation Analysis

The testing approach utilizes pytest’s async capabilities with pytest.mark.asyncio decorators for asynchronous test execution. The implementation follows fixture-based testing patterns, with dedicated fixtures for agent initialization and file handling. FastAPI integration is tested through UploadFile handling and async file operations.

Technical Details

  • Testing Framework: pytest with asyncio support
  • Database: SQLite with SQLAlchemy integration
  • File Storage: Local file system implementation
  • Fixtures: tmp_project_root and file_upload for resource management
  • Async/await patterns for API testing

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Proper test isolation using fixtures
  • Explicit test case organization
  • Resource cleanup in file handling
  • Clear failure marking with pytest.mark.xfail
  • Comprehensive validation of response objects

significant-gravitas/autogpt

classic/forge/forge/agent_protocol/agent_test.py

            
from pathlib import Path

import pytest
from fastapi import UploadFile

from forge.file_storage.base import FileStorageConfiguration
from forge.file_storage.local import LocalFileStorage

from .agent import ProtocolAgent
from .database.db import AgentDB
from .models.task import StepRequestBody, Task, TaskListResponse, TaskRequestBody


@pytest.fixture
def agent(tmp_project_root: Path):
    db = AgentDB("sqlite:///test.db")
    config = FileStorageConfiguration(root=tmp_project_root)
    workspace = LocalFileStorage(config)
    return ProtocolAgent(db, workspace)


@pytest.fixture
def file_upload():
    this_file = Path(__file__)
    file_handle = this_file.open("rb")
    yield UploadFile(file_handle, filename=this_file.name)
    file_handle.close()


@pytest.mark.asyncio
async def test_create_task(agent: ProtocolAgent):
    task_request = TaskRequestBody(
        input="test_input", additional_input={"input": "additional_test_input"}
    )
    task: Task = await agent.create_task(task_request)
    assert task.input == "test_input"


@pytest.mark.asyncio
async def test_list_tasks(agent: ProtocolAgent):
    task_request = TaskRequestBody(
        input="test_input", additional_input={"input": "additional_test_input"}
    )
    await agent.create_task(task_request)
    tasks = await agent.list_tasks()
    assert isinstance(tasks, TaskListResponse)


@pytest.mark.asyncio
async def test_get_task(agent: ProtocolAgent):
    task_request = TaskRequestBody(
        input="test_input", additional_input={"input": "additional_test_input"}
    )
    task = await agent.create_task(task_request)
    retrieved_task = await agent.get_task(task.task_id)
    assert retrieved_task.task_id == task.task_id


@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_execute_step(agent: ProtocolAgent):
    task_request = TaskRequestBody(
        input="test_input", additional_input={"input": "additional_test_input"}
    )
    task = await agent.create_task(task_request)
    step_request = StepRequestBody(
        input="step_input", additional_input={"input": "additional_test_input"}
    )
    step = await agent.execute_step(task.task_id, step_request)
    assert step.input == "step_input"
    assert step.additional_input == {"input": "additional_test_input"}


@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_get_step(agent: ProtocolAgent):
    task_request = TaskRequestBody(
        input="test_input", additional_input={"input": "additional_test_input"}
    )
    task = await agent.create_task(task_request)
    step_request = StepRequestBody(
        input="step_input", additional_input={"input": "additional_test_input"}
    )
    step = await agent.execute_step(task.task_id, step_request)
    retrieved_step = await agent.get_step(task.task_id, step.step_id)
    assert retrieved_step.step_id == step.step_id


@pytest.mark.asyncio
async def test_list_artifacts(agent: ProtocolAgent):
    tasks = await agent.list_tasks()
    assert tasks.tasks, "No tasks in test.db"

    artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
    assert isinstance(artifacts.artifacts, list)


@pytest.mark.asyncio
async def test_create_artifact(agent: ProtocolAgent, file_upload: UploadFile):
    task_request = TaskRequestBody(
        input="test_input", additional_input={"input": "additional_test_input"}
    )
    task = await agent.create_task(task_request)
    artifact = await agent.create_artifact(
        task_id=task.task_id,
        file=file_upload,
        relative_path=f"a_dir/{file_upload.filename}",
    )
    assert artifact.file_name == file_upload.filename
    assert artifact.relative_path == f"a_dir/{file_upload.filename}"


@pytest.mark.asyncio
async def test_create_and_get_artifact(agent: ProtocolAgent, file_upload: UploadFile):
    task_request = TaskRequestBody(
        input="test_input", additional_input={"input": "additional_test_input"}
    )
    task = await agent.create_task(task_request)

    artifact = await agent.create_artifact(
        task_id=task.task_id,
        file=file_upload,
        relative_path=f"b_dir/{file_upload.filename}",
    )
    await file_upload.seek(0)
    file_upload_content = await file_upload.read()

    retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
    retrieved_artifact_content = bytearray()
    async for b in retrieved_artifact.body_iterator:
        retrieved_artifact_content.extend(b)  # type: ignore
    assert retrieved_artifact_content == file_upload_content