Back to Repositories

Validating Logical Process Group Communication in ColossalAI

This test suite validates the logical process group initialization and communication within ColossalAI’s device mesh implementation. It focuses on testing distributed tensor operations and process group management across multiple GPU devices.

Test Coverage Overview

The test suite provides comprehensive coverage of logical process group initialization and communication patterns in a distributed GPU environment.

Key areas tested include:
  • Device mesh creation and configuration
  • Process group initialization across multiple axes
  • All-reduce operations validation
  • Tensor synchronization across devices

Implementation Analysis

The testing approach uses a spawn-based distributed testing pattern to simulate a multi-GPU environment. It implements a systematic verification of device mesh operations using PyTest’s distributed testing capabilities.

Technical implementation features:
  • NCCL backend communication
  • 2×2 mesh topology testing
  • Tensor equality verification
  • Process rank validation

Technical Details

Testing infrastructure includes:
  • PyTest framework with distributed testing decorators
  • NCCL backend for GPU communication
  • Custom launch configuration with localhost setup
  • 4-GPU device configuration
  • ColossalAI’s DeviceMesh implementation
  • PyTorch distributed communication primitives

Best Practices Demonstrated

The test implementation showcases several testing best practices for distributed systems.

Notable practices include:
  • Isolated process group testing
  • Deterministic test scenarios
  • Port conflict handling
  • Clean process group initialization
  • Systematic tensor operation validation

hpcaitech/colossalai

tests/test_device/test_init_logical_pg.py

            
import pytest
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp

from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_layer(rank, world_size, port):
    launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")

    physical_mesh_id = torch.arange(0, 4)
    assert rank == dist.get_rank()

    tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()
    mesh_shape = (2, 2)
    # [[0, 1,
    #  [2, 3]]
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

    for axis in range(len(mesh_shape)):
        tensor = torch.ones(4).cuda()
        pg = device_mesh.get_process_group(axis=axis)
        dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
        assert tensor.equal(tensor_to_check)


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_logical_pg():
    spawn(check_layer, 4)


if __name__ == "__main__":
    test_logical_pg()