Back to Repositories

Testing ZeRO Optimization Compilation Workflows in DeepSpeed

This test suite validates DeepSpeed’s ZeRO optimization stages with PyTorch compilation support. It tests different combinations of precision types, ZeRO stages, and offload devices to ensure proper functionality of the DeepSpeed ZeRO feature when used with PyTorch 2.1+ compilation capabilities.

Test Coverage Overview

The test suite provides comprehensive coverage of ZeRO optimization configurations:
  • Tests all three ZeRO stages (1, 2, and 3)
  • Supports multiple precision types (bfloat16, float16, float32)
  • Validates different offload devices (none, CPU, NVMe)
  • Ensures compatibility with PyTorch compilation features

Implementation Analysis

The implementation uses pytest’s parametrize feature to create a test matrix across different configurations. The test class inherits from DistributedTest with world_size=2 for multi-GPU testing scenarios.
  • Implements systematic parameter combinations
  • Handles architecture-specific requirements for bfloat16
  • Includes proper skip conditions for unsupported configurations

Technical Details

Testing infrastructure includes:
  • PyTorch 2.1+ requirement check
  • NCCL version verification for bfloat16
  • Dynamic configuration dictionary generation
  • Temporary directory handling for NVMe offload
  • Custom loss comparison utilities

Best Practices Demonstrated

The test implementation showcases several testing best practices:
  • Proper test isolation and configuration
  • Comprehensive parameter coverage
  • Hardware compatibility checks
  • Clean skip conditions for unsupported scenarios
  • Structured configuration management

microsoft/deepspeed

tests/unit/runtime/compile/test_compile_zero.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch

from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator

from unit.runtime.compile.util import compare_loss
from unit.common import DistributedTest
from unit.util import bf16_required_version_check, skip_on_arch

pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1),
                                reason="Compile tests requires Pytorch version 2.1 or above")


class TestZeRO(DistributedTest):
    world_size = 2
    non_daemonic_procs = True

    @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
    @pytest.mark.parametrize('zero_stage', [1, 2, 3])
    @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme])
    def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
        if dtype == torch.bfloat16:
            skip_on_arch(min_arch=8)
        if dtype == torch.bfloat16 and not bf16_required_version_check():
            pytest.skip(
                "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly"
            )
        if get_accelerator().device_name() == "cpu":
            pytest.skip("CPU does not support this test yet")

        if offload_device == OffloadDeviceEnum.nvme:
            if zero_stage != 3:
                pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}")

        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            },
            "zero_optimization": {
                "stage": zero_stage,
            }
        }

        if offload_device == OffloadDeviceEnum.cpu:
            config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
        elif offload_device == OffloadDeviceEnum.nvme:
            config_dict["zero_optimization"]["offload_optimizer"] = {
                "device": offload_device,
                "nvme_path": str(tmpdir)
            }
        if dtype == torch.float16:
            config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
        elif dtype == torch.bfloat16:
            config_dict["bf16"] = {"enabled": True}

        compare_loss(self, config_dict, dtype)