Testing Checkpoint Reshaping Implementation in DeepSpeed
This test suite validates the checkpoint reshaping functionality in DeepSpeed, specifically focusing on 3D model parallelism transformations. It ensures correct redistribution of model states across different parallel processing configurations.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
microsoft/deepspeed
tests/unit/checkpoint/test_reshape_checkpoint.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.checkpoint import model_3d_desc
def _do_reshape(src_3d, tgt_3d):
assert src_3d.can_reshape(tgt_3d)
new_3d_map = src_3d.reshape(tgt_3d)
assert len(new_3d_map) == tgt_3d.dp_degree
for new_2d_map in new_3d_map:
assert new_2d_map.pp_degree == tgt_3d.pp_degree
assert new_2d_map.tp_degree == tgt_3d.tp_degree
return new_3d_map
# Specify 3d shape as pp/tp/dp
def test_reshape_222_to_111():
src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
tgt_3d = model_3d_desc(pp_degree=1, tp_degree=1, dp_degree=1)
new_3d_map = _do_reshape(src_3d, tgt_3d)
assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5, 2, 6, 3, 7]
def test_reshape_222_to_121():
src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=1)
new_3d_map = _do_reshape(src_3d, tgt_3d)
assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 2, 6]
assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5, 3, 7]
def test_reshape_222_to_122():
src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=2)
new_3d_map = _do_reshape(src_3d, tgt_3d)
assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4]
assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5]
assert new_3d_map[1].get_data(pp_index=0, tp_index=0) == [2, 6]
assert new_3d_map[1].get_data(pp_index=0, tp_index=1) == [3, 7]
def test_reshape_222_to_211():
src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2)
tgt_3d = model_3d_desc(pp_degree=2, tp_degree=1, dp_degree=1)
new_3d_map = _do_reshape(src_3d, tgt_3d)
assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5]
assert new_3d_map[0].get_data(pp_index=1, tp_index=0) == [2, 6, 3, 7]