Back to Repositories

Testing Deep Learning Model Zoo Implementation in ColossalAI

This test file implements a comprehensive validation suite for various deep learning model architectures from both torchvision and timm libraries. It focuses on testing model compatibility and functionality across different neural network architectures including CNNs, Transformers, and hybrid models.

Test Coverage Overview

The test suite provides extensive coverage of popular deep learning model architectures.

Key areas tested include:
  • Classic CNN architectures (AlexNet, ResNet, VGG)
  • Modern efficient architectures (MobileNet, EfficientNet)
  • Transformer-based models (ViT, BEiT)
  • Hybrid architectures (ConvNeXt)
Notable edge cases include handling of different input shapes and model-specific requirements.

Implementation Analysis

The testing approach utilizes model zoo definitions from both torchvision and timm frameworks, allowing for comprehensive architecture validation.

Implementation features:
  • Standardized input shape (batch_size, 3, 224, 224)
  • Selective model inclusion with documented exclusions
  • Parallel testing of equivalent architectures across frameworks

Technical Details

Testing infrastructure includes:
  • torchvision.models for PyTorch native architectures
  • timm.models for additional model implementations
  • Fixed input dimensions (224×224)
  • Batch processing support
  • Commented exclusions for problematic models

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Clear model categorization and organization
  • Explicit documentation of excluded models with reasons
  • Consistent input specifications
  • Cross-framework validation approach
  • Modular test structure for easy maintenance

hpcaitech/colossalai

tests/test_analyzer/test_fx/zoo.py

            
import timm.models as tmm
import torchvision.models as tm

# input shape: (batch_size, 3, 224, 224)
tm_models = [
    tm.alexnet,
    tm.convnext_base,
    tm.densenet121,
    # tm.efficientnet_v2_s,
    # tm.googlenet,   # output bad case
    # tm.inception_v3,  # bad case
    tm.mobilenet_v2,
    tm.mobilenet_v3_small,
    tm.mnasnet0_5,
    tm.resnet18,
    tm.regnet_x_16gf,
    tm.resnext50_32x4d,
    tm.shufflenet_v2_x0_5,
    tm.squeezenet1_0,
    # tm.swin_s,  # fx bad case
    tm.vgg11,
    tm.vit_b_16,
    tm.wide_resnet50_2,
]

tmm_models = [
    tmm.beit_base_patch16_224,
    tmm.beitv2_base_patch16_224,
    tmm.cait_s24_224,
    tmm.coat_lite_mini,
    tmm.convit_base,
    tmm.deit3_base_patch16_224,
    tmm.dm_nfnet_f0,
    tmm.eca_nfnet_l0,
    tmm.efficientformer_l1,
    # tmm.ese_vovnet19b_dw,
    tmm.gmixer_12_224,
    tmm.gmlp_b16_224,
    # tmm.hardcorenas_a,
    tmm.hrnet_w18_small,
    tmm.inception_v3,
    tmm.mixer_b16_224,
    tmm.nf_ecaresnet101,
    tmm.nf_regnet_b0,
    # tmm.pit_b_224,  # pretrained only
    # tmm.regnetv_040,
    # tmm.skresnet18,
    # tmm.swin_base_patch4_window7_224,     # fx bad case
    # tmm.tnt_b_patch16_224,    # bad case
    tmm.vgg11,
    tmm.vit_base_patch16_18x2_224,
    tmm.wide_resnet50_2,
]