Back to Repositories

Validating GELU Activation Function Implementation in DeepSpeed

This test suite validates the GELU (Gaussian Error Linear Unit) activation function implementation in DeepSpeed’s inference operations. It compares DeepSpeed’s GELU implementation against PyTorch’s reference implementation, ensuring numerical accuracy across different data types and tensor configurations.

Test Coverage Overview

The test suite provides comprehensive coverage of GELU activation function testing across various dimensions and configurations.

  • Tests different batch sizes (1, 2)
  • Validates sequence lengths (1, 128, 255)
  • Tests channel dimensions (512, 1232, 4096)
  • Supports both standard and Triton-optimized implementations
  • Handles float16 data type precision

Implementation Analysis

The testing approach employs a systematic comparison between DeepSpeed’s GELU implementation and PyTorch’s reference implementation.

The suite utilizes pytest’s parametrize feature for comprehensive test combinations, implementing version-specific GELU behavior handling for PyTorch 1.12+ compatibility. Custom tolerance checks are implemented for different precision levels.

Technical Details

  • Testing Framework: pytest
  • Primary Dependencies: DeepSpeed, PyTorch
  • Key Components: InferenceBuilder, BiasGeluOp
  • Configuration: DeepSpeedInferenceConfig for dtype settings
  • Custom Utilities: allclose() for precision-specific comparisons

Best Practices Demonstrated

The test implementation showcases several testing best practices for deep learning operations.

  • Precise numerical comparison with dtype-specific tolerances
  • Comprehensive parameter space coverage
  • Version-aware implementation handling
  • Hardware acceleration compatibility checks
  • Modular test function organization

microsoft/deepspeed

tests/unit/ops/transformer/inference/test_gelu.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer import DeepSpeedInferenceConfig
from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp
from deepspeed.utils.torch import required_torch_version

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
    pytest.skip("Inference ops are not available on this system", allow_module_level=True)


def allclose(x, y):
    assert x.dtype == y.dtype
    rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
    return torch.allclose(x, y, rtol=rtol, atol=atol)


def version_appropriate_gelu(activations):
    # gelu behavior changes (correctly) in torch 1.12
    if required_torch_version(min_version=1.12):
        return torch.nn.functional.gelu(activations, approximate='tanh')
    else:
        return torch.nn.functional.gelu(activations)


def run_gelu_reference(activations):
    # Expected behavior is that of casting to float32 internally and using the tanh approximation
    return version_appropriate_gelu(activations.to(torch.float32)).to(activations.dtype)


def run_gelu_ds(activations, use_triton_ops=False):
    if use_triton_ops:
        from deepspeed.ops.transformer.inference.triton import gelu
        return gelu(activations)

    device = deepspeed.accelerator.get_accelerator().device_name()
    channels = activations.shape[-1]
    bias = torch.zeros((channels), dtype=activations.dtype, device=device)
    config = DeepSpeedInferenceConfig(dtype=activations.dtype)
    return BiasGeluOp(config)(activations, bias)


@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_gelu(batch, sequence, channels, dtype, use_triton_ops):
    device = deepspeed.accelerator.get_accelerator().device_name()
    activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
    activations_ref = activations_ds.clone().detach()

    if not deepspeed.get_accelerator().is_triton_supported():
        pytest.skip("triton is not supported on this system")
    ds_out = run_gelu_ds(activations_ds, use_triton_ops)
    ref_out = run_gelu_reference(activations_ref)
    assert (allclose(ds_out, ref_out))