Back to Repositories

Testing Bias-Based Gated Activation Functions in DeepSpeed

This test suite validates the implementation of bias-based gated activation functions (GEGLU and SILU) in DeepSpeed’s inference operations. It ensures correct functionality of these activation mechanisms across different input dimensions and data types, comparing DeepSpeed’s optimized implementation against reference PyTorch implementations.

Test Coverage Overview

The test suite provides comprehensive coverage of bias-based gated activation functions:
  • Tests both GEGLU and SILU activation variants
  • Validates across multiple batch sizes (1, 2)
  • Tests varying sequence lengths (1, 128, 255)
  • Covers different channel dimensions (512, 1232, 4096)
  • Supports multiple data types through parametrized testing

Implementation Analysis

The testing approach employs pytest’s parametrization to systematically verify activation functions:
  • Implements reference functions using PyTorch’s native operations
  • Compares DeepSpeed’s optimized implementations against reference outputs
  • Utilizes custom allclose comparison for numerical stability
  • Handles device-specific acceleration through get_accelerator()

Technical Details

Key technical components include:
  • PyTest framework with inference_ops marker
  • DeepSpeed’s InferenceBuilder and GatedActivationOp
  • Custom dtype handling through get_dtypes()
  • Accelerator-aware device placement
  • Tensor manipulation with reshape and chunk operations

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Modular test functions with clear separation of concerns
  • Comprehensive parameter space coverage
  • Reference implementation comparison for validation
  • Proper error handling and skip conditions
  • Consistent testing patterns across different activation types

microsoft/deepspeed

tests/unit/ops/transformer/inference/test_bias_geglu.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.transformer.inference.op_binding.gated_activation import GatedActivationOp
from deepspeed.utils.types import ActivationFuncType
from .inference_test_utils import allclose, get_dtypes

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
    pytest.skip("Inference ops are not available on this system", allow_module_level=True)


def run_bias_geglu_reference(activations, bias):
    # Expected behavior is that of casting to float32 internally
    # Explicitly using the default GeLU
    activations = activations + bias.reshape(1, 1, -1)
    hidden_states, gate = activations.chunk(2, dim=-1)
    return hidden_states * torch.nn.functional.gelu(gate.to(torch.float32)).to(activations.dtype)


def run_bias_geglu_ds(activation, bias):
    return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_GELU)


@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_bias_geglu(batch, sequence, channels, dtype):
    activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device=get_accelerator().device_name())
    bias = torch.randn((channels * 2), dtype=dtype, device=get_accelerator().device_name())

    ds_out = run_bias_geglu_ds(activation, bias)
    ref_out = run_bias_geglu_reference(activation, bias)
    assert (allclose(ds_out, ref_out))


def run_gated_silu_reference(activations, bias):
    # Expected behavior is that of casting to float32 internally
    # Explicitly using the default GeLU
    activations = activations + bias.reshape(1, 1, -1)
    hidden_states, gate = activations.chunk(2, dim=-1)
    return hidden_states * torch.nn.functional.silu(gate.to(torch.float32)).to(activations.dtype)


def run_gated_silu_ds(activation, bias):
    return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_SILU)


@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_gated_silu(batch, sequence, channels, dtype):
    activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device=get_accelerator().device_name())
    bias = torch.randn((channels * 2), dtype=dtype, device=get_accelerator().device_name())

    ds_out = run_gated_silu_ds(activation, bias)
    ref_out = run_gated_silu_reference(activation, bias)
    assert (allclose(ds_out, ref_out))