Skip to content

Commit

Permalink
Adding generic preference dataset builder (#1623)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Sep 19, 2024
1 parent c5db813 commit cd573f9
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ torchtune supports several widely used text-only datasets to help quickly bootst
alpaca_dataset
alpaca_cleaned_dataset
grammar_dataset
hh_rlhf_helpful_dataset
samsum_dataset
slimorca_dataset
stack_exchange_paired_dataset
Expand Down Expand Up @@ -51,6 +52,7 @@ These are especially useful for specifying from a YAML config.

instruct_dataset
chat_dataset
preference_dataset
text_completion_dataset

Generic dataset classes
Expand Down
1 change: 1 addition & 0 deletions tests/assets/hh_rlhf_tiny.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"chosen":[{"content":"What do I do when I have a hole in my trousers?","role":"user"},{"content":"Fix the hole.","role":"assistant"}],"rejected":[{"content":"What do I do when I have a hole in my trousers?","role":"user"},{"content":"Take them off.","role":"assistant"}]}]
46 changes: 45 additions & 1 deletion tests/torchtune/datasets/test_preference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from unittest import mock

import pytest
from tests.common import ASSETS
from tests.test_utils import DummyTokenizer
from torchtune.data import Message
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets._preference import PreferenceDataset
from torchtune.datasets._preference import preference_dataset, PreferenceDataset
from torchtune.modules.transforms import Transform


Expand Down Expand Up @@ -111,3 +112,46 @@ def test_get_item(self, mock_load_dataset, dialogue, expected):
prompt, label = ds[0]["rejected_input_ids"], ds[0]["rejected_labels"]
assert prompt == expected_rejected_tokens
assert label == expected_rejected_labels

def test_load_local_json(self):
expected_tokenized_chosen_prompts = [
[0, 4, 2, 1, 2, 4, 1, 4, 1, 4, 2, 2, 9, 3, 3, 5, -1]
]
expected_tokenized_rejected_prompts = [
[0, 4, 2, 1, 2, 4, 1, 4, 1, 4, 2, 2, 9, 4, 4, 4, -1]
]

# prompt length is number of tokens shared between
# the tokenized rejected and chosen messages
prompt_length = 13
expected_chosen_labels = [
[CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [3, 3, 5, -1]
]
expected_rejected_labels = [
[CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [4, 4, 4, -1]
]

ds = preference_dataset(
tokenizer=DummyTokenizer(),
source="json",
data_files=str(ASSETS / "hh_rlhf_tiny.json"),
train_on_input=False,
split="train",
)

assert len(ds) == 1

expected_keys = [
"chosen_input_ids",
"chosen_labels",
"rejected_input_ids",
"rejected_labels",
]
assert set(ds[0].keys()) == set(expected_keys)
assert len(ds[0].keys()) == 4

assert expected_tokenized_chosen_prompts[0] == ds[0]["chosen_input_ids"]
assert expected_tokenized_rejected_prompts[0] == ds[0]["rejected_input_ids"]

assert expected_chosen_labels[0] == ds[0]["chosen_labels"]
assert expected_rejected_labels[0] == ds[0]["rejected_labels"]
3 changes: 2 additions & 1 deletion torchtune/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset
from torchtune.datasets._instruct import instruct_dataset, InstructDataset
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._preference import PreferenceDataset
from torchtune.datasets._preference import preference_dataset, PreferenceDataset
from torchtune.datasets._samsum import samsum_dataset
from torchtune.datasets._sft import SFTDataset
from torchtune.datasets._slimorca import slimorca_dataset
Expand All @@ -35,6 +35,7 @@
"slimorca_dataset",
"ChatDataset",
"instruct_dataset",
"preference_dataset",
"chat_dataset",
"text_completion_dataset",
"TextCompletionDataset",
Expand Down
7 changes: 4 additions & 3 deletions torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,11 @@ def chat_dataset(
towards the column with the conversations.
Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is
set to ``False`` by default
set to ``False`` by default.
- If ``train_on_input`` is True, the prompt is used during training and
contributes to the loss.
- If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100)
contributes to the loss.
- If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100).
Args:
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
Expand Down
165 changes: 161 additions & 4 deletions torchtune/datasets/_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping
from typing import Any, Dict, List, Mapping, Optional

import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset

from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data import ChosenRejectedToMessages, CROSS_ENTROPY_IGNORE_IDX

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
Expand All @@ -35,8 +35,11 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
|----------|----------|------------|
| Q1 | A1 | A2 |
At a high level, this class will load the data from source and apply the following pre-processing steps
when a sample is retrieved:
In the above case when the format is prompt-chosen-rejected, only single-turn interactions are supported.
At a high level, this class will load the data from source and apply the following pre-processing steps when a
sample is retrieved:
1. Dataset-specific transform. This is typically unique to each dataset and extracts
the necessary prompt and chosen/rejected columns into torchtune's :class:`~torchtune.data.Message`
Expand Down Expand Up @@ -137,3 +140,157 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
)

