Back to Repositories

Validating RMS Normalization Operations in DeepSpeed

This test suite validates the RMS (Root Mean Square) normalization operations in DeepSpeed’s inference engine. It includes comprehensive testing of both standard RMS normalization and pre-normalization with residual connections, covering various data types and tensor shapes.

Test Coverage Overview

The test suite provides extensive coverage of RMS normalization operations.

Key areas tested include:
  • Standard RMS normalization and pre-normalization variants
  • Multiple data types (fp16, bf16) support
  • Various tensor shapes and dimensions
  • Residual connection handling
  • Numerical accuracy verification against reference implementations

Implementation Analysis

The testing approach uses a helper function architecture to systematically verify normalization operations. The implementation employs pytest’s parametrization for comprehensive coverage across different configurations.

Notable patterns include:
  • Reference implementation comparison
  • CUDA kernel validation
  • Dynamic tensor shape testing
  • Dtype compatibility verification

Technical Details

Testing infrastructure includes:
  • PyTorch tensor operations
  • DeepSpeed CUDA kernels (CUDARMSNorm, CUDARMSPreNorm)
  • Pytest parametrization for test case generation
  • Custom assertion utilities for tensor comparison
  • Accelerator-aware device management

Best Practices Demonstrated

The test suite exemplifies high-quality testing practices in deep learning systems.

Key practices include:
  • Systematic parameter space exploration
  • Reference implementation verification
  • Explicit error tolerance handling
  • Comprehensive dtype coverage
  • Modular test helper functions

microsoft/deepspeed

tests/unit/inference/v2/kernels/core_ops/test_rms_norm.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.inference_utils import DtypeEnum
from deepspeed.inference.v2.kernels.core_ops import CUDARMSNorm, CUDARMSPreNorm
from ....v2.inference_test_utils import get_dtypes, allclose


def reference_rms_norm(vals: torch.Tensor, gamma: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor:
    variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True)
    vals = vals * torch.rsqrt(variance + epsilon)

    if gamma.dtype in [torch.float16, torch.bfloat16]:
        vals = vals.to(gamma.dtype)

    return gamma * vals


def reference_rms_pre_norm(vals: torch.Tensor,
                           residual: torch.Tensor,
                           gamma: torch.Tensor,
                           epsilon: float = 1e-5) -> torch.Tensor:
    residual = residual + vals
    return residual, reference_rms_norm(residual, gamma, epsilon)


def _rms_norm_testing_helper(rows: int, channels: int, do_residual: bool, dtype: DtypeEnum) -> None:
    device = get_accelerator().current_device_name()
    t_dtype = dtype.value

    vals = torch.randn((rows, channels), dtype=t_dtype, device=device)
    gamma = torch.randn((channels), dtype=t_dtype, device=device)
    epsilon = 1e-5

    if do_residual:
        residual_in = torch.randn((rows, channels), dtype=t_dtype, device=device)
        ds_residual = residual_in.clone()

        ref_residual, ref_output = reference_rms_pre_norm(vals, residual_in, gamma, epsilon)

        kernel = CUDARMSPreNorm(channels, t_dtype, epsilon=epsilon)
        ds_out = torch.empty_like(ds_residual)

        kernel(ds_residual, ds_out, residual_in, vals, gamma)

        assert allclose(ds_out, ref_output)
        assert allclose(ds_residual, ref_residual)
    else:

        ref_output = reference_rms_norm(vals, gamma, epsilon)

        kernel = CUDARMSNorm(channels, t_dtype, epsilon=epsilon)
        ds_out = torch.empty_like(vals)

        kernel(ds_out, vals, gamma)

        assert allclose(ds_out, ref_output)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("dtype", get_dtypes())
@pytest.mark.parametrize("do_residual", [True, False])
def test_rms_dtypes(dtype: DtypeEnum, do_residual: bool) -> None:
    _rms_norm_testing_helper(883, 1024, do_residual, DtypeEnum(dtype))


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("rows, cols", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)])
@pytest.mark.parametrize("do_residual", [True, False])
def test_rms_shapes(rows: int, cols: int, do_residual: bool) -> None:
    _rms_norm_testing_helper(rows, cols, do_residual, DtypeEnum.fp16)