Back to Repositories

Testing Meta Tensor Backward Pass Computation in ColossalAI

This test suite validates backward pass functionality for various deep learning models using meta tensors in ColossalAI. It covers both TorchVision and TIMM model architectures, ensuring proper gradient computation with memory optimization.

Test Coverage Overview

The test suite provides comprehensive coverage of backward pass operations across multiple model architectures:

  • Tests 9 TorchVision models including VGG, ResNet, DenseNet families
  • Validates 13 TIMM models including transformers, ConvNeXt, and modern architectures
  • Verifies gradient computation using meta tensors
  • Ensures compatibility with torch versions >= 1.12.0

Implementation Analysis

The testing approach utilizes pytest’s parametrization pattern to efficiently test multiple model architectures. It employs meta tensors for memory-efficient testing, with each test creating synthetic input data of shape (100000, 3, 224, 224) and validating the backward pass through sum() operation.

The implementation leverages ColossalAI’s meta tensor capabilities and clear cache functionality to ensure isolated test environments.

Technical Details

  • Testing Framework: pytest
  • Key Dependencies: torch, timm, torchvision
  • Meta Tensor Configuration: device=’meta’, fake_device=CPU
  • Cache Management: @clear_cache_before_run decorator
  • Version Compatibility: @pytest.mark.skipif for version checking

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Proper test isolation through cache clearing
  • Version compatibility checks
  • Modular test organization for different model families
  • Memory-efficient testing using meta tensors
  • Comprehensive model coverage across architectures

hpcaitech/colossalai

tests/test_fx/test_meta/test_backward.py

            
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm

from colossalai.fx._compatibility import is_compatible_with_meta

if is_compatible_with_meta():
    from colossalai.fx.profiler import MetaTensor

from colossalai.testing import clear_cache_before_run

tm_models = [
    tm.vgg11,
    tm.resnet18,
    tm.densenet121,
    tm.mobilenet_v3_small,
    tm.resnext50_32x4d,
    tm.wide_resnet50_2,
    tm.regnet_x_16gf,
    tm.mnasnet0_5,
    tm.efficientnet_b0,
]

tmm_models = [
    tmm.resnest.resnest50d,
    tmm.beit.beit_base_patch16_224,
    tmm.cait.cait_s24_224,
    tmm.efficientnet.efficientnetv2_m,
    tmm.resmlp_12_224,
    tmm.vision_transformer.vit_base_patch16_224,
    tmm.deit_base_distilled_patch16_224,
    tmm.convnext.convnext_base,
    tmm.vgg.vgg11,
    tmm.dpn.dpn68,
    tmm.densenet.densenet121,
    tmm.rexnet.rexnet_100,
    tmm.swin_transformer.swin_base_patch4_window7_224,
]


@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_torchvision_models():
    for m in tm_models:
        model = m()
        data = torch.rand(100000, 3, 224, 224, device="meta")
        model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward()


@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_timm_models():
    for m in tmm_models:
        model = m()
        data = torch.rand(100000, 3, 224, 224, device="meta")
        model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward()


if __name__ == "__main__":
    test_torchvision_models()
    test_timm_models()