return tokenized_dict


def preference_dataset(
tokenizer: ModelTokenizer,
*,
source: str,
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> PreferenceDataset:
"""
Configures a custom preference dataset comprising interactions between user and
model assistant.
This builder function can be used to configure a custom preference dataset directly from the yaml config
as an alternative to :class:`~torchtune.datasets.PreferenceDataset`, as it is made to be config friendly.
This function requires the dataset to have "chosen" and "rejected" columns. A single sample will share an
identical system +/ user prompt between both "chosen" and "rejected" columns, followed by one or multiple
turns of user and assistant messages::
| chosen | rejected |
|----------------------------------------|----------------------------------------|
| [{"role": "user", "content": Q1}, | [{"role": "user", "content": Q1}, |
| {"role": "assistant", "content": C1}] | {"role": "assistant", "content": R1}] |
This example will be converted to:
.. code-block:: python
chosen_messages = [
Message(role="user", content="Q1"),
Message(role="assistant", content="C1"),
]
rejected_messages = [
Message(role="user", content="Q1"),
Message(role="assistant", content="R1"),
]
These lists of messages are then tokenized for model training. Currently, this function only supports
conversations identical to :class:`~torchtune.data.JSONToMessages`, and does not support custom
message formats.
If your dataset does not follow this format, we recommend creating a custom message transform similar to
:class:`~torchtune.data.ChosenRejectedToMessages` and using it in a custom dataset builder function similar
to :class:`~torchtune.datasets.preference_dataset`.
Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is:
set to ``False`` by default.
- If ``train_on_input`` is True, the prompt is used during training and
contributes to the loss.
- If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100).
Args:
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text"), pass
in the filepath in ``data_files``, and set ``split="train"``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details.
column_map (Optional[Dict[str, str]]): a mapping from the expected columns "chosen" and "rejected"
in the message transform :class:`~torchtune.data.ChosenRejectedToMessages` to the new column names in
the dataset. Keys should be "chosen" and "rejected" and values should be the actual column names.
If None, keep the default columns "chosen" and "rejected".
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
new_system_prompt (Optional[str]): if specified, prepend a system message to every sample for both chosen
and rejected. This can serve as instructions to guide the model response. Setting this will OVERRIDE
any system messages already present in the dataset. Default is None.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Examples:
::
my_preference_dataset.json
[
{
"chosen_conversations": [
{
"content": "What do I do when I have a hole in my trousers?",
"role": "user"
},
{ "content": "Fix the hole.", "role": "assistant" }
],
"rejected_conversations": [
{
"content": "What do I do when I have a hole in my trousers?",
"role": "user"
},
{ "content": "Take them off.", "role": "assistant" }
]
}
]
::
>>> from torchtune.datasets import preference_dataset
>>> column_map = {
... "chosen": "chosen_conversations",
... "rejected": "rejected_conversations"
>>> }
>>> dataset = preference_dataset(
... tokenizer=tokenizer,
... source="json",
... column_map=column_map,
... data_files=my_preference_dataset.json,
... train_on_input=False,
... split="train",
>>> )
>>> tokenizer.decode(dataset[0]["chosen_input_ids"], skip_special_tokens=True)
What do I do when I have a hole in my trousers?Fix the hole.
>>> tokenizer.decode(dataset[0]["rejected_input_ids"], skip_special_tokens=True)
What do I do when I have a hole in my trousers?Take them off.
This can also be accomplished via the yaml config:
.. code-block:: yaml
dataset:
_component_: torchtune.datasets.preference_dataset
source: json
data_files: my_preference_dataset.json
column_map:
chosen: chosen_conversations
rejected: rejected_conversations
train_on_input: False
split: train
Returns:
PreferenceDataset: The preference dataset built from source paired data.
"""

message_transform = ChosenRejectedToMessages(
train_on_input=train_on_input,
column_map=column_map,
new_system_prompt=new_system_prompt,
)

return PreferenceDataset(
source=source,
message_transform=message_transform,
tokenizer=tokenizer,
split=split,
**load_dataset_kwargs,
)

0 comments on commit cd573f9

Please sign in to comment.