Back to Repositories

Testing FP8 All-Reduce Operations in ColossalAI

This test suite validates the FP8 all-reduce functionality in ColossalAI, specifically comparing standard PyTorch distributed all-reduce operations with FP8-optimized versions. The tests ensure numerical accuracy and performance across different data types and tensor shapes.

Test Coverage Overview

The test suite provides comprehensive coverage of FP8 all-reduce operations across multiple configurations.

Key areas tested include:
  • Various tensor shapes including 2D and 1D configurations
  • Multiple data types (float16, bfloat16)
  • Different FP8 formats (e4m3, e5m2)
  • Synchronous and asynchronous operations
  • Standard sum and average reduction operations

Implementation Analysis

The testing approach uses parameterized tests to systematically verify FP8 all-reduce operations against standard PyTorch distributed operations. The implementation employs parallel execution across 4 GPUs, comparing results between standard and FP8-optimized versions with specified tolerance levels.

Key patterns include:
  • Parameterized test decorators for comprehensive coverage
  • Direct comparison between PyTorch and FP8 implementations
  • Asynchronous operation handling with wait() calls

Technical Details

Testing infrastructure includes:
  • PyTorch distributed framework integration
  • ColossalAI launch and spawn utilities
  • Custom FP8 quantization modules
  • Tensor comparison tools with configurable tolerance
  • Multi-GPU test environment setup
  • Automated test retry mechanism for address conflicts

Best Practices Demonstrated

The test implementation showcases several testing best practices for distributed systems.

Notable practices include:
  • Comprehensive parameter space coverage
  • Proper error tolerance handling for floating-point operations
  • Robust distributed testing setup
  • Clean separation of test configuration and execution
  • Proper resource cleanup and error handling

hpcaitech/colossalai

tests/test_fp8/test_fp8_allreduce.py

            
import torch
import torch.distributed as dist
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize(
    "shape",
    [
        (3, 7),
        (4, 7),
        (7, 4),
        (8, 9),
        (3),
        (7,),
        (8,),
    ],
)
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
    x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
    x_fp8 = x.clone()
    origin_handle = dist.all_reduce(x, async_op=async_op)
    fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op)
    if async_op:
        origin_handle.wait()
        fp8_handle.wait()
    assert_close(x, x_fp8, rtol=0.1, atol=0.1)

    origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op)
    fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op)
    if async_op:
        origin_handle.wait()
        fp8_handle.wait()
    assert_close(x, x_fp8, rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
    launch(rank=rank, world_size=world_size, port=port, host="localhost")
    check_4gpu()


@rerun_if_address_is_in_use()
def test_all_reduce():
    spawn(run_dist, 4)


if __name__ == "__main__":
    test_all_reduce()