Back to Repositories

Testing Model Training and Distillation Workflows in PaddleOCR

This test suite implements training and evaluation functionality for PaddleOCR models, focusing on model training, distillation, and optimization with support for distributed training, quantization, and pruning.

Test Coverage Overview

The test suite covers comprehensive model training scenarios including:
  • Standard classification model training
  • Model distillation with single/multiple optimizers
  • Quantization-aware training (QAT)
  • Model pruning
  • Distributed training capabilities
  • Model evaluation and metric tracking

Implementation Analysis

The implementation uses PaddlePaddle’s distributed training framework with support for:
  • Automatic mixed precision (AMP) training
  • Multiple loss functions (Distillation, DML, KL-JS)
  • Dynamic learning rate scheduling
  • Model state saving/loading
  • Multi-GPU data parallel training

Technical Details

Key technical components include:
  • PaddleSlim for model compression
  • Custom loss functions for knowledge distillation
  • Gradient scaling for AMP training
  • DataLoader with multi-worker support
  • Configurable batch sizes and worker counts
  • Top-k accuracy metrics

Best Practices Demonstrated

The test implementation showcases several best practices:
  • Modular code organization with separate model/loss/optimizer components
  • Proper error handling for distributed operations
  • Regular model checkpointing and best model tracking
  • Configurable training parameters
  • Comprehensive logging and metrics tracking

paddlepaddle/paddleocr

test_tipc/supplementary/train.py

            
import paddle
import numpy as np
import os
import paddle.nn as nn
import paddle.distributed as dist

dist.get_world_size()
dist.init_parallel_env()

from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
from optimizer import create_optimizer
from data_loader import build_dataloader
from metric import create_metric
from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
from config import preprocess
import time

from paddleslim.dygraph.quant import QAT
from slim.slim_quant import PACT, quant_config
from slim.slim_fpgm import prune_model
from utils import load_model


def _mkdir_if_not_exist(path, logger):
    """
    mkdir if not exists, ignore the exception when multiprocess mkdir together
    """
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(path):
                logger.warning(
                    "be happy if some process has already created {}".format(path)
                )
            else:
                raise OSError("Failed to mkdir {}".format(path))


def save_model(
    model, optimizer, model_path, logger, is_best=False, prefix="ppocr", **kwargs
):
    """
    save model to the target path
    """
    _mkdir_if_not_exist(model_path, logger)
    model_prefix = os.path.join(model_path, prefix)
    paddle.save(model.state_dict(), model_prefix + ".pdparams")
    if type(optimizer) is list:
        paddle.save(optimizer[0].state_dict(), model_prefix + ".pdopt")
        paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + ".pdopt")

    else:
        paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")

    # # save metric and config
    # with open(model_prefix + '.states', 'wb') as f:
    #     pickle.dump(kwargs, f, protocol=2)
    if is_best:
        logger.info("save best model is to {}".format(model_prefix))
    else:
        logger.info("save model in {}".format(model_prefix))


def amp_scaler(config):
    if "AMP" in config and config["AMP"]["use_amp"] is True:
        AMP_RELATED_FLAGS_SETTING = {
            "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
            "FLAGS_max_inplace_grad_add": 8,
        }
        paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
        scale_loss = config["AMP"].get("scale_loss", 1.0)
        use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling", False)
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        )
        return scaler
    else:
        return None


def set_seed(seed):
    paddle.seed(seed)
    np.random.seed(seed)


