Testing Deep Learning Model Zoo Implementation in ColossalAI
This test file implements a comprehensive validation suite for various deep learning model architectures from both torchvision and timm libraries. It focuses on testing model compatibility and functionality across different neural network architectures including CNNs, Transformers, and hybrid models.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_analyzer/test_fx/zoo.py
import timm.models as tmm
import torchvision.models as tm
# input shape: (batch_size, 3, 224, 224)
tm_models = [
tm.alexnet,
tm.convnext_base,
tm.densenet121,
# tm.efficientnet_v2_s,
# tm.googlenet, # output bad case
# tm.inception_v3, # bad case
tm.mobilenet_v2,
tm.mobilenet_v3_small,
tm.mnasnet0_5,
tm.resnet18,
tm.regnet_x_16gf,
tm.resnext50_32x4d,
tm.shufflenet_v2_x0_5,
tm.squeezenet1_0,
# tm.swin_s, # fx bad case
tm.vgg11,
tm.vit_b_16,
tm.wide_resnet50_2,
]
tmm_models = [
tmm.beit_base_patch16_224,
tmm.beitv2_base_patch16_224,
tmm.cait_s24_224,
tmm.coat_lite_mini,
tmm.convit_base,
tmm.deit3_base_patch16_224,
tmm.dm_nfnet_f0,
tmm.eca_nfnet_l0,
tmm.efficientformer_l1,
# tmm.ese_vovnet19b_dw,
tmm.gmixer_12_224,
tmm.gmlp_b16_224,
# tmm.hardcorenas_a,
tmm.hrnet_w18_small,
tmm.inception_v3,
tmm.mixer_b16_224,
tmm.nf_ecaresnet101,
tmm.nf_regnet_b0,
# tmm.pit_b_224, # pretrained only
# tmm.regnetv_040,
# tmm.skresnet18,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm.vgg11,
tmm.vit_base_patch16_18x2_224,
tmm.wide_resnet50_2,
]