Back to Repositories

Testing Zero-3 Parameter Context Management in DeepSpeed

This test suite validates the handling of external parameters and return types in DeepSpeed’s Zero Stage-3 optimization context. It focuses on testing parameter partitioning, external parameter management, and various output types in distributed training scenarios.

Test Coverage Overview

The test suite covers critical aspects of DeepSpeed’s Zero Stage-3 parameter handling:
  • External parameter management in nested model structures
  • Parameter persistence across forward/backward passes
  • Variable output type handling in Zero Stage-3
  • Parameter status verification during model execution

Implementation Analysis

The testing approach implements multiple model classes that simulate real-world scenarios of parameter handling:
  • Uses inheritance patterns to test parameter propagation
  • Implements custom forward passes to validate parameter status
  • Validates parameter persistence across model hierarchy
  • Tests different return type scenarios including tensors, dictionaries, and custom objects

Technical Details

Testing infrastructure utilizes:
  • PyTest framework for test organization
  • DeepSpeed’s distributed testing utilities
  • Custom model containers for parameter tracking
  • Mixed precision training configurations (FP16/BF16)
  • Zero Stage-3 optimization settings

Best Practices Demonstrated

The test suite showcases several testing best practices:
  • Systematic parameter state validation
  • Comprehensive edge case coverage
  • Clear separation of test scenarios
  • Proper test isolation and environment setup
  • Parametrized testing for multiple scenarios

microsoft/deepspeed

tests/unit/runtime/zero/test_zero_context_return.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from types import SimpleNamespace
import torch
import pytest
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

from utils import setup_serial_env
from unit.common import DistributedTest, preferred_dtype


class DanglingBias(torch.nn.Linear):

    def forward(self, *inputs):
        out = super().forward(*inputs)
        # return the bias to trigger a dangling external param
        return out, self.bias


class DataClass:
    """Just wraps data in an object. """

    def __init__(self, out=None, bias=None):
        self.out = out
        self.bias = bias


class DanglingBiasClass(DanglingBias):

    def forward(self, *inputs):
        out, bias = super().forward(*inputs)
        return DataClass(out=out, bias=bias)


class DanglingAttention(torch.nn.Linear):

    def __init__(self, dim=16, return_obj=False):
        super().__init__(dim, dim)
        self.dim = dim
        self.return_obj = return_obj
        if return_obj:
            self.d_linear = DanglingBiasClass(dim, dim)
        else:
            self.d_linear = DanglingBias(dim, dim)

    def forward(self, input):
        out = super().forward(input)
        if self.return_obj:
            out_obj = self.d_linear(out)
            assert out_obj.bias.ds_status == ZeroParamStatus.AVAILABLE
            # forward the external param
            return out_obj.out, out_obj.bias
        else:
            out, bias = self.d_linear(out)
            assert hasattr(bias, 'ds_status') or hasattr(bias, 'ds_param_alias')
            z3_bias = bias if hasattr(bias, 'ds_status') else bias.ds_param_alias
            assert z3_bias.ds_status == ZeroParamStatus.AVAILABLE
            return out, bias


class ModelContainer(torch.nn.Module):

    def __init__(self, dim=16, return_obj=False):
        super().__init__()
        self.dim = dim
        self.linear1 = torch.nn.Linear(dim, dim)
        self.dangler = DanglingAttention(dim, return_obj=return_obj)

    def forward(self, input):
        act1 = self.linear1(input)
        # bias is actually dangler.d_linear1.bias
        act2, bias = self.dangler(act1)
        return (act2 + bias).sum()


class DanglingExt(torch.nn.Module):

    def __init__(self, dim=16):
        super().__init__()
        self.dim = dim
        self.container = ModelContainer(dim)

    def forward(self, input):
        out = self.container(input)

        # Make sure it's at the right level of the stack
        assert len(self._external_params) == 0
        assert len(self.container._external_params) == 1
        assert len(self.container.dangler._external_params) == 0
        return out


class ModelContainerVariableOutputType(ModelContainer):

    def __init__(self, dim=16, output_type=dict):
        super().__init__()
        self.output_type = output_type
        self.dim = dim
        self.linear1 = torch.nn.Linear(dim, dim)

    def forward(self, input):
        act1 = self.linear1(input)
        if self.output_type is dict:
            return {'loss': act1.sum()}
        if self.output_type is torch.tensor:
            return act1.sum()


config = {
    "train_batch_size": 1,
    "steps_per_print": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.00015
        }
    },
    "zero_optimization": {
        "stage": 3,
        "stage3_param_persistence_threshold": 1,
    }
}

if get_accelerator().is_fp16_supported():
    config["fp16"] = {"enabled": True, "loss_scale": 138.}
elif get_accelerator().is_bf16_supported():
    config["bf16"] = {"enabled": True}


class TestReturnParam(DistributedTest):
    world_size = 1

    def test_ext_param_return(self):
        setup_serial_env()

        net = DanglingExt()

        args = SimpleNamespace(local_rank=0)
        engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)

        for _ in range(5):
            input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
            loss = engine(input)
            engine.backward(loss)
            engine.step()

    @pytest.mark.skip('WIP')
    def test_ext_param_returnobj(self):
        setup_serial_env()
        print()

        net = ModelContainer(return_obj=True)

        args = SimpleNamespace(local_rank=0)
        engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)

        for _ in range(5):
            input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
            loss = engine(input)
            assert len(net._external_params) == 1
            assert len(net.dangler._external_params) == 0
            engine.backward(loss)
            engine.step()

    @pytest.mark.parametrize('output_type', [torch.tensor, dict, None])
    def test_stage_3_output_type(self, output_type):
        setup_serial_env()
        print()

        net = ModelContainerVariableOutputType(output_type=output_type)

        args = SimpleNamespace(local_rank=0)
        engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)

        for _ in range(1):
            input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
            loss = engine(input)
            if loss is not None:
                if isinstance(loss, dict):
                    loss = loss['loss']
                engine.backward(loss)
                engine.step()