Back to Repositories

Testing FP8 Quantization Hook Implementation in ColossalAI

This test suite validates the FP8 hook functionality in Colossal-AI, focusing on custom parameter operations and quantization hooks for linear operations. The tests ensure proper integration of FP8 quantization with PyTorch’s linear layers.

Test Coverage Overview

The test suite covers FP8 hook implementation and linear operation quantization.

Key areas tested include:
  • Custom FP8 hook class implementation
  • Linear operation rewriting functionality
  • Shape validation for transformed tensors
  • Hook triggering verification

Implementation Analysis

The testing approach uses a custom FP8TestHook class that extends the base FP8Hook implementation. It employs global flags to track hook execution and operation replacement, demonstrating both inheritance and operation interception patterns.

Framework features utilized:
  • PyTorch Parameter wrapping
  • Custom hook management
  • BFloat16 precision handling
  • Device capability checking

Technical Details

Testing components include:
  • pytest for test execution
  • PyTorch tensor operations
  • ColoParameter for parameter management
  • ColoParamOpHookManager for hook context
  • Device capability verification
  • BFloat16 dtype configuration

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Proper test isolation using context managers
  • Hardware compatibility checking
  • Clear state verification
  • Explicit shape validation
  • Resource cleanup through context management
  • Conditional test execution based on hardware capabilities

hpcaitech/colossalai

tests/test_fp8/test_fp8_hook.py

            
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device

REPLACED = False
TRIGGERED = False


def new_linear_fp8(x, w, bias=None):
    global TRIGGERED
    TRIGGERED = True
    return linear_fp8(x, w, bias)


class FP8TestHook(FP8Hook):
    def rewrite_op(self, func):
        func = super().rewrite_op(func)
        if func is linear_fp8:
            global REPLACED
            REPLACED = True
            return new_linear_fp8
        return func


D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16


@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_hook():
    # create tensors
    w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
    x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
    w.__class__ = ColoParameter
    w.__init__(w, requires_grad=True)
    hook = FP8TestHook()
    with ColoParamOpHookManager.use_hooks(hook):
        o = F.linear(x, w)
    assert o.shape == (B, S, D_OUT)
    assert REPLACED
    assert TRIGGERED