Back to Repositories

Validating Checkpoint Solver Consistency Implementation in ColossalAI

This test suite validates the consistency between Python and C implementations of solver algorithms in ColossalAI’s checkpoint optimization system. It specifically compares the outputs of both solvers when processing neural network models with memory constraints.

Test Coverage Overview

The test suite provides comprehensive coverage of solver consistency verification across different model architectures (ResNet50 and DenseNet121) and memory budgets.

Key areas tested include:
  • Optimization table generation and comparison
  • Sequence operation matching between implementations
  • Memory budget handling for different model sizes
  • Meta tensor operations and graph module processing

Implementation Analysis

The testing approach employs a dual-execution strategy, running both Python and C solvers on identical inputs and comparing their outputs for exact matches. The implementation utilizes ColossalAI’s graph tracing and transformation capabilities, with specific focus on activation checkpoint optimization.

Technical patterns include:
  • Graph module transformation and codegen setup
  • Meta tensor propagation
  • Deep copy comparisons for solver outputs
  • Structured assertion checks for optimization tables

Technical Details

Testing tools and configuration:
  • PyTest framework for test orchestration
  • NCCL backend for distributed testing
  • Meta device tensor operations
  • ColoTracer for graph generation
  • ActivationCheckpointCodeGen for code generation
  • Memory budgets: 4000MB for ResNet50, 8080MB for DenseNet121

Best Practices Demonstrated

The test implementation showcases high-quality testing practices through systematic validation and thorough comparison methodology.

Notable practices include:
  • Comprehensive assertion checks for solver consistency
  • Proper cleanup with GPC destruction
  • Conditional test execution based on torch version compatibility
  • Reusable test configuration with spawn wrapper
  • Clear separation of test setup and execution logic

hpcaitech/colossalai

tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py

            
import copy

import pytest
import torch
import torch.fx
import torchvision.models as tm

import colossalai
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta

# from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.legacy.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn

if is_compatible_with_meta():
    from colossalai.fx.profiler.tensor import MetaTensor

try:
    from colossalai.fx.codegen import ActivationCheckpointCodeGen

    withcodegen = True
except:
    withcodegen = False


def _run_C_solver_consistency_test(rank, world_size, port):
    colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")

    for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
        model = M()
        data = torch.rand(128, 3, 224, 224, device="meta")

        tracer = ColoTracer()
        graph = tracer.trace(model, meta_args={"x": data})
        graph.set_codegen(ActivationCheckpointCodeGen())
        gm = ColoGraphModule(model, graph, model.__class__.__name__)
        if is_compatible_with_meta():
            data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
        MetaInfoProp(gm).run(data_meta)

        # python solver
        gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True)
        sequence_python: Sequence = copy.deepcopy(gm.__sequence__)
        opt_python = copy.deepcopy(gm.__opttable__)

        # C solver
        gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024)
        sequence_C: Sequence = copy.deepcopy(gm.__sequence__)
        opt_C = copy.deepcopy(gm.__opttable__)

        # make sure the opt_tables are the same
        for m in range(len(opt_python)):
            for d in range(1, len(opt_python[0])):
                for i in range(len(opt_python[0]) - d):
                    assert (
                        opt_python[m][i][i + d] == opt_C[m][i][i + d]
                    ), f"item ({m}, {i}, {i + d}) is not consistent with python version!
python version: {opt_python[m][i][i + d]}
C version: {opt_C[m][i][i + d]}"

        sequence_python = sequence_python.list_operations()
        sequence_C = sequence_C.list_operations()

        # make sure the sequences are the same
        assert len(sequence_python) == len(sequence_C) and all(
            python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C)
        )

    gpc.destroy()


@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
@rerun_if_address_is_in_use()
def test_C_solver_consistency():
    spawn(_run_C_solver_consistency_test, 1)


if __name__ == "__main__":
    _run_C_solver_consistency_test(rank=0)