Back to Repositories

Validating Pipeline Parallel Topology Management in DeepSpeed

This test suite validates the topology management and process grid functionality in DeepSpeed’s pipeline parallel implementation. It verifies correct process coordination, rank mapping, and communication patterns across multi-dimensional parallel configurations.

Test Coverage Overview

The test suite provides comprehensive coverage of DeepSpeed’s topology and process grid features:
  • 2D and 3D topology configurations and rank mappings
  • Process coordinate systems and dimensional transformations
  • Pipeline parallel group formation and validation
  • Data parallel collective communications
  • Edge cases for invalid configurations and prime factorizations

Implementation Analysis

The testing approach employs both unit and distributed test patterns:
  • Systematic validation of topology construction and coordinate mapping
  • Verification of process group formation using PyTest fixtures
  • Integration with DeepSpeed’s communication primitives
  • Distributed testing with controlled world sizes

Technical Details

Key technical components include:
  • PyTest framework with distributed test support
  • PyTorch distributed communication backend
  • DeepSpeed topology and process grid implementations
  • Custom test fixtures for distributed environments
  • Accelerator-aware tensor operations

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Systematic decomposition of complex topology configurations
  • Comprehensive validation of dimensional transformations
  • Explicit error case handling
  • Clear separation of unit and distributed tests
  • Efficient test organization with proper setup/teardown

microsoft/deepspeed

tests/unit/runtime/pipe/test_topology.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest

import torch
import deepspeed.comm as dist

from deepspeed.runtime.pipe.topology import PipelineParallelGrid as Grid
from deepspeed.runtime.pipe.topology import ProcessTopology as Topo
from deepspeed.runtime.pipe.topology import _prime_factors

from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest


def test_topology_2d():
    topo = Topo(axes=['row', 'col'], dims=[2, 2])

    assert topo.world_size() == 4

    assert topo.get_rank(row=0, col=0) == 0
    assert topo.get_rank(row=0, col=1) == 1
    assert topo.get_rank(row=1, col=0) == 2
    assert topo.get_rank(row=1, col=1) == 3

    assert topo.get_axis_list(axis='row', idx=0) == [0, 1]
    assert topo.get_axis_list(axis='row', idx=1) == [2, 3]
    assert topo.get_axis_list(axis='col', idx=0) == [0, 2]
    assert topo.get_axis_list(axis='col', idx=1) == [1, 3]


def test_topology_dims():
    topo = Topo(axes=['a', 'b', 'c'], dims=[2, 3, 4])
    assert topo.world_size() == 24
    assert topo.get_dim('a') == 2
    assert topo.get_dim('b') == 3
    assert topo.get_dim('c') == 4


def test_topology_match():
    topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
    print(topo.filter_match(pipe=0, data=1))
    assert topo.filter_match(pipe=0, data=1) == [2, 3]
    print([topo.get_coord(r) for r in topo.filter_match(pipe=0, data=1)])


def test_topology_rank_repr():
    topo = Topo(axes=['a', 'b'], dims=[2, 2])
    assert topo.get_rank_repr(rank=0) == 'a_00-b_00'
    assert topo.get_rank_repr(rank=1) == 'a_00-b_01'
    assert topo.get_rank_repr(rank=2) == 'a_01-b_00'
    assert topo.get_rank_repr(rank=3) == 'a_01-b_01'

    assert topo.get_rank_repr(rank=3, inner_sep='+') == 'a+01-b+01'
    assert topo.get_rank_repr(rank=3, inner_sep='🤗', outer_sep='_JEFF_') == 'a🤗01_JEFF_b🤗01'

    topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
    assert topo.get_rank_repr(rank=0) == ''
    assert topo.get_rank_repr(rank=1) == ''
    assert topo.get_rank_repr(rank=2) == ''
    assert topo.get_rank_repr(rank=3) == ''

    assert topo.get_rank_repr(rank=0, omit_axes=['pipe']) == 'data_00'
    assert topo.get_rank_repr(rank=1, omit_axes=['pipe']) == 'data_01'
    assert topo.get_rank_repr(rank=2, omit_axes=['pipe']) == 'data_00'
    assert topo.get_rank_repr(rank=3, omit_axes=['pipe']) == 'data_01'

    assert topo.get_rank_repr(rank=0, omit_axes=[]) == 'pipe_00-data_00'
    assert topo.get_rank_repr(rank=1, omit_axes=[]) == 'pipe_00-data_01'
    assert topo.get_rank_repr(rank=2, omit_axes=[]) == 'pipe_01-data_00'
    assert topo.get_rank_repr(rank=3, omit_axes=[]) == 'pipe_01-data_01'

    topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
    assert topo.get_rank_repr(rank=0) == 'model_00'
    assert topo.get_rank_repr(rank=1) == 'model_01'
    assert topo.get_rank_repr(rank=2) == 'model_00'
    assert topo.get_rank_repr(rank=3) == 'model_01'
    assert topo.get_rank_repr(rank=4) == 'model_00'
    assert topo.get_rank_repr(rank=5) == 'model_01'
    assert topo.get_rank_repr(rank=6) == 'model_00'
    assert topo.get_rank_repr(rank=7) == 'model_01'


