Back to Repositories

Validating DS4Sci_EvoformerAttention Operations in DeepSpeed

This test suite validates the DS4Sci_EvoformerAttention implementation in DeepSpeed, focusing on the attention mechanism used in Evoformer architectures. It performs comprehensive testing of forward and backward passes with different data types and tensor shapes.

Test Coverage Overview

The test suite provides thorough coverage of the DS4Sci_EvoformerAttention operation:
  • Tests both float16 and bfloat16 data types
  • Validates multiple tensor shapes and configurations
  • Compares outputs against reference implementation
  • Verifies gradient computations and backpropagation

Implementation Analysis

The testing approach uses pytest’s parametrization to validate multiple scenarios systematically. It implements a reference attention mechanism for comparison and validates both forward pass outputs and backward pass gradients with specific epsilon tolerances for different data types.

Technical Details

Key technical components include:
  • PyTorch tensor operations and functional modules
  • DeepSpeed ops and accelerator integration
  • Custom epsilon values for different precision levels
  • Architecture-specific skipping logic

Best Practices Demonstrated

The test demonstrates several testing best practices:
  • Comprehensive parameter validation
  • Precise numerical comparisons with appropriate tolerances
  • Proper gradient checking methodology
  • Hardware-aware test configuration

microsoft/deepspeed

tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import List

import pytest
import torch
from torch.nn import functional as F
import deepspeed
from deepspeed.ops.op_builder import EvoformerAttnBuilder
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
from deepspeed.accelerator import get_accelerator
from unit.util import skip_on_arch

if not deepspeed.ops.__compatible_ops__[EvoformerAttnBuilder.NAME]:
    pytest.skip("DS4Sci_EvoformerAttention ops are not available on this system", allow_module_level=True)


def attention_reference(
        q_input: torch.Tensor,  # [*, Dim_Q, H, C_hid]
        k_input: torch.Tensor,  # [*, Dim_Q, H, C_hid]
        v_input: torch.Tensor,  # [*, Dim_Q, H, C_hid]
        biases: List[torch.Tensor],
        sm_scale: float) -> torch.Tensor:
    q = q_input.transpose(-2, -3)
    k = k_input.transpose(-2, -3)
    v = v_input.transpose(-2, -3)
    k_t = k.transpose(-1, -2)
    a = torch.matmul(q, k_t) * sm_scale

    for b in biases:
        a += b

    a = F.softmax(a, dim=-1)
    a_v = torch.matmul(a, v)
    o = a_v.transpose(-2, -3)

    return o


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("tensor_shape", [(1, 256, 256, 4, 32), (1, 512, 256, 8, 8)])
def test_DS4Sci_EvoformerAttention(dtype, tensor_shape):
    skip_on_arch(8 if dtype == torch.bfloat16 else 7)
    batch, n, seq_len, heads, dim = tensor_shape
    Q = torch.randn(batch,
                    n,
                    seq_len,
                    heads,
                    dim,
                    dtype=dtype,
                    device=get_accelerator().device_name(),
                    requires_grad=True)
    K = torch.randn(batch,
                    n,
                    seq_len,
                    heads,
                    dim,
                    dtype=dtype,
                    device=get_accelerator().device_name(),
                    requires_grad=True)
    V = torch.randn(batch,
                    n,
                    seq_len,
                    heads,
                    dim,
                    dtype=dtype,
                    device=get_accelerator().device_name(),
                    requires_grad=True)
    mask = torch.randint(0, 2, (batch, n, 1, 1, seq_len), dtype=dtype, device=get_accelerator().device_name())
    mask_bias = 1e9 * (mask - 1)
    bias = torch.randn(batch,
                       1,
                       heads,
                       seq_len,
                       seq_len,
                       dtype=dtype,
                       device=get_accelerator().device_name(),
                       requires_grad=True)
    dummy_out = torch.rand_like(Q, dtype=dtype, device=get_accelerator().device_name())
    ref_out = attention_reference(Q, K, V, [mask_bias, bias], 1 / (dim**0.5))
    ref_out.backward(dummy_out)
    ref_dv, V.grad = V.grad.clone(), None
    ref_dk, K.grad = K.grad.clone(), None
    ref_dq, Q.grad = Q.grad.clone(), None
    ref_db, bias.grad = bias.grad.clone(), None

    out = DS4Sci_EvoformerAttention(Q, K, V, [mask_bias, bias])
    out.backward(dummy_out)
    dv, v_grad = V.grad.clone(), None
    dk, k_grad = K.grad.clone(), None
    dq, q_grad = Q.grad.clone(), None
    db, bias.grad = bias.grad.clone(), None

    eps = 1e-2 if dtype == torch.float16 else 5e-2

    assert torch.max(torch.abs(ref_out - out)).item() < eps, f"out eps: {torch.max(torch.abs(ref_out - out))}"
    assert torch.max(torch.abs(ref_dv - dv)) < eps, f"dv eps: {torch.max(torch.abs(ref_dv - dv))}"
    assert torch.max(torch.abs(ref_dk - dk)) < eps, f"dk eps: {torch.max(torch.abs(ref_dk - dk))}"
    assert torch.max(torch.abs(ref_dq - dq)) < eps, f"dq eps: {torch.max(torch.abs(ref_dq - dq))}"
    assert torch.max(torch.abs(ref_db - db)) < 2 * eps, f"db eps: {torch.max(torch.abs(ref_db - db))}"