Back to Repositories

Testing RankingDataCollator Implementation in LAION-AI/Open-Assistant

This test suite validates the RankingDataCollator functionality in the LAION-AI/Open-Assistant project, focusing on proper handling of tokenization, batch processing, and message formatting for ranking model training.

Test Coverage Overview

The test suite provides comprehensive coverage of the RankingDataCollator class functionality, including:
  • System tag handling and message formatting
  • Empty message handling
  • Local dataset processing
  • Dataset loading and validation
  • Batch processing verification
  • Token padding and attention mask generation

Implementation Analysis

The implementation follows a systematic approach using pytest fixtures and parametrized test cases. It validates data collation across multiple scenarios, employing PyTorch tensors and the Transformers library for tokenization. The tests verify both standard and edge cases for message formatting and token generation.

Key patterns include fixture-based tokenizer setup, comprehensive assertion chains, and detailed output validation.

Technical Details

  • Testing Framework: pytest
  • Key Dependencies: torch, transformers
  • Custom Components: RankingDataCollator, DatasetEntryRm
  • Tokenizer: AutoTokenizer with Pythia configuration
  • Test Data: Synthetic examples and real dataset samples

Best Practices Demonstrated

The test suite exemplifies several testing best practices:
  • Comprehensive fixture usage for test setup
  • Detailed input/output validation
  • Edge case coverage
  • Clear test organization and naming
  • Thorough assertion messages
  • Modular test structure

laion-ai/open-assistant

model/model_training/tests/test_ranking_collator.py

            
from argparse import Namespace

import pytest
import torch
from model_training.custom_datasets import get_one_dataset
from model_training.custom_datasets.formatting import (
    QA_SPECIAL_TOKENS,
    DatasetEntryRm,
    Role,
    Utterance,
    create_dataset_entry_qa,
)
from model_training.custom_datasets.ranking_collator import RankingDataCollator
from model_training.utils.utils import get_tokenizer, match_tokenizer_name
from torch.utils.data import DataLoader
from transformers.models.auto.tokenization_auto import AutoTokenizer


@pytest.fixture
def pythia_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("tests/resources/data_collator", local_files_only=True)
    # for this test we use the pythia special tokens but note that this test is model agnostic
    tokenizer_config = match_tokenizer_name("pythia")

    tokenizer.add_special_tokens(
        {
            "pad_token": tokenizer_config.special_tokens.pad_token,
            "eos_token": tokenizer_config.special_tokens.eos_token,
            "sep_token": tokenizer_config.special_tokens.sep_token,
        }
    )

    additional_special_tokens = list(QA_SPECIAL_TOKENS.values())

    tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
    return tokenizer


