Back to Repositories

Testing PyTorch FX Model Splitting Implementation in ColossalAI

This test suite validates the model splitting functionality in ColossalAI, focusing on comparing outputs between original and split models. It ensures consistent behavior when partitioning deep learning models using the FX graph manipulation tools.

Test Coverage Overview

The test coverage focuses on model splitting validation using PyTorch FX graph manipulation. Key areas include:

  • Model tracing and graph generation
  • Balanced splitting of models into two parts
  • Output consistency verification between original and split models
  • RNG state management for deterministic testing

Implementation Analysis

The testing approach employs PyTorch’s FX framework with ColoTracer for model analysis. It implements a systematic workflow that:

  • Traces the original model using ColoTracer
  • Applies balanced split passes to partition the model
  • Handles both single tensor and multiple tensor outputs
  • Maintains consistent random states across comparisons

Technical Details

Testing infrastructure includes:

  • PyTorch FX GraphModule for model manipulation
  • ColoTracer for model tracing
  • Custom split passes (balanced_split_pass and split_with_split_nodes_pass)
  • Deterministic configuration with fixed random seeds
  • Python’s inspect module for dynamic parameter handling

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Explicit random state management for reproducibility
  • Comprehensive error handling with descriptive messages
  • Dynamic input handling for flexible model testing
  • Strict output equality verification
  • Clean separation of tracing, transformation, and validation steps

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_timm_model/timm_utils.py

            
import inspect
import random

import numpy as np
import torch
from torch.fx import GraphModule

from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass

MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True


def split_model_and_compare_output(model, data, meta_args=None):
    model.eval()

    # get origin output and rng state
    cpu_rng_state = torch.get_rng_state()
    output = model(data)

    # tracing model
    tracer = ColoTracer()
    try:
        graph = tracer.trace(root=model, meta_args=meta_args)
    except Exception as e:
        raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
    gm = GraphModule(model, graph, model.__class__.__name__)
    gm.recompile()

    # apply transform passes
    annotated_model = balanced_split_pass(gm, 2)
    split_model, split_submodules = split_with_split_nodes_pass(annotated_model)

    # get split model
    model_part0 = list(split_model.children())[0]
    model_part1 = list(split_model.children())[1]

    # set rng state and compute output of split model
    torch.set_rng_state(cpu_rng_state)
    output_part0 = model_part0(data)
    sig = inspect.signature(model_part1.forward)
    if isinstance(output_part0, torch.Tensor):
        output_part1 = model_part1(output_part0)
    else:
        if len(output_part0) > len(sig.parameters):
            output_part0 = output_part0[: len(sig.parameters)]
        output_part1 = model_part1(*output_part0)
    assert output.equal(output_part1)