Back to Repositories

Testing Checkpoint Reshaping Implementation in DeepSpeed

This test suite validates the checkpoint reshaping functionality in DeepSpeed, specifically focusing on 3D model parallelism transformations. It ensures correct redistribution of model states across different parallel processing configurations.

Test Coverage Overview

The test suite covers various 3D reshaping scenarios for model checkpoints, validating transformations between different parallel processing configurations.

Key functionality tested:
  • 2x2x2 to 1x1x1 reshaping
  • 2x2x2 to 1x2x1 reshaping
  • 2x2x2 to 1x2x2 reshaping
  • 2x2x2 to 2x1x1 reshaping

Implementation Analysis

The testing approach uses a systematic verification of model state redistribution across different parallel dimensions. Tests employ the model_3d_desc class to define source and target parallel configurations, validating correct data mapping and index preservation during reshaping operations.

Technical Details

Testing components:
  • DeepSpeed checkpoint module integration
  • model_3d_desc class for parallel configuration
  • Helper function _do_reshape for validation
  • Assert statements for dimensional integrity checks

Best Practices Demonstrated

The test suite demonstrates robust verification practices for distributed computing operations.

Notable practices:
  • Systematic validation of dimensional integrity
  • Comprehensive edge case coverage
  • Clear test case organization
  • Explicit assertion checks for expected outcomes

microsoft/deepspeed

tests/unit/checkpoint/test_reshape_checkpoint.py

            
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.checkpoint import model_3d_desc


def _do_reshape(src_3d, tgt_3d):
    assert src_3d.can_reshape(tgt_3d)
    new_3d_map = src_3d.reshape(tgt_3d)

    assert len(new_3d_map) == tgt_3d.dp_degree
    for new_2d_map in new_3d_map:
        assert new_2d_map.pp_degree == tgt_3d.pp_degree
        assert new_2d_map.tp_degree == tgt_3d.tp_degree

    return new_3d_map


# Specify 3d shape as pp/tp/dp
def test_reshape_222_to_111():
    src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
    tgt_3d = model_3d_desc(pp_degree=1, tp_degree=1, dp_degree=1)

    new_3d_map = _do_reshape(src_3d, tgt_3d)

    assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5, 2, 6, 3, 7]


def test_reshape_222_to_121():
    src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
    tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=1)

    new_3d_map = _do_reshape(src_3d, tgt_3d)

    assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 2, 6]
    assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5, 3, 7]


def test_reshape_222_to_122():
    src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
    tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=2)

    new_3d_map = _do_reshape(src_3d, tgt_3d)

    assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4]
    assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5]
    assert new_3d_map[1].get_data(pp_index=0, tp_index=0) == [2, 6]
    assert new_3d_map[1].get_data(pp_index=0, tp_index=1) == [3, 7]


def test_reshape_222_to_211():
    src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
    tgt_3d = model_3d_desc(pp_degree=2, tp_degree=1, dp_degree=1)

    new_3d_map = _do_reshape(src_3d, tgt_3d)

    assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5]
    assert new_3d_map[0].get_data(pp_index=1, tp_index=0) == [2, 6, 3, 7]