Back to Repositories

Testing OPT Model Pipeline Splitting in ColossalAI

This test suite evaluates the pipeline functionality of OPT models within ColossalAI, focusing on model splitting and output comparison. It validates both base OPT and OPT for causal language modeling implementations, ensuring correct behavior when models are split across different configurations.

Test Coverage Overview

The test suite covers critical aspects of OPT model implementations in ColossalAI.

Key areas include:
  • Model splitting verification for OPTModel and OPTForCausalLM
  • Output comparison between original and split models
  • Handling of different model configurations with custom parameters
  • Validation of attention mask and input processing

Implementation Analysis

The testing approach employs a systematic verification of model splitting functionality using the HuggingFace Transformers library. The implementation utilizes a data generation pattern with controlled batch size and sequence length parameters, allowing for consistent input testing across different model variants.

Technical patterns include:
  • Dynamic model instantiation with configurable parameters
  • Batch processing verification
  • Attention mask handling
  • Model output comparison logic

Technical Details

Testing infrastructure includes:
  • pytest as the testing framework
  • HuggingFace Transformers library for model implementations
  • PyTorch for tensor operations
  • Custom configuration with vocab_size=100, hidden_size=128
  • 4 hidden layers and 4 attention heads
  • Batch size of 1 and sequence length of 16

Best Practices Demonstrated

The test implementation showcases several testing best practices in deep learning model validation.

Notable practices include:
  • Modular test structure with clear separation of concerns
  • Controlled test environment with fixed random states
  • Comprehensive model configuration testing
  • Efficient test data generation
  • Clear skip marker documentation for incomplete features

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_hf_model/test_opt.py

            
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output

BATCH_SIZE = 1
SEQ_LENGHT = 16


@pytest.mark.skip("balance split v2 is not ready")
def test_opt():
    MODEL_LIST = [
        transformers.OPTModel,
        transformers.OPTForCausalLM,
    ]

    config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)

    def data_gen():
        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
        return kwargs

    for model_cls in MODEL_LIST:
        model = model_cls(config=config)
        split_model_and_compare_output(model, data_gen)


if __name__ == "__main__":
    test_opt()