Back to Repositories

Testing Tensor Metadata Propagation in ColossalAI FX Module

This test suite validates the meta information propagation functionality in ColossalAI’s FX module. It ensures proper handling of tensor metadata during symbolic tracing and graph manipulation operations, focusing on shape, dtype, stride, and numel properties.

Test Coverage Overview

The test suite provides comprehensive coverage of tensor metadata propagation through a neural network model.

Key areas tested include:
  • Linear layer tensor transformations
  • Metadata preservation during symbolic tracing
  • Shape and dtype consistency checks
  • Meta device tensor handling

Implementation Analysis

The testing approach uses PyTorch’s symbolic tracing with meta tensors to verify metadata propagation. It implements a systematic verification process that traces a simple linear model and validates tensor metadata at both input and output nodes of the computation graph.

Key patterns include:
  • Meta device tensor initialization
  • Symbolic graph tracing
  • Metadata comparison utilities

Technical Details

Testing infrastructure includes:
  • PyTorch FX for symbolic tracing
  • ColossalAI’s MetaInfoProp and TensorMetadata classes
  • Custom meta_check validation function
  • Cache clearing decorators
  • Configurable batch size and dimensions

Best Practices Demonstrated

The test exhibits several testing best practices including isolation of test cases, proper setup of test fixtures, and comprehensive assertion checking.

Notable practices:
  • Clear separation of setup and verification logic
  • Systematic metadata validation
  • Compatibility checking with meta tensor features
  • Memory management through cache clearing

hpcaitech/colossalai

tests/test_fx/test_meta_info_prop.py

            
import torch
from torch.fx import symbolic_trace

from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from colossalai.testing import clear_cache_before_run

if is_compatible_with_meta():
    from colossalai.fx.profiler import MetaTensor

BATCH_SIZE = 2
DIM_IN = 4
DIM_OUT = 16


def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
    assert meta_info_spec.shape == orig_tensor.shape
    assert meta_info_spec.dtype == orig_tensor.dtype
    assert meta_info_spec.stride == orig_tensor.stride()
    assert meta_info_spec.numel == orig_tensor.numel()


@clear_cache_before_run()
def test_meta_info_prop():
    model = torch.nn.Linear(DIM_IN, DIM_OUT)
    input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="meta")
    if is_compatible_with_meta():
        input_sample = MetaTensor(input_sample, fake_device="cpu")
    orig_output = model(input_sample)
    gm = symbolic_trace(model)
    MetaInfoProp(gm).run(input_sample)
    for node in gm.graph.nodes:
        if node.op == "placeholder":
            meta_check(node.meta["tensor_meta"], input_sample)
        if node.op == "output":
            meta_check(node.meta["tensor_meta"], orig_output)


if __name__ == "__main__":
    test_meta_info_prop()