Back to Repositories

Testing DLRM Model Symbolic Tracing in ColossalAI

This test suite focuses on validating the symbolic tracing functionality for DLRM (Deep Learning Recommendation Model) implementations in ColossalAI. It ensures proper model conversion and output consistency between traced and non-traced versions of recommendation models.

Test Coverage Overview

The test suite provides comprehensive coverage of DLRM model tracing and validation.

Key areas tested include:
  • Model tracing accuracy for recommendation models
  • Output consistency verification
  • Handling of different model architectures
  • Tensor and dictionary output comparisons

Implementation Analysis

The testing approach utilizes symbolic tracing to convert models and validate their behavior. It implements systematic comparison between original and traced model outputs, with specific handling for both tensor and dictionary-based outputs. The framework leverages PyTorch’s tracing capabilities while accounting for batch normalization modes and control flow variations.

Technical Details

Key technical components include:
  • PyTorch backend with CUDNN deterministic settings
  • Symbolic tracing via ColossalAI analyzer
  • Meta argument handling for control flow cases
  • Cache clearing mechanisms for consistent testing
  • Numerical comparison with configurable tolerance (1e-5)

Best Practices Demonstrated

The test implementation showcases several testing best practices:
  • Proper model evaluation mode setting
  • Graceful handling of different output types
  • Comprehensive error messages with detailed output comparison
  • Modular test structure with clear separation of concerns
  • Efficient test case generation using model zoo registry

hpcaitech/colossalai

tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py

            
import torch

from colossalai._analyzer.fx import symbolic_trace
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo

BATCH = 2
SHAPE = 10


def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
    # trace
    model = model_cls()

    # convert to eval for inference
    # it is important to set it to eval mode before tracing
    # without this statement, the torch.nn.functional.batch_norm will always be in training mode
    model.eval()

    gm = symbolic_trace(model, meta_args=meta_args)
    gm.eval()
    # run forward
    with torch.no_grad():
        fx_out = gm(**data)
        non_fx_out = model(**data)

    # compare output
    transformed_fx_out = output_transform_fn(fx_out)
    transformed_non_fx_out = output_transform_fn(non_fx_out)

    assert len(transformed_fx_out) == len(transformed_non_fx_out)
    if torch.is_tensor(fx_out):
        assert torch.allclose(
            fx_out, non_fx_out
        ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
    else:
        assert torch.allclose(
            fx_out.values(), non_fx_out.values()
        ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
    for key in transformed_fx_out.keys():
        fx_output_val = transformed_fx_out[key]
        non_fx_output_val = transformed_non_fx_out[key]
        if torch.is_tensor(fx_output_val):
            assert torch.allclose(
                fx_output_val, non_fx_output_val, atol=1e-5
            ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}"
        else:
            assert torch.allclose(
                fx_output_val.values(), non_fx_output_val.values()
            ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"


@clear_cache_before_run()
def test_torchrec_dlrm_models():
    torch.backends.cudnn.deterministic = True
    dlrm_models = model_zoo.get_sub_registry(keyword="deepfm", allow_empty=True)

    for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items():
        data = data_gen_fn()

        # dlrm_interactionarch is not supported
        # TODO(FrankLeeeee): support this model
        if name == "dlrm_interactionarch":
            continue

        if attribute is not None and attribute.has_control_flow:
            meta_args = {k: v.to("meta") for k, v in data.items()}
        else:
            meta_args = None

        trace_and_compare(model_fn, data, output_transform_fn, meta_args)


if __name__ == "__main__":
    test_torchrec_dlrm_models()