Back to Repositories

Testing Shared Weight Checkpoint Management in DeepSpeed

This test suite validates the handling of shared weights in DeepSpeed checkpoints, specifically focusing on model state preservation during save and load operations. It ensures proper weight sharing relationships are maintained when using DeepSpeed’s Zero Optimizer Stage 2.

Test Coverage Overview

The test suite covers checkpoint functionality for models with shared weights using DeepSpeed’s Zero Optimizer Stage 2.

Key areas tested include:
  • Saving model state with tied layer weights
  • Loading checkpoint state while preserving weight sharing
  • Integration with DeepSpeed’s FP32 state dict conversion
  • Verification of strict state loading compliance

Implementation Analysis

The testing approach implements a custom neural network model (ModelWithSharedWeights) that explicitly ties weights between different layers. The test validates DeepSpeed’s checkpoint mechanism by:

  • Initializing a model with shared weights using DeepSpeed
  • Saving the model state to a checkpoint
  • Loading and verifying the state maintains weight sharing integrity

Technical Details

Testing infrastructure utilizes:
  • PyTorch’s nn.Module for model definition
  • DeepSpeed’s initialize() and Zero Optimizer Stage 2
  • Custom DistributedTest base class for multi-GPU testing
  • get_fp32_state_dict_from_zero_checkpoint utility for state conversion
  • Adam optimizer for model parameter management

Best Practices Demonstrated

The test exemplifies several testing best practices:

  • Isolated test environment using temporary path fixtures
  • Explicit configuration of distributed training parameters
  • Verification of strict state dict loading
  • Proper cleanup and resource management
  • Clear separation of model definition and test logic

microsoft/deepspeed

tests/unit/checkpoint/test_shared_weights.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import torch.nn as nn

import deepspeed
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from unit.common import DistributedTest


class ModelWithSharedWeights(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer0 = nn.Linear(100, 100)
        self.layer1 = nn.Linear(200, 200)
        self.layer2 = nn.Linear(300, 300)
        # tie layer 1 and layer 2
        self.layer1.weight = self.layer2.weight


class TestCheckpointSharedWeights(DistributedTest):
    world_size = 2

    def test_checkpoint_shared_weights(self, tmp_path):
        config = {
            "train_micro_batch_size_per_gpu": 2,
            "zero_allow_untested_optimizer": True,
            "zero_optimization": {
                "stage": 2
            },
        }
        model = ModelWithSharedWeights()
        optimizer = torch.optim.Adam(model.parameters())

        deepspeed_engine, _, _, _ = deepspeed.initialize(
            config=config,
            model=model,
            optimizer=optimizer,
        )
        filename = tmp_path / "checkpoint.pt"
        deepspeed_engine.save_checkpoint(filename, tag="checkpoint")

        model = ModelWithSharedWeights()
        state_dict = get_fp32_state_dict_from_zero_checkpoint(filename, tag="checkpoint")
        model.load_state_dict(state_dict, strict=True)