Validating Pipeline Schedule Operations in DeepSpeed
This test suite validates the pipeline scheduling functionality in DeepSpeed, focusing on inference and training schedules across multiple pipeline stages. It ensures proper handling of micro-batches, data flow between stages, and schedule coordination.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
microsoft/deepspeed
tests/unit/runtime/pipe/test_pipe_schedule.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import deepspeed.runtime.pipe.schedule as schedule
def _count_type(cmds, classtype):
return len(list(filter(lambda c: type(c) == classtype, cmds)))
def test_pipe_inference_schedule_singlestage():
sched = schedule.InferenceSchedule(micro_batches=4, stages=1, stage_id=0)
assert sched.num_micro_batches == 4
full = list(iter(sched))
for idx, cmds in enumerate(full):
assert len(cmds) == 2
assert type(cmds[0]) == schedule.LoadMicroBatch
assert type(cmds[1]) == schedule.ForwardPass
assert cmds[0].buffer_id == cmds[1].buffer_id
assert len(full) == sched.num_micro_batches
def test_pipe_train_schedule_singlestage():
sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0)
assert sched.num_micro_batches == 4
full = list(iter(sched))
for idx, cmds in enumerate(full):
if (idx % 2) != 0:
assert (len(cmds) == 1) or (len(cmds) == 4)
assert type(cmds[0]) == schedule.BackwardPass
else:
assert len(cmds) == 2
assert type(cmds[0]) == schedule.LoadMicroBatch
assert type(cmds[1]) == schedule.ForwardPass
assert cmds[0].buffer_id == cmds[1].buffer_id
assert len(full) == sched.num_micro_batches * 2
@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_firststage(micro_batches, stages=3):
sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=stages, stage_id=0)
assert sched.num_micro_batches == micro_batches
full = list(iter(sched))
for idx, cmds in enumerate(full):
# Ensure we don't send an activation the first step
if idx == 0:
assert len(cmds) == 2
assert type(cmds[0]) == schedule.LoadMicroBatch
assert type(cmds[1]) == schedule.ForwardPass
assert cmds[0].buffer_id == cmds[1].buffer_id
continue
# the last active step is only a send
if idx == sched.num_micro_batches:
assert len(cmds) == 1
assert type(cmds[0]) == schedule.SendActivation
continue
# no work later on
if idx > sched.num_micro_batches:
assert len(cmds) == 0
continue
# Normally we need to load/forward/send
assert len(cmds) == 3
assert _count_type(cmds, schedule.LoadMicroBatch) == 1
assert _count_type(cmds, schedule.ForwardPass) == 1
assert _count_type(cmds, schedule.SendActivation) == 1
assert len(full) == micro_batches + stages - 1
@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_midstage(micro_batches, stages=3):
sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=stages, stage_id=1)
full = list(iter(sched))
for idx, cmds in enumerate(full):
if idx < sched.stage:
assert len(cmds) == 0
continue
if idx == sched.stage + sched.num_micro_batches:
assert len(cmds) == 1
assert type(cmds[0]) == schedule.SendActivation
continue
if idx > sched.stage + sched.num_micro_batches:
assert len(cmds) == 0
continue
assert _count_type(cmds, schedule.LoadMicroBatch) == 0
assert _count_type(cmds, schedule.ForwardPass) == 1
assert _count_type(cmds, schedule.RecvActivation) == 1
if idx > sched.stage:
assert _count_type(cmds, schedule.SendActivation) == 1
assert len(full) == micro_batches + stages - 1
@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_laststage(micro_batches, stages=3):
sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=stages, stage_id=2)
full = list(iter(sched))
for idx, cmds in enumerate(full):
if idx < sched.stage or idx > sched.stage + sched.num_micro_batches:
assert len(cmds) == 0
continue
assert _count_type(cmds, schedule.LoadMicroBatch) == 1
assert _count_type(cmds, schedule.ForwardPass) == 1
assert _count_type(cmds, schedule.RecvActivation) == 1
assert _count_type(cmds, schedule.SendActivation) == 0
assert len(full) == micro_batches + stages - 1
def test_pipe_schedule_firststage():
sched = schedule.TrainSchedule(micro_batches=8, stages=3, stage_id=0)
for cmds in sched:
assert all(instr.__class__ != schedule.SendGrad for instr in cmds)
assert all(instr.__class__ != schedule.RecvActivation for instr in cmds)
for instr in cmds:
if isinstance(instr, schedule.BufferOpInstruction):
assert 0 <= instr.buffer_id < sched.num_pipe_buffers()
def test_pipe_schedule_laststage():
sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
assert len(list(iter(sched))) == 2 * (sched.micro_batches + sched.stages - 1)
for cmds in sched:
assert all(instr.__class__ != schedule.SendActivation for instr in cmds)
assert all(instr.__class__ != schedule.RecvGrad for instr in cmds)
def test_pipe_stagequery():
sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=0)
assert sched.is_first_stage
assert not sched.is_last_stage
sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=1)
assert not sched.is_first_stage
assert not sched.is_last_stage
sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
assert not sched.is_first_stage
assert sched.is_last_stage