Back to Repositories

Validating Normalization Layer Memory Patterns in ColossalAI

This test suite validates the memory usage and meta information for normalization layers in ColossalAI’s auto-parallel functionality. It focuses on testing BatchNorm2d and LayerNorm implementations with different tensor shapes and device configurations.

Test Coverage Overview

The test suite provides comprehensive coverage of normalization layer implementations:
  • BatchNorm2d memory testing across distributed environments
  • LayerNorm meta information verification with different tensor shapes
  • Memory allocation and peak usage measurement for forward and backward passes
  • Strategy validation across multiple devices in a mesh configuration

Implementation Analysis

The testing approach employs a systematic validation of memory patterns and meta information:
  • Utilizes DeviceMesh for distributed testing scenarios
  • Implements parameterized testing for different tensor shapes
  • Compares estimated vs actual memory usage patterns
  • Validates both forward and backward pass behaviors

Technical Details

Key technical components and configurations:
  • PyTorch version requirement: >= 1.12.0 for aten level operations
  • NCCL backend for distributed communication
  • 2×2 device mesh configuration
  • Meta device usage for estimation
  • CUDA memory tracking utilities

Best Practices Demonstrated

The test implementation showcases several testing best practices:
  • Proper separation of test configurations and execution logic
  • Comprehensive memory tracking and validation
  • Environment-specific test filtering
  • Reusable test utilities and helper functions
  • Clear error handling and version compatibility checks

hpcaitech/colossalai

tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py

            
import pytest
import torch
import torch.nn as nn

from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results

if torch.__version__ >= "1.12.0":
    from colossalai.auto_parallel.meta_profiler import meta_register


def _batchnorm_module_mem_test(rank, world_size, port):
    """This function is for batchnorm memory test
    Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL

    Args:
        rank: device rank
        bias: indicate whether conv module need bias
        world_size: number of devices
        port: port for initializing process group
    """
    disable_existing_loggers()
    launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    model = nn.Sequential(nn.BatchNorm2d(128)).cuda()
    input = torch.rand(4, 128, 64, 64).cuda()
    input.requires_grad = True
    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

    # index of target node in computation graph
    node_index = 1
    # total number of target node strategies
    strategy_number = 9
    mem_test_for_node_strategy(
        rank=rank,
        model=model,
        device_mesh=device_mesh,
        node_index=node_index,
        strategy_number=strategy_number,
        input_args=[input],
        meta_arg_names=["input"],
    )


@run_on_environment_flag(name="AUTO_PARALLEL")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_batchnorm_meta_concrete_info_match():
    spawn(_batchnorm_module_mem_test, 4)


@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations")
@parameterize(
    "tensor_shape",
    [
        [256, 1024],
        [1024, 256],
    ],
)
def test_layernorm_meta_info(tensor_shape):
    meta_func = meta_register.get(torch.nn.LayerNorm)

    # construct input
    input_tensor = torch.rand(*tensor_shape, device="meta")
    output_tensor = torch.rand(*tensor_shape, device="meta")
    weight_tensor = torch.rand(tensor_shape[1], device="meta")
    bias_tensor = torch.rand(tensor_shape[1], device="meta")

    # construct operation data
    input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor)

    output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor)

    weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor)

    bias_data = OperationData(name="bias", type=OperationDataType.PARAM, data=bias_tensor)

    # construct args and kwargs
    args = [input_data, output_data, weight_data, bias_data]
    kwargs = {"inplace": False}

    # estimated results
    compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)

    # actual results
    input_real_tensor = torch.rand(*tensor_shape, device="cuda:0")

    input_real_tensor.requires_grad = True

    ln_module = torch.nn.LayerNorm(tensor_shape[1]).cuda()

    # fwd
    torch.cuda.reset_peak_memory_stats()
    mem_stamp0 = torch.cuda.memory_allocated()
    output_real_tensor = ln_module(input_real_tensor)
    fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
    fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0

    # bwd
    upstream_grad = torch.rand_like(output_real_tensor)
    torch.cuda.reset_peak_memory_stats()
    mem_stamp0 = torch.cuda.memory_allocated()
    torch.autograd.backward(output_real_tensor, upstream_grad)
    bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
    bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0

    compute_cost: TrainCycleItem
    memory_cost: TrainCycleItem

    print_results(
        [input_real_tensor],
        [output_real_tensor],
        compute_cost,
        memory_cost,
        fwd_allocated,
        fwd_peak,
        bwd_allocated,
        bwd_peak,
    )


if __name__ == "__main__":
    test_batchnorm_meta_concrete_info_match()
    test_layernorm_meta_info()