Back to Repositories

Testing T5 Transformer Model Split Operations in ColossalAI

This test suite validates the functionality of T5 transformer models within the ColossalAI framework, focusing on model splitting and output comparison across different T5 variants. It ensures proper handling of encoder-only and encoder-decoder architectures while maintaining output consistency.

Test Coverage Overview

The test suite comprehensively covers multiple T5 model variants including T5Model, T5ForConditionalGeneration, and T5EncoderModel.

Key functionality tested includes:
  • Model configuration and initialization
  • Input handling for different model architectures
  • Output comparison after model splitting
  • Batch processing validation

Implementation Analysis

The testing approach employs a systematic verification of T5 model variants using PyTest framework. It implements separate data generation functions for encoder-only and encoder-decoder architectures, ensuring appropriate input handling for each model type.

The test utilizes custom configurations with controlled parameters (vocab_size=100, d_model=128, num_layers=2) to ensure consistent testing conditions.

Technical Details

Testing tools and configuration:
  • PyTest for test execution and management
  • Transformers library for T5 model implementations
  • PyTorch for tensor operations
  • Custom split_model_and_compare_output utility
  • Configured batch size (1) and sequence length (16)

Best Practices Demonstrated

The test implementation showcases several testing best practices including modular test design, clear separation of concerns, and robust input data generation.

Notable practices:
  • Dynamic model configuration
  • Isolated test cases for different architectures
  • Controlled test environment
  • Proper exception handling

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_hf_model/test_t5.py

            
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output

BATCH_SIZE = 1
SEQ_LENGHT = 16


@pytest.mark.skip("balance split v2 is not ready")
def test_t5():
    MODEL_LIST = [
        transformers.T5Model,
        transformers.T5ForConditionalGeneration,
        transformers.T5EncoderModel,
    ]

    config = transformers.T5Config(vocab_size=100, d_model=128, num_layers=2)

    def data_gen():
        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        return kwargs

    def data_gen_for_encoder_only():
        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        kwargs = dict(input_ids=input_ids)
        return kwargs

    for model_cls in MODEL_LIST:
        model = model_cls(config=config)

        if isinstance(model, transformers.T5EncoderModel):
            data_gen_func = data_gen_for_encoder_only
        else:
            data_gen_func = data_gen

        split_model_and_compare_output(model, data_gen_func)


if __name__ == "__main__":
    test_t5()