Back to Repositories

Testing RaggedLogitsGather Kernel Operations in DeepSpeed

This test suite validates the RaggedLogitsGather kernel functionality in DeepSpeed’s inference v2 implementation, focusing on gathering logits from hidden states with varying sequence lengths and model dimensions.

Test Coverage Overview

The test suite provides comprehensive coverage of the RaggedLogitsGather kernel functionality.

Key areas tested include:
  • Support for different data types (dtype validation)
  • Multi-sequence input handling with varying sequence lengths
  • Different model dimension configurations
  • Edge cases with single-token sequences and varied batch sizes

Implementation Analysis

The testing approach employs a baseline implementation for result verification against the kernel implementation.

Technical patterns include:
  • Parametrized testing using pytest
  • Dynamic tensor generation for different configurations
  • Accelerator-aware device placement
  • Numerical accuracy validation using custom allclose utility

Technical Details

Testing infrastructure includes:
  • PyTest framework with custom markers
  • DeepSpeed accelerator utilities
  • Custom testing utilities for tensor comparison
  • Support for multiple dtype configurations
  • Ragged tensor operation testing utilities

Best Practices Demonstrated

The test suite exemplifies high-quality testing practices.

Notable elements include:
  • Systematic parameter variation
  • Baseline implementation for verification
  • Clear test case organization
  • Comprehensive edge case coverage
  • Efficient test data generation

microsoft/deepspeed

tests/unit/inference/v2/kernels/ragged_ops/test_logits_gather.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import List

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.kernels.ragged_ops import RaggedLogitsGather
from ....v2.inference_test_utils import allclose, get_dtypes
from .ragged_testing_utils import build_simple_batch


def baseline_implementation(hidden_states: torch.Tensor, seq_lens: List[int]) -> torch.Tensor:
    output = torch.empty((len(seq_lens), hidden_states.shape[1]),
                         dtype=hidden_states.dtype,
                         device=hidden_states.device)

    offset = 0
    for i, seq_len in enumerate(seq_lens):
        output[i] = hidden_states[offset + seq_len - 1]
        offset += seq_len

    return output


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize('dtype', get_dtypes())
def test_supported_dtypes(dtype: torch.dtype) -> None:
    """
    Validate support on nominally supported data types.
    """
    model_dim = 4096

    batch = build_simple_batch([256], padding=False)
    hidden_states = torch.randn((batch.current_tokens, model_dim),
                                dtype=dtype,
                                device=get_accelerator().current_device())

    reference_result = baseline_implementation(hidden_states, [256])

    kernel = RaggedLogitsGather(model_dim, dtype)
    output = torch.empty_like(reference_result)
    kernel(output, hidden_states, batch)

    assert allclose(output, reference_result)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize('seq_lens', [[128, 64, 192, 32], [57, 112, 63, 89, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1],
                                      [63, 27, 74, 83, 32, 17, 1, 1, 1, 1, 1]])
def test_multiple_sequences(seq_lens: List[int]) -> None:
    """
    Validate support on more multi-sequence inputs.
    """
    model_dim = 4096
    dtype = torch.float16

    batch = build_simple_batch(seq_lens, padding=False)
    hidden_states = torch.randn((batch.current_tokens, model_dim),
                                dtype=dtype,
                                device=get_accelerator().current_device())

    reference_result = baseline_implementation(hidden_states, seq_lens)

    kernel = RaggedLogitsGather(model_dim, dtype)
    output = torch.empty_like(reference_result)
    kernel(output, hidden_states, batch)

    assert allclose(output, reference_result)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("model_dim", [1024, 6144, 6784])
def test_problem_size_permutations(model_dim: int) -> None:
    """
    Validate for different embedding sizes.
    """
    dtype = torch.float16
    seq_lens = [128, 64, 192, 32]

    batch = build_simple_batch(seq_lens, padding=False)
    hidden_states = torch.randn((batch.current_tokens, model_dim),
                                dtype=dtype,
                                device=get_accelerator().current_device())

    reference_result = baseline_implementation(hidden_states, seq_lens)

    kernel = RaggedLogitsGather(model_dim, dtype)
    output = torch.empty_like(reference_result)
    kernel(output, hidden_states, batch)

    assert allclose(output, reference_result)