Back to Repositories

Testing FP8 All-to-All Communication in ColossalAI

This test suite validates the FP8 all-to-all communication functionality in ColossalAI, focusing on distributed tensor operations with different precision formats. It ensures reliable data exchange between multiple GPUs while maintaining numerical accuracy in FP8 quantization.

Test Coverage Overview

The test suite comprehensively covers all-to-all communication patterns with FP8 precision across multiple GPUs.

Key areas tested include:
  • Multiple tensor shapes and scatter dimensions
  • Different data types (bfloat16, float16)
  • FP8 format variations (e4m3, e5m2)
  • 4-GPU distributed setup verification
  • Numerical accuracy comparison between FP8 and regular all-to-all operations

Implementation Analysis

The testing approach employs parameterized testing to evaluate multiple configurations systematically. It implements a distributed testing pattern using PyTorch’s distributed communication primitives and ColossalAI’s FP8 quantization features.

Key implementation aspects:
  • Parameterized test decoration for multiple test scenarios
  • Distributed process spawning and management
  • Tensor chunking and reconstruction verification
  • Relative tolerance checking for numerical stability

Technical Details

Testing infrastructure includes:
  • PyTorch distributed communication framework
  • ColossalAI launch utilities and accelerator management
  • Custom FP8 quantization implementations
  • Tensor comparison utilities with configurable tolerance
  • Process spawning mechanism for multi-GPU testing
  • Automatic port management and address reuse protection

Best Practices Demonstrated

The test implementation showcases several testing best practices for distributed systems.

Notable practices include:
  • Isolated test environment for each configuration
  • Comprehensive parameter space coverage
  • Proper error tolerance handling for floating-point operations
  • Robust process management and cleanup
  • Automatic test retry mechanism for network address conflicts

hpcaitech/colossalai

tests/test_fp8/test_fp8_all_to_all.py

            
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import _all_to_all_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, scatter_dim, dtype, fp8_format):
    world_size = dist.get_world_size()
    input_tensor = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
    input_tensor_list = list(torch.chunk(input_tensor, world_size, scatter_dim))
    input_tensor_list = [x.contiguous() for x in input_tensor_list]
    output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
    output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
    _all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
    dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
    assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
    launch(rank=rank, world_size=world_size, port=port, host="localhost")
    check_4gpu()


@rerun_if_address_is_in_use()
def test_all_to_all():
    spawn(run_dist, 4)


if __name__ == "__main__":
    test_all_to_all()