Back to Repositories

Validating State Manager Configuration Parameters in DeepSpeed

This test suite validates configuration parameters for the DeepSpeed State Manager, focusing on boundary conditions and invalid inputs for ragged batch processing. It ensures robust error handling for sequence tracking and batch size management in inference scenarios.

Test Coverage Overview

The test suite provides comprehensive validation of DSStateManagerConfig parameters, specifically targeting invalid configurations.

  • Validates negative and zero values for max_tracked_sequences
  • Tests boundary conditions for max_ragged_batch_size
  • Verifies max_ragged_sequence_count constraints
  • Ensures proper relationship between batch size and sequence counts

Implementation Analysis

The implementation uses pytest’s exception handling mechanisms to verify validation errors. Each test case focuses on a specific configuration parameter, employing a systematic approach to validate input constraints.

  • Uses pytest.mark.inference_v2 for test categorization
  • Implements ValidationError checks from pydantic
  • Follows isolated test case pattern for each parameter

Technical Details

  • Testing Framework: pytest
  • Validation Library: pydantic
  • Test Scope: Unit tests
  • Component: DSStateManagerConfig
  • Module: deepspeed.inference.v2.ragged

Best Practices Demonstrated

The test suite exemplifies strong testing practices through clear separation of concerns and thorough validation coverage.

  • Individual test cases for each failure condition
  • Consistent test naming convention
  • Proper use of pytest markers
  • Clear validation error expectations
  • Type hints for function signatures

microsoft/deepspeed

tests/unit/inference/v2/ragged/test_manager_configs.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest

from pydantic import ValidationError

from deepspeed.inference.v2.ragged import DSStateManagerConfig


@pytest.mark.inference_v2
def test_negative_max_tracked_sequences() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_tracked_sequences=-1)


@pytest.mark.inference_v2
def test_zero_max_tracked_sequences() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_tracked_sequences=0)


@pytest.mark.inference_v2
def test_negative_max_ragged_batch_size() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_ragged_batch_size=-1)


@pytest.mark.inference_v2
def test_zero_max_ragged_batch_size() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_ragged_batch_size=0)


@pytest.mark.inference_v2
def test_negative_max_ragged_sequence_count() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_ragged_sequence_count=-1)


@pytest.mark.inference_v2
def test_zero_max_ragged_sequence_count() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_ragged_sequence_count=0)


@pytest.mark.inference_v2
def test_too_small_max_ragged_batch_size() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_ragged_batch_size=512, max_ragged_sequence_count=1024)


@pytest.mark.inference_v2
def test_too_small_max_tracked_sequences() -> None:
    with pytest.raises(ValidationError):
        DSStateManagerConfig(max_tracked_sequences=512, max_ragged_sequence_count=1024)