Back to Repositories

Testing No-Sync Context Manager Implementation in DeepSpeed

This test suite validates the no_sync context manager functionality in DeepSpeed, focusing on gradient synchronization control across different Zero stages and data types. It ensures proper handling of distributed training scenarios with various optimization configurations.

Test Coverage Overview

The test suite provides comprehensive coverage of the no_sync context manager across different Zero optimization stages (0-3) and data types (float16, bfloat16, float32).

Key functionality tested includes:
  • Zero stage compatibility validation
  • Engine step behavior within no_sync context
  • Multiple context manager interactions
  • Context manager reentry prevention

Implementation Analysis

The testing approach uses pytest’s parametrize feature to systematically test combinations of Zero stages and data types.

Testing patterns include:
  • Configuration validation checks
  • Gradient accumulation verification
  • Loss comparison across multiple contexts
  • Exception handling for invalid operations

Technical Details

Testing tools and configuration:
  • PyTest framework with distributed test support
  • DeepSpeed initialization with various optimization stages
  • Custom SimpleModel and random_dataloader utilities
  • Distributed barrier synchronization
  • Gradient norm calculations and comparisons

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Systematic parameter variation using pytest.mark.parametrize
  • Proper exception handling and validation
  • Distributed test synchronization
  • Comprehensive edge case coverage
  • Clear test organization and separation of concerns

microsoft/deepspeed

tests/unit/runtime/test_no_sync_ctxt.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest

from contextlib import nullcontext
import torch

from unit.simple_model import SimpleModel, random_dataloader
from unit.common import DistributedTest

import deepspeed
import deepspeed.comm as dist
from deepspeed.utils import safe_get_full_grad


class TestNoSyncCtxt(DistributedTest):
    world_size = 2

    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
    @pytest.mark.parametrize("zero_stage", [0, 1, 2, 3])
    def test_zero_stage(self, zero_stage, dtype):
        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "zero_optimization": {
                "stage": zero_stage,
            },
        }

        invalid_cfg = zero_stage > 1
        if dtype == torch.bfloat16:
            config_dict["bf16"] = {"enabled": True}
        elif dtype == torch.float16:
            config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}

        hidden_dim = 64
        total_samples = 32
        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
        data_loader = random_dataloader(model=model,
                                        total_samples=total_samples,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=dtype)
        dist.barrier()

        with pytest.raises(AssertionError) if invalid_cfg else nullcontext() as assertinfo:
            with model.no_sync():
                for _, batch in enumerate(data_loader):
                    loss = model(batch[0], batch[1])
                    model.backward(loss)
        if invalid_cfg:
            assert ("no_sync context manager is incompatible" in str(assertinfo))

    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
    @pytest.mark.parametrize("zero_stage", [0, 1])
    def test_engine_step(self, zero_stage, dtype):
        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "zero_optimization": {
                "stage": zero_stage,
            },
        }

        if dtype == torch.bfloat16:
            config_dict["bf16"] = {"enabled": True}
        elif dtype == torch.float16:
            config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}

        hidden_dim = 64
        total_samples = 32
        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
        data_loader = random_dataloader(model=model,
                                        total_samples=total_samples,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=dtype)
        dist.barrier()

        with model.no_sync():
            for _, batch in enumerate(data_loader):
                loss = model(batch[0], batch[1])
                model.backward(loss)
                with pytest.raises(AssertionError) as assertinfo:
                    model.step()
                assert ("It is illegal to call Engine.step() inside no_sync context manager" in str(assertinfo))

    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
    @pytest.mark.parametrize("zero_stage", [0, 1])
    def test_multiple_ctxts(self, zero_stage, dtype):
        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "zero_optimization": {
                "stage": zero_stage,
            },
        }

        if dtype == torch.bfloat16:
            config_dict["bf16"] = {"enabled": True}
        elif dtype == torch.float16:
            config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}

        hidden_dim = 64
        total_samples = 32
        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
        data_loader = random_dataloader(model=model,
                                        total_samples=total_samples,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=dtype)
        dist.barrier()

        param_list = list(model.parameters())
        first_losses = []
        first_grad_norms = []
        with model.no_sync():
            for _, batch in enumerate(data_loader):
                loss = model(batch[0], batch[1])
                first_losses.append(loss.item())
                model.backward(loss)
                grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list])
                first_grad_norms.append(grad_norm.item())

        second_losses = []
        second_grad_norms = []

        model.zero_grad()
        with model.no_sync():
            for _, batch in enumerate(data_loader):
                loss = model(batch[0], batch[1])
                second_losses.append(loss.item())
                model.backward(loss)
                grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list])
                second_grad_norms.append(grad_norm.item())

        assert len(first_losses) == len(second_losses)
        for x, y in zip(first_losses, second_losses):
            assert x == y

        assert len(first_grad_norms) == len(second_grad_norms)
        for x, y in zip(first_grad_norms, second_grad_norms):
            assert x == y

    def test_reentry(self):
        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "zero_optimization": {
                "stage": 1,
            },
        }

        hidden_dim = 64
        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
        dist.barrier()

        with model.no_sync():
            with pytest.raises(AssertionError) as assertinfo:
                with model.no_sync():
                    pass
            assert ("no_sync context manager reentry is unsupported" in str(assertinfo))