def test_topology_3d():
    topo = Topo(axes=['a', 'b', 'c'], dims=[2, 2, 2])

    assert topo.get_rank(a=0, b=0, c=0) == 0
    assert topo.get_rank(a=0, b=0, c=1) == 1
    assert topo.get_rank(a=0, b=1, c=0) == 2
    assert topo.get_rank(a=0, b=1, c=1) == 3
    assert topo.get_rank(a=1, b=0, c=0) == 4
    assert topo.get_rank(a=1, b=0, c=1) == 5
    assert topo.get_rank(a=1, b=1, c=0) == 6
    assert topo.get_rank(a=1, b=1, c=1) == 7

    assert topo.get_axis_list('a', 0) == [0, 1, 2, 3]
    assert topo.get_axis_list('a', 1) == [4, 5, 6, 7]
    assert topo.get_axis_list('b', 0) == [0, 1, 4, 5]
    assert topo.get_axis_list('b', 1) == [2, 3, 6, 7]
    assert topo.get_axis_list('c', 0) == [0, 2, 4, 6]
    assert topo.get_axis_list('c', 1) == [1, 3, 5, 7]

    assert topo.get_coord(0) == topo.ProcessCoord(0, 0, 0)
    assert topo.get_coord(1) == topo.ProcessCoord(0, 0, 1)
    assert topo.get_coord(2) == topo.ProcessCoord(0, 1, 0)
    assert topo.get_coord(3) == topo.ProcessCoord(0, 1, 1)
    assert topo.get_coord(4) == topo.ProcessCoord(1, 0, 0)
    assert topo.get_coord(5) == topo.ProcessCoord(1, 0, 1)
    assert topo.get_coord(6) == topo.ProcessCoord(1, 1, 0)
    assert topo.get_coord(7) == topo.ProcessCoord(1, 1, 1)

    assert topo.filter_match(a=0) == [0, 1, 2, 3]
    assert topo.filter_match(b=1, c=1) == [3, 7]
    assert topo.filter_match(a=1, b=1, c=1) == [7]

    # Easy access method
    assert topo.get_coord(0).a == 0


def test_topology_comm_list():
    topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])

    assert topo.get_rank(pipe=0, data=0, model=0) == 0
    assert topo.get_rank(pipe=0, data=0, model=1) == 1
    assert topo.get_rank(pipe=0, data=1, model=0) == 2
    assert topo.get_rank(pipe=0, data=1, model=1) == 3
    assert topo.get_rank(pipe=1, data=0, model=0) == 4
    assert topo.get_rank(pipe=1, data=0, model=1) == 5
    assert topo.get_rank(pipe=1, data=1, model=0) == 6
    assert topo.get_rank(pipe=1, data=1, model=1) == 7

    pipe_list = [
        [0, 4],  # data=0, model=0
        [1, 5],  # data=0, model=1
        [2, 6],  # data=1, model=0
        [3, 7],  # data=1, model=1
    ]
    assert topo.get_axis_comm_lists('pipe') == pipe_list

    data_list = [
        [0, 2],  # pipe=0, model=0
        [1, 3],  # pipe=0, model=1
        [4, 6],  # pipe=1, model=0
        [5, 7],  # pipe=1, model=1
    ]
    assert topo.get_axis_comm_lists('data') == data_list

    model_list = [
        [0, 1],  # pipe=0, data=0
        [2, 3],  # pipe=0, data=1
        [4, 5],  # pipe=1, data=0
        [6, 7],  # pipe=1, data=1
    ]
    assert topo.get_axis_comm_lists('model') == model_list

    # Handle nonsense. We don't want to RuntimeError because it allows us to write more
    # generalized code for data/model/pipe parallelism
    assert topo.get_axis_comm_lists('jeff') == []


class TestDistributedTopology(DistributedTest):
    world_size = 4

    def test_grid_pipe_data(self):
        topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
        grid = Grid(topology=topo)

        assert grid._is_grid_valid()

        rank = dist.get_rank()

        assert grid.is_first_stage == (grid.get_stage_id() == 0)
        assert grid.is_last_stage == (grid.get_stage_id() == grid.get_pipe_parallel_world_size() - 1)

        # Test collectives along the pipeline parallel process groups
        rank_tensor = torch.LongTensor(data=[rank]).to(get_accelerator().device_name())
        dist.all_reduce(rank_tensor, group=grid.get_pipe_parallel_group())
        pipe_group = grid.pp_group
        assert torch.all(rank_tensor == sum(pipe_group))

        # Test collectives along the data parallel process groups
        rank_tensor = torch.LongTensor(data=[rank]).to(get_accelerator().device_name())
        dist.all_reduce(rank_tensor, group=grid.get_data_parallel_group())
        data_group = grid.dp_group
        assert torch.all(rank_tensor == sum(data_group))

    def test_stage_to_global(self):
        topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
        grid = Grid(topology=topo)

        assert grid._is_grid_valid()

        assert grid.stage_to_global(stage_id=0, data=0) == 0
        assert grid.stage_to_global(stage_id=0, data=1) == 1
        assert grid.stage_to_global(stage_id=1, data=0) == 2
        assert grid.stage_to_global(stage_id=1, data=1) == 3

        me = topo.get_coord(rank=dist.get_rank())
        if me.data == 0:
            assert grid.stage_to_global(stage_id=0) == 0
            assert grid.stage_to_global(stage_id=1) == 2
        else:
            assert grid.stage_to_global(stage_id=0) == 1
            assert grid.stage_to_global(stage_id=1) == 3


def test_primes():
    """ Test prime factorizations. """

    def _product(ps):
        p = 1
        for num in ps:
            p *= num
        return p

    with pytest.raises(ValueError):
        _prime_factors(0)

    for x in range(1, 30):
        primes = _prime_factors(x)
        assert _product(primes) == x
        for p in primes:
            assert _prime_factors(p) == [p]