Back to Repositories

Testing Zero Stage 3 Multiple Model Calls in DeepSpeed

This test suite validates DeepSpeed’s Zero Stage 3 optimization functionality when making multiple model calls within the same training iteration. It focuses on ensuring proper gradient accumulation and optimization behavior when combining losses from different forward passes.

Test Coverage Overview

The test suite covers Zero Stage 3 optimization with multiple model forward passes.

Key functionality includes:
  • Multiple forward pass handling with loss combination
  • Gradient accumulation verification
  • Zero Stage 3 parameter management
  • Mixed precision training support (FP16/BF16)

Implementation Analysis

The testing approach uses a SimpleModel implementation with configurable hidden dimensions and layers.

Key patterns include:
  • Iterative data loading with random samples
  • Combined loss computation from multiple forward passes
  • Gradient tracking verification
  • Explicit backward pass handling

Technical Details

Testing tools and configuration:
  • DeepSpeed initialization with Zero Stage 3
  • Adam optimizer configuration
  • Configurable precision training (FP16/BF16)
  • Custom distributed test framework
  • PyTorch-based model implementation

Best Practices Demonstrated

The test exhibits strong testing practices through comprehensive validation of Zero Stage 3 behavior.

Notable practices include:
  • Explicit gradient remaining verification
  • Proper model parameter initialization
  • Controlled test environment setup
  • Systematic assertion checking

microsoft/deepspeed

tests/unit/runtime/zero/test_zero_multiple_run.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed
import torch
from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import SimpleModel, random_dataloader


class TestZ3MultipleModelCall(DistributedTest):
    world_size = 1

    def test_z3_multiple_model_call(self):
        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "steps_per_print": 1,
            "zero_optimization": {
                "stage": 3
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
        }
        if preferred_dtype() is torch.float16:
            config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
        elif preferred_dtype() is torch.bfloat16:
            config_dict["bf16"] = {"enabled": True}
        hidden_dim, nlayers = 2048, 3
        model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers)
        model_engine, _, _, _ = deepspeed.initialize(config=config_dict,
                                                     model=model,
                                                     model_parameters=model.parameters())
        data_loader = iter(
            random_dataloader(model=model_engine, total_samples=10, hidden_dim=hidden_dim, device=model_engine.device))

        for n, batch in enumerate(data_loader):
            loss1 = model_engine(batch[0], batch[1])
            with torch.no_grad():
                loss2 = model_engine(batch[0], batch[1])
            loss = loss1 + loss2
            model_engine.backward(loss)
            for name, submodule in model_engine.module.linears._modules.items():
                assert hasattr(submodule, "ds_grads_remaining"), \
                  f"linears.{name} does not have variable ds_grads_remaining"
                assert submodule.ds_grads_remaining == 0, \
                  f"ds_grads_remaining of linears.{name} is not 0 ({submodule.ds_grads_remaining})"
            model_engine.step()