Testing FP8 Quantization Hook Implementation in ColossalAI
This test suite validates the FP8 hook functionality in Colossal-AI, focusing on custom parameter operations and quantization hooks for linear operations. The tests ensure proper integration of FP8 quantization with PyTorch’s linear layers.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_fp8/test_fp8_hook.py
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
REPLACED = False
TRIGGERED = False
def new_linear_fp8(x, w, bias=None):
global TRIGGERED
TRIGGERED = True
return linear_fp8(x, w, bias)
class FP8TestHook(FP8Hook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
if func is linear_fp8:
global REPLACED
REPLACED = True
return new_linear_fp8
return func
D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_hook():
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
w.__class__ = ColoParameter
w.__init__(w, requires_grad=True)
hook = FP8TestHook()
with ColoParamOpHookManager.use_hooks(hook):
o = F.linear(x, w)
assert o.shape == (B, S, D_OUT)
assert REPLACED
assert TRIGGERED