Validating Zero Context Parameter Management in DeepSpeed
A comprehensive unit test suite for DeepSpeed’s Zero Context functionality, focusing on parameter gathering, scattering, and memory management in distributed training scenarios. The tests verify Zero Stage 3 optimization behavior and parameter handling across different model configurations.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
microsoft/deepspeed
tests/unit/runtime/zero/test_zero_context.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from types import SimpleNamespace
import torch
import pytest
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import SimpleModel
from utils import setup_serial_env
# Test that no sub-class or super-class is missed
class ConvX(torch.nn.Conv1d):
def __init__(self, *args):
super().__init__(*args)
# This would not be partitioned before bugfix 5ca8167
self.param_in = torch.nn.Parameter(torch.FloatTensor(5).uniform_())
def forward(self, x):
return x
class ConvNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = ConvX(1, 3, 4)
self.param = torch.nn.Parameter(torch.FloatTensor(5).uniform_())
def forward(self, x):
return x
config = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": 1,
}
}
if get_accelerator().is_fp16_supported():
config["fp16"] = {"enabled": True, "loss_scale": 138.}
elif get_accelerator().is_bf16_supported():
config["bf16"] = {"enabled": True}
class TestZeroGatheredParametersFree(DistributedTest):
world_size = 1
def test(self):
config_dict = {"train_batch_size": 1, "zero_optimization": {"stage": 3}}
hidden_dim = 10
class MyModel(torch.nn.Module):
def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = MyModel(hidden_dim)
with deepspeed.zero.GatheredParameters(list(model.parameters())):
assert model.l1.weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor"
# on exit from `GatheredParameters` the gathered params should be freed and not leak memory
assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized"
class TestMiCSGatheredParametersFree(DistributedTest):
world_size = 1
def test(self):
config_dict = {"train_batch_size": 1, "zero_optimization": {"stage": 3, "mics_shard_size": 1}}
hidden_dim = 10
class MyModel(torch.nn.Module):
def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
model = MyModel(hidden_dim)
with deepspeed.zero.GatheredParameters(list(model.parameters())):
assert model.l1.weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor"
# on exit from `GatheredParameters` the gathered params should be freed and not leak memory
assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized"
class TestSerialContext(DistributedTest):
world_size = 1
init_distributed = False
set_dist_env = False
def test_subclass_param(self):
setup_serial_env()
with deepspeed.zero.Init(config=config):
model = ConvNet()
assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert model.conv1.param_in.ds_status == ZeroParamStatus.NOT_AVAILABLE
def test_scattered_init_dist(self):
setup_serial_env()
assert not dist.is_initialized()
with deepspeed.zero.Init():
assert dist.is_initialized()
def test_scatter_halftype(self):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
setup_serial_env()
with deepspeed.zero.Init():
l = torch.nn.Linear(10, 10)
assert l.weight.ds_tensor.dtype == torch.float16
y = torch.LongTensor([3, 3])
assert y.dtype == torch.long
def test_throughput_calculation(self):
setup_serial_env()
train_micro_batch_size_per_gpu = 7
gradient_accumulation_steps = 6
config_dict = {
"train_micro_batch_size_per_gpu": train_micro_batch_size_per_gpu,
"gradient_accumulation_steps": gradient_accumulation_steps,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
}
},
"zero_optimization": {
"stage": 0
},
}
args = SimpleNamespace(local_rank=0)
net = SimpleModel(hidden_dim=4)
engine, _, _, _ = deepspeed.initialize(args=args,
config=config_dict,
model=net,
model_parameters=net.parameters())
assert engine.tput_timer.batch_size == train_micro_batch_size_per_gpu * gradient_accumulation_steps
assert not engine.tput_timer.initialized
assert not engine.tput_timer.started
assert engine.tput_timer.start_step == 2
assert engine.tput_timer.start_time == 0
assert engine.tput_timer.micro_step_count == 0
assert engine.tput_timer.global_step_count == 0
assert engine.tput_timer.total_elapsed_time == 0
# calling stop() while uninitialized - has no effect
engine.tput_timer.stop()
assert not engine.tput_timer.initialized
assert not engine.tput_timer.started
assert engine.tput_timer.start_time == 0
assert engine.tput_timer.micro_step_count == 0
assert engine.tput_timer.global_step_count == 0
assert engine.tput_timer.total_elapsed_time == 0
# any call to start() (from dataloader or not) initializes the timer
engine.tput_timer.start()
assert engine.tput_timer.initialized
assert engine.tput_timer.started
assert engine.tput_timer.start_time == 0
assert engine.tput_timer.micro_step_count == 0
assert engine.tput_timer.global_step_count == 0
assert engine.tput_timer.total_elapsed_time == 0
# calling stop() after initialized - increments the local micro step counter
engine.tput_timer.stop()
assert engine.tput_timer.initialized
assert not engine.tput_timer.started
assert engine.tput_timer.start_time == 0
assert engine.tput_timer.micro_step_count == 1
assert engine.tput_timer.global_step_count == 0
assert engine.tput_timer.total_elapsed_time == 0
# calling start()/stop() to increment the step counter until start_step
while engine.tput_timer.micro_step_count < (gradient_accumulation_steps * engine.tput_timer.start_step):
engine.tput_timer.start()
global_step = (engine.tput_timer.micro_step_count + 1) % gradient_accumulation_steps == 0
engine.tput_timer.stop(global_step=global_step)
assert engine.tput_timer.global_step_count == engine.tput_timer.start_step
assert engine.tput_timer.total_elapsed_time == 0
# calling start()/stop() accumulates duration during gradient accumulation
while engine.tput_timer.global_step_count == engine.tput_timer.start_step:
engine.tput_timer.start()
current_duration = engine.tput_timer.step_elapsed_time
total_duration = engine.tput_timer.total_elapsed_time
global_step = (engine.tput_timer.micro_step_count + 1) % gradient_accumulation_steps == 0
engine.tput_timer.stop(global_step=global_step)
duration = engine.tput_timer.end_time - engine.tput_timer.start_time
# step elapsed time is reset after gradient accumulation steps
assert engine.tput_timer.step_elapsed_time == (0 if engine.tput_timer.global_step_count
!= engine.tput_timer.start_step else current_duration +
duration)
assert engine.tput_timer.total_elapsed_time == total_duration + duration
def test_ext_param_getattr(self):
setup_serial_env()
class ExtLinear(torch.nn.Module):
def __init__(self, dim=16):
super().__init__()
self.dim = dim
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
def forward(self, input):
A = self.linear1(input)
B = self.linear2(A)
# external use of self.linear1.weight
C = torch.nn.functional.linear(B, self.linear1.weight)
return C.sum()
net = ExtLinear()
args = SimpleNamespace(local_rank=0)
engine, optim, _, _ = deepspeed.initialize(args=args,
model=net,
model_parameters=net.parameters(),
config=config)
with deepspeed.zero.GatheredParameters(net.linear1.weight):
assert net.linear1.weight.numel() == net.dim**2
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
loss = engine(input)
engine.backward(loss)
engine.step()
class TestScatterGather(DistributedTest):
world_size = 2
def test(self):
with deepspeed.zero.Init():
l = torch.nn.Linear(6, 3)
assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert l.weight.shape == torch.Size(partitioned_param_data_shape)
# Ensure there is no impact outside the context
l2 = torch.nn.Linear(6, 3)
assert not hasattr(l2.weight, 'ds_status')
assert l2.weight.numel() == l2.in_features * l2.out_features
with deepspeed.zero.GatheredParameters(l.weight):
assert l.weight.ds_status == ZeroParamStatus.AVAILABLE
assert l.weight.numel() == l.in_features * l.out_features
class TestGatherUpdate(DistributedTest):
world_size = 2
def test(self):
with deepspeed.zero.Init():
l = torch.nn.Linear(4, 2)
assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
# Gather and make a change
with deepspeed.zero.GatheredParameters(l.weight, modifier_rank=1):
assert l.weight.ds_status == ZeroParamStatus.AVAILABLE
if dist.get_rank() == 1:
with torch.no_grad():
l.weight.zero_()
# should now be scattered again
# Now gather again and ensure the change is global
with deepspeed.zero.GatheredParameters(l.weight):
# all ranks compare
assert torch.equal(l.weight, torch.zeros_like(l.weight))