Back to Repositories

Testing Multi-Output Model Behavior in DeepSpeed

This test suite validates DeepSpeed’s multi-output model functionality, focusing on models that produce multiple outputs and their corresponding loss calculations. The tests verify gradient accumulation, loss scaling, and backward pass behavior for both two-output and three-output model scenarios.

Test Coverage Overview

The test suite provides comprehensive coverage of multi-output model scenarios in DeepSpeed:
  • Two-output model testing with gradient accumulation
  • Three-output model validation with different batch sizes
  • Loss calculation and scaling verification
  • Mixed precision training support (FP16/BF16)

Implementation Analysis

The testing approach implements distributed testing patterns using pytest:
  • Uses DistributedTest class for managing test environments
  • Implements custom data loading with multi_output_dataloader
  • Validates model initialization, forward pass, and backward propagation
  • Verifies loss scaling across different gradient accumulation steps

Technical Details

Key technical components include:
  • PyTest framework with distributed testing support
  • DeepSpeed initialization with custom configurations
  • Dynamic precision handling (FP16/BF16 based on accelerator support)
  • Custom MultiOutputModel implementation
  • Batch size and gradient accumulation configuration

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Proper test isolation and setup
  • Comprehensive assertion checks for expected behaviors
  • Flexible configuration handling
  • Proper resource cleanup
  • Clear separation of test cases for different output scenarios

microsoft/deepspeed

tests/unit/runtime/test_multi_output_model.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
from pytest import approx
from unit.common import DistributedTest, preferred_dtype
from unit.multi_output_model import MultiOutputModel, multi_output_dataloader


class TestTwoOutputModel(DistributedTest):
    world_size = 1

    def test(self, tmpdir):
        grad_accumulation_steps = 2
        micro_batch_size = 1
        world_size = self.world_size
        config_dict = {
            "train_micro_batch_size_per_gpu": micro_batch_size,
            "gradient_accumulation_steps": grad_accumulation_steps,
            "train_batch_size": micro_batch_size * grad_accumulation_steps * world_size,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            },
        }
        if get_accelerator().is_fp16_supported():
            config_dict["fp16"] = {"enabled": True}
        elif get_accelerator().is_bf16_supported():
            config_dict["bf16"] = {"enabled": True}

        hidden_dim = 10
        weight_value = 0.1

        model = MultiOutputModel(hidden_dim, weight_value)
        model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
        total_samples = 4
        data_loader = multi_output_dataloader(model=model,
                                              total_samples=total_samples,
                                              hidden_dim=hidden_dim,
                                              device=model.device,
                                              inputs=[1.0, 2.0],
                                              targets=[1, 2])
        for n, batch in enumerate(data_loader):
            assert len(batch) % 2 == 0, \
                 f"multi_output_dataloader failed to return even number of data samples (input+target)"

            midpoint = len(batch) // 2
            inputs, targets = batch[:midpoint], batch[midpoint:]
            loss_tuple = model(inputs, targets)

            expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device)
            for loss in loss_tuple:
                assert loss.shape == torch.Size([])
                assert loss.item() == approx(expected_loss.item())

            summed_loss = sum(loss_tuple)
            scaled_loss = model.backward(summed_loss)
            expected_scaled_loss = summed_loss.float() / grad_accumulation_steps
            assert scaled_loss.item() == approx(expected_scaled_loss.item())

            model.step()


class TestThreeOutputModel(DistributedTest):
    world_size = 1

    def test(self, tmpdir):
        grad_accumulation_steps = 3
        micro_batch_size = 1
        world_size = 1
        config_dict = {
            "train_micro_batch_size_per_gpu": micro_batch_size,
            "gradient_accumulation_steps": grad_accumulation_steps,
            "train_batch_size": micro_batch_size * grad_accumulation_steps * world_size,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            },
        }
        if get_accelerator().is_fp16_supported():
            config_dict["fp16"] = {"enabled": True}
        elif get_accelerator().is_bf16_supported():
            config_dict["bf16"] = {"enabled": True}

        hidden_dim = 10
        weight_value = 0.1

        model = MultiOutputModel(hidden_dim, weight_value)
        model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())

        total_samples = grad_accumulation_steps * micro_batch_size * 2
        data_loader = multi_output_dataloader(model=model,
                                              total_samples=total_samples,
                                              hidden_dim=hidden_dim,
                                              device=model.device,
                                              inputs=[1.0, 2.0, 3.0],
                                              targets=[1, 2, 3])
        for n, batch in enumerate(data_loader):
            assert len(batch) % 2 == 0, \
                 f"multi_output_dataloader failed to return even number of data samples (input+target)"

            midpoint = len(batch) // 2
            inputs, targets = batch[:midpoint], batch[midpoint:]
            loss_tuple = model(inputs, targets)
            assert len(loss_tuple) == 3

            expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device)

            for loss in loss_tuple:
                assert loss.shape == torch.Size([])
                assert loss.item() == approx(expected_loss.item())

            summed_loss = sum(loss_tuple)
            scaled_loss = model.backward(summed_loss)
            expected_scaled_loss = summed_loss.float() / grad_accumulation_steps
            assert scaled_loss.item() == approx(expected_scaled_loss.item())

            model.step()