Back to Repositories

Testing FP8 Precision Composability with ZeRO Optimization in DeepSpeed

This test suite evaluates FP8 precision composability across different ZeRO optimization stages in DeepSpeed, focusing on transformer engine integration and various data types (FP16, BF16, FP32).

Test Coverage Overview

The test suite comprehensively covers FP8 precision compatibility across multiple ZeRO optimization stages (0-3).

Key functionality includes:
  • Testing with different base data types (FP16, BF16, FP32)
  • Verification of loss consistency across ZeRO stages
  • Integration with transformer engine’s FP8 autocast
  • Validation of gradient accumulation and optimization

Implementation Analysis

The implementation utilizes a parametrized test approach with the DistributedTest framework, enabling systematic validation across different configurations.

Notable patterns include:
  • Dynamic model initialization with configurable data types
  • FP8 recipe configuration with DelayedScaling
  • Consistent batch processing with loss computation
  • Cross-stage loss comparison for validation

Technical Details

Testing infrastructure leverages:
  • PyTest for test organization and execution
  • Transformer Engine for FP8 computation support
  • DeepSpeed initialization with custom configurations
  • Torch random seed management for reproducibility
  • Custom skip_on_arch utility for architecture compatibility

Best Practices Demonstrated

The test implementation showcases several testing best practices in deep learning contexts.

Notable practices include:
  • Parametrized testing for multiple scenarios
  • Proper resource cleanup and initialization
  • Controlled random state management
  • Precise numerical comparison with appropriate tolerances
  • Clear separation of configuration and test logic

microsoft/deepspeed

tests/unit/runtime/half_precision/test_fp8.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import deepspeed
import pytest
from unit.common import DistributedTest
from unit.util import skip_on_arch

try:
    import transformer_engine.pytorch as transformer_engine
    from transformer_engine.common import recipe
except ImportError:
    pytest.skip("Transformer Engine package is missing, skipping tests", allow_module_level=True)


@pytest.mark.parametrize("base_datatype", ["fp16", "bf16", "fp32"])
class TestFp8ComposabilityAcrossZero(DistributedTest):
    world_size = 1

    def test(self, base_datatype):
        skip_on_arch(min_arch=9)

        def run_zero(stage, model_dtype):
            num_batches = 128
            batch_size = 16
            hidden_dim = 768
            # Have to set seed before model
            torch.random.manual_seed(42)
            enable_fp16 = model_dtype == torch.float16
            enable_bf16 = model_dtype == torch.bfloat16
            # TransformerEngine Model
            model = transformer_engine.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype)

            # Create FP8 recipe. Note: All input args are optional.
            fp8_recipe = recipe.DelayedScaling(fp8_format=recipe.Format.HYBRID,
                                               amax_history_len=16,
                                               amax_compute_algo="max")
            config = {
                "train_batch_size": batch_size,
                "gradient_accumulation_steps": 1,
                "optimizer": {
                    "type": "Adam",
                    "params": {
                        "lr": 0.00001
                    }
                },
                "zero_optimization": {
                    "stage": stage
                },
                "fp16": {
                    "enabled": enable_fp16,
                    "loss_scale": 0.1
                },
                "bf16": {
                    "enabled": enable_bf16
                }
            }
            # Init DeepSpeed
            model, optimizer, _, _ = deepspeed.initialize(args=None,
                                                          model=model,
                                                          model_parameters=model.parameters(),
                                                          config=config)

            batches = torch.randn(num_batches, batch_size, hidden_dim, device=model.device, dtype=model_dtype)
            for batch in batches:
                # Enables autocasting for the forward pass
                with transformer_engine.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                    out = model(batch)
                loss = out.mean()
                model.backward(loss)
                model.step()
            return loss

        if base_datatype == "fp16":
            model_dtype = torch.float16
        elif base_datatype == "bf16":
            model_dtype = torch.bfloat16
        else:
            model_dtype = torch.float32

        # config
        zero_stage = [0, 1, 2, 3]
        losses = []
        for stage in zero_stage:
            loss = run_zero(stage, model_dtype)
            losses.append(loss)
        all_equal = all(torch.allclose(loss, losses[0], 1e-07, 1e-05) for loss in losses)
        assert (all_equal)