Validating DeviceMesh Process Group Management in ColossalAI
This test suite validates the DeviceMesh functionality in ColossalAI, focusing on device mesh initialization, process group management, and rank mapping capabilities. The tests ensure proper handling of both 1D and 2D device mesh configurations in distributed computing environments.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_device/test_device_mesh.py
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_device_mesh():
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
def check_1d_device_mesh():
# check for 1D device mesh
process_group = dist.GroupMember.WORLD
device_mesh = DeviceMesh.from_process_group(process_group)
# checks
assert device_mesh.shape == [4]
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, "Expected 1 axis for the process group dict"
assert device_mesh.get_process_group(axis=0) == process_group, "Expected world process group"
assert device_mesh.is_initialized
assert device_mesh.num_devices == 4
assert device_mesh.is_initialized
assert device_mesh.logical_mesh_id is None
assert device_mesh._is_init_from_process_group
def check_2d_device_mesh():
# create process group for 2D device mesh
first_row_ranks = [0, 1]
second_row_ranks = [2, 3]
first_col_ranks = [0, 2]
second_col_ranks = [1, 3]
first_row_pg = dist.new_group(first_row_ranks, backend="nccl")
second_row_pg = dist.new_group(second_row_ranks, backend="nccl")
first_col_pg = dist.new_group(first_col_ranks, backend="nccl")
second_col_pg = dist.new_group(second_col_ranks, backend="nccl")
# check for
current_rank = dist.get_rank()
if current_rank in first_row_ranks:
row_pg = first_row_pg
else:
row_pg = second_row_pg
if current_rank in first_col_ranks:
col_pg = first_col_pg
else:
col_pg = second_col_pg
device_mesh = DeviceMesh.from_process_group([col_pg, row_pg])
# checks
assert device_mesh.shape == [2, 2]
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, "Expected 2 axes for the process group dict"
assert device_mesh.get_process_group(axis=0) == col_pg, "Expected column process group"
assert device_mesh.get_process_group(axis=1) == row_pg, "Expected row process group"
assert device_mesh.num_devices == 4
assert device_mesh.is_initialized
assert device_mesh.logical_mesh_id is None
assert device_mesh._is_init_from_process_group
def check_init_from_process_group(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_device_mesh_from_process_group():
spawn(check_init_from_process_group, 4)
if __name__ == "__main__":
test_device_mesh()
test_device_mesh_from_process_group()