Back to Repositories

Testing FP8 Compression in DDP Communication Hooks in ColossalAI

This test suite validates the FP8 compression functionality in distributed data parallel (DDP) communication hooks within ColossalAI. It focuses on testing both synchronous and asynchronous gradient compression during multi-GPU training scenarios.

Test Coverage Overview

The test suite provides comprehensive coverage of FP8 compression in DDP communication.

Key areas tested include:
  • Synchronous and asynchronous gradient compression hooks
  • Gradient accuracy validation with compression
  • Multi-GPU distributed training scenarios
  • Comparison between compressed and uncompressed gradients

Implementation Analysis

The testing approach implements a ToyModel neural network to validate gradient compression behavior across distributed processes.

Key implementation patterns include:
  • Process group initialization and cleanup
  • Gradient comparison with relative tolerance
  • Hook registration for gradient compression
  • Controlled random seed for reproducibility

Technical Details

Testing infrastructure utilizes:
  • PyTorch DDP and distributed utilities
  • NCCL backend for GPU communication
  • Custom FP8 compression hooks from ColossalAI
  • Multi-process spawning with torch.multiprocessing
  • MSE loss and SGD optimizer for training simulation

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Proper cleanup of distributed resources
  • Deterministic testing with fixed random seeds
  • Isolation of compression effects through controlled comparisons
  • Appropriate tolerance levels for floating-point comparisons
  • GPU device management and process coordination

hpcaitech/colossalai

tests/test_fp8/test_fp8_ddp_comm_hook.py

            
import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    def get_grads_after_one_iteration(hook=None):
        torch.manual_seed(0)
        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)

        ddp_model = DDP(model, device_ids=[rank])

        if hook is not None:
            ddp_model.register_comm_hook(None, hook)

        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

        optimizer.zero_grad()
        outputs = ddp_model(torch.randn(20, 10))
        labels = torch.randn(20, 5).to(rank)
        loss_fn(outputs, labels).backward()
        optimizer.step()

        torch.distributed.barrier()

        grad_dict = {}
        for name, params in ddp_model.named_parameters():
            grad_dict[name] = params.grad
        return grad_dict

    from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync

    grad_dict = get_grads_after_one_iteration()
    for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]:
        grad_dict_w_hook = get_grads_after_one_iteration(hook)
        if dist.get_rank() == 0:
            for name in grad_dict:
                assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)

    cleanup()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_basic, world_size)