Back to Repositories

Testing Distributed Data Loading Components in PaddleOCR

This test suite validates data loading functionality for the PaddleOCR project, focusing on CIFAR-100 dataset handling and distributed batch processing. It implements robust data loading mechanisms with configurable batch sizes, worker processes, and device placement.

Test Coverage Overview

The test coverage encompasses critical data loading components including:
  • CIFAR-100 dataset initialization and normalization
  • Distributed batch sampling configuration
  • Multi-mode support (train/test/valid/eval)
  • Process termination handling
Key edge cases include invalid mode detection and graceful process termination scenarios.

Implementation Analysis

The testing approach utilizes PaddlePaddle’s vision datasets and IO modules for systematic validation. Implementation patterns include:
  • Configurable batch sampling with distributed processing support
  • Signal handling for clean process termination
  • Normalized data transformation pipeline
  • Flexible device placement options

Technical Details

Testing infrastructure leverages:
  • PaddlePaddle vision.datasets and vision.transforms
  • Custom DataLoader configuration with DistributedBatchSampler
  • Signal handling (SIGINT, SIGTERM)
  • Numpy for array operations
  • Configurable worker processes and device placement

Best Practices Demonstrated

The test implementation showcases several quality practices:
  • Robust error handling for invalid modes
  • Clean process termination handling
  • Configurable batch processing parameters
  • Standardized data normalization
  • Flexible dataset mode selection

paddlepaddle/paddleocr

test_tipc/supplementary/data_loader.py

            
import numpy as np
from paddle.vision.datasets import Cifar100
from paddle.vision.transforms import Normalize
import signal
import os
from paddle.io import Dataset, DataLoader, DistributedBatchSampler


def term_mp(sig_num, frame):
    """kill all child processes"""
    pid = os.getpid()
    pgid = os.getpgid(os.getpid())
    print("main proc {} exit, kill process group " "{}".format(pid, pgid))
    os.killpg(pgid, signal.SIGKILL)
    return


def build_dataloader(mode, batch_size=4, seed=None, num_workers=0, device="gpu:0"):
    normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format="HWC")

    if mode.lower() == "train":
        dataset = Cifar100(mode=mode, transform=normalize)
    elif mode.lower() in ["test", "valid", "eval"]:
        dataset = Cifar100(mode="test", transform=normalize)
    else:
        raise ValueError(f"{mode} should be one of ['train', 'test']")

    # define batch sampler
    batch_sampler = DistributedBatchSampler(
        dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )

    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        places=device,
        num_workers=num_workers,
        return_list=True,
        use_shared_memory=False,
    )

    # support exit using ctrl+c
    signal.signal(signal.SIGINT, term_mp)
    signal.signal(signal.SIGTERM, term_mp)

    return data_loader


# cifar100 = Cifar100(mode='train', transform=normalize)

# data = cifar100[0]

# image, label = data

# reader = build_dataloader('train')

# for idx, data in enumerate(reader):
#     print(idx, data[0].shape, data[1].shape)
#     if idx >= 10:
#         break