Back to Repositories

Validating Pipeline Schedule Operations in DeepSpeed

This test suite validates the pipeline scheduling functionality in DeepSpeed, focusing on inference and training schedules across multiple pipeline stages. It ensures proper handling of micro-batches, data flow between stages, and schedule coordination.

Test Coverage Overview

The test suite provides comprehensive coverage of pipeline scheduling mechanisms in DeepSpeed.

Key areas tested include:
  • Single-stage inference and training schedules
  • Multi-stage pipeline configurations
  • First, middle, and last stage behaviors
  • Micro-batch handling across different batch sizes
  • Stage-specific operations and transitions

Implementation Analysis

The testing approach employs parametrized testing using pytest to validate schedule behaviors across various configurations.

Key patterns include:
  • Buffer operation verification
  • Command sequence validation
  • Stage-specific instruction checking
  • Schedule iteration validation

Technical Details

Testing infrastructure includes:
  • pytest framework for test organization
  • DeepSpeed pipeline scheduling module
  • Custom helper functions for command type counting
  • Parametrized test cases for different micro-batch sizes
  • Stage-specific schedule implementations

Best Practices Demonstrated

The test suite exemplifies robust testing practices through:
  • Systematic validation of edge cases
  • Comprehensive stage boundary testing
  • Parametrized test scenarios
  • Clear separation of test cases
  • Efficient test helper utilities

microsoft/deepspeed

tests/unit/runtime/pipe/test_pipe_schedule.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import deepspeed.runtime.pipe.schedule as schedule


def _count_type(cmds, classtype):
    return len(list(filter(lambda c: type(c) == classtype, cmds)))


def test_pipe_inference_schedule_singlestage():
    sched = schedule.InferenceSchedule(micro_batches=4, stages=1, stage_id=0)
    assert sched.num_micro_batches == 4
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        assert len(cmds) == 2
        assert type(cmds[0]) == schedule.LoadMicroBatch
        assert type(cmds[1]) == schedule.ForwardPass
        assert cmds[0].buffer_id == cmds[1].buffer_id
    assert len(full) == sched.num_micro_batches


def test_pipe_train_schedule_singlestage():
    sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0)
    assert sched.num_micro_batches == 4
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        if (idx % 2) != 0:
            assert (len(cmds) == 1) or (len(cmds) == 4)
            assert type(cmds[0]) == schedule.BackwardPass
        else:
            assert len(cmds) == 2
            assert type(cmds[0]) == schedule.LoadMicroBatch
            assert type(cmds[1]) == schedule.ForwardPass
            assert cmds[0].buffer_id == cmds[1].buffer_id
    assert len(full) == sched.num_micro_batches * 2


@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_firststage(micro_batches, stages=3):
    sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=stages, stage_id=0)
    assert sched.num_micro_batches == micro_batches
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        # Ensure we don't send an activation the first step
        if idx == 0:
            assert len(cmds) == 2
            assert type(cmds[0]) == schedule.LoadMicroBatch
            assert type(cmds[1]) == schedule.ForwardPass
            assert cmds[0].buffer_id == cmds[1].buffer_id
            continue

        # the last active step is only a send
        if idx == sched.num_micro_batches:
            assert len(cmds) == 1
            assert type(cmds[0]) == schedule.SendActivation
            continue

        # no work later on
        if idx > sched.num_micro_batches:
            assert len(cmds) == 0
            continue

        # Normally we need to load/forward/send
        assert len(cmds) == 3
        assert _count_type(cmds, schedule.LoadMicroBatch) == 1
        assert _count_type(cmds, schedule.ForwardPass) == 1
        assert _count_type(cmds, schedule.SendActivation) == 1
    assert len(full) == micro_batches + stages - 1


@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_midstage(micro_batches, stages=3):
    sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=stages, stage_id=1)

    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        if idx < sched.stage:
            assert len(cmds) == 0
            continue
        if idx == sched.stage + sched.num_micro_batches:
            assert len(cmds) == 1
            assert type(cmds[0]) == schedule.SendActivation
            continue
        if idx > sched.stage + sched.num_micro_batches:
            assert len(cmds) == 0
            continue
        assert _count_type(cmds, schedule.LoadMicroBatch) == 0
        assert _count_type(cmds, schedule.ForwardPass) == 1
        assert _count_type(cmds, schedule.RecvActivation) == 1
        if idx > sched.stage:
            assert _count_type(cmds, schedule.SendActivation) == 1
    assert len(full) == micro_batches + stages - 1


@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_laststage(micro_batches, stages=3):
    sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=stages, stage_id=2)
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        if idx < sched.stage or idx > sched.stage + sched.num_micro_batches:
            assert len(cmds) == 0
            continue
        assert _count_type(cmds, schedule.LoadMicroBatch) == 1
        assert _count_type(cmds, schedule.ForwardPass) == 1
        assert _count_type(cmds, schedule.RecvActivation) == 1
        assert _count_type(cmds, schedule.SendActivation) == 0
    assert len(full) == micro_batches + stages - 1


def test_pipe_schedule_firststage():
    sched = schedule.TrainSchedule(micro_batches=8, stages=3, stage_id=0)
    for cmds in sched:
        assert all(instr.__class__ != schedule.SendGrad for instr in cmds)
        assert all(instr.__class__ != schedule.RecvActivation for instr in cmds)
        for instr in cmds:
            if isinstance(instr, schedule.BufferOpInstruction):
                assert 0 <= instr.buffer_id < sched.num_pipe_buffers()


def test_pipe_schedule_laststage():
    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
    assert len(list(iter(sched))) == 2 * (sched.micro_batches + sched.stages - 1)
    for cmds in sched:
        assert all(instr.__class__ != schedule.SendActivation for instr in cmds)
        assert all(instr.__class__ != schedule.RecvGrad for instr in cmds)


def test_pipe_stagequery():
    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=0)
    assert sched.is_first_stage
    assert not sched.is_last_stage

    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=1)
    assert not sched.is_first_stage
    assert not sched.is_last_stage

    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
    assert not sched.is_first_stage
    assert sched.is_last_stage