Back to Repositories

Testing FSDP Checkpoint Operations in ColossalAI

This test suite validates checkpoint I/O operations for PyTorch FSDP (Fully Sharded Data Parallel) in ColossalAI. It ensures proper model and optimizer state saving/loading functionality while maintaining distributed training integrity.

Test Coverage Overview

The test suite provides comprehensive coverage of FSDP checkpoint operations:
  • Model state dictionary saving and loading
  • Optimizer state preservation across save/load cycles
  • Both sharded and non-sharded checkpoint scenarios
  • Distributed training state consistency validation

Implementation Analysis

The testing approach implements a systematic verification of FSDP checkpoint mechanics:
  • Uses ResNet18 as the test model architecture
  • Validates state consistency using nested dictionary comparisons
  • Implements distributed testing with 2-process spawning
  • Verifies output consistency before and after checkpoint operations

Technical Details

Test infrastructure utilizes:
  • PyTorch version 1.12.0+ requirement
  • ColossalAI Booster with TorchFSDPPlugin
  • Temporary directory management for checkpoint storage
  • Custom nested dictionary comparison utility
  • Distributed environment simulation

Best Practices Demonstrated

The test implementation showcases several testing best practices:
  • Proper test isolation using temporary directories
  • Comprehensive state verification at multiple levels
  • Version compatibility checking
  • Proper cleanup of distributed resources
  • Rerun capability for addressing port conflicts

hpcaitech/colossalai

tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py

            
import pytest
import torch
from packaging import version
from torch.optim import SGD
from torchvision.models import resnet18
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster

if version.parse(torch.__version__) >= version.parse("1.12.0"):
    from colossalai.booster.plugin import TorchFSDPPlugin
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from colossalai.testing import rerun_if_address_is_in_use, spawn


def compare_nested_dict(dict1, dict2):
    for key in dict1:
        if key in dict2:
            if type(dict1[key]) is dict:
                assert type(dict2[key]) is dict
                diff = compare_nested_dict(dict1[key], dict2[key])
                if not diff:
                    return diff
            elif type(dict1[key]) is list:
                assert type(dict2[key]) is list
                for i, val in enumerate(dict1[key]):
                    if isinstance(val, torch.Tensor):
                        if not torch.equal(dict1[key][i], dict2[key][i]):
                            return False
                    elif val != dict2[key][i]:
                        return False
            elif type(dict1[key]) is torch.Tensor:
                assert type(dict2[key]) is torch.Tensor
                if not torch.equal(dict1[key], dict2[key]):
                    return False
            else:
                if dict1[key] != dict2[key]:
                    return False
        else:
            return False
    return True


def check_torch_fsdp_ckpt():
    model = resnet18()
    plugin = TorchFSDPPlugin()
    booster = Booster(plugin=plugin)
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
    criterion = lambda x: x.mean()
    fsdp_model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

    inputs = torch.randn(4, 3, 224, 224)
    outputs = None

    def run_model():
        nonlocal outputs
        outputs = fsdp_model(inputs)
        optimizer.zero_grad()
        criterion(outputs).backward()
        optimizer.step()

    with shared_tempdir() as tempdir:
        model_ckpt_path = f"{tempdir}/model"
        optim_ckpt_path = f"{tempdir}/optimizer"

        run_model()

        booster.save_model(fsdp_model, model_ckpt_path, shard=False)
        booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)

        full_msd = fsdp_model.state_dict()
        # full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer)
        sharded_osd = optimizer.state_dict()
        import copy

        sharded_osd = copy.deepcopy(sharded_osd)

        run_model()

        full_msd_updated = fsdp_model.state_dict()
        # full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
        sharded_osd_updated = optimizer.state_dict()

        assert not compare_nested_dict(sharded_osd, sharded_osd_updated)
        assert not compare_nested_dict(full_msd_updated, full_msd)
        outputs_first = fsdp_model(inputs)
        assert criterion(outputs_first) != criterion(outputs)

        booster.load_model(fsdp_model, model_ckpt_path)
        booster.load_optimizer(optimizer, optim_ckpt_path)

        full_msd_restore = fsdp_model.state_dict()
        # full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
        sharded_osd_restore = optimizer.state_dict()

        assert compare_nested_dict(sharded_osd, sharded_osd_restore)
        assert compare_nested_dict(full_msd_restore, full_msd)
        outputs_sec = fsdp_model(inputs)
        assert criterion(outputs_sec) == criterion(outputs)

    with shared_tempdir() as tempdir:
        model_ckpt_path = f"{tempdir}/model"
        optim_ckpt_path = f"{tempdir}/optimizer"

        run_model()

        booster.save_model(fsdp_model, model_ckpt_path, shard=True)
        booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)

        full_msd = fsdp_model.unwrap().state_dict()
        full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)

        import copy

        sharded_osd = copy.deepcopy(full_osd)

        run_model()

        full_msd_updated = fsdp_model.unwrap().state_dict()
        full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)

        # cost much time led to timeout
        # assert not compare_nested_dict(full_osd_updated, sharded_osd)
        # assert not compare_nested_dict(full_msd_updated, full_msd)
        outputs_first = fsdp_model(inputs)
        assert criterion(outputs_first) != criterion(outputs)

        booster.load_model(fsdp_model, model_ckpt_path)
        booster.load_optimizer(optimizer, optim_ckpt_path)

        full_msd_restore = fsdp_model.unwrap().state_dict()
        sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)

        assert compare_nested_dict(sharded_osd, sharded_osd_restore)
        assert compare_nested_dict(full_msd_restore, full_msd)
        outputs_sec = fsdp_model(inputs)
        assert criterion(outputs_sec) == criterion(outputs)


def run_dist(rank, world_size, port):
    # init dist env
    colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
    check_torch_fsdp_ckpt()


@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="requires torch1.12 or higher")
@rerun_if_address_is_in_use()
def test_torch_fsdp_ckpt():
    spawn(run_dist, 2)