Back to Repositories

Testing CUDA Gated Activation Operations in DeepSpeed

This test suite validates the CUDA-based gated activation operations in DeepSpeed’s inference v2 kernels. It focuses on testing various gated activation functions like GEGLU, ReGLU, and SiGLU, ensuring correct behavior across different data types and tensor shapes.

Test Coverage Overview

The test suite provides comprehensive coverage of gated activation operations:
  • Data type compatibility testing across different tensor shapes
  • Multiple activation function variants (GEGLU, ReGLU, SiGLU)
  • Bias addition functionality
  • Edge cases including maximum channel limits and error conditions

Implementation Analysis

The testing approach implements a reference-based validation strategy, comparing CUDA kernel outputs against PyTorch reference implementations. Tests utilize parametrized fixtures to verify functionality across different configurations and ensure numerical accuracy with custom tolerance checks.

Technical Details

  • Testing Framework: pytest with inference_v2_ops markers
  • Core Components: CUDAGatedActivation kernel, ActivationType enums
  • Hardware Integration: CUDA device management via DeepSpeed accelerator
  • Data Types: Support for float16 and other precision formats

Best Practices Demonstrated

The test suite exemplifies robust testing practices:
  • Comprehensive error handling and validation
  • Parametrized testing for multiple configurations
  • Clear separation of reference and implementation code
  • Explicit validation of error conditions and edge cases

microsoft/deepspeed

tests/unit/inference/v2/kernels/core_ops/test_gated_activation.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Optional

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.kernels.core_ops import CUDAGatedActivation
from deepspeed.inference.v2.inference_utils import ActivationType
from ....v2.inference_test_utils import get_dtypes, allclose


def reference_geglu_implementation(input: torch.Tensor,
                                   bias: Optional[torch.Tensor] = None,
                                   act_fn: Optional[ActivationType] = ActivationType.GEGLU) -> torch.Tensor:
    act_func_map = {
        ActivationType.ReGLU: torch.nn.functional.relu,
        ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
        ActivationType.SiGLU: torch.nn.functional.silu,
    }

    dtype = input.dtype
    input = input.to(torch.float32)

    if bias is not None:
        bias = bias.to(torch.float32)
        input = input + bias

    act_act = input[..., ::2]
    act_linear = input[..., 1::2]

    act_act = act_func_map[act_fn](act_act)

    return (act_act * act_linear).to(dtype)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("shape", [(1372, 16384), (2, 743, 22016)])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_dtypes(shape: Iterable[int], dtype: torch.dtype) -> None:
    input_tensor = torch.randn(shape, dtype=dtype, device=get_accelerator().current_device_name())

    # Reference output
    ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU)

    # Build kernel
    geglu = CUDAGatedActivation(input_tensor.size(-1), input_tensor.dtype, ActivationType.GEGLU)

    # New output
    output_shape = list(input_tensor.shape)
    output_shape[-1] //= 2
    output_tensor = torch.empty(output_shape, dtype=input_tensor.dtype, device=get_accelerator().current_device_name())
    geglu(output_tensor, input_tensor)

    # Check
    assert allclose(output_tensor, ref_output)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("act_fn", [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU])
def test_act_fn(act_fn: ActivationType) -> None:
    input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device())

    # Reference output
    ref_output = reference_geglu_implementation(input_tensor, act_fn=act_fn)

    cuda_act = CUDAGatedActivation(4096, torch.float16, act_fn)

    # New output
    output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device())
    cuda_act(output_tensor, input_tensor)

    assert allclose(output_tensor, ref_output)


@pytest.mark.inference_v2_ops
def test_act_with_bias():
    input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device())
    bias = torch.randn(4096, dtype=torch.float16, device=get_accelerator().current_device())

    # Reference output
    ref_output = reference_geglu_implementation(input_tensor, bias=bias, act_fn=ActivationType.GEGLU)

    cuda_act = CUDAGatedActivation(4096, torch.float16, ActivationType.GEGLU)

    # New output
    output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device())

    cuda_act(output_tensor, input_tensor, bias)

    assert allclose(output_tensor, ref_output)


@pytest.mark.inference_v2_ops
def test_max_channels():
    input_tensor = torch.randn(832, 48152, dtype=torch.float16, device=get_accelerator().current_device())

    ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU)

    cuda_act = CUDAGatedActivation(48152, torch.float16, ActivationType.GEGLU)

    output_tensor = torch.empty(832, 24076, dtype=torch.float16, device=get_accelerator().current_device())
    cuda_act(output_tensor, input_tensor)

    assert allclose(output_tensor, ref_output)


@pytest.mark.inference_v2_ops
def test_bad_dtype() -> None:
    with pytest.raises(ValueError):
        CUDAGatedActivation(128, torch.int8, ActivationType.GEGLU)


@pytest.mark.inference_v2_ops
def test_bad_act_fn() -> None:
    with pytest.raises(ValueError):
        CUDAGatedActivation(128, torch.float16, ActivationType.RELU)


@pytest.mark.inference_v2_ops
def test_bad_alignment() -> None:
    with pytest.raises(ValueError):
        CUDAGatedActivation(127, torch.float16, ActivationType.GEGLU)


@pytest.mark.inference_v2_ops
def test_too_many_channels() -> None:
    with pytest.raises(ValueError):
        CUDAGatedActivation(49160, torch.float16, ActivationType.GEGLU)