Back to Repositories

Testing Custom ReLU Operator Integration in PaddleOCR

This test file implements a custom ReLU operator integration test for PaddleOCR using LeNet architecture and MNIST dataset. It validates the functionality of a JIT-compiled custom operator within a deep learning model training pipeline.

Test Coverage Overview

The test suite provides comprehensive coverage of custom operator integration in PaddlePaddle.

Key areas tested include:
  • Custom ReLU operator JIT compilation
  • Integration with LeNet CNN architecture
  • MNIST dataset training pipeline
  • GPU device execution
  • Gradient computation and optimization

Implementation Analysis

The testing approach implements an end-to-end training loop using a custom ReLU operator.

Notable implementation patterns include:
  • C++/CUDA custom operator compilation
  • Neural network layer definition using paddle.nn
  • Data normalization and loading pipeline
  • Cross-entropy loss computation
  • Adam optimizer integration

Technical Details

Testing infrastructure includes:
  • PaddlePaddle framework and CUDA support
  • Custom operator JIT compilation tools
  • Vision transforms and data loading utilities
  • Paddle inference configuration
  • GPU device setup
  • Batch processing with worker threads

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Modular network architecture definition
  • Proper error handling and GPU resource management
  • Efficient data loading with worker threads
  • Regular loss monitoring and reporting
  • Clean separation of model, data, and training components

paddlepaddle/paddleocr

test_tipc/supplementary/custom_op/test.py

            
import paddle
import paddle.nn as nn
from paddle.vision.transforms import Compose, Normalize
from paddle.utils.cpp_extension import load
from paddle.inference import Config
from paddle.inference import create_predictor
import numpy as np

EPOCH_NUM = 4
BATCH_SIZE = 64

# jit compile custom op
custom_ops = load(
    name="custom_jit_ops", sources=["custom_relu_op.cc", "custom_relu_op.cu"]
)


class LeNet(nn.Layer):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2D(
            in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2
        )
        self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.linear2 = nn.Linear(in_features=120, out_features=84)
        self.linear3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = custom_ops.custom_relu(x)
        x = self.max_pool1(x)
        x = custom_ops.custom_relu(x)
        x = self.conv2(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        x = self.linear1(x)
        x = custom_ops.custom_relu(x)
        x = self.linear2(x)
        x = custom_ops.custom_relu(x)
        x = self.linear3(x)
        return x


# set device
paddle.set_device("gpu")

# model
net = LeNet()
loss_fn = nn.CrossEntropyLoss()
opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=net.parameters())

# data loader
transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format="CHW")])
train_dataset = paddle.vision.datasets.MNIST(mode="train", transform=transform)
train_loader = paddle.io.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2
)

# train
for epoch_id in range(EPOCH_NUM):
    for batch_id, (image, label) in enumerate(train_loader()):
        out = net(image)
        loss = loss_fn(out, label)
        loss.backward()

        if batch_id % 300 == 0:
            print(
                "Epoch {} batch {}: loss = {}".format(
                    epoch_id, batch_id, np.mean(loss.numpy())
                )
            )

        opt.step()
        opt.clear_grad()