Back to Repositories

Validating BERT Model Tracing Compatibility in ColossalAI

A comprehensive test suite for validating BERT model functionality within the ColossalAI framework, focusing on model tracing and output comparison. This test ensures proper integration of Hugging Face BERT models and verifies their behavior when processed through ColossalAI’s tracing mechanisms.

Test Coverage Overview

The test suite provides extensive coverage of BERT model variants from the Transformers library, excluding BertForQuestionAnswering.

Key areas covered include:
  • Model output consistency verification
  • Trace compatibility testing
  • Integration with model zoo registry
  • Data generation and handling

Implementation Analysis

The implementation utilizes pytest’s parametrized testing approach, leveraging the model_zoo registry for systematic testing of BERT variants. The trace_model_and_compare_output function serves as the core testing mechanism, validating output consistency between original and traced models.

Notable patterns include:
  • Dynamic model instantiation
  • Selective test execution based on torch version
  • Automated cache clearing between test runs

Technical Details

Testing infrastructure includes:
  • PyTorch version requirement >= 1.12.0
  • Custom cache clearing decorator
  • Model zoo registry integration
  • Hugging Face Transformers compatibility layer
  • Specialized data generation functions
  • Output comparison utilities

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Version-specific test skipping
  • Memory management through cache clearing
  • Modular test design with separate model and data generation
  • Explicit exception handling for specific model types
  • Clean separation of test setup and execution

hpcaitech/colossalai

tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py

            
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version

from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo


@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@clear_cache_before_run()
def test_bert():
    sub_registry = model_zoo.get_sub_registry("transformers_bert")

    for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
        model = model_fn()
        if model.__class__.__name__ == "BertForQuestionAnswering":
            continue
        trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "next_sentence_label"])


if __name__ == "__main__":
    test_bert()