def train(config, scaler=None):
    EPOCH = config["epoch"]
    topk = config["topk"]

    batch_size = config["TRAIN"]["batch_size"]
    num_workers = config["TRAIN"]["num_workers"]
    train_loader = build_dataloader(
        "train", batch_size=batch_size, num_workers=num_workers
    )

    # build metric
    metric_func = create_metric

    # build model
    # model = MobileNetV3_large_x0_5(class_dim=100)
    model = build_model(config)

    # build_optimizer
    optimizer, lr_scheduler = create_optimizer(
        config, parameter_list=model.parameters()
    )

    # load model
    pre_best_model_dict = load_model(config, model, optimizer)
    if len(pre_best_model_dict) > 0:
        pre_str = "The metric of loaded metric as follows {}".format(
            ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
        )
        logger.info(pre_str)

    # about slim prune and quant
    if "quant_train" in config and config["quant_train"] is True:
        quanter = QAT(config=quant_config, act_preprocess=PACT)
        quanter.quantize(model)
    elif "prune_train" in config and config["prune_train"] is True:
        model = prune_model(model, [1, 3, 32, 32], 0.1)
    else:
        pass

    # distribution
    model.train()
    model = paddle.DataParallel(model)
    # build loss function
    loss_func = build_loss(config)

    data_num = len(train_loader)

    best_acc = {}
    for epoch in range(EPOCH):
        st = time.time()
        for idx, data in enumerate(train_loader):
            img_batch, label = data
            img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
            label = paddle.unsqueeze(label, -1)

            if scaler is not None:
                with paddle.amp.auto_cast():
                    outs = model(img_batch)
            else:
                outs = model(img_batch)

            # cal metric
            acc = metric_func(outs, label)

            # cal loss
            avg_loss = loss_func(outs, label)

            if scaler is None:
                # backward
                avg_loss.backward()
                optimizer.step()
                optimizer.clear_grad()
            else:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)

            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()

            if idx % 10 == 0:
                et = time.time()
                strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
                strs += f"loss: {float(avg_loss)}"
                strs += (
                    f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
                )
                strs += f", batch_time: {round(et-st, 4)} s"
                logger.info(strs)
                st = time.time()

        if epoch % 10 == 0:
            acc = eval(config, model)
            if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
                best_acc = acc
                best_acc["epoch"] = epoch
                is_best = True
            else:
                is_best = False
            logger.info(
                f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
            )
            save_model(
                model,
                optimizer,
                config["save_model_dir"],
                logger,
                is_best,
                prefix="cls",
            )


def train_distill(config, scaler=None):
    EPOCH = config["epoch"]
    topk = config["topk"]

    batch_size = config["TRAIN"]["batch_size"]
    num_workers = config["TRAIN"]["num_workers"]
    train_loader = build_dataloader(
        "train", batch_size=batch_size, num_workers=num_workers
    )

    # build metric
    metric_func = create_metric

    # model = distillmv3_large_x0_5(class_dim=100)
    model = build_model(config)

    # pact quant train
    if "quant_train" in config and config["quant_train"] is True:
        quanter = QAT(config=quant_config, act_preprocess=PACT)
        quanter.quantize(model)
    elif "prune_train" in config and config["prune_train"] is True:
        model = prune_model(model, [1, 3, 32, 32], 0.1)
    else:
        pass

    # build_optimizer
    optimizer, lr_scheduler = create_optimizer(
        config, parameter_list=model.parameters()
    )

    # load model
    pre_best_model_dict = load_model(config, model, optimizer)
    if len(pre_best_model_dict) > 0:
        pre_str = "The metric of loaded metric as follows {}".format(
            ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
        )
        logger.info(pre_str)

    model.train()
    model = paddle.DataParallel(model)

    # build loss function
    loss_func_distill = LossDistill(model_name_list=["student", "student1"])
    loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
    loss_func_js = KLJSLoss(mode="js")

    data_num = len(train_loader)

    best_acc = {}
    for epoch in range(EPOCH):
        st = time.time()
        for idx, data in enumerate(train_loader):
            img_batch, label = data
            img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
            label = paddle.unsqueeze(label, -1)
            if scaler is not None:
                with paddle.amp.auto_cast():
                    outs = model(img_batch)
            else:
                outs = model(img_batch)

            # cal metric
            acc = metric_func(outs["student"], label)

            # cal loss
            avg_loss = (
                loss_func_distill(outs, label)["student"]
                + loss_func_distill(outs, label)["student1"]
                + loss_func_dml(outs, label)["student_student1"]
            )

            # backward
            if scaler is None:
                avg_loss.backward()
                optimizer.step()
                optimizer.clear_grad()
            else:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)

            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()

            if idx % 10 == 0:
                et = time.time()
                strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
                strs += f"loss: {float(avg_loss)}"
                strs += (
                    f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
                )
                strs += f", batch_time: {round(et-st, 4)} s"
                logger.info(strs)
                st = time.time()

        if epoch % 10 == 0:
            acc = eval(config, model._layers.student)
            if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
                best_acc = acc
                best_acc["epoch"] = epoch
                is_best = True
            else:
                is_best = False
            logger.info(
                f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
            )

            save_model(
                model,
                optimizer,
                config["save_model_dir"],
                logger,
                is_best,
                prefix="cls_distill",
            )


