Back to Repositories

Testing Meta Trace Implementation for Vision Models in ColossalAI

This test suite validates meta tracing functionality for various deep learning model architectures from torchvision and timm libraries in ColossalAI. It ensures proper tracing behavior across a comprehensive set of popular vision models while utilizing meta device optimization.

Test Coverage Overview

The test suite provides extensive coverage of meta tracing capabilities across multiple vision model architectures.

Key areas tested include:
  • TorchVision models including VGG, ResNet, DenseNet, and MobileNet variants
  • TIMM models including ResNeSt, BEiT, ViT, ConvNeXt and Swin Transformer
  • Meta device compatibility and tracing behavior
  • Model initialization and forward pass validation

Implementation Analysis

The implementation uses pytest’s parametrized testing approach to systematically verify meta tracing across model architectures. The tests leverage ColossalAI’s meta_trace functionality to validate model behavior on meta devices, ensuring efficient memory usage during tracing operations.

Key patterns include:
  • Conditional test execution based on torch version compatibility
  • Cache clearing between test runs
  • Standardized input tensor creation
  • Consistent device handling across models

Technical Details

Testing tools and configuration:
  • PyTest for test orchestration
  • ColossalAI’s meta_trace and compatibility utilities
  • TorchVision and TIMM model libraries
  • PyTorch meta device functionality
  • Cache management decorators
  • Standard input shape of (1000, 3, 224, 224)

Best Practices Demonstrated

The test suite exemplifies several testing best practices in deep learning model validation.

Notable practices include:
  • Comprehensive model coverage across libraries
  • Proper resource management with cache clearing
  • Version compatibility checks
  • Modular test organization
  • Consistent input tensor configuration
  • Proper exception handling for incompatible environments

hpcaitech/colossalai

tests/test_fx/test_meta/test_meta_trace.py

            
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm

from colossalai.fx._compatibility import is_compatible_with_meta

if is_compatible_with_meta():
    from colossalai.fx import meta_trace

from colossalai.testing import clear_cache_before_run

tm_models = [
    tm.vgg11,
    tm.resnet18,
    tm.densenet121,
    tm.mobilenet_v3_small,
    tm.resnext50_32x4d,
    tm.wide_resnet50_2,
    tm.regnet_x_16gf,
    tm.mnasnet0_5,
    tm.efficientnet_b0,
]

tmm_models = [
    tmm.resnest.resnest50d,
    tmm.beit.beit_base_patch16_224,
    tmm.cait.cait_s24_224,
    tmm.efficientnet.efficientnetv2_m,
    tmm.resmlp_12_224,
    tmm.vision_transformer.vit_base_patch16_224,
    tmm.deit_base_distilled_patch16_224,
    tmm.convnext.convnext_base,
    tmm.vgg.vgg11,
    tmm.dpn.dpn68,
    tmm.densenet.densenet121,
    tmm.rexnet.rexnet_100,
    tmm.swin_transformer.swin_base_patch4_window7_224,
]


@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_torchvision_models_trace():
    for m in tm_models:
        model = m()
        data = torch.rand(1000, 3, 224, 224, device="meta")
        meta_trace(model, torch.device("cpu"), data)


@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_timm_models_trace():
    for m in tmm_models:
        model = m()
        data = torch.rand(1000, 3, 224, 224, device="meta")
        meta_trace(model, torch.device("cpu"), data)


if __name__ == "__main__":
    test_torchvision_models_trace()
    test_timm_models_trace()