Back to Repositories

Testing Data Parallel Plugin Dataloader Sharding in ColossalAI

This test suite validates the data parallel (DP) plugin functionality in ColossalAI, focusing on dataloader sharding and distributed training capabilities. The tests ensure proper data distribution across multiple processes and verify the plugin’s core distributed training features.

Test Coverage Overview

The test suite covers essential aspects of the DP plugin implementation with focus on dataloader sharding.

  • Validates dataloader preparation and batch distribution
  • Tests rank-specific data handling
  • Verifies distributed communication between processes
  • Ensures data uniqueness across different ranks

Implementation Analysis

The testing approach implements a wrapper class (DPPluginWrapper) extending DPPluginBase to validate plugin initialization and dataloader functionality.

Key implementation patterns include:
  • Custom dataset creation using TensorDataset
  • Distributed process spawning with rank coordination
  • Broadcast operations for cross-rank data comparison
  • Batch-level validation across distributed processes

Technical Details

Testing infrastructure utilizes:

  • PyTorch’s distributed communication (torch.distributed)
  • ColossalAI’s spawn and rerun decorators
  • Custom TensorDataset with sequential data
  • Two-process distributed environment setup
  • Configurable port and host settings for distributed testing

Best Practices Demonstrated

The test implementation showcases robust distributed testing practices with clear separation of concerns.

  • Proper distributed environment initialization
  • Explicit rank-based validation checks
  • Clean teardown of distributed processes
  • Effective use of PyTorch’s distributed primitives
  • Modular test structure with clear validation points

hpcaitech/colossalai

tests/test_booster/test_plugin/test_dp_plugin_base.py

            
from typing import Callable, Dict, Iterator, List, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader, TensorDataset

import colossalai
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use, spawn


class DPPluginWrapper(DPPluginBase):
    """This is a wrapper class for testing DP plugin initialization and dataloader creation."""

    def configure(
        self,
        model: nn.Module,
        optimizer: Optimizer,
        criterion: Callable = None,
        dataloader: DataLoader = None,
        lr_scheduler: LRScheduler = None,
    ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
        pass

    def control_checkpoint_io(self) -> bool:
        pass

    def control_device(self) -> bool:
        pass

    def control_precision(self) -> bool:
        pass

    def get_checkpoint_io(self) -> CheckpointIO:
        pass

    def support_no_sync(self) -> bool:
        pass

    def supported_devices(self) -> List[str]:
        pass

    def supported_precisions(self) -> List[str]:
        pass

    def no_sync(self, model: nn.Module) -> Iterator[None]:
        pass

    def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
        pass

    def support_lora(self) -> bool:
        pass


def check_dataloader_sharding():
    plugin = DPPluginWrapper()

    # create a custom dataset with 0 to 10
    dataset = TensorDataset(torch.arange(0, 10))
    train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)

    # get the first batch of data
    batch = next(iter(train_dataloader))[0].cuda()
    is_rank_0 = dist.get_rank() == 0

    if is_rank_0:
        batch_to_compare = batch.clone()
    else:
        batch_to_compare = batch
    # pass to the rank 1 value to rank 0
    dist.broadcast(batch_to_compare, src=1)

    # compare on rank 0
    if is_rank_0:
        assert not torch.equal(
            batch, batch_to_compare
        ), "Same number was found across ranks but expected it to be different"


def run_dist(rank, world_size, port):
    # init dist env
    colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
    check_dataloader_sharding()


@rerun_if_address_is_in_use()
def test_dp_plugin_dataloader():
    spawn(run_dist, 2)