Back to Repositories

Testing BLAS Linear Operations Implementation in DeepSpeed

A comprehensive unit test suite for validating BLAS linear operations in DeepSpeed’s inference module, focusing on matrix multiplication with different data types and problem shapes. The tests ensure correct handling of tensor operations across various dimensions and precision levels.

Test Coverage Overview

The test suite covers BLAS linear operations with extensive problem shape combinations and data types.

  • Tests both FP16 and BF16 data types
  • Validates matrix multiplication across various tensor dimensions
  • Includes problem shapes from small (1×1024) to large (32768×8192) matrices
  • Tests both regular and transposed weight configurations

Implementation Analysis

The testing approach implements a dual-validation strategy, comparing DeepSpeed’s BLAS implementation against a reference implementation.

  • Uses pytest parametrization for comprehensive coverage
  • Implements custom tolerance checking via allclose()
  • Handles tensor stride and contiguity edge cases
  • Validates both forward pass and transposed operations

Technical Details

  • PyTest framework with custom markers (inference_v2_ops)
  • DeepSpeed accelerator integration for device management
  • Custom BLAS kernel implementation (BlasLibLinear)
  • Tensor operations with controlled scaling factors (0.1 and 0.01)
  • Dynamic output tensor allocation

Best Practices Demonstrated

The test suite exemplifies robust testing practices for high-performance computing operations.

  • Systematic problem shape coverage
  • Explicit dtype handling
  • Device-agnostic implementation
  • Clear reference implementation comparison
  • Modular test structure with parametrization

microsoft/deepspeed

tests/unit/inference/v2/kernels/core_ops/test_blas_linear.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Tuple

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.kernels.core_ops import BlasLibLinear
from ....v2.inference_test_utils import allclose

# Note: only testing with FP16 and BF16 because we use TF32 on Ampere and we don't have a good
# set of tolerances. Since this is just on top of BLAS though, the test is more about
# making sure the stride/contiguity is correct and that's data type agnostic.


def reference_implementation(hidden_states, weights):
    return hidden_states @ weights.t()


problem_shapes = [
    (1, 1, 1024, 1024),
    (1, 1024, 1024, 1024),
    (2, 1024, 1024, 1024),
    (1, 128, 768, 3072),
    (1, 128, 3072, 768),
    (1, 1024, 8192, 8192),
    (1, 733, 8192, 32768),
    (1, 13, 32768, 8192),
]


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("problem_shape", problem_shapes)
def test_blas_linear(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]):
    batch, seq_len, in_features, out_features = problem_shape
    hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype,
                                device=get_accelerator().current_device()) * 0.1
    weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01
    ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device())

    ds_kernel = BlasLibLinear(fp_dtype)

    ds_output = ds_kernel(ds_output, hidden_states, weights)
    ref_output = reference_implementation(hidden_states, weights)

    assert allclose(ds_output, ref_output)


@pytest.mark.inference_v2_ops
@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("problem_shape", problem_shapes)
def test_blas_linear_t(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]):
    batch, seq_len, in_features, out_features = problem_shape
    hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype,
                                device=get_accelerator().current_device()) * 0.1
    weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01
    ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device())

    ds_kernel = BlasLibLinear(fp_dtype)

    # Transpose the weights then revert to the format we expect.
    weights = weights.t().contiguous()
    weights = weights.t()
    ds_output = ds_kernel(ds_output, hidden_states, weights)

    ref_output = reference_implementation(hidden_states, weights)

    assert allclose(ds_output, ref_output)