def test_ranking_collator_system_tag(pythia_tokenizer):
    first_example = DatasetEntryRm(
        messages=[Utterance(text="First instruction.", role=Role.prompter, lang="en")],
        replies=[
            Utterance(text="Answer to first instruction.", role=Role.assistant, lang="en", quality=0.7),
            Utterance(text="Answer to first instruction.", role=Role.assistant, lang="de", quality=0.8),
        ],
    )
    second_example = DatasetEntryRm(
        messages=[Utterance(text="Second instruction.", role=Role.prompter)],
        replies=[
            Utterance(text="Answer to second instruction.", role=Role.assistant, humor=0.1, creativity=0.2),
            Utterance(text="Answer to second instruction.", role=Role.assistant, humor=0.4, creativity=0.3),
        ],
    )
    examples = [first_example, second_example]

    rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
    batch, cu_lens = rdc(examples=examples)

    assert len(batch) == 2
    assert cu_lens == [0, len(first_example.replies), len(first_example.replies) + len(second_example.replies)]
    assert batch.data["attention_mask"].shape[0] == 4  # we have 4 replies in total
    assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape
    eos = pythia_tokenizer.eos_token

    # check each instruction
    first_example_first_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][0])
    f"{QA_SPECIAL_TOKENS['Question']}{first_example.messages[0].text}{eos}" in first_example_first_answer_decoded
    f"{QA_SPECIAL_TOKENS['Answer']}{first_example.replies[0].text}{eos}" in first_example_first_answer_decoded
    "lang: en" in first_example_first_answer_decoded
    "quality: 0.7" in first_example_first_answer_decoded

    first_example_second_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][1])
    f"{QA_SPECIAL_TOKENS['Question']}{first_example.messages[0].text}{eos}" in first_example_second_answer_decoded
    f"{QA_SPECIAL_TOKENS['Answer']}{first_example.replies[1].text}{eos}" in first_example_second_answer_decoded
    "lang: de" in first_example_second_answer_decoded
    "quality: 0.8" in first_example_second_answer_decoded

    second_example_first_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][2])
    f"{QA_SPECIAL_TOKENS['Question']}{second_example.messages[0].text}{eos}" in second_example_first_answer_decoded
    f"{QA_SPECIAL_TOKENS['Answer']}{second_example.replies[0].text}{eos}" in second_example_first_answer_decoded
    "humor: 0.1" in second_example_first_answer_decoded
    "creativity: 0.2" in second_example_first_answer_decoded

    second_example_second_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][2])
    f"{QA_SPECIAL_TOKENS['Question']}{second_example.messages[0].text}{eos}" in second_example_second_answer_decoded
    f"{QA_SPECIAL_TOKENS['Answer']}{second_example.replies[1].text}{eos}" in second_example_second_answer_decoded
    "humor: 0.4" in second_example_second_answer_decoded
    "creativity: 0.3" in second_example_second_answer_decoded


def test_ranking_collator_no_messages(pythia_tokenizer):
    first_messages = None
    first_replies = [
        "Response A to None",
        "Response B to None",
        "Response C to None",
    ]
    examples = [(first_messages, first_replies)]
    rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
    eos = pythia_tokenizer.eos_token
    examples_ds = [
        DatasetEntryRm(messages=None, replies=[Utterance(text=r, role=Role.assistant) for r in first_replies])
    ]
    # make sure that formatting via dataset entry and lists is the same
    for ex in [examples, examples_ds]:
        batch, cu_lens = rdc(examples=ex)
        assert len(batch) == 2
        assert cu_lens == [0, len(first_replies)]
        assert batch.data["attention_mask"].shape[0] == 3  # we have 5 replies in total
        assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape

        # check each instruction
        assert pythia_tokenizer.decode(batch.data["input_ids"][0]) == f"{first_replies[0]}{eos}"
        assert pythia_tokenizer.decode(batch.data["input_ids"][1]) == f"{first_replies[1]}{eos}"
        assert pythia_tokenizer.decode(batch.data["input_ids"][2]) == f"{first_replies[2]}{eos}"
        assert (batch.attention_mask == torch.where(batch.input_ids == 1, 0, 1)).all()


