Back to Repositories

Testing DeepSpeed Zero Stage 3 Optimization Implementation in microsoft/DeepSpeed

This test suite validates the implementation of DeepSpeed’s Zero Optimization stage 3 with a small linear model for debugging purposes. It focuses on testing parameter partitioning and gradient computation in a multi-layer neural network setup.

Test Coverage Overview

The test coverage encompasses the initialization and training flow of a linear neural network using DeepSpeed’s Zero stage 3 optimization.

Key areas tested include:
  • Model initialization with Zero.Init()
  • Forward pass through multiple linear layers
  • Backward pass with gradient computation
  • Parameter partitioning verification

Implementation Analysis

The testing approach uses a custom LinearStack model with verbose initialization tracking to debug Zero stage 3 behavior. The implementation leverages PyTorch’s ModuleList and Linear layers with DeepSpeed’s zero optimization wrapper to validate parameter partitioning and gradient flow.

Technical patterns include:
  • Custom layer inheritance for debugging
  • Modular layer stacking
  • Gradient tracking setup
  • Zero initialization integration

Technical Details

Testing tools and configuration:
  • PyTorch framework for model definition
  • DeepSpeed Zero Optimization stage 3
  • Custom VerboseLinear layer extending torch.nn.Linear
  • Random tensor initialization for inputs and gradients
  • Model configuration: variable input, hidden, and output dimensions

Best Practices Demonstrated

The test implementation showcases several testing best practices for deep learning model debugging.

Notable practices include:
  • Isolated component testing
  • Verbose initialization tracking
  • Controlled gradient flow verification
  • Modular model architecture
  • Clear separation of setup and execution phases

microsoft/deepspeed

tests/small_model_debugging/stage3_test.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

import deepspeed

###################################
# Setup
###################################


class VerboseLinear(torch.nn.Linear):

    def __init__(self, **kwargs):
        print(f'Begin VerboseLinear.__init__')
        super().__init__(**kwargs)
        print(f'End VerboseLinear.__init__')


class LinearStack(torch.nn.Module):

    def __init__(self, input_dim=2, hidden_dim=4, output_dim=4, num_layers=2):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim

        self.input_layer = VerboseLinear(in_features=self.input_dim, out_features=self.hidden_dim)
        self.layers = torch.nn.ModuleList([
            torch.nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim, bias=False)
            for x in range(num_layers)
        ])
        self.output_layer = torch.nn.Linear(in_features=self.hidden_dim, out_features=self.output_dim)
        self.identity = torch.nn.Identity()

    def forward(self, x):
        x = self.input_layer(x)
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        x = self.identity(x)
        return x


###################################
# DRIVER
###################################


def test_driver():
    print()
    print('BUILDING MODEL')
    with deepspeed.zero.Init():
        model = LinearStack()
    print()

    # parted = [name for (name, p) in model.named_parameters() if p._partitioned]
    # not_parted = [name for (name, p) in model.named_parameters() if not p._partitioned]
    # print('partitioned: ', parted)
    # print('full: ', not_parted)
    # print()

    model.train()

    test_input = torch.rand(1, model.input_dim)
    grad_output = torch.rand(1, model.output_dim)

    grad_output.requires_grad = False
    test_input.requires_grad = False

    print()
    print('BEGINNING FORWARD')
    print()

    output = model(test_input)
    output.backward(grad_output)

    # parted = [name for (name, p) in model.named_parameters() if p._partitioned]
    # not_parted = [name for (name, p) in model.named_parameters() if not p._partitioned]
    # print('partitioned: ', parted)
    # print('full:' , not_parted)
    # print()

    #samyamspeed.disable()


test_driver()