Back to Repositories

Testing Automatic Casting Behavior in DeepSpeed Linear Module Operations

This test suite validates the automatic casting functionality in DeepSpeed’s runtime environment, specifically focusing on the LinearModuleForZeroStage3 implementation. It verifies proper handling of different data types and automatic casting behaviors in both enabled and disabled states.

Test Coverage Overview

The test suite provides comprehensive coverage of automatic casting scenarios in DeepSpeed’s linear module operations.

Key functionality tested includes:
  • Missing AMP autocast behavior
  • Disabled autocast operations
  • Enabled autocast with various input/weight combinations
Edge cases cover different data type combinations between inputs and weights, including half-precision scenarios.

Implementation Analysis

The testing approach utilizes pytest’s parametrization to systematically verify autocast behavior across different configurations. The implementation employs distributed testing patterns through the DistributedTest class, with specific focus on LinearModuleForZeroStage3 module testing.

Framework features leverage pytest’s parametrize decorator for comprehensive test combinations and proper isolation of test cases.

Technical Details

Testing tools and components include:
  • PyTorch for tensor operations and data type handling
  • DeepSpeed’s accelerator utilities
  • Custom LinearModuleForZeroStage3 implementation
  • Pytest for test organization and execution
Configuration includes device-specific setup and precision control mechanisms.

Best Practices Demonstrated

The test suite exemplifies high-quality testing practices through systematic organization and thorough coverage.

Notable practices include:
  • Parametrized test cases for comprehensive coverage
  • Proper test isolation and setup
  • Clear assertion checks for data types
  • Efficient test class organization

microsoft/deepspeed

tests/unit/runtime/test_autocast.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest


@pytest.mark.parametrize('half_op', [False, True])
class TestAutoCastDisable(DistributedTest):

    def test_missing_amp_autocast(self, half_op):
        hidden_dim = 4
        if half_op:
            input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half()
            ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()).half()
        else:
            input = torch.randn(hidden_dim).to(get_accelerator().device_name())
            ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())

        output = ds_linear(input)
        assert output.dtype == ds_linear.weight.dtype

    def test_disable_autocast_linear(self, half_op):
        amp = get_accelerator().amp()

        hidden_dim = 4
        if half_op:
            input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half()
            ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()).half()
        else:
            input = torch.randn(hidden_dim).to(get_accelerator().device_name())
            ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())

        with amp.autocast(False):
            output = ds_linear(input)
            assert output.dtype == ds_linear.weight.dtype


@pytest.mark.skipif(get_accelerator().amp() is None, reason='amp is not installed')
@pytest.mark.parametrize('half_input, half_weight', [(False, False), (False, True), (True, False), (True, True)])
class TestAutoCastEnable(DistributedTest):

    def test_autocast_linear(self, tmpdir, half_input, half_weight):
        amp = get_accelerator().amp()

        hidden_dim = 4
        input = torch.randn(hidden_dim).to(get_accelerator().device_name())
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())

        if half_input:
            input = input.half()

        if half_weight:
            ds_linear = ds_linear.half()

        with amp.autocast():
            output = ds_linear(input)
            assert output.dtype == torch.half or output.dtype == torch.bfloat16