Back to Repositories

Testing FP8 Casting Operations in ColossalAI

This test suite validates FP8 (8-bit floating point) casting operations in ColossalAI, focusing on conversion between different floating point formats and precision levels. The tests ensure accurate data type conversions while maintaining numerical stability.

Test Coverage Overview

The test suite provides comprehensive coverage of FP8 casting operations across multiple data shapes, types and formats.

Key functionality tested includes:
  • Direct FP8 casting with cast_to_fp8 and cast_from_fp8
  • Pipeline-based FP8 casting workflow
  • Support for multiple input shapes and dtypes
  • Both e4m3 and e5m2 FP8 formats

Implementation Analysis

The testing approach uses parameterized testing to validate FP8 casting across different configurations.

Technical implementation features:
  • Parameterized test decorators for shape, dtype and format combinations
  • Direct and pipeline-based casting validation
  • Relative and absolute tolerance checks
  • Device-aware tensor operations

Technical Details

Testing tools and configuration:
  • PyTorch testing utilities (assert_close)
  • ColossalAI accelerator framework
  • FP8 quantization module
  • Supported dtypes: bfloat16, float16, float32
  • Configurable relative/absolute tolerances (0.1)

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Comprehensive parameter space coverage
  • Isolated test cases with clear assertions
  • Hardware-agnostic implementation using accelerator abstraction
  • Proper tolerance handling for floating-point comparisons
  • Modular test structure with reusable components

hpcaitech/colossalai

tests/test_fp8/test_fp8_cast.py

            
import torch
from torch.testing import assert_close

from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize


@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def test_fp8_cast(shape, dtype, fp8_format):
    x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
    ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
    out = cast_from_fp8(ret, scale_inv, x.dtype)
    assert_close(out, x, rtol=0.1, atol=0.1)

    if x.size(-1) % 2 == 0:
        inp_dict = {"hidden_states": x.clone()}
        cast_to_fp8_pipeline(inp_dict)
        cast_from_fp8_pipeline(inp_dict)
        assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)


if __name__ == "__main__":
    test_fp8_cast()