Back to Repositories

Testing Tensor Broadcasting and Sharding Operations in ColossalAI

This test suite validates broadcasting operations and tensor sharding specifications in ColossalAI’s auto-parallel system. It focuses on testing tensor broadcasting compatibility, shape calculations, and sharding spec recovery for distributed tensor operations.

Test Coverage Overview

The test suite covers three main areas of broadcasting functionality:
  • Broadcasting compatibility validation between tensor shapes
  • Broadcast shape calculations for tensor pairs
  • Sharding specification recovery for broadcast operations
Tests verify both positive and negative cases for tensor broadcasting rules.

Implementation Analysis

The testing approach employs PyTorch tensors and ColossalAI’s device mesh infrastructure to validate broadcasting behaviors. Tests use systematic verification of shape compatibility, broadcast shape generation, and sharding spec transformations across distributed tensor operations.

Technical Details

Key technical components include:
  • PyTorch tensor operations and shape manipulation
  • DeviceMesh configuration with 2×2 mesh topology
  • ShardingSpec implementations for distributed tensors
  • Broadcast shape recovery utilities

Best Practices Demonstrated

The test suite demonstrates strong testing practices including:
  • Isolated test functions for specific functionality
  • Comprehensive edge case coverage
  • Clear test case organization
  • Validation of both successful and failure scenarios

hpcaitech/colossalai

tests/test_auto_parallel/test_tensor_shard/test_broadcast.py

            
import torch

from colossalai.auto_parallel.tensor_shard.utils import (
    get_broadcast_shape,
    is_broadcastable,
    recover_sharding_spec_for_broadcast_shape,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec


def test_is_broadcastable():
    x1 = torch.rand(4, 4, 8)
    x2 = torch.rand(1, 8)
    assert is_broadcastable(x1.shape, x2.shape)

    x1 = torch.rand(4, 2, 8)
    x2 = torch.rand(2, 8)
    assert is_broadcastable(x1.shape, x2.shape)

    x1 = torch.rand(4, 2, 8)
    x2 = torch.rand(4, 8)
    assert not is_broadcastable(x1.shape, x2.shape)


def test_get_broadcast_shape():
    x1 = torch.rand(4, 4, 8)
    x2 = torch.rand(1, 8)
    assert get_broadcast_shape(x1.shape, x2.shape) == [4, 4, 8]

    x1 = torch.rand(4, 2, 8)
    x2 = torch.rand(2, 8)
    assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8]

    x1 = torch.rand(4, 2, 8)
    x2 = torch.rand(8)
    assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8]


def test_recover_sharding_spec_for_broadcast_shape():
    x1 = torch.rand(4, 1, 8)
    x2 = torch.rand(2, 8)

    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    # [[0, 1]
    #  [2, 3]]
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)

    broadcast_shape = get_broadcast_shape(x1.shape, x2.shape)
    logical_sharding_spec_for_x1 = ShardingSpec(
        device_mesh=device_mesh, dim_partition_dict={0: [0], 1: [1]}, entire_shape=broadcast_shape
    )
    physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape(
        logical_sharding_spec_for_x1, broadcast_shape, x1.shape
    )
    print(physical_sharding_spec_for_x1)

    assert physical_sharding_spec_for_x1.entire_shape == x1.shape
    # dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore
    assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]}
    assert physical_sharding_spec_for_x1.sharding_sequence == ["S0", "R", "R"]