Back to Repositories

Testing TIMM Model Architecture Integration in ColossalAI

This test suite validates the functionality of various TIMM (PyTorch Image Models) model architectures within the ColossalAI framework. It tests both models with and without control flow, ensuring proper model splitting and output consistency across different neural network architectures.

Test Coverage Overview

The test suite provides comprehensive coverage of multiple TIMM model architectures, divided into two main categories:

  • Models without control flow: ResNeSt, BEiT, CaiT, ConvMixer, EfficientNetV2, ResMLP, ViT, and DeiT
  • Models with control flow: ConvNeXT, VGG, DPN, DenseNet, ReXNet, and Swin Transformer

Each model is tested for proper splitting functionality and output consistency using random input tensors.

Implementation Analysis

The testing approach utilizes PyTest framework with model-specific test functions. The implementation employs a systematic pattern of initializing models, generating random input data, and validating output consistency through the split_model_and_compare_output utility function. Special handling is implemented for models with control flow using meta-arguments and CUDNN deterministic settings.

Technical Details

Key technical components include:

  • PyTest for test organization and execution
  • TIMM library for accessing pre-implemented model architectures
  • PyTorch for tensor operations and model handling
  • Custom split_model_and_compare_output utility for validation
  • CUDNN deterministic configuration for consistent results

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Modular test organization with separate functions for different model categories
  • Consistent input tensor dimensions across all tests
  • Proper use of PyTest decorators for test management
  • Explicit handling of deterministic behavior for control flow models
  • Comprehensive model coverage across different architecture types

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_timm_model/test_timm.py

            
import pytest
import timm.models as tm
import torch
from timm_utils import split_model_and_compare_output


@pytest.mark.skip("balance split v2 is not ready")
def test_timm_models_without_control_flow():
    MODEL_LIST = [
        tm.resnest.resnest50d,
        tm.beit.beit_base_patch16_224,
        tm.cait.cait_s24_224,
        tm.convmixer.convmixer_768_32,
        tm.efficientnet.efficientnetv2_m,
        tm.resmlp_12_224,
        tm.vision_transformer.vit_base_patch16_224,
        tm.deit_base_distilled_patch16_224,
    ]

    data = torch.rand(2, 3, 224, 224)

    for model_cls in MODEL_LIST:
        model = model_cls()
        split_model_and_compare_output(model, data)


@pytest.mark.skip("balance split v2 is not ready")
def test_timm_models_with_control_flow():
    torch.backends.cudnn.deterministic = True

    MODEL_LIST_WITH_CONTROL_FLOW = [
        tm.convnext.convnext_base,
        tm.vgg.vgg11,
        tm.dpn.dpn68,
        tm.densenet.densenet121,
        tm.rexnet.rexnet_100,
        tm.swin_transformer.swin_base_patch4_window7_224,
    ]

    data = torch.rand(2, 3, 224, 224)

    meta_args = {"x": data.to("meta")}

    for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
        model = model_cls()
        split_model_and_compare_output(model, data, meta_args)


if __name__ == "__main__":
    test_timm_models_without_control_flow()
    test_timm_models_with_control_flow()