Back to Repositories

Testing MPI Communication Performance with OneBit Compression in DeepSpeed

This test suite evaluates the performance of MPI-based communication in DeepSpeed’s onebit compression implementation. It measures latency and throughput of compressed allreduce operations using MPI4py and PyTorch distributed communication.

Test Coverage Overview

The test coverage focuses on measuring performance metrics for compressed allreduce operations in a distributed environment.

Key areas tested include:
  • MPI backend initialization and configuration
  • CUDA-aware MPI communication settings
  • Tensor compression and error handling
  • Latency measurements across multiple iterations

Implementation Analysis

The testing approach implements a systematic performance benchmark using MPI collective operations with DeepSpeed’s compression features. The test utilizes warm-up iterations followed by timed runs to measure communication latency.

Key patterns include:
  • Distributed tensor initialization with rank-based bias
  • Worker and server error buffer management
  • Synchronized wall clock timing measurements

Technical Details

Testing infrastructure includes:
  • MPI4py for distributed communication
  • PyTorch for tensor operations
  • DeepSpeed’s SynchronizedWallClockTimer
  • Configurable tensor sizes and iteration counts
  • CUDA device management and local rank mapping

Best Practices Demonstrated

The test implements several performance testing best practices for distributed systems.

Notable practices include:
  • Proper warm-up period before measurements
  • Statistical analysis of latency (min/max/mean)
  • Controlled tensor initialization to avoid numerical edge cases
  • Proper device and rank management
  • Synchronized timing mechanisms

microsoft/deepspeed

tests/onebit/test_mpi_perf.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from mpi4py import MPI
import torch
import deepspeed

from deepspeed.runtime.comm.mpi import MpiBackend

# Configure wall clock timer
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.accelerator import get_accelerator

from statistics import mean

timers = SynchronizedWallClockTimer()

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)

local_rank = rank % get_accelerator().device_count()
device = torch.device(get_accelerator().device_name(), local_rank)

tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size
right_server_size = right_tensor_size // size

# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank

worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)

warmup = 10
iters = 10

# Warmup
for i in range(warmup):
    backend.compressed_allreduce(a, worker_error, server_error, local_rank)

time_list = []

for i in range(iters):
    timers('compressed_allreduce').start()
    backend.compressed_allreduce(a, worker_error, server_error, local_rank)
    timers('compressed_allreduce').stop()
    time_list.append(timers('compressed_allreduce').elapsed())

timer_names = ['compressed_allreduce']
timers.log(names=timer_names, normalizer=1, memory_breakdown=None)

places = 2
convert = 1e3
float_size = 4

if rank == 0:
    for i in range(iters):
        lat = time_list[i]
        print("latency = ", lat * convert)

minlat = round(min(time_list) * convert)
maxlat = round(max(time_list) * convert)
meanlat = round(mean(time_list) * convert, places)
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))