Back to Repositories

Testing FP8 Linear Layer Quantization in ColossalAI

This test suite validates the FP8 linear layer implementation in ColossalAI, focusing on numerical accuracy and gradient computation with different configurations. It ensures proper functionality of FP8 quantization for linear operations while maintaining acceptable precision tolerances.

Test Coverage Overview

The test suite provides comprehensive coverage of FP8 linear operations by testing:
  • Both biased and unbiased linear transformations
  • Batch and non-batch input scenarios
  • Forward pass computation accuracy
  • Backward pass gradient verification
Key integration points include torch.nn.functional interfaces and ColossalAI’s quantization modules.

Implementation Analysis

The testing approach employs parametrized pytest fixtures to validate multiple configurations systematically. The implementation compares FP8 quantized operations against reference PyTorch implementations, utilizing device-specific capabilities and custom tolerance thresholds for numerical comparisons.

Notable patterns include clone-and-compare methodology and gradient computation verification.

Technical Details

Testing tools and configuration:
  • PyTest framework with parametrize decorators
  • CUDA device capability checks (>= 9.0)
  • BFloat16 dtype for precision control
  • Custom tensor shapes: D_IN=16, D_OUT=32
  • Configurable batch parameters: B=2, S=64

Best Practices Demonstrated

The test implementation showcases several testing best practices:
  • Explicit device capability verification
  • Comprehensive parameter space coverage
  • Proper tensor gradient tracking setup
  • Controlled numerical tolerance checking
  • Modular test case organization with clear separation of concerns

hpcaitech/colossalai

tests/test_fp8/test_fp8_linear.py

            
import pytest
import torch
import torch.nn.functional as F
from torch.testing import assert_close

from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8
from colossalai.utils import get_current_device

D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16


@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("use_batch", [True, False])
def test_fp8_linear(use_bias: bool, use_batch: bool):
    # create tensors
    w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
    ref_w = w.clone().detach().requires_grad_()
    if use_batch:
        x_shape = (B, S, D_IN)
    else:
        x_shape = (S, D_IN)
    x = torch.rand(x_shape, device=get_current_device(), dtype=DTYPE, requires_grad=True)
    ref_x = x.clone().detach().requires_grad_()
    if use_bias:
        bias = torch.rand(D_OUT, device=get_current_device(), dtype=DTYPE, requires_grad=True)
        ref_bias = bias.clone().detach().requires_grad_()
    else:
        bias = None
        ref_bias = None

    out = linear_fp8(x, w, bias)
    assert out.shape == x_shape[:-1] + (D_OUT,)
    out.sum().backward()
    ref_out = F.linear(ref_x, ref_w, ref_bias)
    ref_out.sum().backward()

    assert_close(out, ref_out, rtol=0.2, atol=0.1)
    assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
    assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
    if use_bias:
        assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1)