Testing T5 Transformer Model Split Operations in ColossalAI
This test suite validates the functionality of T5 transformer models within the ColossalAI framework, focusing on model splitting and output comparison across different T5 variants. It ensures proper handling of encoder-only and encoder-decoder architectures while maintaining output consistency.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_fx/test_pipeline/test_hf_model/test_t5.py
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip("balance split v2 is not ready")
def test_t5():
MODEL_LIST = [
transformers.T5Model,
transformers.T5ForConditionalGeneration,
transformers.T5EncoderModel,
]
config = transformers.T5Config(vocab_size=100, d_model=128, num_layers=2)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
if isinstance(model, transformers.T5EncoderModel):
data_gen_func = data_gen_for_encoder_only
else:
data_gen_func = data_gen
split_model_and_compare_output(model, data_gen_func)
if __name__ == "__main__":
test_t5()