Back to Repositories

Testing ZeRO-3 State Offloading Operations in DeepSpeed

This test suite validates the state offloading functionality in DeepSpeed’s ZeRO-3 optimizer, focusing on memory management and state preservation during model training. It tests the ability to offload and reload various model states including parameters, gradients, and optimizer states between GPU and CPU memory.

Test Coverage Overview

The test suite provides comprehensive coverage of DeepSpeed’s state offloading mechanisms.

Key areas tested include:
  • Parameter state offloading (both high and low precision)
  • Optimizer state management
  • Gradient buffer handling
  • Memory allocation verification
  • State restoration accuracy
Edge cases covered include non-blocking transfers and pinned memory scenarios.

Implementation Analysis

The testing approach uses a parametrized pytest framework to systematically verify state offloading operations. The implementation employs the SimpleModel class with configurable hidden dimensions and utilizes DeepSpeed’s ZeRO-3 optimization with BFloat16 precision.

Key patterns include:
  • State device validation
  • Memory allocation tracking
  • Tensor equality verification
  • Distributed testing coordination

Technical Details

Testing tools and configuration:
  • PyTest for test orchestration
  • DeepSpeed ZeRO-3 optimizer
  • Distributed testing with 2 GPUs
  • BFloat16 precision training
  • Custom validation utilities
  • Memory tracking mechanisms

Best Practices Demonstrated

The test suite exemplifies robust testing practices for distributed deep learning systems.

Notable practices include:
  • Systematic state verification
  • Memory leak prevention
  • Comprehensive parameter coverage
  • Proper resource cleanup
  • Explicit device management
  • Deterministic test scenarios

microsoft/deepspeed

tests/unit/runtime/zero/test_offload_states.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest

import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
import torch

from unit.common import DistributedTest
from unit.simple_model import random_dataloader, SimpleModel

import deepspeed
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state
from deepspeed.runtime.zero.offload_states import get_state_devices


def validate_device(model, device: torch.device, include) -> None:

    def compare_device(state) -> bool:
        devices = get_state_devices(model, state)
        return len(devices) == 1 and device in devices

    for state in OffloadStateTypeEnum:
        if include is None or state in include:
            if state == OffloadStateTypeEnum.contiguous_grad_buffer and device == torch.device("cpu"):
                assert len(get_state_devices(model,
                                             state)) == 0, f"State {state} must be removed after offload_states()"
            else:
                assert compare_device(state), f"State {state} is not on device {device}"


def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking):
    # Currently we only support OffloadDeviceEnum.cpu
    offload_device = OffloadDeviceEnum.cpu

    model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
    data_loader = random_dataloader(model=model,
                                    total_samples=10,
                                    hidden_dim=hidden_dim,
                                    device=model.device,
                                    dtype=dtype)
    dist.barrier()
    for batch in data_loader:
        loss = model(batch[0], batch[1])
        model.backward(loss)
        model.step()

        hp_params_expected = [safe_get_local_fp32_param(p).clone() for p in model.parameters()]
        lp_params_expected = [p.ds_tensor.clone() for p in model.parameters()]
        lp_grads_expected = model.optimizer.grad_partitions_flat_buffer.clone()
        adam_exp_avg_expected = [safe_get_local_optimizer_state(p, "exp_avg").clone() for p in model.parameters()]
        adam_exp_avg_sq = [safe_get_local_optimizer_state(p, "exp_avg_sq").clone() for p in model.parameters()]

        # Start offloading
        alloc_before_offload = get_accelerator().memory_allocated()
        model.offload_states(include=include, device=offload_device, pin_memory=pin_memory, non_blocking=non_blocking)
        alloc_after_offload = get_accelerator().memory_allocated()
        assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload"

        validate_device(model, torch.device(offload_device.value), include)

        # Reload states
        model.reload_states()
        assert alloc_after_offload < get_accelerator().memory_allocated(
        ), f"Allocated memory should increase after offload back"

        # Verify restored states
        hp_param_restored = [safe_get_local_fp32_param(p) for p in model.parameters()]
        for hp_param_expected, hp_param_restored in zip(hp_params_expected, hp_param_restored):
            assert torch.equal(hp_param_expected, hp_param_restored)

        lp_param_restored = [p.ds_tensor for p in model.parameters()]

        for lp_param_expected, lp_param_restored in zip(lp_params_expected, lp_param_restored):
            assert torch.equal(lp_param_expected, lp_param_restored)

        assert torch.equal(lp_grads_expected, model.optimizer.grad_partitions_flat_buffer)

        adam_exp_avg_restored = [safe_get_local_optimizer_state(p, "exp_avg") for p in model.parameters()]
        for adam_exp_avg_expected, adam_exp_avg_restored in zip(adam_exp_avg_expected, adam_exp_avg_restored):
            assert torch.equal(adam_exp_avg_expected, adam_exp_avg_restored)

        adam_exp_avg_sq_restored = [safe_get_local_optimizer_state(p, "exp_avg_sq") for p in model.parameters()]
        for adam_exp_avg_sq_expected, adam_exp_avg_sq_restored in zip(adam_exp_avg_sq, adam_exp_avg_sq_restored):
            assert torch.equal(adam_exp_avg_sq_expected, adam_exp_avg_sq_restored)

        validate_device(model, torch.device(get_accelerator().current_device_name()), include)

    # Needed in ZeRO 3. Not doing so can give memory leak
    model.destroy()


@pytest.mark.parametrize("included_state", [
    OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_params, OffloadStateTypeEnum.optim_states,
    OffloadStateTypeEnum.lp_grads, OffloadStateTypeEnum.contiguous_grad_buffer, None
])
@pytest.mark.parametrize("pin_memory", [False, True])
@pytest.mark.parametrize("non_blocking", [False, True])
class TestOffloadStates(DistributedTest):
    # Need multiple gpus to test possible hanging
    world_size = 2

    def test_offload_states(self, included_state, pin_memory, non_blocking):
        hidden_dim = 1024

        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-6
                }
            },
            "zero_optimization": {
                "stage": 3,
            }
        }
        config_dict["bf16"] = {"enabled": True}

        with deepspeed.zero.Init(config_dict_or_path=config_dict):
            model = SimpleModel(hidden_dim, nlayers=4)

        include = None if included_state is None else [included_state]
        run_model(model, config_dict, hidden_dim, torch.bfloat16, include, pin_memory, non_blocking)