Back to Repositories

Testing Logical Device Mesh Search Optimization in ColossalAI

This test suite validates the logical device mesh search functionality in ColossalAI’s device management system, focusing on the AlphaBetaProfiler’s ability to determine optimal device configurations for distributed computing environments.

Test Coverage Overview

The test suite covers the core functionality of the AlphaBetaProfiler class, specifically its ability to search for optimal logical device mesh configurations.

  • Tests multiple physical device combinations (4-GPU and 2-GPU configurations)
  • Verifies expected mesh outputs for different device configurations
  • Validates distributed computing setup and initialization

Implementation Analysis

The testing approach employs pytest’s parameterized testing capabilities to evaluate different device configurations within a distributed environment.

Key implementation patterns include:
  • Use of spawn() for multi-process testing
  • Custom decorators for distributed testing scenarios
  • Automatic port management and address reuse handling

Technical Details

  • Testing Framework: pytest
  • Key Libraries: ColossalAI device management modules
  • Setup: NCCL backend, localhost configuration
  • Custom Utilities: rerun_if_address_is_in_use, parameterize decorators
  • Process Management: spawn-based multi-GPU testing

Best Practices Demonstrated

The test implementation showcases several testing best practices for distributed systems testing.

  • Parameterized test cases for multiple scenarios
  • Proper cleanup and logging management
  • Resilient test design with automatic retries
  • Clear assertion patterns for expected outcomes

hpcaitech/colossalai

tests/test_device/test_search_logical_device_mesh.py

            
import pytest

from colossalai.device import AlphaBetaProfiler
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


def check_alpha_beta(rank, world_size, port, physical_devices):
    disable_existing_loggers()
    launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    profiler = AlphaBetaProfiler(physical_devices)
    best_logical_mesh = profiler.search_best_logical_mesh()

    if physical_devices == [0, 1, 2, 3]:
        assert best_logical_mesh == [[0, 1], [2, 3]]
    elif physical_devices == [0, 3]:
        assert best_logical_mesh == [[0, 3]]


@pytest.mark.skip(reason="Skip because assertion may fail for CI devices")
@pytest.mark.dist
@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]])
@rerun_if_address_is_in_use()
def test_profile_alpha_beta(physical_devices):
    spawn(check_alpha_beta, 4, physical_devices=physical_devices)


if __name__ == "__main__":
    test_profile_alpha_beta()