Back to Repositories

Testing Memory Management for ZeroStage3 Linear Operations in DeepSpeed

This test suite validates memory management and linear module functionality in DeepSpeed’s ZeroStage3 implementation. It focuses on monitoring memory allocation patterns during forward and backward passes while using half-precision tensors.

Test Coverage Overview

The test suite covers critical memory management aspects of DeepSpeed’s LinearModuleForZeroStage3.

Key areas tested include:
  • Memory allocation tracking during model operations
  • Half-precision tensor handling
  • Forward and backward pass memory patterns
  • Weight zeroing impact on memory usage

Implementation Analysis

The testing approach utilizes direct memory monitoring through DeepSpeed’s accelerator interface. It implements a systematic workflow that tracks memory allocation at multiple checkpoints:

  • Custom memory usage tracking function implementation
  • Linear module initialization and conversion to half precision
  • Sequential memory snapshots during operations

Technical Details

Testing components and configuration:
  • PyTorch tensor operations with half precision (float16)
  • DeepSpeed’s LinearModuleForZeroStage3 implementation
  • Custom memory usage tracking utilizing accelerator interface
  • Tensor dimensions: 1024×16384
  • Device-agnostic testing through accelerator abstraction

Best Practices Demonstrated

The test exemplifies robust debugging practices for deep learning memory management:

  • Systematic memory tracking at critical execution points
  • Proper tensor cleanup and memory management
  • Consistent logging and monitoring patterns
  • Hardware-agnostic implementation through abstraction layers

microsoft/deepspeed

tests/small_model_debugging/test.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from deepspeed.pt.deepspeed_linear import LinearModuleForZeroStage3
from deepspeed.pt.log_utils import logger
from deepspeed.accelerator import get_accelerator


def see_memory_usage(message):

    # Print message except when distributed but not rank 0
    logger.info(message)
    logger.info(
        "Memory Allocated %s GigaBytes ",
        get_accelerator().memory_allocated() / (1024 * 1024 * 1024),
    )
    logger.info(
        "Max Memory Allocated %s GigaBytes",
        get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),
    )
    logger.info(
        "Cache Allocated %s GigaBytes",
        get_accelerator().memory_cached() / (1024 * 1024 * 1024),
    )
    logger.info(
        "Max cache Allocated %s GigaBytes",
        get_accelerator().max_memory_cached() / (1024 * 1024 * 1024),
    )


tens = torch.rand(1024, 16384, dtype=torch.half, device=torch.device(get_accelerator().device_name()))
tens_back = tens.detach().clone()

#linear_bk = torch.nn.functional.linear
#torch.nn.functional.linear = deepspeed.pt.deepspeed_linear.LinearFunctionForZeroStage3.apply
model = LinearModuleForZeroStage3(16384, 16384)

model.to(get_accelerator().device_name()).half()

see_memory_usage("Before forward")
y = model(tens)

see_memory_usage("After forward")

model.weight.data = torch.zeros(1, dtype=torch.half, device=torch.device(get_accelerator().device_name()))

see_memory_usage("After weight zero")

y.backward(tens_back)