Testing Memory Optimization Performance in ColossalAI
This test suite evaluates the performance of memory optimization and offloading strategies in ColossalAI, comparing asynchronous offloading with Gemini optimization for large model training. It focuses on measuring execution time, memory usage, and training efficiency for GPT-2 model implementations.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
hpcaitech/colossalai
tests/test_auto_parallel/test_offload/test_perf.py
import time
import pytest
import torch
from torch.utils._pytree import tree_map
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size
from colossalai.legacy.zero.gemini.colo_init_context import ColoInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import *
# from tests.test_tensor.common_utils import set_seed
@parameterize("model_name", ["gpt2_"])
@parameterize("memory_budget", [5000])
@parameterize("solver_name", ["asyn"])
def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
# build model
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
label = torch.randint(
low=0,
high=128,
size=(
64,
8,
),
device=get_accelerator().get_current_device(),
)
criterion = LMLoss()
set_seed(42)
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024**2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
data_args = data_gen(device="cpu")
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
data_args = tree_map(wrap_fn, data_args)
start_time = time.time()
model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name)
solver_time = time.time() - start_time
print(f"solver_time={solver_time:.3f} s")
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
optim = AMPOptimizer(hybrid_optimizer, model)
with ColoInitContext(device=torch.device("cpu")):
gemini_model = model_builder()
gemini_model.train()
hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_config = dict(
strict_ddp_mode=False,
device=torch.device("cpu"),
placement_policy="cpu",
pin_memory=True,
hidden_dim=8192,
search_range_m=128,
)
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# test gemini
time_list = []
set_seed(42)
data_args = data_gen(device="cuda")
for step in range(10):
gemini_optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
gemini_out = gemini_model(**data_args)
gemini_loss = criterion(gemini_out, label)
gemini_optim.backward(gemini_loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
gemini_optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f"gemini | model_name: {model_name}")
print(
f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB "
f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|"
)
print(time_list)
del data_args
del gemini_model
del gemini_optim
del gemini_out
del gemini_loss
# test asyn offload
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
time_list = []
set_seed(42)
data_args = data_gen(device="cuda")
data_args = tree_map(wrap_fn, data_args)
for step in range(10):
optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
loss = criterion(model(**data_args), label)
optim.backward(loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f"solver_name: {solver_name} | model_name: {model_name}")
print(
f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB "
f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|"
)
print(time_list)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_fwd_bwd()
@pytest.mark.skip("this test failed")
@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed")
@rerun_if_address_is_in_use()
def test_perf():
spawn(run_dist, 1)
if __name__ == "__main__":
test_perf()