Back to Repositories

Testing HuggingFace Model Split Functionality in ColossalAI

This test suite validates the model splitting functionality in ColossalAI, specifically focusing on HuggingFace model compatibility. It ensures that models can be properly split into multiple parts while maintaining output consistency and correctness.

Test Coverage Overview

The test coverage encompasses model splitting and output validation for HuggingFace models.

Key areas tested include:
  • Model tracing with ColoTracer
  • Balanced splitting of models into two parts
  • Output consistency between original and split models
  • Different output structures (logits, prediction_logits, last_hidden_state)

Implementation Analysis

The testing approach utilizes a systematic comparison between original model outputs and split model outputs.

Technical implementation features:
  • RNG state management for reproducibility
  • Meta device usage for tracing
  • Dynamic handling of model outputs based on signature inspection
  • Support for various HuggingFace output formats

Technical Details

Testing infrastructure includes:
  • PyTorch FX for model transformation
  • ColoTracer for model analysis
  • GraphModule for model manipulation
  • Custom split passes (balanced_split_pass, split_with_split_nodes_pass)
  • Deterministic testing with fixed random seeds

Best Practices Demonstrated

The test implementation showcases robust testing practices and error handling.

Notable features:
  • Comprehensive error handling for tracing failures
  • Controlled random state management
  • Flexible input/output handling
  • Clear separation of setup, execution, and validation steps
  • Support for different model output structures

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_hf_model/hf_utils.py

            
import inspect
import random

import numpy as np
import torch
from torch.fx import GraphModule

from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass

MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)


def split_model_and_compare_output(model, data_gen):
    model.eval()

    # generate input sample
    kwargs = data_gen()

    # get origin output and rng state
    cpu_rng_state = torch.get_rng_state()
    output = model(**kwargs)

    # tracing model
    tracer = ColoTracer()
    try:
        meta_args = {k: v.to("meta") for k, v in kwargs.items()}
        graph = tracer.trace(root=model, meta_args=meta_args)
    except Exception as e:
        raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
    gm = GraphModule(model, graph, model.__class__.__name__)
    gm.recompile()

    # apply transform passes
    annotated_model = balanced_split_pass(gm, 2)
    split_model, split_submodules = split_with_split_nodes_pass(annotated_model)

    # get split model
    model_part0 = list(split_model.children())[0]
    model_part1 = list(split_model.children())[1]

    # set rng state and compute output of split model
    torch.set_rng_state(cpu_rng_state)
    output_part0 = model_part0(**kwargs)
    sig = inspect.signature(model_part1.forward)
    if isinstance(output_part0, torch.Tensor):
        output_part1 = model_part1(output_part0)
    else:
        if len(output_part0) > len(sig.parameters):
            output_part0 = output_part0[: len(sig.parameters)]
        output_part1 = model_part1(*output_part0)

    # get output tensor from HFOutput datastructure
    if "logits" in output:
        output_to_compare = output["logits"]
    elif "prediction_logits" in output:
        output_to_compare = output["prediction_logits"]
    else:
        output_to_compare = output["last_hidden_state"]

    # compare output
    if isinstance(output_part1, torch.Tensor):
        assert output_to_compare.equal(output_part1)
    elif isinstance(output_part1, (tuple, list)):
        assert output_to_compare.equal(output_part1[0])
    else:
        assert False