Back to Repositories

Testing GPT-2 Model Pipeline Implementations in ColossalAI

This test suite validates GPT-2 model variants from the Hugging Face Transformers library within ColossalAI’s pipeline framework. It focuses on ensuring proper model splitting and output consistency across different GPT-2 architectures.

Test Coverage Overview

The test suite covers multiple GPT-2 model variants including base model, language model head, double heads, and token classification implementations.

Key functionality tested:
  • Model initialization with custom configuration
  • Input tensor handling and validation
  • Output comparison between split and original models
  • Multiple architecture variants support

Implementation Analysis

The testing approach utilizes pytest fixtures and custom data generators to validate model behavior. It implements a systematic comparison between original and pipeline-split models using the split_model_and_compare_output utility.

Technical patterns include:
  • Batch processing simulation
  • Tensor shape consistency checks
  • Model configuration customization
  • Dynamic test case generation

Technical Details

Testing tools and configuration:
  • PyTest framework for test organization
  • Hugging Face Transformers library for model implementations
  • PyTorch for tensor operations
  • Custom batch size: 64
  • Sequence length: 16
  • GPT-2 config: 4 layers, 8 attention heads

Best Practices Demonstrated

The test implementation showcases several testing best practices for deep learning models.

Notable practices:
  • Modular test case design
  • Comprehensive model variant coverage
  • Controlled test environment with fixed parameters
  • Clear separation of data generation and testing logic
  • Proper error handling and skipping mechanism

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_hf_model/test_gpt.py

            
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output

BATCH_SIZE = 64
SEQ_LENGHT = 16
NUM_EPOCHS = 2
NUM_CHUNKS = 1


@pytest.mark.skip("balance split v2 is not ready")
def test_gpt():
    MODEL_LIST = [
        transformers.GPT2Model,
        transformers.GPT2LMHeadModel,
        transformers.GPT2DoubleHeadsModel,
        transformers.GPT2ForTokenClassification,
        # transformers.GPT2ForSequenceClassification, # not supported yet
    ]
    config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8)

    def data_gen():
        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        return kwargs

    for model_cls in MODEL_LIST:
        model = model_cls(config=config)
        split_model_and_compare_output(model, data_gen)


if __name__ == "__main__":
    test_gpt()