Back to Repositories

Validating DeviceMesh Process Group Management in ColossalAI

This test suite validates the DeviceMesh functionality in ColossalAI, focusing on device mesh initialization, process group management, and rank mapping capabilities. The tests ensure proper handling of both 1D and 2D device mesh configurations in distributed computing environments.

Test Coverage Overview

The test suite provides comprehensive coverage of DeviceMesh functionality:
  • Device mesh initialization and configuration validation
  • Process group management for 1D and 2D mesh structures
  • Rank mapping and conversion between global and local ranks
  • Integration with PyTorch distributed communication backend

Implementation Analysis

The testing approach implements systematic validation of DeviceMesh components using pytest fixtures and distributed testing patterns. The implementation leverages PyTorch’s distributed computing capabilities and NCCL backend for process group management.

Key patterns include spawn-based distributed testing, process group initialization, and mesh topology verification.

Technical Details

Testing infrastructure includes:
  • PyTest framework with distributed testing support
  • PyTorch distributed communication (torch.distributed)
  • NCCL backend for process groups
  • Custom DeviceMesh class implementation
  • Spawn-based multi-process test execution

Best Practices Demonstrated

The test suite exemplifies robust testing practices for distributed systems:
  • Isolated test cases for different mesh configurations
  • Comprehensive assertion checks for mesh properties
  • Error handling for process group initialization
  • Clean setup and teardown of distributed environment
  • Reusable test utilities for address management

hpcaitech/colossalai

tests/test_device/test_device_mesh.py

            
import pytest
import torch
import torch.distributed as dist

import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import rerun_if_address_is_in_use, spawn


def test_device_mesh():
    physical_mesh_id = torch.arange(0, 16)
    mesh_shape = (4, 4)
    # [[0, 1, 2, 3],
    #  [4, 5, 6, 7],
    #  [8, 9, 10,11],
    #  [12,13,14,15]]
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
    assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
    assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
    assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]


def check_1d_device_mesh():
    # check for 1D device mesh
    process_group = dist.GroupMember.WORLD
    device_mesh = DeviceMesh.from_process_group(process_group)

    # checks
    assert device_mesh.shape == [4]
    assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, "Expected 1 axis for the process group dict"
    assert device_mesh.get_process_group(axis=0) == process_group, "Expected world process group"
    assert device_mesh.is_initialized
    assert device_mesh.num_devices == 4
    assert device_mesh.is_initialized
    assert device_mesh.logical_mesh_id is None
    assert device_mesh._is_init_from_process_group


def check_2d_device_mesh():
    # create process group for 2D device mesh
    first_row_ranks = [0, 1]
    second_row_ranks = [2, 3]
    first_col_ranks = [0, 2]
    second_col_ranks = [1, 3]

    first_row_pg = dist.new_group(first_row_ranks, backend="nccl")
    second_row_pg = dist.new_group(second_row_ranks, backend="nccl")
    first_col_pg = dist.new_group(first_col_ranks, backend="nccl")
    second_col_pg = dist.new_group(second_col_ranks, backend="nccl")

    # check for
    current_rank = dist.get_rank()

    if current_rank in first_row_ranks:
        row_pg = first_row_pg
    else:
        row_pg = second_row_pg

    if current_rank in first_col_ranks:
        col_pg = first_col_pg
    else:
        col_pg = second_col_pg

    device_mesh = DeviceMesh.from_process_group([col_pg, row_pg])

    # checks
    assert device_mesh.shape == [2, 2]
    assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, "Expected 2 axes for the process group dict"
    assert device_mesh.get_process_group(axis=0) == col_pg, "Expected column process group"
    assert device_mesh.get_process_group(axis=1) == row_pg, "Expected row process group"
    assert device_mesh.num_devices == 4
    assert device_mesh.is_initialized
    assert device_mesh.logical_mesh_id is None
    assert device_mesh._is_init_from_process_group


def check_init_from_process_group(rank, world_size, port):
    colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_device_mesh_from_process_group():
    spawn(check_init_from_process_group, 4)


if __name__ == "__main__":
    test_device_mesh()
    test_device_mesh_from_process_group()