Skip to content

Commit

Permalink
Feat: Drop long samples and shuffle rl samples (#2040) [skip ci]
Browse files Browse the repository at this point in the history
* feat: LOG warn if samples are dropped due to seq length

* feat: add drop long samples for RL

* feat: add ipo

* fix: remove num_proc for map as subprocesses are prone to die

* feat: shuffle rl dataset

* fix: support preprocess for kto

* chore: use set instead of list

* feat: add simpo
  • Loading branch information
NanoCode012 authored Nov 19, 2024
1 parent d9b71ed commit f007c38
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 17 deletions.
71 changes: 67 additions & 4 deletions src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,57 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
tokenizer = load_tokenizer(cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)

if isinstance(data_set, DatasetDict):
data_set = data_set["train"]

data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]

return data_set


def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in ("dpo", "ipo", "orpo", "simpo"):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)

prompt = sample["prompt"]
chosen = sample["chosen"]
rejected = sample["rejected"]

len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])

return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len

if rl == "kto":
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")

prompt = sample["prompt"]
completion = sample["completion"]

len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_completion = len(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)

return (len_prompt + len_completion) <= sequence_len

raise ValueError("Unknown RL type")


def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
Expand All @@ -94,7 +136,7 @@ def load_split(dataset_cfgs, _cfg):
)
split_datasets.insert(i, ds)

tokenizer = None
tokenizer = load_tokenizer(cfg)

for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
Expand All @@ -121,7 +163,28 @@ def load_split(dataset_cfgs, _cfg):
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set

return concatenate_datasets(split_datasets)
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)

prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")

combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)

return combined_datasets

with zero_first(is_main_process()):
train_is_preprocessed = False
Expand Down
43 changes: 31 additions & 12 deletions src/axolotl/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,47 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):


def check_rl_example_labels(example, tokenizer, text_only=False):
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
field_prompt, field_chosen, field_rejected, field_completion = (
"prompt",
"chosen",
"rejected",
"completion",
)

input_tokens = example[field_prompt]
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]

labels_chosen = example.get(field_chosen)
labels_rejected = example.get(field_rejected)
labels_completion = example.get(field_completion)

# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "

# Process and color each type of token
colored_tokens = process_tokens_for_rl_debug(
input_tokens, "yellow", tokenizer, text_only
)
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)

# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
# Process tokens
if labels_completion is None:
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
else:
colored_completion = process_tokens_for_rl_debug(
labels_completion, "green", tokenizer, text_only
)

# Logging information
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")

if labels_completion is None:
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
else:
LOG.info(f"COMPLETION RESPONSE: {delimiter.join(colored_completion)}\n\n\n")

return delimiter.join(colored_tokens)
24 changes: 23 additions & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,37 +203,59 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

prior_len = len(train_dataset)
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")

if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")

# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0

prior_len = len(train_dataset)
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)

if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
)

if cfg.group_by_length:
train_dataset = train_dataset.map(
Expand Down Expand Up @@ -493,7 +515,7 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
Expand Down

0 comments on commit f007c38

Please sign in to comment.