Back to Repositories

Validating MoE Checkpoint Operations in DeepSpeed

This test suite validates checkpoint functionality for Mixture of Experts (MoE) models in DeepSpeed, focusing on state preservation and recovery. It ensures proper handling of model parameters, optimizer states, and compatibility with DeepSpeed’s Zero Optimization stages.

Test Coverage Overview

The test suite provides comprehensive coverage of MoE checkpoint operations:
  • Basic MoE checkpoint saving and loading with different expert configurations
  • Integration with Zero Optimization stages
  • Verification of optimizer state preservation
  • Parameter splitting and grouping for MoE optimization
  • Edge cases for different expert partition sizes

Implementation Analysis

The testing approach utilizes pytest’s parametrization to validate multiple configurations. It implements a systematic verification process through the checkpoint_correctness_verification utility, testing both standalone MoE and MoE with Zero Optimization scenarios.

The implementation leverages PyTorch 1.8+ specific features for MoE functionality and employs custom model definitions for controlled testing environments.

Technical Details

Key technical components include:
  • PyTorch 1.8+ requirement for MoE support
  • FP16 precision testing configuration
  • AdamW optimizer integration
  • Distributed testing environment with configurable world size
  • Custom checkpoint verification utilities
  • Integration with DeepSpeed’s MoE utilities and Zero Optimization

Best Practices Demonstrated

The test suite exemplifies robust testing practices through:
  • Parameterized test cases for comprehensive coverage
  • Explicit version compatibility checks
  • Systematic state verification
  • Controlled test environments with specific configurations
  • Integration of multiple optimization techniques
  • Clear separation of test scenarios

microsoft/deepspeed

tests/unit/checkpoint/test_moe_checkpoint.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
from deepspeed.utils.torch import required_torch_version

from unit.common import DistributedTest
from unit.simple_model import *

from unit.checkpoint.common import checkpoint_correctness_verification

import pytest


class TestMoECheckpoint(DistributedTest):
    world_size = 4

    @pytest.mark.parametrize("ep_size", [4])
    def test_checkpoint_moe(self, tmpdir, ep_size):
        if not required_torch_version(min_version=1.8):
            pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

        config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
        hidden_dim = 16

        models = [SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2)]
        optimizers = [torch.optim.AdamW(params=model.parameters()) for model in models]
        checkpoint_correctness_verification(config_dict,
                                            models=models,
                                            hidden_dim=hidden_dim,
                                            tmpdir=tmpdir,
                                            load_optimizer_states=True,
                                            load_lr_scheduler_states=False,
                                            empty_tag=True,
                                            base_optimizers=optimizers,
                                            seq_dataloader=True,
                                            dtype=torch.float16)

    @pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)])
    def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states):
        if not required_torch_version(min_version=1.8):
            pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

        config_dict = {
            "train_batch_size": 8,
            "steps_per_print": 1,
            "optimizer": {
                "type": 'Adam',
                "params": {
                    "lr": 0.00015,
                    "betas": [0.8, 0.999],
                    "eps": 1e-8,
                    "weight_decay": 3e-7
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            },
            "zero_optimization": {
                "stage": 2,
            }
        }
        hidden_dim = 16

        models = [SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2)]
        # param group must have a random unique name (for now)
        # TODO: clean-up this requirement, the unique name should not be required here
        param_groups = [{'params': [p for p in model.parameters()], 'name': 'random-unique-name'} for model in models]
        params = [split_params_into_different_moe_groups_for_optimizer(group) for group in param_groups]
        optimizers = [torch.optim.AdamW(params=param) for param in params]
        checkpoint_correctness_verification(config_dict,
                                            models=models,
                                            hidden_dim=hidden_dim,
                                            tmpdir=tmpdir,
                                            load_optimizer_states=load_optim_states,
                                            load_lr_scheduler_states=False,
                                            empty_tag=True,
                                            base_optimizers=optimizers,
                                            seq_dataloader=True,
                                            dtype=torch.float16)