rishiraj commited on
Commit
8d1ee8d
1 Parent(s): 4a002be

Upload data.py

Browse files
Files changed (1) hide show
  1. data.py +187 -0
data.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import re
18
+ from typing import List, Literal, Optional
19
+
20
+ from datasets import DatasetDict, concatenate_datasets, load_dataset
21
+
22
+ from .configs import DataArguments
23
+
24
+
25
+ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
26
+
27
+
28
+ def apply_chat_template(
29
+ example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"
30
+ ):
31
+ def _strip_prefix(s, pattern):
32
+ # Use re.escape to escape any special characters in the pattern
33
+ return re.sub(f"^{re.escape(pattern)}", "", s)
34
+
35
+ if task in ["sft", "generation"]:
36
+ messages = example["messages"]
37
+ # We add an empty system message if there is none
38
+ if messages[0]["role"] != "system":
39
+ messages.insert(0, {"role": "system", "content": ""})
40
+ example["text"] = tokenizer.apply_chat_template(
41
+ messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
42
+ )
43
+ elif task == "rm":
44
+ if all(k in example.keys() for k in ("chosen", "rejected")):
45
+ chosen_messages = example["chosen"]
46
+ rejected_messages = example["rejected"]
47
+ # We add an empty system message if there is none
48
+ if chosen_messages[0]["role"] != "system":
49
+ chosen_messages.insert(0, {"role": "system", "content": ""})
50
+ if rejected_messages[0]["role"] != "system":
51
+ rejected_messages.insert(0, {"role": "system", "content": ""})
52
+ example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
53
+ example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
54
+ else:
55
+ raise ValueError(
56
+ f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
57
+ )
58
+ elif task == "dpo":
59
+ if all(k in example.keys() for k in ("chosen", "rejected")):
60
+ # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
61
+ prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]
62
+ # Insert system message
63
+ if example["chosen"][0]["role"] != "system":
64
+ prompt_messages.insert(0, {"role": "system", "content": ""})
65
+ else:
66
+ prompt_messages.insert(0, example["chosen"][0])
67
+ # TODO: handle case where chosen/rejected also have system messages
68
+ chosen_messages = example["chosen"][1:]
69
+ rejected_messages = example["rejected"][1:]
70
+ example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
71
+ example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
72
+ example["text_prompt"] = tokenizer.apply_chat_template(
73
+ prompt_messages, tokenize=False, add_generation_prompt=True
74
+ )
75
+
76
+ example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
77
+ example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
78
+ else:
79
+ raise ValueError(
80
+ f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
81
+ )
82
+ return example
83
+
84
+
85
+ def get_datasets(
86
+ data_config: DataArguments | dict,
87
+ splits: List[str] = ["train", "test"],
88
+ shuffle: bool = True,
89
+ ) -> DatasetDict:
90
+ """
91
+ Loads one or more datasets with varying training set proportions.
92
+
93
+ Args:
94
+ data_config (`DataArguments` or `dict`):
95
+ Dataset configuration and split proportions.
96
+ splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
97
+ Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
98
+ shuffle (`bool`, *optional*, defaults to `True`):
99
+ Whether to shuffle the training data.
100
+
101
+ Returns
102
+ [`DatasetDict`]: The dataset dictionary containing the loaded datasets.
103
+ """
104
+
105
+ if type(data_config) is DataArguments:
106
+ # Structure of the config to read the datasets and their mix
107
+ # datasets_mixer:
108
+ # - 'dataset1': 0.5
109
+ # - 'dataset2': 0.3
110
+ # - 'dataset3': 0.2
111
+ dataset_mixer = data_config.dataset_mixer
112
+ elif type(data_config) is dict:
113
+ # Structure of the input is:
114
+ # dataset_mixer = {
115
+ # "dataset1": 0.5,
116
+ # "dataset1": 0.3,
117
+ # "dataset1": 0.2,
118
+ # }
119
+ dataset_mixer = data_config
120
+ else:
121
+ raise ValueError(f"Data config {data_config} not recognized.")
122
+
123
+ raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
124
+ return raw_datasets
125
+
126
+
127
+ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict:
128
+ """
129
+ Loads and mixes datasets according to proportions specified in `dataset_mixer`.
130
+
131
+ Args:
132
+ dataset_mixer (`dict`):
133
+ Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
134
+ splits (Optional[List[str]], *optional*, defaults to `None`):
135
+ Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
136
+ shuffle (`bool`, *optional*, defaults to `True`):
137
+ Whether to shuffle the training data.
138
+ """
139
+ raw_datasets = DatasetDict()
140
+ raw_train_datasets = []
141
+ raw_val_datasets = []
142
+ fracs = []
143
+ for ds, frac in dataset_mixer.items():
144
+ fracs.append(frac)
145
+ for split in splits:
146
+ if "train" in split:
147
+ raw_train_datasets.append(
148
+ load_dataset(
149
+ ds,
150
+ split=split,
151
+ )
152
+ )
153
+ elif "test" in split:
154
+ raw_val_datasets.append(
155
+ load_dataset(
156
+ ds,
157
+ split=split,
158
+ )
159
+ )
160
+ else:
161
+ raise ValueError(f"Split type {split} not recognized as one of test or train.")
162
+
163
+ if any(frac < 0 for frac in fracs):
164
+ raise ValueError("Dataset fractions cannot be negative.")
165
+
166
+ if len(raw_train_datasets) > 0:
167
+ train_subsets = []
168
+ for dataset, frac in zip(raw_train_datasets, fracs):
169
+ train_subset = dataset.select(range(int(frac * len(dataset))))
170
+ train_subsets.append(train_subset)
171
+ if shuffle:
172
+ raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
173
+ else:
174
+ raw_datasets["train"] = concatenate_datasets(train_subsets)
175
+ # No subsampling for test datasets to enable fair comparison across models
176
+ if len(raw_val_datasets) > 0:
177
+ if shuffle:
178
+ raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
179
+ else:
180
+ raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
181
+
182
+ if len(raw_datasets) == 0:
183
+ raise ValueError(
184
+ f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
185
+ )
186
+
187
+ return raw_datasets