Testing NCCL Backend Compressed Communication in DeepSpeed
This test suite validates the NCCL backend implementation in DeepSpeed’s distributed communication system, focusing on compressed allreduce operations and error handling. It verifies the accuracy and reliability of gradient compression and communication across distributed processes.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
microsoft/deepspeed
tests/onebit/test_nccl_backend.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import deepspeed.comm as dist
import numpy as np
import argparse
import deepspeed
import os
from deepspeed.runtime.comm.nccl import NcclBackend
from deepspeed.accelerator import get_accelerator
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
args.local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
size = dist.get_world_size()
rank = dist.get_rank()
backend = NcclBackend()
local_rank = args.local_rank
# A simulated compression function using deepspeed.comm
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
get_accelerator().synchronize()
dist.barrier()
return a_server_compressed, worker_error, server_error
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
get_accelerator().empty_cache()
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
test_correctness = True
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if test_correctness:
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))