Back to Repositories

Testing Multi-Node Runner Implementation in DeepSpeed

This test suite validates the multi-node runner functionality in DeepSpeed, focusing on different distributed computing implementations including PDSH, OpenMPI, MPICH, SLURM, and MVAPICH. The tests ensure proper command generation and environment configuration for distributed training scenarios.

Test Coverage Overview

The test suite provides comprehensive coverage of multiple distributed runtime environments:
  • PDSH command generation and SSH configuration
  • OpenMPI runner with network interface customization
  • MPICH basic functionality validation
  • SLURM integration testing
  • MVAPICH runner command verification

Implementation Analysis

The testing approach uses fixture-based setup to maintain consistent test environments across different runners. Each test validates specific aspects of command generation and environment configuration using pytest’s modular framework features.

The implementation employs parametrized testing patterns with detailed assertions for command structure and environment variables.

Technical Details

Testing tools and configuration:
  • pytest as the primary testing framework
  • Runner fixture providing standardized test setup
  • Environment variable management through os.environ
  • Deep copy utilities for environment isolation
  • Custom argument parsing for launcher configuration

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Isolation of test cases using fixtures
  • Consistent environment setup and teardown
  • Specific assertion checks for command generation
  • Clear separation of concerns between different runner implementations
  • Comprehensive coverage of edge cases in network interface selection

microsoft/deepspeed

tests/unit/launcher/test_multinode_runner.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from copy import deepcopy
from deepspeed.launcher import multinode_runner as mnrunner
from deepspeed.launcher.runner import encode_world_info, parse_args
import os
import pytest


@pytest.fixture
def runner_info():
    hosts = {'worker-0': 4, 'worker-1': 4}
    world_info = encode_world_info(hosts)
    env = deepcopy(os.environ)
    args = parse_args(['test_launcher.py'])
    return env, hosts, world_info, args


def test_pdsh_runner(runner_info):
    env, resource_pool, world_info, args = runner_info
    runner = mnrunner.PDSHRunner(args, world_info)
    cmd, kill_cmd, env = runner.get_cmd(env, resource_pool)
    assert cmd[0] == 'pdsh'
    assert env['PDSH_RCMD_TYPE'] == 'ssh'


def test_openmpi_runner(runner_info):
    env, resource_pool, world_info, args = runner_info
    runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
    cmd = runner.get_cmd(env, resource_pool)
    assert cmd[0] == 'mpirun'
    assert 'eth0' in cmd


def test_btl_nic_openmpi_runner(runner_info):
    env, resource_pool, world_info, _ = runner_info
    args = parse_args(['--launcher_arg', '-mca btl_tcp_if_include eth1', 'test_launcher.py'])

    runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
    cmd = runner.get_cmd(env, resource_pool)
    assert 'eth0' not in cmd
    assert 'eth1' in cmd


def test_btl_nic_two_dashes_openmpi_runner(runner_info):
    env, resource_pool, world_info, _ = runner_info
    args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

    runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
    cmd = runner.get_cmd(env, resource_pool)
    assert 'eth0' not in cmd
    assert 'eth1' in cmd


def test_mpich_runner(runner_info):
    env, resource_pool, world_info, args = runner_info
    runner = mnrunner.MPICHRunner(args, world_info, resource_pool)
    cmd = runner.get_cmd(env, resource_pool)
    assert cmd[0] == 'mpirun'


def test_slurm_runner(runner_info):
    env, resource_pool, world_info, args = runner_info
    runner = mnrunner.SlurmRunner(args, world_info, resource_pool)
    cmd = runner.get_cmd(env, resource_pool)
    assert cmd[0] == 'srun'


def test_mvapich_runner(runner_info):
    env, resource_pool, world_info, args = runner_info
    runner = mnrunner.MVAPICHRunner(args, world_info, resource_pool)
    cmd = runner.get_cmd(env, resource_pool)
    assert cmd[0] == 'mpirun'