Testing Sparse Gradient Optimization in DeepSpeed
This test suite validates sparse gradient handling in DeepSpeed, focusing on the integration between sparse embedding layers and optimization algorithms. It ensures proper functionality of sparse gradient computation and synchronization across distributed training environments.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
microsoft/deepspeed
tests/unit/runtime/sparse_tensor/test_sparse_grads.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import pytest
import deepspeed
from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator
import deepspeed.utils.groups as groups
if get_accelerator().device_name() == 'hpu':
pytest.skip("sparse_gradients not supported by HPU.", allow_module_level=True)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
self.linear = torch.nn.Linear(3, 1)
def forward(self, x, offsets):
return self.linear(self.emb(x, offsets))
class Adam(torch.optim.Optimizer):
def __init__(self, dense_params, sparse_params):
super().__init__(dense_params + sparse_params, defaults={})
self.adam = torch.optim.Adam(dense_params)
self.adam_sparse = torch.optim.SparseAdam(sparse_params)
@torch.no_grad()
def step(self, closure=None):
loss_1 = self.adam.step(closure)
loss_2 = self.adam_sparse.step(closure)
if loss_1 is not None and loss_2 is not None:
return loss_1 + loss_2
return loss_1 or loss_2
class TestSparseAdam(DistributedTest):
world_size = 2
def test(self):
config_dict = {"train_batch_size": 2, "steps_per_print": 1, "sparse_gradients": True}
model = Model()
optimizer = Adam(list(model.linear.parameters()), list(model.emb.parameters()))
engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=config_dict)
loss = torch.nn.BCEWithLogitsLoss()
x = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long, device=engine.device)
offsets = torch.tensor([0, 4], dtype=torch.long, device=engine.device)
y = torch.tensor([[1.0], [0.0]], device=engine.device)
res = engine(x, offsets)
engine.backward(loss(res, y))
engine.step()
results = [engine.all_gather_scalar(i, groups._get_data_parallel_group()) for i in model.emb.parameters()]
for res in results:
assert torch.allclose(res[0], res[1])