Validating Zero Optimization Parameter Inheritance in DeepSpeed
This test suite validates parameter initialization and inheritance behavior in DeepSpeed’s Zero Optimization Stage 3, focusing on class hierarchy and parameter partitioning. It ensures proper handling of parameters across multiple inheritance levels and verifies DeepSpeed’s initialization process.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
microsoft/deepspeed
tests/unit/runtime/zero/test_zero_context_ancestry.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator
from utils import setup_serial_env
from unit.common import DistributedTest
config = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 138.
},
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": 1,
}
}
# test that sub-classes get params that aren't prematurely partitioned and thus requiring gathering
# fixed by https://github.com/microsoft/DeepSpeed/pull/1202
class GrandPa(torch.nn.Module):
def __init__(self, *args):
super().__init__(*args)
self.param_grandpa = torch.nn.Parameter(torch.ones(5))
self.param_grandpa.data = (self.param_grandpa.data + 1).data # test param is not yet partitioned
class Pa(GrandPa):
def __init__(self, *args):
super().__init__(*args)
self.param_pa = torch.nn.Parameter(torch.ones(5))
self.param_pa.data = (self.param_pa.data + 1).data # test param is not yet partitioned
self.param_grandpa.data = (self.param_grandpa.data + 1).data # test param is not yet partitioned
class Son(Pa):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.ones(5))
self.param.data = (self.param.data + 1).data # test param is not yet partitioned
self.param_pa.data = (self.param_pa.data + 1).data # test param is not yet partitioned
self.param_grandpa.data = (self.param_grandpa.data + 1).data # test param is not yet partitioned
class TestSerialParamInit(DistributedTest):
world_size = 1
init_distributed = False
set_dist_env = False
def test_subclass_param_init(self):
setup_serial_env()
with deepspeed.zero.Init(config=config):
model = Son().cpu()
# test that all params have been partitioned
assert model.param_grandpa.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert model.param_pa.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
# test that the weights manipulation during each __init__ worked in all w/o needing gathering
ones = torch.ones(5).half().to(get_accelerator().device_name())
with deepspeed.zero.GatheredParameters(list(model.parameters(recurse=False))):
assert torch.equal(model.param, ones + 1)
assert torch.equal(model.param_pa, ones + 2)
assert torch.equal(model.param_grandpa, ones + 3)
class TestDSInitWZinit(DistributedTest):
world_size = 2
def test(self):
ds_config = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def magic(self):
return 42
with deepspeed.zero.Init():
model = Model()
engine, *_ = deepspeed.initialize(model=model, config=ds_config, model_parameters=model.parameters())
assert engine.magic() == 42