Back to Repositories

Validating CUDA Post Layer Normalization Kernels in DeepSpeed

This test suite validates the CUDA-accelerated Post Layer Normalization implementation in DeepSpeed’s inference v2 kernels. It ensures correct computation of layer normalization with residual connections across different tensor sizes and data types.

Test Coverage Overview

The test suite provides comprehensive coverage of the CUDAFPPostLN kernel functionality.

Key areas tested include:
  • Various input tensor dimensions (tokens x channels)
  • Multiple data types (FP16, BF16, FP32)
  • Residual connection handling
  • Layer normalization accuracy
  • CUDA device compatibility

Implementation Analysis

The testing approach uses a parametrized pytest framework to validate the CUDA kernel implementation against a PyTorch reference.

Testing patterns include:
  • Dynamic tensor shape testing
  • Direct comparison with PyTorch’s native layer_norm
  • Accelerator-aware test execution
  • Comprehensive dtype coverage

Technical Details

Testing infrastructure leverages:
  • PyTest for test organization and execution
  • DeepSpeed’s accelerator abstraction layer
  • Custom allclose utility for tensor comparison
  • CUDA kernel integration testing
  • Parameterized test configurations for multiple scenarios

Best Practices Demonstrated

The test implementation showcases several testing best practices.

Notable aspects include:
  • Systematic parameter variation
  • Reference implementation comparison
  • Explicit dtype handling
  • Device-aware testing
  • Clear test case organization
  • Efficient resource utilization

microsoft/deepspeed

tests/unit/inference/v2/kernels/core_ops/test_post_ln.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.kernels.core_ops import CUDAFPPostLN
from ....v2.inference_test_utils import get_dtypes, allclose


def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor,
                             beta: torch.Tensor, epsilon: float) -> torch.Tensor:
    residual_f = residual.to(torch.float32)
    hidden_states_f = hidden_states.to(torch.float32)
    gamma_f = gamma.to(torch.float32)
    beta_f = beta.to(torch.float32)
    return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ),
                                          weight=gamma_f,
                                          bias=beta_f,
                                          eps=epsilon).to(hidden_states.dtype)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_cuda_post_ln(tokens: int, channels: int, dtype: torch.dtype) -> None:

    # Input vals
    hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name())
    residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name())
    gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
    beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name())
    epsilon = 1e-5

    # Reference output
    ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon)

    # New output
    post_ln_kernel = CUDAFPPostLN(hidden_states.size(-1), residual.dtype)
    ds_output = torch.empty_like(residual)
    post_ln_kernel(ds_output, residual, hidden_states, gamma, beta)

    # Check
    assert allclose(ds_output, ref_output)