Testing RankingDataCollator Implementation in LAION-AI/Open-Assistant
This test suite validates the RankingDataCollator functionality in the LAION-AI/Open-Assistant project, focusing on proper handling of tokenization, batch processing, and message formatting for ranking model training.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
laion-ai/open-assistant
model/model_training/tests/test_ranking_collator.py
from argparse import Namespace
import pytest
import torch
from model_training.custom_datasets import get_one_dataset
from model_training.custom_datasets.formatting import (
QA_SPECIAL_TOKENS,
DatasetEntryRm,
Role,
Utterance,
create_dataset_entry_qa,
)
from model_training.custom_datasets.ranking_collator import RankingDataCollator
from model_training.utils.utils import get_tokenizer, match_tokenizer_name
from torch.utils.data import DataLoader
from transformers.models.auto.tokenization_auto import AutoTokenizer
@pytest.fixture
def pythia_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("tests/resources/data_collator", local_files_only=True)
# for this test we use the pythia special tokens but note that this test is model agnostic
tokenizer_config = match_tokenizer_name("pythia")
tokenizer.add_special_tokens(
{
"pad_token": tokenizer_config.special_tokens.pad_token,
"eos_token": tokenizer_config.special_tokens.eos_token,
"sep_token": tokenizer_config.special_tokens.sep_token,
}
)
additional_special_tokens = list(QA_SPECIAL_TOKENS.values())
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
return tokenizer
def test_ranking_collator_system_tag(pythia_tokenizer):
first_example = DatasetEntryRm(
messages=[Utterance(text="First instruction.", role=Role.prompter, lang="en")],
replies=[
Utterance(text="Answer to first instruction.", role=Role.assistant, lang="en", quality=0.7),
Utterance(text="Answer to first instruction.", role=Role.assistant, lang="de", quality=0.8),
],
)
second_example = DatasetEntryRm(
messages=[Utterance(text="Second instruction.", role=Role.prompter)],
replies=[
Utterance(text="Answer to second instruction.", role=Role.assistant, humor=0.1, creativity=0.2),
Utterance(text="Answer to second instruction.", role=Role.assistant, humor=0.4, creativity=0.3),
],
)
examples = [first_example, second_example]
rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
batch, cu_lens = rdc(examples=examples)
assert len(batch) == 2
assert cu_lens == [0, len(first_example.replies), len(first_example.replies) + len(second_example.replies)]
assert batch.data["attention_mask"].shape[0] == 4 # we have 4 replies in total
assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape
eos = pythia_tokenizer.eos_token
# check each instruction
first_example_first_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][0])
f"{QA_SPECIAL_TOKENS['Question']}{first_example.messages[0].text}{eos}" in first_example_first_answer_decoded
f"{QA_SPECIAL_TOKENS['Answer']}{first_example.replies[0].text}{eos}" in first_example_first_answer_decoded
"lang: en" in first_example_first_answer_decoded
"quality: 0.7" in first_example_first_answer_decoded
first_example_second_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][1])
f"{QA_SPECIAL_TOKENS['Question']}{first_example.messages[0].text}{eos}" in first_example_second_answer_decoded
f"{QA_SPECIAL_TOKENS['Answer']}{first_example.replies[1].text}{eos}" in first_example_second_answer_decoded
"lang: de" in first_example_second_answer_decoded
"quality: 0.8" in first_example_second_answer_decoded
second_example_first_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][2])
f"{QA_SPECIAL_TOKENS['Question']}{second_example.messages[0].text}{eos}" in second_example_first_answer_decoded
f"{QA_SPECIAL_TOKENS['Answer']}{second_example.replies[0].text}{eos}" in second_example_first_answer_decoded
"humor: 0.1" in second_example_first_answer_decoded
"creativity: 0.2" in second_example_first_answer_decoded
second_example_second_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][2])
f"{QA_SPECIAL_TOKENS['Question']}{second_example.messages[0].text}{eos}" in second_example_second_answer_decoded
f"{QA_SPECIAL_TOKENS['Answer']}{second_example.replies[1].text}{eos}" in second_example_second_answer_decoded
"humor: 0.4" in second_example_second_answer_decoded
"creativity: 0.3" in second_example_second_answer_decoded
def test_ranking_collator_no_messages(pythia_tokenizer):
first_messages = None
first_replies = [
"Response A to None",
"Response B to None",
"Response C to None",
]
examples = [(first_messages, first_replies)]
rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
eos = pythia_tokenizer.eos_token
examples_ds = [
DatasetEntryRm(messages=None, replies=[Utterance(text=r, role=Role.assistant) for r in first_replies])
]
# make sure that formatting via dataset entry and lists is the same
for ex in [examples, examples_ds]:
batch, cu_lens = rdc(examples=ex)
assert len(batch) == 2
assert cu_lens == [0, len(first_replies)]
assert batch.data["attention_mask"].shape[0] == 3 # we have 5 replies in total
assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape
# check each instruction
assert pythia_tokenizer.decode(batch.data["input_ids"][0]) == f"{first_replies[0]}{eos}"
assert pythia_tokenizer.decode(batch.data["input_ids"][1]) == f"{first_replies[1]}{eos}"
assert pythia_tokenizer.decode(batch.data["input_ids"][2]) == f"{first_replies[2]}{eos}"
assert (batch.attention_mask == torch.where(batch.input_ids == 1, 0, 1)).all()
def test_ranking_collator_local(pythia_tokenizer):
first_messages = ["First Instruction."]
first_replies = [
"Response A to First Instruction",
"Response B to First Instruction",
"First Response C to First Instruction",
]
second_messages = ["Second Instruction."]
second_replies = ["Response A to Second Instruction", "Response B to Second Instruction"]
examples = [(first_messages, first_replies), (second_messages, second_replies)]
rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
eos = pythia_tokenizer.eos_token
pad = pythia_tokenizer.pad_token
examples_ds = [
create_dataset_entry_qa(mode="rm", questions=first_messages, answers=first_replies),
create_dataset_entry_qa(mode="rm", questions=second_messages, answers=second_replies),
]
# make sure that formatting via dataset entry and lists is the same
for ex in [examples, examples_ds]:
batch, cu_lens = rdc(examples=ex)
assert len(batch) == 2
assert cu_lens == [0, len(first_replies), len(first_replies) + len(second_replies)]
assert batch.data["attention_mask"].shape[0] == 5 # we have 5 replies in total
assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape
# check each instruction
assert (
pythia_tokenizer.decode(batch.data["input_ids"][0])
== f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[0]}{eos}"
+ 5 * pad
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][1])
== f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[1]}{eos}"
+ 5 * pad
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][2])
== f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[2]}{eos}"
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][3])
== f"{QA_SPECIAL_TOKENS['Question']}{second_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_replies[0]}{eos}"
+ 4 * pad
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][4])
== f"{QA_SPECIAL_TOKENS['Question']}{second_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_replies[1]}{eos}"
+ 4 * pad
)
assert (batch.attention_mask == torch.where(batch.input_ids == 1, 0, 1)).all()
@pytest.mark.skip(reason="manual")
def test_rm_datasets():
# dummy configuration
config = Namespace(cache_dir=".cache", model_name="EleutherAI/pythia-70m-deduped")
dataset_names = ["anthropic_rlhf", "hf_summary_pairs", "webgpt", "hellaswag", "shp", "hf_summary"]
for name in dataset_names:
train, val = get_one_dataset(conf=config, dataset_name=name, mode="rm")
print(f"dataset: '{name}' (train ({type(train)}): {len(train)}, val({type(val)}): {len(val)})")
avg_number_continuations = sum(len(x[1]) for x in train) / len(train)
num_more_than_two = sum(1 if len(x[1]) > 2 else 0 for x in train)
print(f"Average number of continuations: {avg_number_continuations} (with >2: {num_more_than_two})")
for i in range(10):
item = train[i + 100]
print(f"[{i}] Prefix: {item[0]}")
continuations = item[1]
print(f"[{i}] Continuations ({len(continuations)}):")
for j, c in enumerate(continuations):
print(f"[{i}.{j}]: {c}")
@pytest.mark.skip(reason="manual")
def test_ranking_collator():
# dummy configuration
config = Namespace(cache_dir=".cache", model_name="EleutherAI/pythia-70m-deduped")
# get a tokenizer
tokenizer = get_tokenizer(config)
print(type(tokenizer))
# load oasst dataset
kwargs = {"lang": "en,es,de,fr", "input_file_path": "2023-03-13_oasst_ready_labels.jsonl.gz", "mode": "rm"}
train, val = get_one_dataset(conf=config, dataset_name="oasst_export", **kwargs)
print(len(train))
a = train[0]
print(type(a))
print(len(a))
print("prefix", a[0])
print("continuations", a[1])
# create RankingCollator
ranking_collator = RankingDataCollator(tokenizer=tokenizer)
dl = DataLoader(
train,
batch_size=4,
collate_fn=ranking_collator,
num_workers=1,
pin_memory=False,
)
data_iter = iter(dl)
b = next(data_iter)
x, y = b
input_ids = x.input_ids
attention_mask = x.attention_mask
print("input_ids", input_ids.shape)
print("attention_mask", attention_mask.shape)
print("input_ids[0, :200]", input_ids[0, :200])
print("decoded input_ids[0, :200]:", tokenizer.decode(input_ids[0, :200]))
print("decoded non masked input_ids[0]:", tokenizer.decode(input_ids[0][x.attention_mask[0] == 1]))
print(y)
if __name__ == "__main__":
test_rm_datasets()
# test_ranking_collator()