Validating Logical Process Group Communication in ColossalAI
This test suite validates the logical process group initialization and communication within ColossalAI’s device mesh implementation. It focuses on testing distributed tensor operations and process group management across multiple GPU devices.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_device/test_init_logical_pg.py
import pytest
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_layer(rank, world_size, port):
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
assert rank == dist.get_rank()
tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
for axis in range(len(mesh_shape)):
tensor = torch.ones(4).cuda()
pg = device_mesh.get_process_group(axis=axis)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
assert tensor.equal(tensor_to_check)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_logical_pg():
spawn(check_layer, 4)
if __name__ == "__main__":
test_logical_pg()