Testing Automatic Casting Behavior in DeepSpeed Linear Module Operations
This test suite validates the automatic casting functionality in DeepSpeed’s runtime environment, specifically focusing on the LinearModuleForZeroStage3 implementation. It verifies proper handling of different data types and automatic casting behaviors in both enabled and disabled states.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
microsoft/deepspeed
tests/unit/runtime/test_autocast.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest
@pytest.mark.parametrize('half_op', [False, True])
class TestAutoCastDisable(DistributedTest):
def test_missing_amp_autocast(self, half_op):
hidden_dim = 4
if half_op:
input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half()
ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()).half()
else:
input = torch.randn(hidden_dim).to(get_accelerator().device_name())
ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())
output = ds_linear(input)
assert output.dtype == ds_linear.weight.dtype
def test_disable_autocast_linear(self, half_op):
amp = get_accelerator().amp()
hidden_dim = 4
if half_op:
input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half()
ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()).half()
else:
input = torch.randn(hidden_dim).to(get_accelerator().device_name())
ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())
with amp.autocast(False):
output = ds_linear(input)
assert output.dtype == ds_linear.weight.dtype
@pytest.mark.skipif(get_accelerator().amp() is None, reason='amp is not installed')
@pytest.mark.parametrize('half_input, half_weight', [(False, False), (False, True), (True, False), (True, True)])
class TestAutoCastEnable(DistributedTest):
def test_autocast_linear(self, tmpdir, half_input, half_weight):
amp = get_accelerator().amp()
hidden_dim = 4
input = torch.randn(hidden_dim).to(get_accelerator().device_name())
ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())
if half_input:
input = input.half()
if half_weight:
ds_linear = ds_linear.half()
with amp.autocast():
output = ds_linear(input)
assert output.dtype == torch.half or output.dtype == torch.bfloat16