Back to Repositories

Testing DeepFM Model Symbolic Tracing in ColossalAI

This test suite validates the symbolic tracing functionality for DeepFM recommendation models in ColossalAI, specifically focusing on the correct transformation and execution of TorchRec models. The tests ensure model outputs remain consistent between traced and non-traced versions while handling batch processing and shape transformations.

Test Coverage Overview

The test suite provides comprehensive coverage of DeepFM model tracing functionality.

Key areas tested include:
  • Model tracing accuracy and consistency
  • Output transformation validation
  • Batch processing verification
  • Tensor comparison and validation
  • Control flow handling for meta arguments

Implementation Analysis

The testing approach implements a systematic comparison between traced and non-traced model outputs. It utilizes symbolic tracing with meta arguments handling and enforces deterministic CUDNN behavior for consistent results.

Key implementation patterns include:
  • Automatic model registration and discovery
  • Dynamic data generation
  • Flexible output transformation
  • Granular tensor comparison

Technical Details

Testing tools and configuration:
  • PyTorch symbolic tracing framework
  • ColossalAI analyzer and FX modules
  • Custom model zoo infrastructure
  • Cache clearing decorators
  • Deterministic CUDNN settings
  • Gradient-free inference testing

Best Practices Demonstrated

The test implementation showcases several testing best practices for deep learning models.

Notable practices include:
  • Proper model evaluation mode setting
  • Comprehensive error messages
  • Flexible tolerance handling
  • Memory management through cache clearing
  • Modular test structure
  • Automated test discovery

hpcaitech/colossalai

tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py

            
import torch

from colossalai._analyzer.fx import symbolic_trace
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo

BATCH = 2
SHAPE = 10


def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
    # trace
    model = model_cls()

    # convert to eval for inference
    # it is important to set it to eval mode before tracing
    # without this statement, the torch.nn.functional.batch_norm will always be in training mode
    model.eval()

    gm = symbolic_trace(model, meta_args=meta_args)
    gm.eval()
    # run forward
    with torch.no_grad():
        fx_out = gm(**data)
        non_fx_out = model(**data)

    # compare output
    transformed_fx_out = output_transform_fn(fx_out)
    transformed_non_fx_out = output_transform_fn(non_fx_out)

    assert len(transformed_fx_out) == len(transformed_non_fx_out)
    if torch.is_tensor(fx_out):
        assert torch.allclose(
            fx_out, non_fx_out
        ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
    else:
        assert torch.allclose(
            fx_out.values(), non_fx_out.values()
        ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
    for key in transformed_fx_out.keys():
        fx_output_val = transformed_fx_out[key]
        non_fx_output_val = transformed_non_fx_out[key]
        if torch.is_tensor(fx_output_val):
            assert torch.allclose(
                fx_output_val, non_fx_output_val, atol=1e-5
            ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}"
        else:
            assert torch.allclose(
                fx_output_val.values(), non_fx_output_val.values()
            ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"


@clear_cache_before_run()
def test_torchrec_deepfm_models():
    deepfm_models = model_zoo.get_sub_registry(keyword="deepfm", allow_empty=True)
    torch.backends.cudnn.deterministic = True

    for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items():
        data = data_gen_fn()
        if attribute is not None and attribute.has_control_flow:
            meta_args = {k: v.to("meta") for k, v in data.items()}
        else:
            meta_args = None

        trace_and_compare(model_fn, data, output_transform_fn, meta_args)


if __name__ == "__main__":
    test_torchrec_deepfm_models()