Back to Repositories

Validating Zero-3 Checkpoint Conversion in DeepSpeed

This test suite validates the checkpoint conversion functionality in DeepSpeed, specifically focusing on converting Zero-3 distributed checkpoints to FP32 state dictionaries while preserving shared weight relationships.

Test Coverage Overview

The test suite covers critical checkpoint conversion scenarios in DeepSpeed, focusing on Zero Stage-3 optimization.

Key areas tested include:
  • Conversion from Zero-3 distributed checkpoints to FP32 format
  • Preservation of shared weight relationships during conversion
  • State dictionary integrity after conversion
  • Proper model reloading with converted checkpoints

Implementation Analysis

The testing approach implements a custom neural network model (ModelWithSharedWeights) with deliberately shared weight tensors between layers.

Technical implementation includes:
  • DeepSpeed initialization with Zero-3 configuration
  • Checkpoint saving and conversion workflow
  • Verification of tensor sharing preservation
  • State dictionary loading validation

Technical Details

Testing infrastructure utilizes:
  • PyTorch framework for model definition
  • DeepSpeed’s Zero-3 optimization configuration
  • DistributedTest base class for multi-GPU testing
  • Adam optimizer for model parameter management
  • Custom temporary directory handling for checkpoint storage

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Proper distributed testing setup with world_size configuration
  • Comprehensive checkpoint validation workflow
  • Memory-efficient shared weight handling
  • Strict state dictionary validation
  • Clean temporary resource management

microsoft/deepspeed

tests/unit/checkpoint/test_convert_checkpoint.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import torch.nn as nn

import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from unit.common import DistributedTest


class ModelWithSharedWeights(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer0 = nn.Linear(100, 100)
        self.layer1 = nn.Linear(200, 200)
        self.layer2 = nn.Linear(300, 300)
        # tie layer 1 and layer 2
        self.layer1.weight = self.layer2.weight


class TestCheckpointConvert(DistributedTest):
    world_size = 2

    def test_convert_zero_checkpoint_to_fp32_state_dict(self, tmpdir):
        config = {
            "train_micro_batch_size_per_gpu": 2,
            "zero_allow_untested_optimizer": True,
            "zero_optimization": {
                "stage": 3
            },
        }
        model = ModelWithSharedWeights()
        optimizer = torch.optim.Adam(model.parameters())

        deepspeed_engine, _, _, _ = deepspeed.initialize(
            config=config,
            model=model,
            optimizer=optimizer,
        )
        ds_save_dir = tmpdir / "checkpoint_ds"
        deepspeed_engine.save_checkpoint(ds_save_dir, tag="checkpoint")

        model = ModelWithSharedWeights()

        # save checkpoint
        fp32_save_dir = tmpdir / "checkpoint_fp32"
        convert_zero_checkpoint_to_fp32_state_dict(ds_save_dir, fp32_save_dir)

        # load state_dict from fp32 checkpoint
        state_dict = torch.load(fp32_save_dir / 'pytorch_model.bin')

        # check shared tensor
        assert id(state_dict['layer1.weight']) == id(state_dict['layer2.weight'])

        # load state_dict into model
        model.load_state_dict(state_dict, strict=True)