Back to Repositories

Testing FP8 All-to-All Communication Operations in ColossalAI

This test suite validates the FP8 all-to-all communication operations in ColossalAI, comparing them with standard PyTorch distributed operations. It ensures accurate data transfer across different tensor shapes, data types, and synchronization modes.

Test Coverage Overview

The test suite provides comprehensive coverage of all-to-all communication patterns in both even and uneven data distribution scenarios.

Key areas tested include:
  • Various tensor shapes and dimensions
  • Multiple data types (bfloat16, float16)
  • Synchronous and asynchronous operations
  • Even and uneven data splitting across processes

Implementation Analysis

The testing approach employs parameterized testing to validate multiple combinations of input parameters. It implements two main test functions:

  • check_all2all: Tests standard even distribution scenarios
  • check_all2all_uneven: Validates uneven data distribution cases
Each test compares the results of standard PyTorch all_to_all_single with ColossalAI’s FP8-optimized version.

Technical Details

Testing infrastructure includes:
  • PyTorch distributed communication framework
  • ColossalAI’s FP8 quantization module
  • Custom test decorators for parameter iteration
  • Process spawning with configurable world size
  • Relative tolerance checking (rtol=0.1, atol=0.1)

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Parameterized test cases for comprehensive coverage
  • Proper error tolerance handling for floating-point comparisons
  • Distributed testing environment setup and teardown
  • Automatic test rerun capability for address conflicts
  • Clear separation of test scenarios and execution logic

hpcaitech/colossalai

tests/test_fp8/test_all_to_all_single.py

            
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
def check_all2all(shape, dtype, async_op):
    x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
    output = torch.empty_like(x)
    output_fp8 = torch.empty_like(x)
    origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op)
    fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op)
    if async_op:
        origin_hanle.wait()
        fp8_handle.wait()
    assert_close(output, output_fp8, rtol=0.1, atol=0.1)


@parameterize("shape", [(8, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
def check_all2all_uneven(shape, dtype, async_op):
    x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
    input_split_sizes = [3, 3, 1, 1]
    if dist.get_rank() in [0, 1]:
        output_split_sizes = [3, 3, 3, 3]
    else:
        output_split_sizes = [1, 1, 1, 1]
    output_shape = list(shape)
    output_shape[0] = sum(output_split_sizes)
    output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
    output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
    origin_hanle = dist.all_to_all_single(
        output,
        x,
        output_split_sizes=output_split_sizes,
        input_split_sizes=input_split_sizes,
        group=_get_default_group(),
        async_op=async_op,
    )
    fp8_handle = all_to_all_single_fp8(
        output_fp8,
        x,
        output_split_sizes=output_split_sizes,
        input_split_sizes=input_split_sizes,
        group=_get_default_group(),
        async_op=async_op,
    )
    if async_op:
        origin_hanle.wait()
        fp8_handle.wait()
    assert_close(output, output_fp8, rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
    launch(rank=rank, world_size=world_size, port=port, host="localhost")
    check_all2all()
    check_all2all_uneven()


@rerun_if_address_is_in_use()
def test_all_to_all_single():
    spawn(run_dist, 4)


if __name__ == "__main__":
    test_all_to_all_single()