Testing FP8 Casting Operations in ColossalAI
This test suite validates FP8 (8-bit floating point) casting operations in ColossalAI, focusing on conversion between different floating point formats and precision levels. The tests ensure accurate data type conversions while maintaining numerical stability.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_fp8/test_fp8_cast.py
import torch
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def test_fp8_cast(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
out = cast_from_fp8(ret, scale_inv, x.dtype)
assert_close(out, x, rtol=0.1, atol=0.1)
if x.size(-1) % 2 == 0:
inp_dict = {"hidden_states": x.clone()}
cast_to_fp8_pipeline(inp_dict)
cast_from_fp8_pipeline(inp_dict)
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
if __name__ == "__main__":
test_fp8_cast()