Back to Repositories

Validating Configuration Model Management in DeepSpeed

This test suite validates the configuration management functionality in DeepSpeed, focusing on proper handling of configuration models, parameter validation, and backward compatibility features. It tests both basic configuration parsing and advanced features like field deprecation and aliasing.

Test Coverage Overview

The test suite provides comprehensive coverage of DeepSpeed’s configuration system, including:
  • Validation of required configuration fields
  • Handling of deprecated parameters
  • Field aliasing functionality
  • Error handling for invalid configurations
  • Duplicate key detection

Implementation Analysis

The testing approach utilizes pytest’s parametrization and fixture capabilities for systematic validation. It implements both positive and negative test cases to verify configuration parsing, leveraging Pydantic models for type validation and schema enforcement.

The tests demonstrate proper handling of configuration inheritance, parameter deprecation, and field aliasing through the DeepSpeedConfigModel base class.

Technical Details

Key technical components include:
  • pytest framework for test organization
  • Pydantic for configuration validation
  • JSON-based configuration files
  • tmpdir fixture for temporary file handling
  • Custom DeepSpeedConfigModel class

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Isolation of test cases using fixtures
  • Comprehensive error case coverage
  • Clear test naming conventions
  • Parametrized tests for multiple scenarios
  • Proper validation of both success and failure paths

microsoft/deepspeed

tests/unit/runtime/test_ds_config_model.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import json
import os
from typing import List, Optional

from pydantic import Field, ValidationError

from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


class SimpleConf(DeepSpeedConfigModel):
    param_1: int = 0
    param_2_old: Optional[str] = Field(None,
                                       json_schema_extra={
                                           "deprecated": True,
                                           "new_param": "param_2",
                                           "new_param_fn": (lambda x: [x])
                                       })
    param_2: Optional[List[str]] = None
    param_3: int = Field(0, alias="param_3_alias")


def test_only_required_fields(tmpdir):
    '''Ensure that config containing only the required fields is accepted. '''
    cfg_json = tmpdir.mkdir('ds_config_unit_test').join('minimal.json')

    with open(cfg_json, 'w') as f:
        required_fields = {'train_batch_size': 64}
        json.dump(required_fields, f)

    run_cfg = ds_config.DeepSpeedConfig(cfg_json)
    assert run_cfg is not None
    assert run_cfg.train_batch_size == 64
    assert run_cfg.train_micro_batch_size_per_gpu == 64
    assert run_cfg.gradient_accumulation_steps == 1


def test_config_duplicate_key(tmpdir):
    config_dict = '''
    {
        "train_batch_size": 24,
        "train_batch_size": 24,
    }
    '''
    config_path = os.path.join(tmpdir, 'temp_config.json')

    with open(config_path, 'w') as jf:
        jf.write("%s" % config_dict)

    with pytest.raises(ValueError):
        run_cfg = ds_config.DeepSpeedConfig(config_path)


def test_config_base():
    config = SimpleConf(**{"param_1": 42})
    assert config.param_1 == 42


def test_config_base_deprecatedfield():
    config = SimpleConf(**{"param_2_old": "DS"})
    assert config.param_2 == ["DS"]


def test_config_base_aliasfield():
    config = SimpleConf(**{"param_3": 10})
    assert config.param_3 == 10

    config = SimpleConf(**{"param_3_alias": 10})
    assert config.param_3 == 10


@pytest.mark.parametrize("config_dict", [{"param_1": "DS"}, {"param_2": "DS"}, {"param_1_typo": 0}])
def test_config_base_literalfail(config_dict):
    with pytest.raises(ValidationError):
        config = SimpleConf(**config_dict)


def test_config_base_deprecatedfail():
    with pytest.raises(AssertionError):
        config = SimpleConf(**{"param_2": ["DS"], "param_2_old": "DS"})