Back to Repositories

Testing Hybrid Engine Model State Transitions in DeepSpeed

This test file validates the DeepSpeed Hybrid Engine functionality by testing model initialization, inference, and training modes with the OPT-350M model. It demonstrates memory usage tracking and model state transitions between evaluation and training phases.

Test Coverage Overview

The test suite provides comprehensive coverage of the Hybrid Engine’s core operations:

  • Model initialization with HybridEngine configuration
  • Memory usage monitoring during setup
  • Inference mode execution
  • Training mode execution
  • State transitions between eval and train modes

Implementation Analysis

The testing approach uses the facebook/opt-350M model as a test subject, implementing both inference and training workflows. It leverages DeepSpeed’s initialization patterns with hybrid engine enabled, demonstrating proper configuration handling and model state management.

  • Hybrid Engine initialization verification
  • Memory tracking implementation
  • Model state transition validation

Technical Details

  • Frameworks: DeepSpeed, PyTorch, Transformers
  • Model: facebook/opt-350M
  • Testing Tools: DeepSpeed runtime utilities
  • Configuration: Hybrid Engine enabled via initialize parameters
  • Hardware: CUDA-compatible device

Best Practices Demonstrated

The test exemplifies several testing best practices for deep learning systems:

  • Explicit memory usage tracking
  • Clear model state transitions
  • Proper resource initialization and management
  • Validation of both inference and training paths
  • Integration with argument parsing for configuration

microsoft/deepspeed

tests/hybrid_engine/hybrid_engine_test.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

from transformers import AutoModelForCausalLM
import deepspeed
import argparse
from deepspeed.accelerator import get_accelerator

deepspeed.runtime.utils.see_memory_usage('pre test', force=True)

model = AutoModelForCausalLM.from_pretrained('facebook/opt-350M').half().to(get_accelerator().device_name())
parser = argparse.ArgumentParser()
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()

deepspeed.runtime.utils.see_memory_usage('post test', force=True)

m, _, _, _ = deepspeed.initialize(model=model, args=args, enable_hybrid_engine=True)

m.eval()
input = torch.ones(1, 16, device='cuda', dtype=torch.long)
out = m(input)

m.train()
out = m(input)
print(out['logits'], out['logits'].norm())