Back to Repositories

Testing MiCS Configuration and Distributed Training in DeepSpeed

This test suite validates DeepSpeed’s MiCS (Model Parallel with Independent Control Sharding) configuration functionality in a distributed training environment. It focuses on testing model initialization, training, and optimization with different sharding configurations and hierarchical parameter gathering.

Test Coverage Overview

The test suite covers essential aspects of MiCS configuration and model training:
  • MiCS initialization and model sharding configuration
  • Distributed data loading and batch processing
  • Zero optimization stages and parameter management
  • Training loop execution with loss computation and backward passes

Implementation Analysis

The testing approach implements a distributed training scenario using PyTorch and DeepSpeed:
  • Uses SimpleModel class with linear layers for testing
  • Implements custom data loading with DistributedSampler
  • Configures MiCS parameters through command-line arguments
  • Validates model behavior across multiple training steps

Technical Details

Key technical components include:
  • torch.nn modules for model definition
  • deepspeed.zero.MiCS_Init for initialization
  • DistributedSampler for data partitioning
  • Custom configuration management with JSON
  • Multi-GPU testing support (8 GPUs)

Best Practices Demonstrated

The test implements several testing best practices:
  • Proper initialization of distributed environment
  • Controlled random seed management
  • Configurable test parameters
  • Clean separation of model, data, and configuration logic
  • Systematic validation of training steps

microsoft/deepspeed

tests/small_model_debugging/test_mics_config.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Testing on a 8 GPUs node
NDEV_PER_NODE=2 torchrun --nnodes 1 --nproc-per-node 8 test_mics_config.py
"""

import os
import json
import argparse
import torch
import deepspeed
from torch.utils.data.distributed import DistributedSampler
import deepspeed.comm as dist


class SimpleModel(torch.nn.Module):

    def __init__(self, hidden_dim, empty_grad=False):
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        if empty_grad:
            self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, x, y):
        hidden = x
        hidden = self.linear(hidden)
        return self.cross_entropy_loss(hidden, y)


def create_config_from_dict(tmpdir, config_dict):
    config_path = os.path.join(tmpdir, 'temp_config.json')
    with open(config_path, 'w') as fd:
        json.dump(config_dict, fd)
    return config_path


def get_data_loader(model, total_samples, hidden_dim, device):
    batch_size = model.train_micro_batch_size_per_gpu()
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.float)
    train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    sampler = DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    return train_loader


def get_args(tmpdir, config_dict):
    parser = argparse.ArgumentParser()
    parser.add_argument('--zero', type=int, default=3)
    parser.add_argument('--local_rank', type=int)

    parser.add_argument('--mics_shard_size', default=2, type=int)
    parser.add_argument('--mics_hierarchical_params_gather', default=False, action='store_true')
    args = parser.parse_args()  #args=''

    config_dict["zero_optimization"]["stage"] = args.zero
    config_dict["zero_optimization"]["mics_shard_size"] = args.mics_shard_size
    config_dict["zero_optimization"]["mics_hierarchical_params_gather"] = args.mics_hierarchical_params_gather

    # print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
    config_path = create_config_from_dict(tmpdir, config_dict)

    args.deepspeed_config = config_path
    return args


def print0(msg):
    if dist.get_rank() == 0:
        print(msg, flush=True)


rank = int(os.environ['RANK'])
print('seed:', 2222 + rank)
torch.random.manual_seed(2222 + rank)

config_dict = {
    "train_batch_size": 8,
    "steps_per_print": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.00015,
        }
    },
    "fp16": {
        "enabled": False,
        "initial_scale_power": 15
    },
    "zero_optimization": {
        "stage": 3,
        "reduce_bucket_size": 20,
        "mics_shard_size": 4,
        "mics_hierarchical_params_gather": True,
        "stage3_model_persistence_threshold": 10
    }
}
#        "initial_scale_power": 15
args = get_args('/tmp/', config_dict)
hidden_dim = 32

with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
    model = SimpleModel(hidden_dim, empty_grad=False)
# print('------> init model with deepspeed.zero.Init()')

model, _, _, _ = deepspeed.initialize(args=args,
                                      model=model,
                                      model_parameters=model.parameters(),
                                      dist_init_required=True)


def print_params(tag, model):
    if dist.get_rank() == 0:
        for n, p in model.named_parameters():
            print0("{} {}:{}".format(tag, n, p))


data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device)
#print_params('pre-train', model)
for n, batch in enumerate(data_loader):
    loss = model(batch[0], batch[1])
    if dist.get_rank() == 0:
        print("LOSS:", loss.item())
    model.backward(loss)
    model.step()
    #print_params('step={}'.format(n), model)
    if n == 5: break