Back to Repositories

Testing FLOP Counting Mechanisms in ColossalAI

This test suite validates FLOP counting functionality in ColossalAI for both PyTorch modules and functions. It ensures accurate computation of forward and backward pass FLOPs across different model architectures and operations.

Test Coverage Overview

The test suite provides comprehensive coverage for FLOP counting across various PyTorch components:

  • Tests multiple torchvision and timm model architectures
  • Validates both forward and backward pass FLOP calculations
  • Covers edge cases with special operations like ReLU, max pooling, and conditional operations
  • Tests integration with MetaTensorMode optimization

Implementation Analysis

The testing approach uses pytest’s parametrization to efficiently test multiple scenarios. It leverages MetaTensorMode for optimized testing performance and implements separate test functions for modules and individual operations.

  • Parametrized tests for scalable test cases
  • Version-specific test skipping for compatibility
  • Custom test configurations for different PyTorch operations

Technical Details

  • Testing Framework: pytest
  • Key Dependencies: torch, torchvision, packaging
  • Test Environment: Requires PyTorch >= 1.12.0
  • Custom Components: MetaTensorMode, flop_count from ColossalAI
  • Test Data: Random tensors with specific shapes (2,3,224,224)

Best Practices Demonstrated

The test suite exemplifies several testing best practices in deep learning frameworks:

  • Proper version checking and test skipping
  • Comprehensive error messages with detailed FLOP information
  • Modular test organization for modules and functions
  • Efficient test parameterization
  • Clear separation of test cases and configurations

hpcaitech/colossalai

tests/test_analyzer/test_subclasses/test_flop_tensor.py

            
import pytest
import torch
import torch.nn.functional as F
import torchvision.models as tm
from packaging import version

from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models

try:
    from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
except:
    pass


@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@pytest.mark.parametrize("m", tm_models + tmm_models)
def test_flop_count_module(m):
    x = torch.rand(2, 3, 224, 224)
    with MetaTensorMode():  # save time for testing
        module = m()
    rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
    assert rs_fwd > 0, f"fwd flop count of {m.__name__} is {rs_fwd}"
    assert rs_bwd > 0, f"bwd flop count of {m.__name__} is {rs_bwd}"


odd_cases = [
    (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}),
    (
        F.max_pool2d,
        (torch.rand(2, 3, 224, 224, requires_grad=True),),
        {"kernel_size": 3, "stride": 2, "padding": 1, "dilation": 2},
    ),
    (
        torch.where,
        (
            torch.rand(2, 3, 224, 224) > 0.5,
            torch.rand(2, 3, 224, 224, requires_grad=True),
            torch.rand(2, 3, 224, 224, requires_grad=True),
        ),
        {},
    ),
]


@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@pytest.mark.parametrize("func, args, kwargs", odd_cases)
def test_flop_count_function(func, args, kwargs):
    rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
    assert rs_fwd > 0, f"fwd flop count of {func.__name__} is {rs_fwd}"
    assert rs_bwd > 0, f"bwd flop count of {func.__name__} is {rs_bwd}"


if __name__ == "__main__":
    test_flop_count_module(tm.resnet18)
    test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True})