Testing Model Training and Distillation Workflows in PaddleOCR
This test suite implements training and evaluation functionality for PaddleOCR models, focusing on model training, distillation, and optimization with support for distributed training, quantization, and pruning.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
paddlepaddle/paddleocr
test_tipc/supplementary/train.py
import paddle
import numpy as np
import os
import paddle.nn as nn
import paddle.distributed as dist
dist.get_world_size()
dist.init_parallel_env()
from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
from optimizer import create_optimizer
from data_loader import build_dataloader
from metric import create_metric
from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
from config import preprocess
import time
from paddleslim.dygraph.quant import QAT
from slim.slim_quant import PACT, quant_config
from slim.slim_fpgm import prune_model
from utils import load_model
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
"be happy if some process has already created {}".format(path)
)
else:
raise OSError("Failed to mkdir {}".format(path))
def save_model(
model, optimizer, model_path, logger, is_best=False, prefix="ppocr", **kwargs
):
"""
save model to the target path
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
paddle.save(model.state_dict(), model_prefix + ".pdparams")
if type(optimizer) is list:
paddle.save(optimizer[0].state_dict(), model_prefix + ".pdopt")
paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + ".pdopt")
else:
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
# # save metric and config
# with open(model_prefix + '.states', 'wb') as f:
# pickle.dump(kwargs, f, protocol=2)
if is_best:
logger.info("save best model is to {}".format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
def amp_scaler(config):
if "AMP" in config and config["AMP"]["use_amp"] is True:
AMP_RELATED_FLAGS_SETTING = {
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
"FLAGS_max_inplace_grad_add": 8,
}
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["AMP"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling", False)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
)
return scaler
else:
return None
def set_seed(seed):
paddle.seed(seed)
np.random.seed(seed)
def train(config, scaler=None):
EPOCH = config["epoch"]
topk = config["topk"]
batch_size = config["TRAIN"]["batch_size"]
num_workers = config["TRAIN"]["num_workers"]
train_loader = build_dataloader(
"train", batch_size=batch_size, num_workers=num_workers
)
# build metric
metric_func = create_metric
# build model
# model = MobileNetV3_large_x0_5(class_dim=100)
model = build_model(config)
# build_optimizer
optimizer, lr_scheduler = create_optimizer(
config, parameter_list=model.parameters()
)
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
pre_str = "The metric of loaded metric as follows {}".format(
", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
)
logger.info(pre_str)
# about slim prune and quant
if "quant_train" in config and config["quant_train"] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
elif "prune_train" in config and config["prune_train"] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
# distribution
model.train()
model = paddle.DataParallel(model)
# build loss function
loss_func = build_loss(config)
data_num = len(train_loader)
best_acc = {}
for epoch in range(EPOCH):
st = time.time()
for idx, data in enumerate(train_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
if scaler is not None:
with paddle.amp.auto_cast():
outs = model(img_batch)
else:
outs = model(img_batch)
# cal metric
acc = metric_func(outs, label)
# cal loss
avg_loss = loss_func(outs, label)
if scaler is None:
# backward
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
else:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
if idx % 10 == 0:
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {float(avg_loss)}"
strs += (
f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
)
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model)
if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
best_acc = acc
best_acc["epoch"] = epoch
is_best = True
else:
is_best = False
logger.info(
f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
)
save_model(
model,
optimizer,
config["save_model_dir"],
logger,
is_best,
prefix="cls",
)
def train_distill(config, scaler=None):
EPOCH = config["epoch"]
topk = config["topk"]
batch_size = config["TRAIN"]["batch_size"]
num_workers = config["TRAIN"]["num_workers"]
train_loader = build_dataloader(
"train", batch_size=batch_size, num_workers=num_workers
)
# build metric
metric_func = create_metric
# model = distillmv3_large_x0_5(class_dim=100)
model = build_model(config)
# pact quant train
if "quant_train" in config and config["quant_train"] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
elif "prune_train" in config and config["prune_train"] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
# build_optimizer
optimizer, lr_scheduler = create_optimizer(
config, parameter_list=model.parameters()
)
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
pre_str = "The metric of loaded metric as follows {}".format(
", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
)
logger.info(pre_str)
model.train()
model = paddle.DataParallel(model)
# build loss function
loss_func_distill = LossDistill(model_name_list=["student", "student1"])
loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
loss_func_js = KLJSLoss(mode="js")
data_num = len(train_loader)
best_acc = {}
for epoch in range(EPOCH):
st = time.time()
for idx, data in enumerate(train_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
if scaler is not None:
with paddle.amp.auto_cast():
outs = model(img_batch)
else:
outs = model(img_batch)
# cal metric
acc = metric_func(outs["student"], label)
# cal loss
avg_loss = (
loss_func_distill(outs, label)["student"]
+ loss_func_distill(outs, label)["student1"]
+ loss_func_dml(outs, label)["student_student1"]
)
# backward
if scaler is None:
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
else:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
if idx % 10 == 0:
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {float(avg_loss)}"
strs += (
f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
)
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model._layers.student)
if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
best_acc = acc
best_acc["epoch"] = epoch
is_best = True
else:
is_best = False
logger.info(
f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
)
save_model(
model,
optimizer,
config["save_model_dir"],
logger,
is_best,
prefix="cls_distill",
)
def train_distill_multiopt(config, scaler=None):
EPOCH = config["epoch"]
topk = config["topk"]
batch_size = config["TRAIN"]["batch_size"]
num_workers = config["TRAIN"]["num_workers"]
train_loader = build_dataloader(
"train", batch_size=batch_size, num_workers=num_workers
)
# build metric
metric_func = create_metric
# model = distillmv3_large_x0_5(class_dim=100)
model = build_model(config)
# build_optimizer
optimizer, lr_scheduler = create_optimizer(
config, parameter_list=model.student.parameters()
)
optimizer1, lr_scheduler1 = create_optimizer(
config, parameter_list=model.student1.parameters()
)
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
pre_str = "The metric of loaded metric as follows {}".format(
", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
)
logger.info(pre_str)
# quant train
if "quant_train" in config and config["quant_train"] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
elif "prune_train" in config and config["prune_train"] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
model.train()
model = paddle.DataParallel(model)
# build loss function
loss_func_distill = LossDistill(model_name_list=["student", "student1"])
loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
loss_func_js = KLJSLoss(mode="js")
data_num = len(train_loader)
best_acc = {}
for epoch in range(EPOCH):
st = time.time()
for idx, data in enumerate(train_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
if scaler is not None:
with paddle.amp.auto_cast():
outs = model(img_batch)
else:
outs = model(img_batch)
# cal metric
acc = metric_func(outs["student"], label)
# cal loss
avg_loss = (
loss_func_distill(outs, label)["student"]
+ loss_func_dml(outs, label)["student_student1"]
)
avg_loss1 = (
loss_func_distill(outs, label)["student1"]
+ loss_func_dml(outs, label)["student_student1"]
)
if scaler is None:
# backward
avg_loss.backward(retain_graph=True)
optimizer.step()
optimizer.clear_grad()
avg_loss1.backward()
optimizer1.step()
optimizer1.clear_grad()
else:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
scaled_avg_loss = scaler.scale(avg_loss1)
scaled_avg_loss.backward()
scaler.minimize(optimizer1, scaled_avg_loss)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
if not isinstance(lr_scheduler1, float):
lr_scheduler1.step()
if idx % 10 == 0:
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {float(avg_loss)}, loss1: {float(avg_loss1)}"
strs += (
f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
)
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model._layers.student)
if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
best_acc = acc
best_acc["epoch"] = epoch
is_best = True
else:
is_best = False
logger.info(
f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
)
save_model(
model,
[optimizer, optimizer1],
config["save_model_dir"],
logger,
is_best,
prefix="cls_distill_multiopt",
)
def eval(config, model):
batch_size = config["VALID"]["batch_size"]
num_workers = config["VALID"]["num_workers"]
valid_loader = build_dataloader(
"test", batch_size=batch_size, num_workers=num_workers
)
# build metric
metric_func = create_metric
outs = []
labels = []
for idx, data in enumerate(valid_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
out = model(img_batch)
outs.append(out)
labels.append(label)
outs = paddle.concat(outs, axis=0)
labels = paddle.concat(labels, axis=0)
acc = metric_func(outs, labels)
strs = f"The metric are as follows: acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
logger.info(strs)
return acc
if __name__ == "__main__":
config, logger = preprocess(is_train=False)
# AMP scaler
scaler = amp_scaler(config)
model_type = config["model_type"]
if model_type == "cls":
train(config)
elif model_type == "cls_distill":
train_distill(config)
elif model_type == "cls_distill_multiopt":
train_distill_multiopt(config)
else:
raise ValueError("model_type should be one of ['']")