Testing PyTorch Meta Tensor Module Patching in ColossalAI
This test suite validates the patched module functionality in ColossalAI’s FX tracer implementation, focusing on various PyTorch neural network components and their meta-tensor behavior. The tests ensure correct shape handling and exception management across different neural network layers and operations.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_fx/test_tracer/test_patched_module.py
import torch
from colossalai.fx.tracer.meta_patch import patched_module
from colossalai.testing import clear_cache_before_run
def _run(data, module, patch_fn):
try:
if isinstance(data, dict):
output = patch_fn(module, **data)
if isinstance(data, tuple) or isinstance(data, list):
output = patch_fn(module, *data)
else:
output = patch_fn(module, data)
return output
except Exception as e:
return e
def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape):
output = _run(data, module, patch_fn)
if expect_exception:
assert isinstance(output, AssertionError)
else:
assert not isinstance(output, Exception)
if isinstance(output, tuple):
for item, shape in zip(output, output_shape):
assert item.is_meta
assert item.shape == shape
else:
assert output.is_meta
assert output.shape == output_shape
@clear_cache_before_run()
def test_linear():
# test linear patch can produce the meta output with correct shape
data = torch.rand(2, 4, device="meta")
module = torch.nn.Linear(4, 2)
_assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2]))
# test if the linear patch can catch exception when dimension does not match
data = torch.rand(2, 2, device="meta")
_assert_output_shape(data, module, patched_module.torch_nn_linear, True, None)
@clear_cache_before_run()
def test_rnn():
# test rnn patch can produce the meta output with correct shape
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20))
module = torch.nn.RNN(10, 20, 2)
output, hn = module(*data)
meta_data = (torch.randn(5, 3, 10).to("meta"), torch.randn(2, 3, 20).to("meta"))
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape))
# test if the rnn patch can catch exception when dimension does not match
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20))
module = torch.nn.RNN(10, 20, 2)
output, hn = module(*data)
meta_data = (torch.randn(5, 3, 1).to("meta"), torch.randn(2, 3, 20).to("meta"))
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None)
@clear_cache_before_run()
def test_embedding():
data = torch.rand(2, 4, device="meta")
# test layernorm
ln = torch.nn.LayerNorm(4)
_assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape)
# test group norm
gn = torch.nn.GroupNorm(4, num_channels=8)
_assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape)
# test batch norm 1d
bn1d = torch.nn.BatchNorm1d(4)
data = torch.rand(2, 4, device="meta")
_assert_output_shape(
data=data,
module=bn1d,
patch_fn=patched_module.torch_nn_normalize,
expect_exception=False,
output_shape=data.shape,
)
data = torch.rand(2, 4, device="meta")
_assert_output_shape(
data=data,
module=bn1d,
patch_fn=patched_module.torch_nn_normalize,
expect_exception=False,
output_shape=data.shape,
)
data = torch.rand(2, 3, 4, device="meta")
_assert_output_shape(
data=data,
module=bn1d,
patch_fn=patched_module.torch_nn_normalize,
expect_exception=False,
output_shape=data.shape,
)
data = torch.rand(1, 2, 3, 4, device="meta")
_assert_output_shape(
data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None
)
# test batch norm 2d
bn2d = torch.nn.BatchNorm2d(4)
data = torch.rand(1, 2, 3, 4, device="meta")
_assert_output_shape(
data=data,
module=bn2d,
patch_fn=patched_module.torch_nn_normalize,
expect_exception=False,
output_shape=data.shape,
)
data = torch.rand(2, 3, 4, device="meta")
_assert_output_shape(
data=data, module=bn2d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None
)
# # test batch size 3d
bn3d = torch.nn.BatchNorm3d(4)
data = torch.rand(1, 1, 2, 3, 4, device="meta")
_assert_output_shape(
data=data,
module=bn3d,
patch_fn=patched_module.torch_nn_normalize,
expect_exception=False,
output_shape=data.shape,
)
data = torch.rand(1, 2, 3, 4, device="meta")
_assert_output_shape(
data=data, module=bn3d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None
)
@clear_cache_before_run()
def test_conv1d():
# test conv 1d
data = torch.rand(2, 3, 4)
conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = conv1d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=conv1d,
patch_fn=patched_module.torch_nn_conv1d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = conv1d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=conv1d,
patch_fn=patched_module.torch_nn_conv1d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv1d = torch.nn.Conv1d(
in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect"
)
materialized_output = conv1d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=conv1d,
patch_fn=patched_module.torch_nn_conv1d,
expect_exception=False,
output_shape=materialized_output.shape,
)
def test_conv2d():
# test conv 2d
data = torch.rand(2, 3, 4, 4)
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = conv2d(data)
_assert_output_shape(
data=data,
module=conv2d,
patch_fn=patched_module.torch_nn_conv2d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = conv2d(data)
_assert_output_shape(
data=data,
module=conv2d,
patch_fn=patched_module.torch_nn_conv2d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2)
materialized_output = conv2d(data)
_assert_output_shape(
data=data,
module=conv2d,
patch_fn=patched_module.torch_nn_conv2d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect"
)
materialized_output = conv2d(data)
_assert_output_shape(
data=data,
module=conv2d,
patch_fn=patched_module.torch_nn_conv2d,
expect_exception=False,
output_shape=materialized_output.shape,
)
@clear_cache_before_run()
def test_conv3d():
# test conv 3d
data = torch.rand(2, 3, 4, 4, 4)
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = conv3d(data)
_assert_output_shape(
data=data,
module=conv3d,
patch_fn=patched_module.torch_nn_conv3d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = conv3d(data)
_assert_output_shape(
data=data,
module=conv3d,
patch_fn=patched_module.torch_nn_conv3d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2)
materialized_output = conv3d(data)
_assert_output_shape(
data=data,
module=conv3d,
patch_fn=patched_module.torch_nn_conv3d,
expect_exception=False,
output_shape=materialized_output.shape,
)
conv3d = torch.nn.Conv3d(
in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect"
)
materialized_output = conv3d(data)
_assert_output_shape(
data=data,
module=conv3d,
patch_fn=patched_module.torch_nn_conv3d,
expect_exception=False,
output_shape=materialized_output.shape,
)
@clear_cache_before_run()
def test_conv_transpose1d():
# test conv transpose1d
data = torch.rand(2, 3, 4)
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = convtrans1d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=convtrans1d,
patch_fn=patched_module.torch_nn_convtranspose1d,
expect_exception=False,
output_shape=materialized_output.shape,
)
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = convtrans1d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=convtrans1d,
patch_fn=patched_module.torch_nn_convtranspose1d,
expect_exception=False,
output_shape=materialized_output.shape,
)
@clear_cache_before_run()
def test_conv_transpose2d():
# test conv transpose2d
data = torch.rand(2, 3, 4, 4)
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = convtrans2d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=convtrans2d,
patch_fn=patched_module.torch_nn_convtranspose2d,
expect_exception=False,
output_shape=materialized_output.shape,
)
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = convtrans2d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=convtrans2d,
patch_fn=patched_module.torch_nn_convtranspose2d,
expect_exception=False,
output_shape=materialized_output.shape,
)
@clear_cache_before_run()
def test_conv_transpose3d():
# test conv transpose2d
data = torch.rand(2, 3, 4, 4, 4)
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = convtrans3d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=convtrans3d,
patch_fn=patched_module.torch_nn_convtranspose3d,
expect_exception=False,
output_shape=materialized_output.shape,
)
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = convtrans3d(data)
meta_data = data.to("meta")
_assert_output_shape(
data=meta_data,
module=convtrans3d,
patch_fn=patched_module.torch_nn_convtranspose3d,
expect_exception=False,
output_shape=materialized_output.shape,
)
@clear_cache_before_run()
def test_pool1d():
combinations = [
[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d],
[torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d],
]
for layer_cls, patch_func in combinations:
pooler = layer_cls(kernel_size=3)
data = torch.rand(2, 3, 4)
materialized_output = pooler(data)
_assert_output_shape(
data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape,
)
data = torch.rand(2, 4)
materialized_output = pooler(data)
_assert_output_shape(
data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape,
)
data = torch.rand(2, 3, 4, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
@clear_cache_before_run()
def test_pool2d():
combinations = [
[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d],
[torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d],
]
for layer_cls, patch_func in combinations:
pooler = layer_cls(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(
data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape,
)
# test max pool 3d
data = torch.rand(2, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(
data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape,
)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
@clear_cache_before_run()
def test_pool3d():
combinations = [
[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d],
[torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d],
]
for layer_cls, patch_func in combinations:
pooler = layer_cls(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(
data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape,
)
# test max pool 3d
data = torch.rand(2, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(
data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape,
)
# test max pool 3d
data = torch.rand(2, 3, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
# adapative pooling is different from other pooling, so test it individually
@clear_cache_before_run()
def test_adaptive_pooling_1d():
pooler = torch.nn.AdaptiveAvgPool1d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_1d
data = torch.rand(3, 4)
output = pooler(data)
_assert_output_shape(
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape
)
data = torch.rand(2, 3, 4)
output = pooler(data)
_assert_output_shape(
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape
)
data = torch.rand(2, 3, 4, 5)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
@clear_cache_before_run()
def test_adaptive_pooling_2d():
pooler = torch.nn.AdaptiveAvgPool2d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_2d
data = torch.rand(3, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
data = torch.rand(2, 3, 4)
output = pooler(data)
_assert_output_shape(
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape
)
data = torch.rand(2, 3, 4, 5)
output = pooler(data)
_assert_output_shape(
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape
)
@clear_cache_before_run()
def test_adaptive_pooling_3d():
pooler = torch.nn.AdaptiveAvgPool3d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_3d
data = torch.rand(3, 4, 5)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
data = torch.rand(2, 3, 4, 5)
output = pooler(data)
_assert_output_shape(
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape
)
data = torch.rand(2, 3, 4, 5, 6)
output = pooler(data)
_assert_output_shape(
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape
)