Testing FP8 Linear Layer Quantization in ColossalAI
This test suite validates the FP8 linear layer implementation in ColossalAI, focusing on numerical accuracy and gradient computation with different configurations. It ensures proper functionality of FP8 quantization for linear operations while maintaining acceptable precision tolerances.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_fp8/test_fp8_linear.py
import pytest
import torch
import torch.nn.functional as F
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8
from colossalai.utils import get_current_device
D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("use_batch", [True, False])
def test_fp8_linear(use_bias: bool, use_batch: bool):
# create tensors
w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_w = w.clone().detach().requires_grad_()
if use_batch:
x_shape = (B, S, D_IN)
else:
x_shape = (S, D_IN)
x = torch.rand(x_shape, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_x = x.clone().detach().requires_grad_()
if use_bias:
bias = torch.rand(D_OUT, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_bias = bias.clone().detach().requires_grad_()
else:
bias = None
ref_bias = None
out = linear_fp8(x, w, bias)
assert out.shape == x_shape[:-1] + (D_OUT,)
out.sum().backward()
ref_out = F.linear(ref_x, ref_w, ref_bias)
ref_out.sum().backward()
assert_close(out, ref_out, rtol=0.2, atol=0.1)
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
if use_bias:
assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1)