def train_distill_multiopt(config, scaler=None):
    EPOCH = config["epoch"]
    topk = config["topk"]

    batch_size = config["TRAIN"]["batch_size"]
    num_workers = config["TRAIN"]["num_workers"]
    train_loader = build_dataloader(
        "train", batch_size=batch_size, num_workers=num_workers
    )

    # build metric
    metric_func = create_metric

    # model = distillmv3_large_x0_5(class_dim=100)
    model = build_model(config)

    # build_optimizer
    optimizer, lr_scheduler = create_optimizer(
        config, parameter_list=model.student.parameters()
    )
    optimizer1, lr_scheduler1 = create_optimizer(
        config, parameter_list=model.student1.parameters()
    )

    # load model
    pre_best_model_dict = load_model(config, model, optimizer)
    if len(pre_best_model_dict) > 0:
        pre_str = "The metric of loaded metric as follows {}".format(
            ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
        )
        logger.info(pre_str)

    # quant train
    if "quant_train" in config and config["quant_train"] is True:
        quanter = QAT(config=quant_config, act_preprocess=PACT)
        quanter.quantize(model)
    elif "prune_train" in config and config["prune_train"] is True:
        model = prune_model(model, [1, 3, 32, 32], 0.1)
    else:
        pass

    model.train()

    model = paddle.DataParallel(model)

    # build loss function
    loss_func_distill = LossDistill(model_name_list=["student", "student1"])
    loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
    loss_func_js = KLJSLoss(mode="js")

    data_num = len(train_loader)
    best_acc = {}
    for epoch in range(EPOCH):
        st = time.time()
        for idx, data in enumerate(train_loader):
            img_batch, label = data
            img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
            label = paddle.unsqueeze(label, -1)

            if scaler is not None:
                with paddle.amp.auto_cast():
                    outs = model(img_batch)
            else:
                outs = model(img_batch)

            # cal metric
            acc = metric_func(outs["student"], label)

            # cal loss
            avg_loss = (
                loss_func_distill(outs, label)["student"]
                + loss_func_dml(outs, label)["student_student1"]
            )
            avg_loss1 = (
                loss_func_distill(outs, label)["student1"]
                + loss_func_dml(outs, label)["student_student1"]
            )

            if scaler is None:
                # backward
                avg_loss.backward(retain_graph=True)
                optimizer.step()
                optimizer.clear_grad()

                avg_loss1.backward()
                optimizer1.step()
                optimizer1.clear_grad()
            else:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)

                scaled_avg_loss = scaler.scale(avg_loss1)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer1, scaled_avg_loss)

            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
            if not isinstance(lr_scheduler1, float):
                lr_scheduler1.step()

            if idx % 10 == 0:
                et = time.time()
                strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
                strs += f"loss: {float(avg_loss)}, loss1: {float(avg_loss1)}"
                strs += (
                    f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
                )
                strs += f", batch_time: {round(et-st, 4)} s"
                logger.info(strs)
                st = time.time()

        if epoch % 10 == 0:
            acc = eval(config, model._layers.student)
            if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
                best_acc = acc
                best_acc["epoch"] = epoch
                is_best = True
            else:
                is_best = False
            logger.info(
                f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
            )
            save_model(
                model,
                [optimizer, optimizer1],
                config["save_model_dir"],
                logger,
                is_best,
                prefix="cls_distill_multiopt",
            )


def eval(config, model):
    batch_size = config["VALID"]["batch_size"]
    num_workers = config["VALID"]["num_workers"]
    valid_loader = build_dataloader(
        "test", batch_size=batch_size, num_workers=num_workers
    )

    # build metric
    metric_func = create_metric

    outs = []
    labels = []
    for idx, data in enumerate(valid_loader):
        img_batch, label = data
        img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
        label = paddle.unsqueeze(label, -1)
        out = model(img_batch)

        outs.append(out)
        labels.append(label)

    outs = paddle.concat(outs, axis=0)
    labels = paddle.concat(labels, axis=0)
    acc = metric_func(outs, labels)

    strs = f"The metric are as follows: acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
    logger.info(strs)
    return acc


if __name__ == "__main__":
    config, logger = preprocess(is_train=False)

    # AMP scaler
    scaler = amp_scaler(config)

    model_type = config["model_type"]

    if model_type == "cls":
        train(config)
    elif model_type == "cls_distill":
        train_distill(config)
    elif model_type == "cls_distill_multiopt":
        train_distill_multiopt(config)
    else:
        raise ValueError("model_type should be one of ['']")