Back to Repositories

Testing QuantizedParameter Implementation in DeepSpeed Linear Quantization

This test suite validates the QuantizedParameter functionality in DeepSpeed’s linear quantization module, focusing on data type handling, device management, and parameter cloning capabilities.

Test Coverage Overview

The test suite covers essential aspects of QuantizedParameter implementation including:
  • Data type validation for torch.half and torch.float
  • Gradient requirement validation
  • Device migration capabilities
  • HuggingFace compatibility through parameter cloning
Key edge cases include unsupported dtype handling and device transition verification.

Implementation Analysis

The testing approach utilizes pytest’s parametrization for comprehensive dtype testing and implements distinct test methods for each functionality aspect.
  • Leverages pytest.raises for exception validation
  • Uses DistributedTest class for distributed environment simulation
  • Implements device-aware testing patterns

Technical Details

Testing infrastructure includes:
  • PyTest framework with parametrization
  • DeepSpeed’s accelerator utilities
  • Custom QuantizationConfig implementation
  • FPQuantizerBuilder integration
  • Distributed testing capabilities

Best Practices Demonstrated

The test suite exemplifies robust testing practices through:
  • Isolated test cases for specific functionality
  • Proper exception handling validation
  • Hardware-agnostic device management
  • Framework compatibility testing
  • Comprehensive configuration validation

microsoft/deepspeed

tests/unit/linear/test_quant_param.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
import deepspeed

from deepspeed.accelerator import get_accelerator
from deepspeed.linear.quantization import QuantizedParameter
from deepspeed.linear.config import QuantizationConfig

from deepspeed.ops.op_builder import FPQuantizerBuilder

from unit.common import DistributedTest

if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]:
    pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True)


class TestQuantParam(DistributedTest):
    world_size = 1

    @pytest.mark.parametrize('dtype', [torch.half, torch.float])
    def test_unsupported_dtypes(self, dtype):
        device = get_accelerator().current_device_name()
        data = torch.rand(5, 5, device='cpu', dtype=dtype)
        qp = QuantizedParameter(data)
        with pytest.raises(AssertionError):
            qp.to(device)

    def test_requires_grad(self):
        data = torch.rand(5, 5, dtype=torch.bfloat16)
        with pytest.raises(ValueError):
            QuantizedParameter(data, requires_grad=True)

    def test_move_to_accelerator(self):
        device = get_accelerator().current_device()
        data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16)
        qp = QuantizedParameter(data)
        assert qp.device == torch.device('cpu')
        qp = qp.to(get_accelerator().current_device_name())
        assert qp.device == torch.device(device)
        assert qp.dtype == torch.uint8

    def test_hf_clone(self):
        device = get_accelerator().current_device_name()
        data = torch.rand(5, 5, device=device, dtype=torch.bfloat16)

        quantization_config = QuantizationConfig(q_bits=6)
        qp = QuantizedParameter(data, quantization_config=quantization_config)

        # should be able to clone parameter via dict, HF expects this to work
        qp_copy = QuantizedParameter(qp.data, **qp.__dict__)

        assert all(qp.data == qp_copy.data)
        assert qp.quantization_config == qp_copy.quantization_config