Back to Repositories

Validating Pipeline Parallel Checkpointing in DeepSpeed

This test suite validates DeepSpeed’s pipeline parallelism checkpointing functionality, focusing on model state preservation and topology transformation scenarios. It ensures reliable model checkpointing across different Zero optimization stages and pipeline configurations.

Test Coverage Overview

The test suite covers critical checkpoint functionality for pipeline-parallel models in DeepSpeed.

Key areas tested include:
  • Checkpoint verification across Zero stages 0 and 1
  • Model state preservation with different pipeline topologies
  • Optimizer and learning rate scheduler state handling
  • Mixed precision (FP16/FP32) compatibility

Implementation Analysis

The implementation employs pytest parametrization to validate multiple scenarios systematically. The testing approach uses LinearStackPipe models with configurable stages and distributed settings, leveraging DeepSpeed’s TorchCheckpointEngine for state management.

Key patterns include:
  • Distributed test execution with controlled world size
  • Parametrized test configurations for Zero stages
  • Topology transformation validation

Technical Details

Testing infrastructure includes:
  • PyTest framework with distributed testing support
  • Custom DistributedTest class configuration
  • DeepSpeed runtime checkpoint engine integration
  • Architecture-specific test skipping functionality
  • Configurable model architectures with LinearStackPipe

Best Practices Demonstrated

The test suite exemplifies robust testing practices for distributed deep learning systems.

Notable practices include:
  • Comprehensive configuration testing across optimization stages
  • Explicit state verification for model parameters
  • Graceful handling of topology transformations
  • Systematic validation of checkpoint correctness
  • Proper cleanup and resource management

microsoft/deepspeed

tests/unit/checkpoint/test_pipeline.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from unit.common import DistributedTest
from unit.simple_model import *
from unit.checkpoint.common import checkpoint_correctness_verification
from unit.util import skip_on_arch

import pytest


class TestPipelineCheckpoint(DistributedTest):
    world_size = 4

    @pytest.mark.parametrize("zero_stage", [0, 1])
    def test_checkpoint_pipe_engine(self, zero_stage, tmpdir):
        skip_on_arch(min_arch=7)

        config_dict = {
            "train_batch_size": 2,
            "train_micro_batch_size_per_gpu": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-5
                }
            },
            "zero_optimization": {
                "stage": zero_stage
            },
            "fp16": {
                "enabled": zero_stage > 0
            },
            "scheduler": {
                "type": "OneCycle",
                "params": {
                    "cycle_first_step_size": 1000,
                    "cycle_first_stair_count": 500,
                    "cycle_second_step_size": 1000,
                    "cycle_second_stair_count": 500,
                    "decay_step_size": 1000,
                    "cycle_min_lr": 0.0001,
                    "cycle_max_lr": 0.0010,
                    "decay_lr_rate": 0.001,
                    "cycle_min_mom": 0.85,
                    "cycle_max_mom": 0.99,
                    "decay_mom_rate": 0.0
                }
            }
        }

        models = [LinearStackPipe(num_stages=2) for _ in range(2)]
        checkpoint_correctness_verification(config_dict=config_dict,
                                            models=models,
                                            hidden_dim=models[0].hidden_dim,
                                            tmpdir=tmpdir,
                                            load_optimizer_states=True,
                                            load_lr_scheduler_states=True,
                                            train_batch=True,
                                            dtype=torch.float16 if zero_stage > 0 else torch.float32)

    @pytest.mark.parametrize(
        "base_topo,test_topo",
        [
            #(PipeTopo(num_pp=1,
            #          num_dp=4),
            # PipeTopo(num_pp=4,
            #          num_dp=1)),
            #(PipeTopo(num_pp=2,
            #          num_dp=2),
            # PipeTopo(num_pp=2,
            #          num_dp=2)),
            #(PipeTopo(num_pp=4,
            #          num_dp=1),
            # PipeTopo(num_pp=2,
            #          num_dp=2)),
        ])
    def test_checkpoint_pipe_module(self, base_topo, test_topo, tmpdir):
        checkpoint_engine = TorchCheckpointEngine()
        base_model = LinearStackPipe(topology=base_topo)
        base_model.save_state_dict(tmpdir, checkpoint_engine=checkpoint_engine)

        dist.barrier()

        test_model = LinearStackPipe(topology=test_topo)
        test_model.load_state_dir(tmpdir, checkpoint_engine=checkpoint_engine)

        # Base and test can have different lengths, so make sure we map from the
        # smaller to larger model
        if len(base_model.forward_funcs) < len(test_model.forward_funcs):
            A = base_model
            B = test_model
        else:
            A = test_model
            B = base_model

        # Compare layers individually since partitions are different
        for idx, A_layer in enumerate(A.forward_funcs):
            if not hasattr(A_layer, 'parameters'):
                # Skip functionals, etc.
                continue

            # Find the corresponding layer in B
            global_idx = idx + A._local_start
            B_local_idx = global_idx - B._local_start
            B_layer = B.forward_funcs[B_local_idx]

            # Compare layer parameters
            for p0, p1 in zip(A_layer.parameters(), B_layer.parameters()):
                assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"