def test_ranking_collator_local(pythia_tokenizer):
    first_messages = ["First Instruction."]
    first_replies = [
        "Response A to First Instruction",
        "Response B to First Instruction",
        "First Response C to First Instruction",
    ]
    second_messages = ["Second Instruction."]
    second_replies = ["Response A to Second Instruction", "Response B to Second Instruction"]
    examples = [(first_messages, first_replies), (second_messages, second_replies)]
    rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
    eos = pythia_tokenizer.eos_token
    pad = pythia_tokenizer.pad_token

    examples_ds = [
        create_dataset_entry_qa(mode="rm", questions=first_messages, answers=first_replies),
        create_dataset_entry_qa(mode="rm", questions=second_messages, answers=second_replies),
    ]
    # make sure that formatting via dataset entry and lists is the same
    for ex in [examples, examples_ds]:
        batch, cu_lens = rdc(examples=ex)

        assert len(batch) == 2
        assert cu_lens == [0, len(first_replies), len(first_replies) + len(second_replies)]
        assert batch.data["attention_mask"].shape[0] == 5  # we have 5 replies in total
        assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape
        # check each instruction
        assert (
            pythia_tokenizer.decode(batch.data["input_ids"][0])
            == f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[0]}{eos}"
            + 5 * pad
        )
        assert (
            pythia_tokenizer.decode(batch.data["input_ids"][1])
            == f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[1]}{eos}"
            + 5 * pad
        )
        assert (
            pythia_tokenizer.decode(batch.data["input_ids"][2])
            == f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[2]}{eos}"
        )
        assert (
            pythia_tokenizer.decode(batch.data["input_ids"][3])
            == f"{QA_SPECIAL_TOKENS['Question']}{second_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_replies[0]}{eos}"
            + 4 * pad
        )
        assert (
            pythia_tokenizer.decode(batch.data["input_ids"][4])
            == f"{QA_SPECIAL_TOKENS['Question']}{second_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_replies[1]}{eos}"
            + 4 * pad
        )

        assert (batch.attention_mask == torch.where(batch.input_ids == 1, 0, 1)).all()


@pytest.mark.skip(reason="manual")
def test_rm_datasets():
    # dummy configuration
    config = Namespace(cache_dir=".cache", model_name="EleutherAI/pythia-70m-deduped")

    dataset_names = ["anthropic_rlhf", "hf_summary_pairs", "webgpt", "hellaswag", "shp", "hf_summary"]
    for name in dataset_names:
        train, val = get_one_dataset(conf=config, dataset_name=name, mode="rm")
        print(f"dataset: '{name}' (train ({type(train)}): {len(train)}, val({type(val)}): {len(val)})")

        avg_number_continuations = sum(len(x[1]) for x in train) / len(train)
        num_more_than_two = sum(1 if len(x[1]) > 2 else 0 for x in train)
        print(f"Average number of continuations: {avg_number_continuations} (with >2: {num_more_than_two})")

        for i in range(10):
            item = train[i + 100]
            print(f"[{i}] Prefix: {item[0]}")
            continuations = item[1]
            print(f"[{i}] Continuations ({len(continuations)}):")
            for j, c in enumerate(continuations):
                print(f"[{i}.{j}]: {c}")


@pytest.mark.skip(reason="manual")
def test_ranking_collator():
    # dummy configuration
    config = Namespace(cache_dir=".cache", model_name="EleutherAI/pythia-70m-deduped")

    # get a tokenizer
    tokenizer = get_tokenizer(config)
    print(type(tokenizer))

    # load oasst dataset
    kwargs = {"lang": "en,es,de,fr", "input_file_path": "2023-03-13_oasst_ready_labels.jsonl.gz", "mode": "rm"}
    train, val = get_one_dataset(conf=config, dataset_name="oasst_export", **kwargs)
    print(len(train))
    a = train[0]

    print(type(a))
    print(len(a))
    print("prefix", a[0])
    print("continuations", a[1])

    # create RankingCollator
    ranking_collator = RankingDataCollator(tokenizer=tokenizer)

    dl = DataLoader(
        train,
        batch_size=4,
        collate_fn=ranking_collator,
        num_workers=1,
        pin_memory=False,
    )

    data_iter = iter(dl)
    b = next(data_iter)
    x, y = b

    input_ids = x.input_ids
    attention_mask = x.attention_mask
    print("input_ids", input_ids.shape)
    print("attention_mask", attention_mask.shape)
    print("input_ids[0, :200]", input_ids[0, :200])
    print("decoded input_ids[0, :200]:", tokenizer.decode(input_ids[0, :200]))
    print("decoded non masked input_ids[0]:", tokenizer.decode(input_ids[0][x.attention_mask[0] == 1]))

    print(y)


if __name__ == "__main__":
    test_rm_datasets()
    # test_ranking_collator()