Back to Repositories

Validating Topology-Based Model Partitioning in ColossalAI

A comprehensive test suite for validating topology-based model partitioning in ColossalAI, focusing on OPT models and MLP architectures. This test suite ensures proper DAG generation and topology verification for distributed training scenarios.

Test Coverage Overview

The test suite provides extensive coverage for topology-based model partitioning:

  • Tests OPT and MLP model architectures
  • Validates DAG generation for different model configurations
  • Verifies topology consistency across model splits
  • Tests with varying batch sizes and sequence lengths

Implementation Analysis

Implements a systematic approach to topology testing using pytest framework:

  • Utilizes model configuration arrays for different architectures
  • Implements custom data generators for each model type
  • Employs split_model_and_get_DAG utility for topology analysis
  • Uses check_topo verification for topology validation

Technical Details

Technical implementation specifications:

  • PyTorch integration for tensor operations
  • Transformers library for OPT model implementation
  • Custom MLP implementation via topo_utils
  • Batch size and sequence length configurations
  • Shape propagation compatibility checks

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Modular test structure with clear separation of concerns
  • Comprehensive model configuration management
  • Dynamic data generation for different model types
  • Proper skip marker usage for incompatible PyTorch versions
  • Clear error handling and topology verification

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_topo/test_topo.py

            
import pytest
import torch
import transformers
from topo_utils import MLP, check_topo, split_model_and_get_DAG

BATCH_SIZE = 1
SEQ_LENGHT = 16


@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0")
def test_opt():
    MODEL_LIST = [
        MLP,
        transformers.OPTModel,
    ]

    CONFIGS = [
        {"dim": 10, "layers": 12},
        transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
    ]

    def data_gen_MLP():
        x = torch.zeros((16, 10))
        kwargs = dict(x=x)
        return kwargs

    def data_gen_OPT():
        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
        return kwargs

    DATAGEN = [
        data_gen_MLP,
        data_gen_OPT,
    ]

    for i, model_cls in enumerate(MODEL_LIST):
        model = model_cls(config=CONFIGS[i])
        top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i])
        # print(f'{top_mod=}
----
{topo=}')
        check_topo(top_mod, topo)


if __name__ == "__main__":
    test_opt()