Back to Repositories

Testing BERT Model Pipeline Splitting Functionality in ColossalAI

This test suite validates BERT model implementations in ColossalAI, focusing on various BERT model variants and their pipeline functionality. It ensures proper model splitting and output consistency across different BERT architectures.

Test Coverage Overview

The test suite covers multiple BERT model variants including base BERT, PreTraining, LMHead, MaskedLM, SequenceClassification, and TokenClassification models.

Key functionality tested includes:
  • Model initialization with custom configurations
  • Input tensor handling and processing
  • Model splitting functionality
  • Output comparison validation

Implementation Analysis

The testing approach utilizes pytest framework with a focus on model splitting verification. It implements a systematic testing pattern using a data generator function that creates controlled input tensors with specific batch size and sequence length parameters.

Technical implementation includes:
  • Custom configuration setup for BERT models
  • Automated testing across multiple model variants
  • Structured input tensor generation

Technical Details

Testing infrastructure leverages:
  • PyTorch for tensor operations
  • Transformers library for BERT implementations
  • Custom hf_utils for model splitting utilities
  • Pytest for test execution
Configuration parameters:
  • Vocabulary size: 100
  • Hidden size: 128
  • Hidden layers: 4
  • Attention heads: 4

Best Practices Demonstrated

The test implementation showcases several testing best practices including modular test design and comprehensive model coverage.

Notable practices include:
  • Parameterized model testing
  • Controlled test data generation
  • Isolated test environments
  • Clear test skip documentation

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_hf_model/test_bert.py

            
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output

BATCH_SIZE = 2
SEQ_LENGHT = 16


@pytest.mark.skip("balance split v2 is not ready")
def test_single_sentence_bert():
    MODEL_LIST = [
        transformers.BertModel,
        transformers.BertForPreTraining,
        transformers.BertLMHeadModel,
        transformers.BertForMaskedLM,
        transformers.BertForSequenceClassification,
        transformers.BertForTokenClassification,
    ]

    config = transformers.BertConfig(
        vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256
    )

    def data_gen():
        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
        meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        return meta_args

    for model_cls in MODEL_LIST:
        model = model_cls(config=config)
        split_model_and_compare_output(model, data_gen)


if __name__ == "__main__":
    test_single_sentence_bert()