Back to Repositories

Testing FP8 All-Gather Operations in ColossalAI

This test suite validates the FP8 all_gather operation implementation in ColossalAI, focusing on distributed tensor gathering with FP8 precision. The tests compare FP8 gathering against standard PyTorch distributed gathering across different data types and formats.

Test Coverage Overview

The test suite provides comprehensive coverage of FP8 all_gather operations with multiple configurations.

Key areas tested include:
  • Multiple tensor shapes and data types (bfloat16, float16)
  • Different FP8 formats (e4m3, e5m2)
  • Synchronous and asynchronous operations
  • 4-GPU distributed setup validation

Implementation Analysis

The testing approach uses parameterized testing to evaluate FP8 gathering across different configurations.

Technical implementation includes:
  • Distributed environment setup using ColossalAI launch utilities
  • Comparison between FP8 and standard PyTorch gathering operations
  • Tolerance-based output validation using assert_close

Technical Details

Testing infrastructure utilizes:
  • PyTorch distributed communication primitives
  • ColossalAI’s FP8 quantization module
  • Custom spawn and rerun decorators for distributed testing
  • Parameterized test fixtures for multiple test scenarios

Best Practices Demonstrated

The test implementation showcases several testing best practices:
  • Parameterized testing for comprehensive coverage
  • Proper distributed environment setup and teardown
  • Error tolerance handling for floating-point comparisons
  • Automated test retry mechanism for address conflicts

hpcaitech/colossalai

tests/test_fp8/test_fp8_allgather.py

            
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import _all_gather_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize(
    "shape",
    [(3, 7, 16)],
)
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
    world_size = dist.get_world_size()
    x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
    output_list = [torch.empty_like(x) for _ in range(world_size)]
    output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
    fp8_handle = _all_gather_fp8(
        output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op
    )
    origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
    if async_op:
        fp8_handle.wait()
        origin_hanle.wait()
    assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
    launch(rank=rank, world_size=world_size, port=port, host="localhost")
    check_4gpu()


@rerun_if_address_is_in_use()
def test_all_gather():
    spawn(run_dist, 4)


if __name__ == "__main__":
    test_all_gather()