Back to Repositories

Validating Distributed Convolutional Memory Management in ColossalAI

A comprehensive test suite for validating convolutional operations in ColossolAI’s auto-parallel functionality. This test file focuses on memory estimation and actual usage for both Conv2d module and functional implementations across distributed environments.

Test Coverage Overview

The test suite provides extensive coverage of convolutional operations in distributed settings.

Key areas tested include:
  • Conv2d module memory testing with and without bias
  • Functional conv2d memory validation
  • Strategy matching across distributed device meshes
  • Memory estimation accuracy verification

Implementation Analysis

The testing approach utilizes PyTest’s distributed testing capabilities to validate memory usage patterns. Implementation leverages device mesh configurations (2×2) and explores 16 different node strategies for both module and functional implementations.

Technical patterns include:
  • Dynamic process group initialization
  • NCCL backend communication
  • Distributed memory tracking
  • Strategy enumeration and validation

Technical Details

Testing infrastructure includes:
  • PyTest with distributed markers
  • Custom environment flags (AUTO_PARALLEL)
  • DeviceMesh configuration (4 GPUs)
  • Memory testing utilities
  • Process spawning with port management
  • NCCL backend for distributed communication

Best Practices Demonstrated

The test implementation showcases robust testing practices for distributed systems.

Notable practices include:
  • Proper resource cleanup and logger management
  • Automatic port conflict resolution
  • Parameterized testing for different configurations
  • Comprehensive memory tracking
  • Clear separation of module and functional testing

hpcaitech/colossalai

tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py

            
import pytest
import torch
import torch.nn as nn

from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy


class ConvFunctionModule(nn.Module):
    def __init__(self, in_channels=4, out_channels=64, kernel_size=3):
        super().__init__()
        self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))

    def forward(self, input):
        return nn.functional.conv2d(input, self.conv_weight)


def _conv_module_mem_test(rank, world_size, port, bias):
    """This function is for conv memory test
    Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL

    Args:
    Args:
        rank: device rank
        bias: indicate whether conv module need bias
        world_size: number of devices
        port: port for initializing process group
    """
    disable_existing_loggers()
    launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda()
    input = torch.rand(4, 4, 64, 64).cuda()
    input.requires_grad = True
    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

    # index of target node in computation graph
    node_index = 1
    # total number of target node strategies
    strategy_number = 16
    mem_test_for_node_strategy(
        rank=rank,
        model=model,
        device_mesh=device_mesh,
        node_index=node_index,
        strategy_number=strategy_number,
        input_args=[input],
        meta_arg_names=["input"],
    )


@run_on_environment_flag(name="AUTO_PARALLEL")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_meta_concrete_info_match(bias=False):
    spawn(_conv_module_mem_test, 4, bias=bias)


def _conv_function_mem_test(rank, world_size, port):
    """This function is for conv function memory test
    Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL

    Args:
        rank: device rank
        bias: indicate whether conv module need bias
        world_size: number of devices
        port: port for initializing process group
    """
    disable_existing_loggers()
    launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    model = ConvFunctionModule().cuda()
    input = torch.rand(4, 4, 64, 64).cuda()
    input.requires_grad = True
    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

    # index of target node in computation graph
    node_index = 2
    # total number of target node strategies
    strategy_number = 16
    mem_test_for_node_strategy(
        rank=rank,
        model=model,
        device_mesh=device_mesh,
        node_index=node_index,
        strategy_number=strategy_number,
        input_args=[input],
        meta_arg_names=["input"],
    )


@run_on_environment_flag(name="AUTO_PARALLEL")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_function_concrete_info_match():
    spawn(_conv_function_mem_test, 4)


if __name__ == "__main__":
    # test_conv_meta_concrete_info_match()
    test_conv_function_concrete_info_match()