Back to Repositories

Testing Bias Addition Module Operations in ColossalAI

This test suite validates the bias addition functionality in neural network modules using ColoTracer and ColoGraphModule within the ColossalAI framework. It specifically tests shape preservation and correct bias addition operations in both linear and convolutional neural network layers.

Test Coverage Overview

The test suite provides comprehensive coverage of bias addition operations across different neural network architectures.

  • Tests linear layer bias addition with shape validation
  • Tests convolutional layer bias addition with dimension broadcasting
  • Verifies meta-data preservation across graph transformations
  • Validates shape consistency through the computational graph

Implementation Analysis

The implementation uses ColoTracer to create computational graphs and validates their transformation.

Key patterns include:
  • Graph-based representation of neural network operations
  • Meta-data tracking for tensor shapes
  • Explicit validation of intermediate tensor dimensions
  • Memory-efficient testing using meta device tensors

Technical Details

Testing infrastructure includes:

  • PyTorch framework for tensor operations
  • ColoTracer for computational graph generation
  • ColoGraphModule for graph compilation
  • Clear cache decorator for memory management
  • Meta device usage for efficient shape testing

Best Practices Demonstrated

The test suite exemplifies robust testing practices for deep learning frameworks.

  • Systematic validation of shape transformations
  • Comprehensive assertion checks
  • Memory-efficient testing approaches
  • Clear separation of test cases
  • Explicit documentation of expected graph structures

hpcaitech/colossalai

tests/test_fx/test_tracer/test_bias_addition_module.py

            
import torch

from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run


class LinearModel(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features)

    def forward(self, x):
        x = self.linear(x)
        x = x * 2

        return x


class ConvModel(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=True):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias
        )

    def forward(self, x):
        x = self.conv(x)
        x = x * 2

        return x


@clear_cache_before_run()
def test_linear_module():
    model = LinearModel(3, 6)
    tracer = ColoTracer()
    # graph():
    #     %x : torch.Tensor [#users=1] = placeholder[target=x]
    #     %linear_weight : [#users=1] = get_attr[target=linear.weight]
    #     %linear_bias : [#users=1] = get_attr[target=linear.bias]
    #     %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
    #     %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
    #     %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
    #     return mul
    graph = tracer.trace(root=model, meta_args={"x": torch.rand(3, 3).to("meta")})
    # def forward(self, x : torch.Tensor):
    #     linear_weight = self.linear.weight
    #     linear_bias = self.linear.bias
    #     linear = torch._C._nn.linear(x, linear_weight);  x = linear_weight = None
    #     add = linear + linear_bias;  linear = linear_bias = None
    #     mul = add * 2;  add = None
    #     return mul
    gm = ColoGraphModule(model, graph)
    gm.recompile()
    node_list = list(graph.nodes)
    for node in node_list:
        if node.op == "output":
            continue
        assert hasattr(node, "_meta_data")
    weight_node = node_list[1]
    bias_node = node_list[2]
    linear_node = node_list[3]
    add_node = node_list[4]
    assert weight_node._meta_data.shape == (6, 3)
    assert bias_node._meta_data.shape == (6,)
    assert linear_node._meta_data.shape == (3, 6)
    assert add_node._meta_data.shape == (3, 6)


@clear_cache_before_run()
def test_conv_module():
    model = ConvModel(3, 6, 2)
    tracer = ColoTracer()
    # graph():
    #     %x : torch.Tensor [#users=1] = placeholder[target=x]
    #     %conv_weight : [#users=1] = get_attr[target=conv.weight]
    #     %conv_bias : [#users=1] = get_attr[target=conv.bias]
    #     %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
    #     %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
    #     %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
    #     %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
    #     return mul
    graph = tracer.trace(root=model, meta_args={"x": torch.rand(4, 3, 64, 64).to("meta")})
    # def forward(self, x : torch.Tensor):
    #     conv_weight = self.conv.weight
    #     conv_bias = self.conv.bias
    #     conv2d = torch.conv2d(x, conv_weight);  x = conv_weight = None
    #     view = conv_bias.view([1, -1, 1, 1]);  conv_bias = None
    #     add = conv2d + view;  conv2d = view = None
    #     mul = add * 2;  add = None
    #     return mul
    gm = ColoGraphModule(model, graph)

    gm.recompile()
    node_list = list(graph.nodes)
    for node in node_list:
        if node.op == "output":
            continue
        assert hasattr(node, "_meta_data")
    weight_node = node_list[1]
    bias_node = node_list[2]
    conv_node = node_list[3]
    view_node = node_list[4]
    add_node = node_list[5]
    assert weight_node._meta_data.shape == (6, 3, 2, 2)
    assert bias_node._meta_data.shape == (6,)
    assert conv_node._meta_data.shape == (4, 6, 63, 63)
    assert view_node._meta_data.shape == (6, 1, 1)
    assert add_node._meta_data.shape == (4, 6, 63, 63)


if __name__ == "__main__":
    test_linear_module()
    test_conv_module()