Back to Repositories

Validating Pipeline Topology Utilities Implementation in ColossalAI

This test suite validates the topology utilities and pipeline partitioning functionality in ColossalAI, focusing on model splitting and DAG generation for distributed training. It ensures proper handling of model partitioning, input/output validation, and topology management.

Test Coverage Overview

The test suite provides comprehensive coverage of topology-related utilities in ColossalAI’s pipeline parallelism implementation.

Key areas tested include:
  • Model splitting and DAG generation
  • Input/output partition validation
  • Submodule topology verification
  • Graph module transformation and tracing
Edge cases covered include proper handling of placeholder nodes and verification of partition connections.

Implementation Analysis

The testing approach uses PyTorch’s FX framework for graph manipulation and transformation. It implements a systematic verification process that includes model tracing, graph module splitting, and topology validation.

Key patterns include:
  • ColoTracer for model graph generation
  • Balanced split pass implementation
  • Partition input/output validation
  • Topology relationship verification

Technical Details

Testing tools and configuration:
  • PyTorch FX for graph manipulation
  • ColoTracer for model tracing
  • Custom MLP model for testing
  • Random seed initialization for reproducibility
  • GraphModule transformation utilities
  • Partition and Topology middleware components

Best Practices Demonstrated

The test implementation showcases several testing best practices for distributed systems.

Notable practices include:
  • Systematic input/output validation
  • Comprehensive topology verification
  • Clear separation of concerns in test functions
  • Proper error handling and assertions
  • Reproducible test environment setup

hpcaitech/colossalai

tests/test_fx/test_pipeline/test_topo/topo_utils.py

            
import random

import numpy as np
import torch
from torch.fx import GraphModule

from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.legacy.pipeline.middleware import Partition, Topo
from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology

MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)


class MLP(torch.nn.Module):
    def __init__(self, config={}):
        super().__init__()
        dim = config["dim"]
        layers = config["layers"]
        self.layers = torch.nn.ModuleList()

        for _ in range(layers):
            self.layers.append(torch.nn.Linear(dim, dim, bias=False))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def split_model_and_get_DAG(model, data_gen):
    model.eval()

    # generate input sample
    kwargs = data_gen()

    # tracing model
    tracer = ColoTracer()
    try:
        meta_args = {k: v.to("meta") for k, v in kwargs.items()}
        graph = tracer.trace(root=model, meta_args=meta_args)
    except Exception as e:
        raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
    gm = GraphModule(model, graph, model.__class__.__name__)
    gm.recompile()

    # apply transform passes
    annotated_model = balanced_split_pass(gm, 2)
    top_module, split_submodules = split_with_split_nodes_pass(annotated_model)

    topo = get_fx_topology(top_module)
    for submodule in split_submodules:
        if isinstance(submodule, torch.fx.GraphModule):
            setattr(submodule, "_topo", topo)

    return top_module, split_submodules[0]._topo


def check_input(top_module, input_partition: Partition):
    partition_output = input_partition.get_output_vals()
    arg_pos = 0
    for node in top_module.graph.nodes:
        if node.op == "placeholder":
            cur_checkee = partition_output[arg_pos]
            to_partition_and_offset = cur_checkee.get()
            assert len(to_partition_and_offset) == len(node.users.keys())
            arg_pos += 1

    assert arg_pos == len(partition_output)


def check_submod(top_module, part_id, mid_partition: Partition):
    partition_input = mid_partition.get_input_vals()
    partition_output = mid_partition.get_output_vals()

    cnt = 1
    cur_node = None
    for node in top_module.graph.nodes:
        if node.name.startswith("submod"):
            cnt += 1
        if cnt == part_id:
            cur_node = node
            break

    assert len(partition_input) == len(cur_node.args)
    assert len(partition_output) == len(cur_node.users)


def check_topo(top_module, topo: Topo):
    input_partition = topo.get_input_partition()
    mid_partitions = topo.get_mid_partitions()

    check_input(top_module, input_partition)
    for part_id, submod in mid_partitions.items():
        check_submod(top_module, part_id, submod)