Back to Repositories

Testing Post-Layer Normalization Module Implementation in DeepSpeed

This test suite validates the post-layer normalization module implementation in DeepSpeed’s inference V2 system. It ensures correct computation of layer normalization with residual connections across various tensor dimensions and data types.

Test Coverage Overview

The test suite provides comprehensive coverage of the CUDA-based post-layer normalization module.

Key areas tested include:
  • Variable input tensor dimensions (tokens and channels)
  • Multiple data type configurations
  • Residual connection handling
  • Numerical accuracy against reference implementation

Implementation Analysis

The testing approach uses parameterized testing with pytest to verify the post-layer normalization implementation. It employs a reference implementation for comparison and validates the module across different configurations.

Key patterns include:
  • Dynamic tensor shape handling
  • Data type conversion management
  • Configuration-based module instantiation
  • Accelerator-aware testing

Technical Details

Testing infrastructure includes:
  • PyTest framework with parametrization
  • DeepSpeed accelerator utilities
  • Custom configuration bundles
  • Tensor comparison utilities
  • CUDA device management
  • Floating-point precision control

Best Practices Demonstrated

The test implementation showcases several testing best practices for deep learning components.

Notable practices include:
  • Systematic parameter space exploration
  • Reference implementation comparison
  • Explicit dtype handling
  • Modular test configuration
  • Clear separation of setup and verification

microsoft/deepspeed

tests/unit/inference/v2/modules/test_post_ln_module.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.modules import ConfigBundle
from deepspeed.inference.v2.modules.configs import DSNormConfig
from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry
from ...v2.inference_test_utils import get_dtypes, allclose


def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor,
                             beta: torch.Tensor, epsilon: float) -> torch.Tensor:
    residual_f = residual.to(torch.float32)
    hidden_states_f = hidden_states.to(torch.float32)
    gamma_f = gamma.to(torch.float32)
    beta_f = beta.to(torch.float32)
    return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ),
                                          weight=gamma_f,
                                          bias=beta_f,
                                          eps=epsilon).to(hidden_states.dtype)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_cuda_post_ln_module(tokens: int, channels: int, dtype: torch.dtype) -> None:
    config = DSNormConfig(max_tokens=2048,
                          type="layer_norm",
                          channels=channels,
                          residual_dtype=dtype,
                          input_dtype=dtype,
                          output_dtype=dtype,
                          eps=1e-5)
    bundle = ConfigBundle(name='cuda_post_ln', config=config)

    # Input vals
    hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name())
    residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name())
    gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name())
    beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name())
    epsilon = 1e-5

    # Reference output
    ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon)

    # New output
    post_ln_module = DSPostNormRegistry.instantiate_config(bundle)
    gamma = post_ln_module.transform_param(gamma)
    beta = post_ln_module.transform_param(beta)
    ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta)

    # Check
    assert allclose(ds_output, ref_output)