Back to Repositories

Validating Tensor Liveness Analysis in ColossalAI Auto-Parallel System

This test suite validates the liveness analysis functionality in ColossalAI’s auto-parallel tensor sharding system. It specifically examines how the framework handles variable lifetimes and in-place operations during graph analysis.

Test Coverage Overview

The test suite provides comprehensive coverage of liveness analysis in a linear neural network model.

Key areas tested include:
  • Variable lifetime tracking across model stages
  • In-place operation detection and handling
  • Live variable counting and uniqueness verification
  • Stage coverage analysis and overlap detection

Implementation Analysis

The testing approach uses a simple LinearModel with two linear layers and ReLU activation to validate liveness analysis.

Technical implementation includes:
  • ColoTracer for graph construction with bias addition splitting
  • Meta tensor-based argument handling
  • Shape propagation through the graph module
  • GraphAnalyser for liveness computation

Technical Details

Testing infrastructure utilizes:
  • PyTest framework with skip decoration for version compatibility
  • Torch meta tensors for graph analysis
  • ColoGraphModule for graph manipulation
  • Custom cache clearing decorators
  • GraphAnalyser for stage analysis

Best Practices Demonstrated

The test exemplifies several testing best practices:

  • Isolated test environment with cache clearing
  • Clear test assertions with specific checks
  • Comprehensive validation of both count and property assertions
  • Proper handling of meta tensors and device placement
  • Structured model definition separate from test logic

hpcaitech/colossalai

tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py

            
import pytest
import torch
import torch.nn as nn

from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run


class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(4, 4)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(4, 4)

    def forward(self, x1, x2):
        x1 = x1 * 2
        x1 = self.linear1(x1)
        x1 = self.relu(x1)
        x1 = self.linear2(x1)
        out = x1 + x2
        return out


@pytest.mark.skip("meta tensor has some bugs in 1.11")
@clear_cache_before_run()
def test_liveness_analysis():
    model = LinearModel()
    tracer = ColoTracer(bias_addition_split=True)
    meta_args = {"x1": torch.rand(4, 4, device="meta"), "x2": torch.rand(4, 4, device="meta")}
    graph = tracer.trace(model, meta_args=meta_args)
    gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
    shape_prop_pass(gm, *meta_args.values())

    graph_analyser = GraphAnalyser(gm)
    liveness_list = graph_analyser.liveness_analysis()
    stage_count = len(liveness_list)

    # if a LiveStage is covered by another LiveStage, we just keep the larger one.
    assert stage_count == 1

    # a variable named `relu` must exist
    # and this live var must have inplace = True
    assert liveness_list[0].all_live_vars.exists("relu")
    relu_var = liveness_list[0].all_live_vars.get("relu")
    assert relu_var.is_inplace

    # the unique vars must be fewer than the all vars since in-place ops exist
    all_live_vars = liveness_list[0].all_live_vars
    unique_live_vars = liveness_list[0].unique_live_vars
    assert len(unique_live_vars) + 1 == len(all_live_vars)


if __name__ == "__main__":
    test_liveness_analysis()