Back to Repositories

Testing Device Mesh Management System in ColossalAI

This test suite validates the functionality of the DeviceMeshManager in ColossalAI, focusing on device mesh creation and configuration for distributed computing environments. The tests ensure proper initialization and management of device meshes with specific shapes and physical device IDs.

Test Coverage Overview

The test suite covers core functionality of the DeviceMeshManager class, specifically focusing on device mesh creation and validation.

  • Tests creation of device meshes with specified shapes
  • Validates logical mesh ID mapping
  • Verifies correct mesh shape configurations
  • Tests integration with NCCL backend

Implementation Analysis

The testing approach utilizes spawn-based parallel execution to simulate a distributed environment with multiple processes.

Key implementation patterns include:
  • Process spawning with rank and world size configuration
  • NCCL backend initialization
  • Device mesh creation with explicit shape specifications
  • Logical to physical device mapping validation

Technical Details

Testing infrastructure leverages:

  • ColossalAI’s launch utility for distributed setup
  • DeviceMeshInfo for mesh configuration
  • NCCL backend for process communication
  • 4-GPU setup with 2×2 mesh configuration
  • Spawn-based test execution framework

Best Practices Demonstrated

The test implementation showcases several testing best practices for distributed systems.

  • Proper process isolation through spawn mechanism
  • Explicit device configuration management
  • Clear assertion-based validation
  • Modular test structure with separate check function
  • Proper cleanup through logger management

hpcaitech/colossalai

tests/test_cluster/test_device_mesh_manager.py

            
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import spawn


def check_device_mesh_manager(rank, world_size, port):
    disable_existing_loggers()
    launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    device_mesh_manager = DeviceMeshManager()
    # TODO(ver217): this test is strictly relies on hardware, temporary skip it
    # device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],)
    # device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto)
    # assert device_mesh_auto.shape == (2, 2)
    # assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]]

    device_mesh_info_with_shape = DeviceMeshInfo(
        physical_ids=[0, 1, 2, 3],
        mesh_shape=(2, 2),
    )
    device_mesh_with_shape = device_mesh_manager.create_device_mesh("1", device_mesh_info_with_shape)

    assert device_mesh_with_shape.shape == (2, 2)
    assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]]


def test_device_mesh_manager():
    spawn(check_device_mesh_manager, 4)


if __name__ == "__main__":
    test_device_mesh_manager()