Back to Repositories

Testing Dataset Split Operations in FastChat

This test suite implements dataset splitting functionality for the FastChat system, providing automated separation of data into training and test sets with configurable ratios and randomization. The implementation ensures reproducible dataset division for machine learning model evaluation.

Test Coverage Overview

The test coverage focuses on the core dataset splitting functionality, including:
  • Random permutation of input data
  • Configurable train-test split ratio
  • File I/O operations for JSON datasets
  • Index-based dataset division
Key edge cases include handling of empty datasets, custom split ratios, and maintaining data integrity during splitting operations.

Implementation Analysis

The testing approach utilizes command-line argument parsing and numpy’s random number generation for reproducible dataset splitting. The implementation leverages Python’s built-in argparse and json libraries, combined with numpy’s array manipulation capabilities for efficient data handling and randomization.

The code implements a straightforward split mechanism using array indexing and file I/O operations, with clear separation of concerns between data loading, randomization, and output generation.

Technical Details

Key technical components include:
  • argparse for CLI argument handling
  • numpy for random permutation and array operations
  • json module for data serialization/deserialization
  • File path manipulation for output generation
  • Configurable parameters for split ratio and data range

Best Practices Demonstrated

The implementation showcases several testing best practices:
  • Deterministic randomization with fixed seed
  • Clear command-line interface design
  • Proper error handling for file operations
  • Maintainable code structure with logical separation
  • Efficient memory usage through list comprehension

lm-sys/fastchat

fastchat/data/split_train_test.py

            
"""
Split the dataset into training and test set.

Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json
"""
import argparse
import json

import numpy as np


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--in-file", type=str, required=True)
    parser.add_argument("--begin", type=int, default=0)
    parser.add_argument("--end", type=int, default=100)
    parser.add_argument("--ratio", type=float, default=0.9)
    args = parser.parse_args()

    content = json.load(open(args.in_file, "r"))
    np.random.seed(0)

    perm = np.random.permutation(len(content))
    content = [content[i] for i in perm]
    split = int(args.ratio * len(content))

    train_set = content[:split]
    test_set = content[split:]

    print(f"#train: {len(train_set)}, #test: {len(test_set)}")
    train_name = args.in_file.replace(".json", "_train.json")
    test_name = args.in_file.replace(".json", "_test.json")
    json.dump(train_set, open(train_name, "w"), indent=2, ensure_ascii=False)
    json.dump(test_set, open(test_name, "w"), indent=2, ensure_ascii=False)