Back to Repositories

Validating Zero Optimization Parameter Inheritance in DeepSpeed

This test suite validates parameter initialization and inheritance behavior in DeepSpeed’s Zero Optimization Stage 3, focusing on class hierarchy and parameter partitioning. It ensures proper handling of parameters across multiple inheritance levels and verifies DeepSpeed’s initialization process.

Test Coverage Overview

The test suite covers two main scenarios in DeepSpeed’s Zero Stage 3 implementation:
  • Parameter initialization across class inheritance (GrandPa -> Pa -> Son hierarchy)
  • DeepSpeed engine initialization with Zero context manager
Key functionality includes parameter partitioning, data manipulation during initialization, and proper parameter gathering.
Edge cases include multi-level inheritance and parameter access patterns.

Implementation Analysis

The testing approach utilizes both serial and distributed test environments to validate Zero Optimization behavior. Tests implement a three-level class hierarchy to verify parameter initialization and partitioning.
  • Uses ZeroParamStatus verification
  • Implements GatheredParameters context manager
  • Validates parameter values across inheritance chain

Technical Details

Testing tools and configuration:
  • PyTorch framework integration
  • DeepSpeed Zero Stage 3 configuration
  • FP16 optimization enabled
  • Custom DistributedTest base class
  • Serial and distributed environment setup

Best Practices Demonstrated

The test suite demonstrates several testing best practices:
  • Isolated test environments for different scenarios
  • Comprehensive parameter state verification
  • Clear inheritance structure testing
  • Proper test configuration management
  • Effective use of context managers for parameter handling

microsoft/deepspeed

tests/unit/runtime/zero/test_zero_context_ancestry.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

from utils import setup_serial_env
from unit.common import DistributedTest

config = {
    "train_batch_size": 1,
    "steps_per_print": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.00015
        }
    },
    "fp16": {
        "enabled": True,
        "loss_scale": 138.
    },
    "zero_optimization": {
        "stage": 3,
        "stage3_param_persistence_threshold": 1,
    }
}


# test that sub-classes get params that aren't prematurely partitioned and thus requiring gathering
# fixed by https://github.com/microsoft/DeepSpeed/pull/1202
class GrandPa(torch.nn.Module):

    def __init__(self, *args):
        super().__init__(*args)
        self.param_grandpa = torch.nn.Parameter(torch.ones(5))
        self.param_grandpa.data = (self.param_grandpa.data + 1).data  # test param is not yet partitioned


class Pa(GrandPa):

    def __init__(self, *args):
        super().__init__(*args)
        self.param_pa = torch.nn.Parameter(torch.ones(5))
        self.param_pa.data = (self.param_pa.data + 1).data  # test param is not yet partitioned
        self.param_grandpa.data = (self.param_grandpa.data + 1).data  # test param is not yet partitioned


class Son(Pa):

    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.ones(5))
        self.param.data = (self.param.data + 1).data  # test param is not yet partitioned
        self.param_pa.data = (self.param_pa.data + 1).data  # test param is not yet partitioned
        self.param_grandpa.data = (self.param_grandpa.data + 1).data  # test param is not yet partitioned


class TestSerialParamInit(DistributedTest):
    world_size = 1
    init_distributed = False
    set_dist_env = False

    def test_subclass_param_init(self):
        setup_serial_env()
        with deepspeed.zero.Init(config=config):
            model = Son().cpu()

        # test that all params have been partitioned
        assert model.param_grandpa.ds_status == ZeroParamStatus.NOT_AVAILABLE
        assert model.param_pa.ds_status == ZeroParamStatus.NOT_AVAILABLE
        assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE

        # test that the weights manipulation during each __init__ worked in all w/o needing gathering
        ones = torch.ones(5).half().to(get_accelerator().device_name())
        with deepspeed.zero.GatheredParameters(list(model.parameters(recurse=False))):
            assert torch.equal(model.param, ones + 1)
            assert torch.equal(model.param_pa, ones + 2)
            assert torch.equal(model.param_grandpa, ones + 3)


class TestDSInitWZinit(DistributedTest):
    world_size = 2

    def test(self):
        ds_config = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            }
        }

        class Model(torch.nn.Module):

            def __init__(self):
                super(Model, self).__init__()
                self.linear = torch.nn.Linear(4, 4)

            def magic(self):
                return 42

        with deepspeed.zero.Init():
            model = Model()
            engine, *_ = deepspeed.initialize(model=model, config=ds_config, model_parameters=model.parameters())
        assert engine.magic() == 42