Validating BERT Model Tracing Compatibility in ColossalAI
A comprehensive test suite for validating BERT model functionality within the ColossalAI framework, focusing on model tracing and output comparison. This test ensures proper integration of Hugging Face BERT models and verifies their behavior when processed through ColossalAI’s tracing mechanisms.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@clear_cache_before_run()
def test_bert():
sub_registry = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
if model.__class__.__name__ == "BertForQuestionAnswering":
continue
trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "next_sentence_label"])
if __name__ == "__main__":
test_bert()