Back to Repositories

Testing MiCS Optimizer Checkpoint Implementation in DeepSpeed

This test suite validates the MiCS (Mixed Controlled Sharding) optimizer checkpoint functionality in DeepSpeed, focusing on saving and loading model states with different shard configurations. The tests ensure proper handling of optimizer states, module-only loading, and partition group checkpointing.

Test Coverage Overview

The test suite provides comprehensive coverage of MiCS checkpoint operations with varying shard sizes (1, 2, and 4).

Key functionality tested includes:
  • Loading optimizer states
  • Non-optimizer state loading
  • Module-only loading scenarios
  • Checkpoint saving on first partition groups
Edge cases cover different shard size configurations and verification of checkpoint correctness across distributed environments.

Implementation Analysis

The testing approach uses a distributed test framework with a world size of 4, implementing parameterized tests for different shard configurations.

Technical patterns include:
  • Toy model configuration with customizable shard sizes
  • DeepSpeed Zero-3 optimization integration
  • MiCS initialization wrapper
  • Distributed test inheritance structure

Technical Details

Testing tools and configuration:
  • pytest framework with parametrize decorators
  • DeepSpeed initialization with custom configurations
  • FP16 mixed precision settings
  • Adam optimizer integration
  • Distributed test environment with 4 processes
  • Temporary directory fixtures for checkpoint storage

Best Practices Demonstrated

The test suite exemplifies high-quality testing practices through systematic validation of checkpoint functionality.

Notable practices include:
  • Parametrized test cases for different configurations
  • Proper test isolation using temporary directories
  • Comprehensive verification of checkpoint correctness
  • Clear separation of test scenarios
  • Proper handling of distributed testing environments

microsoft/deepspeed

tests/unit/checkpoint/test_mics_optimizer.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import deepspeed

from deepspeed.utils.torch import required_torch_version
from unit.common import DistributedTest
from unit.simple_model import *
from unit.checkpoint.common import *

import pytest

if not required_torch_version(max_version=2.0):
    pytest.skip("Skipping until we resolve problems with torch 2.1", allow_module_level=True)


class TestMiCSCheckpoint(DistributedTest):
    world_size = 4

    def _toy_model_config(self, shard_size):

        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": 'Adam',
                "params": {
                    "lr": 0.00015,
                    "betas": [0.8, 0.999],
                    "eps": 1e-8,
                    "weight_decay": 3e-7
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            },
            "wall_clock_breakdown": True,
            "zero_optimization": {
                "stage": 3,
                "mics_shard_size": shard_size
            }
        }

        hidden_dim = 10
        with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
            models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]

        return config_dict, hidden_dim, models

    @pytest.mark.parametrize('shard_size', [1, 2, 4])
    def test_load_optimizer_state(self, tmpdir, shard_size):
        config_dict, hidden_dim, models = self._toy_model_config(shard_size)
        checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=True)

    @pytest.mark.parametrize('shard_size', [1, 2, 4])
    def test_not_load_optimizer_state(self, tmpdir, shard_size):
        config_dict, hidden_dim, models = self._toy_model_config(shard_size)
        checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False)

    @pytest.mark.parametrize('shard_size', [1, 2, 4])
    def test_load_module_only(self, tmpdir, shard_size):
        config_dict, hidden_dim, models = self._toy_model_config(shard_size)
        checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)

    @pytest.mark.parametrize('shard_size', [1, 2, 4])
    def test_save_checkpoint_on_first_partition_group(self, tmpdir, shard_size):
        config_dict, _, models = self._toy_model_config(shard_size)
        ds_engine, _, _, _ = deepspeed.initialize(config=config_dict,
                                                  model=models[0],
                                                  model_parameters=models[0].parameters(),
                                                  optimizer=None)

        ds_engine.save_checkpoint(tmpdir)
        if ds_engine.global_rank < shard_size:
            assert ds_engine.save_non_zero_checkpoint == True
        else:
            assert ds_engine.save_non_zero_checkpoint == False