Back to Repositories

Testing FP8 Reduce Scatter Operations in ColossalAI

This test suite validates the FP8 reduce scatter operations in ColossalAI, comparing them with PyTorch’s native reduce scatter implementation. It ensures accurate distributed tensor operations with FP8 precision across multiple GPUs.

Test Coverage Overview

The test suite comprehensively covers reduce scatter operations with FP8 precision formats.

Key areas tested include:
  • Multiple tensor shapes and scatter dimensions
  • Different data types (bfloat16, float16)
  • FP8 formats (e4m3, e5m2)
  • Synchronous and asynchronous operations
  • 4-GPU distributed scenarios

Implementation Analysis

The testing approach employs parameterized testing to validate reduce scatter operations across various configurations. It uses PyTorch’s distributed computing framework and ColossalAI’s FP8 implementation, comparing results between standard and FP8-based operations with specified tolerance levels.

Key implementation patterns include:
  • Distributed tensor operations
  • Parameterized test cases
  • Asynchronous operation handling
  • Result comparison with tolerance

Technical Details

Testing infrastructure includes:
  • PyTorch distributed communications
  • ColossalAI accelerator framework
  • Custom FP8 quantization utilities
  • Distributed test spawning mechanism
  • Parameterized test decorators
  • Automatic port management for distributed testing

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Comprehensive parameter space coverage
  • Proper handling of distributed resources
  • Automatic test retry mechanisms
  • Clear separation of test setup and execution
  • Robust error tolerance handling
  • Efficient test parameterization

hpcaitech/colossalai

tests/test_fp8/test_fp8_reduce_scatter.py

            
import torch
from torch.distributed import reduce_scatter
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import reduce_scatter_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])
def check_4gpu(shape, scatter_dim, dtype, fp8_format, async_op):
    x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
    input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4))
    input_list = [t.contiguous() for t in input_list]
    output_origin = torch.empty_like(input_list[0])
    output_fp8 = torch.empty_like(input_list[0])
    origin_handle = reduce_scatter(output_origin, input_list, group=_get_default_group(), async_op=async_op)
    fp8_handle = reduce_scatter_fp8(
        output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op
    )
    if async_op:
        origin_handle.wait()
        fp8_handle.wait()
    assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
    launch(rank=rank, world_size=world_size, port=port, host="localhost")
    check_4gpu()


@rerun_if_address_is_in_use()
def test_reduce_scatter():
    spawn(run_dist, 4)


if __name__ == "__main__":
    test_reduce_scatter()