Testing ALBERT Model Splitting and Validation in ColossalAI
This test suite validates the functionality of ALBERT model variants in the ColossalAI framework, specifically focusing on model splitting and output comparison. It tests multiple ALBERT configurations including base model, pre-training, masked language modeling, and classification tasks.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_fx/test_pipeline/test_hf_model/test_albert.py
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
@pytest.mark.skip("balance split v2 is not ready")
def test_single_sentence_albert():
MODEL_LIST = [
transformers.AlbertModel,
transformers.AlbertForPreTraining,
transformers.AlbertForMaskedLM,
transformers.AlbertForSequenceClassification,
transformers.AlbertForTokenClassification,
]
config = transformers.AlbertConfig(
vocab_size=100,
embedding_size=128,
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256,
)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == "__main__":
test_single_sentence_albert()