Back to Repositories

Testing ColoProxy Tensor Operations in ColossalAI

This test suite validates the ColoProxy functionality within ColossalAI’s FX-based tracing system, focusing on tensor metadata handling and shape inference capabilities.

Test Coverage Overview

The test coverage focuses on validating core tensor proxy operations and metadata handling in the ColoProxy class. Key areas tested include:

  • Tensor shape and dimension access
  • Length operations on proxy objects
  • Data type verification
  • Size querying functionality
The test ensures proper integration between PyTorch’s FX graph system and ColossalAI’s custom proxy implementation.

Implementation Analysis

The testing approach uses a simple Conv1D model as a test case to generate a traced graph module. The implementation leverages PyTorch’s meta device for memory-efficient testing and combines FX tracing with ColossalAI’s custom proxy objects.

The test demonstrates proper handling of meta tensors and validates proxy object behavior matches expected tensor operations.

Technical Details

  • Testing Framework: PyTest
  • Key Components: ColoTracer, ColoProxy, GraphModule
  • Dependencies: PyTorch, ColossalAI
  • Test Setup: Uses @clear_cache_before_run decorator
  • Environment: Meta device for tensor operations

Best Practices Demonstrated

The test exhibits several testing best practices including isolated test cases, proper cleanup with cache clearing, and comprehensive assertion checking. It validates both basic functionality and metadata handling, ensuring the proxy system correctly mirrors tensor behavior.

  • Clean test isolation
  • Comprehensive assertion coverage
  • Memory-efficient testing with meta device
  • Clear test structure and organization

hpcaitech/colossalai

tests/test_fx/test_coloproxy.py

            
import torch
import torch.nn as nn
from torch.fx import GraphModule

from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing import clear_cache_before_run


class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.shape[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x


@clear_cache_before_run()
def test_coloproxy():
    tracer = ColoTracer()
    model = Conv1D(3, 3)
    input_sample = {"x": torch.rand(3, 3).to("meta")}

    graph = tracer.trace(root=model, meta_args=input_sample)
    gm = GraphModule(model, graph, model.__class__.__name__)
    gm.recompile()
    node = list(gm.graph.nodes)[0]

    proxy = ColoProxy(node=node, tracer=tracer)
    proxy.meta_data = torch.empty(4, 2, device="meta")

    assert len(proxy) == 4
    assert proxy.shape[0] == 4 and proxy.shape[1] == 2
    assert proxy.dim() == 2
    assert proxy.dtype == torch.float32
    assert proxy.size(0) == 4


if __name__ == "__main__":
    test_coloproxy()