Back to Repositories

Testing Sparse Gradient Optimization in DeepSpeed

This test suite validates sparse gradient handling in DeepSpeed, focusing on the integration between sparse embedding layers and optimization algorithms. It ensures proper functionality of sparse gradient computation and synchronization across distributed training environments.

Test Coverage Overview

The test suite covers sparse gradient computation and optimization in distributed settings.

  • Tests EmbeddingBag layer with sparse gradients
  • Validates custom Adam optimizer combining dense and sparse parameter updates
  • Verifies gradient synchronization across multiple processes
  • Tests distributed training with sparse gradients enabled

Implementation Analysis

The implementation uses a hybrid testing approach combining PyTorch’s sparse operations with DeepSpeed’s distributed training capabilities.

Key patterns include:
  • Custom Model class with mixed sparse/dense layers
  • Hybrid optimizer implementation for handling both parameter types
  • DeepSpeed initialization with sparse_gradients configuration
  • Distributed gradient verification using all_gather operations

Technical Details

  • Testing Framework: pytest with DistributedTest class
  • World Size: 2 (testing multi-process scenarios)
  • Key Components: torch.nn.EmbeddingBag, custom Adam optimizer
  • DeepSpeed Configuration: sparse_gradients enabled
  • Device Compatibility: Skips HPU devices

Best Practices Demonstrated

The test implementation showcases several testing best practices for distributed deep learning systems.

  • Proper isolation of sparse and dense parameter handling
  • Comprehensive gradient synchronization verification
  • Clean separation of model, optimizer, and training logic
  • Explicit device placement and tensor type management
  • Proper distributed test setup and teardown

microsoft/deepspeed

tests/unit/runtime/sparse_tensor/test_sparse_grads.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import pytest
import deepspeed
from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator
import deepspeed.utils.groups as groups

if get_accelerator().device_name() == 'hpu':
    pytest.skip("sparse_gradients not supported by HPU.", allow_module_level=True)


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.emb = torch.nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
        self.linear = torch.nn.Linear(3, 1)

    def forward(self, x, offsets):
        return self.linear(self.emb(x, offsets))


class Adam(torch.optim.Optimizer):

    def __init__(self, dense_params, sparse_params):
        super().__init__(dense_params + sparse_params, defaults={})
        self.adam = torch.optim.Adam(dense_params)
        self.adam_sparse = torch.optim.SparseAdam(sparse_params)

    @torch.no_grad()
    def step(self, closure=None):
        loss_1 = self.adam.step(closure)
        loss_2 = self.adam_sparse.step(closure)

        if loss_1 is not None and loss_2 is not None:
            return loss_1 + loss_2
        return loss_1 or loss_2


class TestSparseAdam(DistributedTest):
    world_size = 2

    def test(self):
        config_dict = {"train_batch_size": 2, "steps_per_print": 1, "sparse_gradients": True}

        model = Model()
        optimizer = Adam(list(model.linear.parameters()), list(model.emb.parameters()))
        engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=config_dict)
        loss = torch.nn.BCEWithLogitsLoss()
        x = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long, device=engine.device)
        offsets = torch.tensor([0, 4], dtype=torch.long, device=engine.device)
        y = torch.tensor([[1.0], [0.0]], device=engine.device)
        res = engine(x, offsets)
        engine.backward(loss(res, y))
        engine.step()

        results = [engine.all_gather_scalar(i, groups._get_data_parallel_group()) for i in model.emb.parameters()]
        for res in results:
            assert torch.allclose(res[0], res[1])