Back to Repositories

Testing Node Arguments Converting Pass in ColossalAI Auto-Parallel System

This test suite validates the node arguments converting functionality in ColossalAI’s auto-parallel system, focusing on tensor sharding and device mesh operations. It ensures proper handling of tensor transformations and sharding specifications across distributed computing environments.

Test Coverage Overview

The test suite covers the node_args_converting_pass functionality, verifying tensor sharding and reshaping operations.

Key areas tested include:
  • Tensor view transformations
  • Sharding specification handling
  • Device mesh operations
  • Narrow operation insertion

Implementation Analysis

The testing approach implements a TestModule class with tensor view operations and validates the conversion of node arguments in a distributed context.

Key implementation patterns include:
  • ColoTracer for graph generation
  • ShardingSpec configuration for tensor partitioning
  • Dynamic graph modification with node insertion
  • Device mesh initialization and management

Technical Details

Testing infrastructure utilizes:
  • PyTorch framework for tensor operations
  • ColossalAI’s auto_parallel package components
  • Custom DeviceMesh configuration (2×2 mesh)
  • Meta device for argument initialization
  • Graph manipulation utilities

Best Practices Demonstrated

The test implementation showcases high-quality testing practices for distributed systems.

Notable practices include:
  • Clear cache management with decorators
  • Systematic graph node manipulation
  • Proper assertion validation
  • Structured test setup and teardown
  • Isolated test environment configuration

hpcaitech/colossalai

tests/test_auto_parallel/test_pass/test_node_converting_pass.py

            
import torch

from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run


class TestModule(torch.nn.Module):
    def forward(self, x):
        x = x.view(4, 4, 2)
        return x


def insert_narrow(gm, x_node):
    graph = gm.graph
    with graph.inserting_after(x_node):
        shard_node = graph.create_node("call_method", "narrow", args=(x_node, 0, 0, 2), kwargs={})
    view_node = list(x_node.users.keys())[0]
    new_args = list(view_node.args)
    new_args[0] = shard_node
    view_node.args = tuple(new_args)
    return gm


@clear_cache_before_run()
def test_node_args_converting_pass():
    model = TestModule()
    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
    meta_args = {"x": torch.rand(4, 8).to("meta")}
    input = torch.rand(4, 8)
    tracer = ColoTracer()
    graph = tracer.trace(root=model, meta_args=meta_args)

    x_node = list(graph.nodes)[0]
    view_node = list(graph.nodes)[1]
    sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
    setattr(x_node, "sharding_spec", sharding_spec)
    setattr(view_node, "sharding_spec", sharding_spec)

    gm = ColoGraphModule(model, graph)
    gm = node_args_converting_pass(gm, device_mesh)
    gm = insert_narrow(gm, x_node)
    gm.recompile()
    output = gm(input)
    assert output.shape == torch.Size([2, 4, 2])


if __name__ == "__main__":
    test_node_args_converting_pass()