Back to Repositories

Testing Custom Layer Normalization Module Registration in DeepSpeed

This test suite validates custom module registration and implementation for the Layer Normalization (LN) component in DeepSpeed’s inference v2 framework. It ensures proper integration of custom implementations outside the core repository while maintaining accuracy and performance.

Test Coverage Overview

The test suite covers custom Layer Normalization module registration and execution validation.

  • Validates custom module registration in DSPostNormRegistry
  • Tests implementation accuracy against reference Layer Normalization
  • Verifies tensor dtype handling and transformation
  • Ensures proper parameter handling for gamma and beta values

Implementation Analysis

The testing approach implements a custom PostLN module extending DSPostLNCUDAModule with PyTorch integration.

Key implementation patterns include:
  • Custom module registration using decorator pattern
  • Configuration bundling for module instantiation
  • Tensor dtype conversion and validation
  • Reference implementation comparison using allclose

Technical Details

Testing infrastructure utilizes:

  • PyTest framework with inference_v2_ops marker
  • DeepSpeed accelerator utilities
  • Custom ConfigBundle and DSNormConfig implementations
  • CUDA-enabled tensor operations
  • PyTorch’s nn.functional layer_norm for reference

Best Practices Demonstrated

The test implementation showcases robust testing practices for deep learning components.

  • Explicit type checking and conversion
  • Comprehensive parameter validation
  • Modular test structure with clear separation of concerns
  • Reference implementation comparison for accuracy validation
  • Device-aware tensor handling

microsoft/deepspeed

tests/unit/inference/v2/modules/test_custom_module.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.modules import ConfigBundle
from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry
from deepspeed.inference.v2.modules.configs import DSNormConfig
from deepspeed.inference.v2.modules.implementations import cuda_post_ln
from ...v2.inference_test_utils import allclose


def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor,
                             beta: torch.Tensor, epsilon: float) -> torch.Tensor:
    residual_f = residual.to(torch.float32)
    hidden_states_f = hidden_states.to(torch.float32)
    gamma_f = gamma.to(torch.float32)
    beta_f = beta.to(torch.float32)
    return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ),
                                          weight=gamma_f,
                                          bias=beta_f,
                                          eps=epsilon).to(hidden_states.dtype)


@DSPostNormRegistry.register_module
class CustomPostLNModule(cuda_post_ln.DSPostLNCUDAModule):

    @staticmethod
    def name():
        return 'custom_post_ln'


"""
Here, we explicitly register an LN implementation outside the core deepspeed repo. This should
validate that the registry is working as expected and we can implement modules outside the core
repo.
"""


@pytest.mark.inference_v2_ops
def test_custom_registration():
    channels = 4096
    dtype = torch.float16
    tokens = 1024

    config = DSNormConfig(max_tokens=2048,
                          type="layer_norm",
                          channels=channels,
                          residual_dtype=dtype,
                          input_dtype=dtype,
                          output_dtype=dtype,
                          eps=1e-5)
    bundle = ConfigBundle(name='custom_post_ln', config=config)

    # Input vals
    hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name())
    residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name())
    gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name())
    beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name())
    epsilon = 1e-5

    # Reference output
    ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon)

    # New output
    post_ln_module = DSPostNormRegistry.instantiate_config(bundle)
    gamma = post_ln_module.transform_param(gamma)
    beta = post_ln_module.transform_param(beta)
    ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta)

    # Check
    assert allclose(ds_output, ref_output)