ACMC
commited on
Commit
•
bd73a7b
1
Parent(s):
bf9e30f
Bugfix
Browse files- app.py +99 -46
- utils.py +52 -49
- validation.py +40 -27
app.py
CHANGED
@@ -1,33 +1,41 @@
|
|
1 |
# %%
|
|
|
|
|
|
|
2 |
from uuid import uuid4
|
3 |
-
|
4 |
import datasets
|
5 |
-
import
|
6 |
-
import io
|
7 |
-
from utils import (
|
8 |
-
process_chat_file,
|
9 |
-
transform_conversations_dataset_into_training_examples,
|
10 |
-
)
|
11 |
-
from validation import (
|
12 |
-
check_format_errors,
|
13 |
-
estimate_cost,
|
14 |
-
get_distributions,
|
15 |
-
)
|
16 |
import matplotlib.pyplot as plt
|
17 |
|
|
|
|
|
|
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
modified_dataset = None
|
21 |
for file in progress.tqdm(files, desc="Processing files"):
|
22 |
if modified_dataset is None:
|
23 |
# First file
|
24 |
modified_dataset = process_chat_file(
|
25 |
-
file,
|
|
|
|
|
|
|
|
|
26 |
)
|
27 |
else:
|
28 |
# Concatenate the datasets
|
29 |
this_file_dataset = process_chat_file(
|
30 |
-
file,
|
|
|
|
|
|
|
|
|
31 |
)
|
32 |
modified_dataset = datasets.concatenate_datasets(
|
33 |
[modified_dataset, this_file_dataset]
|
@@ -43,25 +51,41 @@ def file_upload_callback(
|
|
43 |
user_role,
|
44 |
model_role,
|
45 |
whatsapp_name,
|
|
|
|
|
46 |
progress=gr.Progress(),
|
47 |
):
|
48 |
-
|
49 |
-
full_system_prompt = f"""
|
50 |
-
|
51 |
The {model_role} and the {user_role} can send multiple messages in a row, as a JSON list of strings. Your answer always needs to be JSON compliant. The strings are delimited by double quotes ("). The strings are separated by a comma (,). The list is delimited by square brackets ([, ]). Always start your answer with [", and close it with "]. Do not write anything else in your answer after "].
|
52 |
# Information about me
|
53 |
-
You should use the following information about me to answer:
|
54 |
{system_prompt}"""
|
55 |
# Example
|
56 |
# [{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
|
57 |
# Response:
|
58 |
# [{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# # Avoid using the full system prompt for now, as it is too long and increases the cost of the training
|
61 |
# full_system_prompt = system_prompt
|
62 |
dataset = convert_to_dataset(
|
63 |
-
files=files,
|
|
|
|
|
|
|
|
|
|
|
64 |
)
|
|
|
|
|
65 |
training_examples_ds = transform_conversations_dataset_into_training_examples(
|
66 |
conversations_ds=dataset,
|
67 |
system_prompt=full_system_prompt,
|
@@ -69,6 +93,7 @@ You should use the following information about me to answer:
|
|
69 |
model_role=model_role,
|
70 |
whatsapp_name=whatsapp_name,
|
71 |
)
|
|
|
72 |
|
73 |
# Split into training and validation datasets (80% and 20%)
|
74 |
training_examples_ds = training_examples_ds.train_test_split(
|
@@ -78,9 +103,9 @@ You should use the following information about me to answer:
|
|
78 |
training_examples_ds["train"],
|
79 |
training_examples_ds["test"],
|
80 |
)
|
81 |
-
training_examples_ds = training_examples_ds
|
82 |
# range(min(250, len(training_examples_ds)))
|
83 |
-
#)
|
84 |
validation_examples_ds = validation_examples_ds.select(
|
85 |
range(min(200, len(validation_examples_ds)))
|
86 |
)
|
@@ -124,6 +149,12 @@ You should use the following information about me to answer:
|
|
124 |
file_path_validation = f"validation_examples_{uuid}.jsonl"
|
125 |
validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False)
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
return (
|
128 |
file_path,
|
129 |
gr.update(visible=True),
|
@@ -142,7 +173,7 @@ def remove_file_and_hide_button(file_path):
|
|
142 |
try:
|
143 |
os.remove(file_path)
|
144 |
except Exception as e:
|
145 |
-
|
146 |
|
147 |
return gr.update(visible=False)
|
148 |
|
@@ -190,32 +221,52 @@ with gr.Blocks(theme=theme) as demo:
|
|
190 |
info="Enter your WhatsApp name as it appears in your profile. It needs to match exactly your name. If you're unsure, you can check the chat messages to see it.",
|
191 |
)
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
submit = gr.Button(value="Submit", variant="primary")
|
221 |
|
@@ -253,6 +304,8 @@ with gr.Blocks(theme=theme) as demo:
|
|
253 |
user_role,
|
254 |
model_role,
|
255 |
whatsapp_name,
|
|
|
|
|
256 |
],
|
257 |
outputs=[
|
258 |
output_file,
|
|
|
1 |
# %%
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
from uuid import uuid4
|
6 |
+
|
7 |
import datasets
|
8 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
|
11 |
+
from utils import (process_chat_file,
|
12 |
+
transform_conversations_dataset_into_training_examples)
|
13 |
+
from validation import check_format_errors, estimate_cost, get_distributions
|
14 |
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
logger.setLevel(logging.INFO)
|
17 |
+
|
18 |
+
|
19 |
+
def convert_to_dataset(files, do_spelling_correction, progress, whatsapp_name, datetime_dayfirst, message_line_format):
|
20 |
modified_dataset = None
|
21 |
for file in progress.tqdm(files, desc="Processing files"):
|
22 |
if modified_dataset is None:
|
23 |
# First file
|
24 |
modified_dataset = process_chat_file(
|
25 |
+
file,
|
26 |
+
do_spelling_correction=do_spelling_correction,
|
27 |
+
whatsapp_name=whatsapp_name,
|
28 |
+
datetime_dayfirst=datetime_dayfirst,
|
29 |
+
message_line_format=message_line_format,
|
30 |
)
|
31 |
else:
|
32 |
# Concatenate the datasets
|
33 |
this_file_dataset = process_chat_file(
|
34 |
+
file,
|
35 |
+
do_spelling_correction=do_spelling_correction,
|
36 |
+
whatsapp_name=whatsapp_name,
|
37 |
+
datetime_dayfirst=datetime_dayfirst,
|
38 |
+
message_line_format=message_line_format,
|
39 |
)
|
40 |
modified_dataset = datasets.concatenate_datasets(
|
41 |
[modified_dataset, this_file_dataset]
|
|
|
51 |
user_role,
|
52 |
model_role,
|
53 |
whatsapp_name,
|
54 |
+
datetime_dayfirst,
|
55 |
+
message_line_format,
|
56 |
progress=gr.Progress(),
|
57 |
):
|
58 |
+
logger.info(f"Processing {files}")
|
59 |
+
full_system_prompt = f"""# Task
|
60 |
+
You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me.
|
61 |
The {model_role} and the {user_role} can send multiple messages in a row, as a JSON list of strings. Your answer always needs to be JSON compliant. The strings are delimited by double quotes ("). The strings are separated by a comma (,). The list is delimited by square brackets ([, ]). Always start your answer with [", and close it with "]. Do not write anything else in your answer after "].
|
62 |
# Information about me
|
|
|
63 |
{system_prompt}"""
|
64 |
# Example
|
65 |
# [{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
|
66 |
# Response:
|
67 |
# [{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
|
68 |
|
69 |
+
# Check if the user has not chosen any files
|
70 |
+
if not files or len(files) == 0:
|
71 |
+
raise gr.Error("Please upload at least one file.")
|
72 |
+
|
73 |
+
# Check if the user has not entered their whatsapp name
|
74 |
+
if not whatsapp_name or len(whatsapp_name) == 0:
|
75 |
+
raise gr.Error("Please enter your WhatsApp name.")
|
76 |
+
|
77 |
# # Avoid using the full system prompt for now, as it is too long and increases the cost of the training
|
78 |
# full_system_prompt = system_prompt
|
79 |
dataset = convert_to_dataset(
|
80 |
+
files=files,
|
81 |
+
progress=progress,
|
82 |
+
do_spelling_correction=do_spelling_correction,
|
83 |
+
whatsapp_name=whatsapp_name,
|
84 |
+
datetime_dayfirst=datetime_dayfirst,
|
85 |
+
message_line_format=message_line_format,
|
86 |
)
|
87 |
+
logger.info(f"Number of conversations of dataset before being transformed: {len(dataset)}")
|
88 |
+
|
89 |
training_examples_ds = transform_conversations_dataset_into_training_examples(
|
90 |
conversations_ds=dataset,
|
91 |
system_prompt=full_system_prompt,
|
|
|
93 |
model_role=model_role,
|
94 |
whatsapp_name=whatsapp_name,
|
95 |
)
|
96 |
+
logger.info(f"Number of training examples: {len(training_examples_ds)}")
|
97 |
|
98 |
# Split into training and validation datasets (80% and 20%)
|
99 |
training_examples_ds = training_examples_ds.train_test_split(
|
|
|
103 |
training_examples_ds["train"],
|
104 |
training_examples_ds["test"],
|
105 |
)
|
106 |
+
training_examples_ds = training_examples_ds # .select(
|
107 |
# range(min(250, len(training_examples_ds)))
|
108 |
+
# )
|
109 |
validation_examples_ds = validation_examples_ds.select(
|
110 |
range(min(200, len(validation_examples_ds)))
|
111 |
)
|
|
|
149 |
file_path_validation = f"validation_examples_{uuid}.jsonl"
|
150 |
validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False)
|
151 |
|
152 |
+
# If there's less than 50 training examples, show a warning message
|
153 |
+
if len(training_examples_ds) < 50:
|
154 |
+
gr.Warning(
|
155 |
+
"Warning: There are less than 50 training examples. The model may not perform well with such a small dataset. Consider adding more chat files to increase the number of training examples."
|
156 |
+
)
|
157 |
+
|
158 |
return (
|
159 |
file_path,
|
160 |
gr.update(visible=True),
|
|
|
173 |
try:
|
174 |
os.remove(file_path)
|
175 |
except Exception as e:
|
176 |
+
logger.info(f"Error removing file {file_path}: {e}")
|
177 |
|
178 |
return gr.update(visible=False)
|
179 |
|
|
|
221 |
info="Enter your WhatsApp name as it appears in your profile. It needs to match exactly your name. If you're unsure, you can check the chat messages to see it.",
|
222 |
)
|
223 |
|
224 |
+
# Advanced parameters section, collapsed by default
|
225 |
+
with gr.Accordion(label="Advanced Parameters", open=False):
|
226 |
+
gr.Markdown(
|
227 |
+
"""
|
228 |
+
These are advanced parameters that you can change if you know what you're doing. If you're unsure, you can leave them as they are.
|
229 |
+
"""
|
230 |
+
)
|
231 |
|
232 |
+
user_role = gr.Textbox(
|
233 |
+
label="Role for User",
|
234 |
+
info="This is a technical parameter. If you don't know what to write, just type 'user'.",
|
235 |
+
value="user",
|
236 |
+
)
|
237 |
|
238 |
+
model_role = gr.Textbox(
|
239 |
+
label="Role for Model",
|
240 |
+
info="This is a technical parameter. Usual values are 'model' or 'assistant'.",
|
241 |
+
value="model",
|
242 |
+
)
|
243 |
|
244 |
+
message_line_format = gr.Textbox(
|
245 |
+
label="Message Line Format",
|
246 |
+
info="Format of each message line in the chat file, as a regular expression. The default value should work for most cases.",
|
247 |
+
value=r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+): (?P<message>.+)",
|
248 |
+
)
|
249 |
+
|
250 |
+
datetime_dayfirst = gr.Checkbox(
|
251 |
+
label="Date format: Day first",
|
252 |
+
info="Check this box if the date time format in the chat messages is in the format 'DD/MM/YYYY'. You can check your phone settings to see the date format. Otherwise, it will be assumed that the date time format is 'MM/DD/YYYY'.",
|
253 |
+
value=True,
|
254 |
+
)
|
255 |
+
|
256 |
+
do_spelling_correction = gr.Checkbox(
|
257 |
+
label="Do Spelling Correction (English)",
|
258 |
+
info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.",
|
259 |
+
)
|
260 |
+
|
261 |
+
# Allow the user to choose the validation split size
|
262 |
+
validation_split = gr.Slider(
|
263 |
+
minimum=0.0,
|
264 |
+
maximum=0.5,
|
265 |
+
value=0.2,
|
266 |
+
interactive=True,
|
267 |
+
label="Validation Split",
|
268 |
+
info="Choose the percentage of the dataset to be used for validation. For example, if you choose 0.2, 20% of the dataset will be used for validation and 80% for training.",
|
269 |
+
)
|
270 |
|
271 |
submit = gr.Button(value="Submit", variant="primary")
|
272 |
|
|
|
304 |
user_role,
|
305 |
model_role,
|
306 |
whatsapp_name,
|
307 |
+
datetime_dayfirst,
|
308 |
+
message_line_format,
|
309 |
],
|
310 |
outputs=[
|
311 |
output_file,
|
utils.py
CHANGED
@@ -1,36 +1,13 @@
|
|
1 |
-
import datasets
|
2 |
import datetime
|
3 |
-
import os
|
4 |
import json
|
5 |
-
|
|
|
6 |
import re
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
)
|
11 |
-
|
12 |
-
|
13 |
-
def process_line(example):
|
14 |
-
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
|
15 |
-
try:
|
16 |
-
groups = exp.match(example["text"]).groupdict()
|
17 |
-
timestamp = datetime.datetime(
|
18 |
-
int(groups["year"]),
|
19 |
-
int(groups["month"]),
|
20 |
-
int(groups["day"]),
|
21 |
-
int(groups["hour"]),
|
22 |
-
int(groups["minute"]),
|
23 |
-
).timestamp()
|
24 |
-
return {
|
25 |
-
"message": groups["message"],
|
26 |
-
"contact_name": groups["contact_name"],
|
27 |
-
"timestamp": timestamp,
|
28 |
-
}
|
29 |
-
except Exception as e:
|
30 |
-
print(e)
|
31 |
-
print(example["text"])
|
32 |
-
raise e
|
33 |
-
|
34 |
|
35 |
# %%
|
36 |
# Now, create message groups ('conversations')
|
@@ -63,10 +40,11 @@ def printable_conversation(conversation):
|
|
63 |
)
|
64 |
|
65 |
|
|
|
|
|
66 |
# %%
|
67 |
# Use spacy to spell check the messages
|
68 |
import spacy
|
69 |
-
import contextualSpellCheck
|
70 |
from spellchecker import SpellChecker
|
71 |
|
72 |
spell = SpellChecker()
|
@@ -78,17 +56,17 @@ def spell_check_conversation(conversation):
|
|
78 |
for i, message in enumerate(conversation["conversations"]):
|
79 |
# Use SpaCy to get the words
|
80 |
words = spell.split_words(message["message"])
|
81 |
-
|
82 |
corrected_message = []
|
83 |
for word in words:
|
84 |
correction = spell.correction(word)
|
85 |
if (correction != None) and (correction != word):
|
86 |
-
|
87 |
corrected_message.append(correction)
|
88 |
else:
|
89 |
corrected_message.append(word)
|
90 |
|
91 |
-
|
92 |
joined_message = " ".join(corrected_message)
|
93 |
conversation["conversations"][i]["message"] = joined_message
|
94 |
|
@@ -107,7 +85,7 @@ def spell_check_conversation_spacy(conversation):
|
|
107 |
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
|
108 |
for i, doc in enumerate(docs):
|
109 |
if doc._.performed_spellCheck:
|
110 |
-
|
111 |
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
|
112 |
|
113 |
return conversation
|
@@ -144,8 +122,8 @@ A: I'm fine too
|
|
144 |
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
|
145 |
"""
|
146 |
|
147 |
-
from transformers import AutoTokenizer, AutoModelForNextSentencePrediction
|
148 |
import torch
|
|
|
149 |
|
150 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
151 |
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
@@ -186,10 +164,12 @@ def swap_messages_if_needed(message1, message2):
|
|
186 |
swap = logits[0, 0] - logits[1, 0] < -0.2
|
187 |
if swap:
|
188 |
# Swap the messages
|
189 |
-
|
|
|
|
|
190 |
return message2, message1
|
191 |
else:
|
192 |
-
#
|
193 |
return message1, message2
|
194 |
|
195 |
|
@@ -208,8 +188,8 @@ def swap_messages_if_needed_in_conversation(conversation):
|
|
208 |
new_conversation[-1] = message1
|
209 |
new_conversation.append(message2)
|
210 |
|
211 |
-
#
|
212 |
-
#
|
213 |
return new_conversation
|
214 |
|
215 |
|
@@ -226,26 +206,38 @@ test_conversation = [
|
|
226 |
"timestamp": 3,
|
227 |
},
|
228 |
]
|
229 |
-
#
|
230 |
|
231 |
# %%
|
232 |
# Now, we'll train an mT5 model to generate the next message in a conversation
|
233 |
import os
|
234 |
|
235 |
|
236 |
-
# For the contact_name, rewrite everything that is not 'Aldi' to 'Other'
|
237 |
-
def rewrite_contact_name(conversation):
|
238 |
-
for message in conversation["conversations"]:
|
239 |
-
if message["contact_name"] != "Aldi":
|
240 |
-
message["contact_name"] = "Other"
|
241 |
-
return conversation
|
242 |
-
|
243 |
-
|
244 |
# %%
|
245 |
-
def process_chat_file(file, do_spelling_correction, do_reordering=False):
|
246 |
"""
|
247 |
Process a chat file and return a dataset with the conversations.
|
248 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
ds = (
|
250 |
datasets.load_dataset("text", data_files=[file])["train"]
|
251 |
.filter(
|
@@ -288,6 +280,13 @@ def process_chat_file(file, do_spelling_correction, do_reordering=False):
|
|
288 |
else:
|
289 |
reordered_conversations_ds = spell_checked_conversations_ds
|
290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
changed_contact_name_ds = reordered_conversations_ds.map(
|
292 |
rewrite_contact_name
|
293 |
) # , num_proc=os.cpu_count() - 1)
|
@@ -372,6 +371,10 @@ def transform_conversations_dataset_into_training_examples(
|
|
372 |
]
|
373 |
}
|
374 |
)
|
|
|
|
|
|
|
|
|
375 |
# Before returning, flatten the list of dictionaries into a dictionary of lists
|
376 |
flattened_examples = {}
|
377 |
for key in processed_examples[0].keys():
|
|
|
|
|
1 |
import datetime
|
|
|
2 |
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
import re
|
6 |
+
import datasets
|
7 |
+
import dateutil.parser
|
8 |
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# %%
|
13 |
# Now, create message groups ('conversations')
|
|
|
40 |
)
|
41 |
|
42 |
|
43 |
+
import contextualSpellCheck
|
44 |
+
|
45 |
# %%
|
46 |
# Use spacy to spell check the messages
|
47 |
import spacy
|
|
|
48 |
from spellchecker import SpellChecker
|
49 |
|
50 |
spell = SpellChecker()
|
|
|
56 |
for i, message in enumerate(conversation["conversations"]):
|
57 |
# Use SpaCy to get the words
|
58 |
words = spell.split_words(message["message"])
|
59 |
+
logger.info(f"Words: {words}")
|
60 |
corrected_message = []
|
61 |
for word in words:
|
62 |
correction = spell.correction(word)
|
63 |
if (correction != None) and (correction != word):
|
64 |
+
logger.info(f"Spell check: {word} -> {correction}")
|
65 |
corrected_message.append(correction)
|
66 |
else:
|
67 |
corrected_message.append(word)
|
68 |
|
69 |
+
logger.info(f"Corrected message: {corrected_message}")
|
70 |
joined_message = " ".join(corrected_message)
|
71 |
conversation["conversations"][i]["message"] = joined_message
|
72 |
|
|
|
85 |
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
|
86 |
for i, doc in enumerate(docs):
|
87 |
if doc._.performed_spellCheck:
|
88 |
+
logger.info(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
|
89 |
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
|
90 |
|
91 |
return conversation
|
|
|
122 |
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
|
123 |
"""
|
124 |
|
|
|
125 |
import torch
|
126 |
+
from transformers import AutoModelForNextSentencePrediction, AutoTokenizer
|
127 |
|
128 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
129 |
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
|
|
164 |
swap = logits[0, 0] - logits[1, 0] < -0.2
|
165 |
if swap:
|
166 |
# Swap the messages
|
167 |
+
logger.info(
|
168 |
+
f"Swapping messages: {message1['message']} <-> {message2['message']}"
|
169 |
+
)
|
170 |
return message2, message1
|
171 |
else:
|
172 |
+
# logger.info(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
|
173 |
return message1, message2
|
174 |
|
175 |
|
|
|
188 |
new_conversation[-1] = message1
|
189 |
new_conversation.append(message2)
|
190 |
|
191 |
+
# logger.info(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
|
192 |
+
# logger.info(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
|
193 |
return new_conversation
|
194 |
|
195 |
|
|
|
206 |
"timestamp": 3,
|
207 |
},
|
208 |
]
|
209 |
+
# logger.info(swap_messages_if_needed_in_conversation(test_conversation))
|
210 |
|
211 |
# %%
|
212 |
# Now, we'll train an mT5 model to generate the next message in a conversation
|
213 |
import os
|
214 |
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
# %%
|
217 |
+
def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayfirst, message_line_format, do_reordering=False):
|
218 |
"""
|
219 |
Process a chat file and return a dataset with the conversations.
|
220 |
"""
|
221 |
+
exp = re.compile(
|
222 |
+
# r"(?P<msg_datetime>.+?) - (?P<contact_name>.+): (?P<message>.+)"
|
223 |
+
# r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+): (?P<message>.+)"
|
224 |
+
message_line_format
|
225 |
+
)
|
226 |
+
|
227 |
+
def process_line(example):
|
228 |
+
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
|
229 |
+
try:
|
230 |
+
groups = exp.match(example["text"]).groupdict()
|
231 |
+
timestamp = dateutil.parser.parse(groups['msg_datetime'], dayfirst=datetime_dayfirst).timestamp()
|
232 |
+
return {
|
233 |
+
"message": groups["message"],
|
234 |
+
"contact_name": groups["contact_name"],
|
235 |
+
"timestamp": timestamp,
|
236 |
+
}
|
237 |
+
except Exception as e:
|
238 |
+
logger.exception(example["text"])
|
239 |
+
raise e
|
240 |
+
|
241 |
ds = (
|
242 |
datasets.load_dataset("text", data_files=[file])["train"]
|
243 |
.filter(
|
|
|
280 |
else:
|
281 |
reordered_conversations_ds = spell_checked_conversations_ds
|
282 |
|
283 |
+
# For the contact_name, rewrite everything that is not 'my_whatsapp_name' to 'Other'
|
284 |
+
def rewrite_contact_name(conversation):
|
285 |
+
for message in conversation["conversations"]:
|
286 |
+
if message["contact_name"] != whatsapp_name:
|
287 |
+
message["contact_name"] = "Other"
|
288 |
+
return conversation
|
289 |
+
|
290 |
changed_contact_name_ds = reordered_conversations_ds.map(
|
291 |
rewrite_contact_name
|
292 |
) # , num_proc=os.cpu_count() - 1)
|
|
|
371 |
]
|
372 |
}
|
373 |
)
|
374 |
+
else:
|
375 |
+
logger.warning(
|
376 |
+
f"Discarding conversation because the length is not at least {MIN_MESSAGES_THRESHOLD}: {messages}"
|
377 |
+
)
|
378 |
# Before returning, flatten the list of dictionaries into a dictionary of lists
|
379 |
flattened_examples = {}
|
380 |
for key in processed_examples[0].keys():
|
validation.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
-
import
|
2 |
from collections import defaultdict
|
|
|
|
|
3 |
import tiktoken
|
4 |
|
|
|
|
|
|
|
5 |
|
6 |
def check_format_errors(train_dataset, user_role, model_role):
|
7 |
"""
|
@@ -24,7 +29,10 @@ def check_format_errors(train_dataset, user_role, model_role):
|
|
24 |
if "role" not in message or "content" not in message:
|
25 |
format_errors["message_missing_key"] += 1
|
26 |
|
27 |
-
if any(
|
|
|
|
|
|
|
28 |
format_errors["message_unrecognized_key"] += 1
|
29 |
|
30 |
if message.get("role", None) not in ["system", user_role, model_role]:
|
@@ -40,14 +48,15 @@ def check_format_errors(train_dataset, user_role, model_role):
|
|
40 |
format_errors["example_missing_assistant_message"] += 1
|
41 |
|
42 |
if format_errors:
|
43 |
-
|
44 |
for k, v in format_errors.items():
|
45 |
-
|
46 |
else:
|
47 |
-
|
48 |
|
49 |
return format_errors if format_errors else {}
|
50 |
|
|
|
51 |
def get_distributions(train_dataset, user_role, model_role):
|
52 |
"""
|
53 |
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
@@ -76,7 +85,6 @@ def get_distributions(train_dataset, user_role, model_role):
|
|
76 |
num_tokens += len(encoding.encode(message["content"]))
|
77 |
return num_tokens
|
78 |
|
79 |
-
|
80 |
n_missing_system = 0
|
81 |
n_missing_user = 0
|
82 |
n_messages = []
|
@@ -92,13 +100,13 @@ def get_distributions(train_dataset, user_role, model_role):
|
|
92 |
n_messages.append(len(messages))
|
93 |
convo_lens.append(num_tokens_from_messages(messages))
|
94 |
assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
|
95 |
-
|
96 |
return {
|
97 |
"n_missing_system": n_missing_system,
|
98 |
"n_missing_user": n_missing_user,
|
99 |
"n_messages": n_messages,
|
100 |
"convo_lens": convo_lens,
|
101 |
-
"assistant_message_lens": assistant_message_lens
|
102 |
}
|
103 |
|
104 |
|
@@ -106,48 +114,49 @@ def check_token_counts(train_dataset, user_role, model_role):
|
|
106 |
"""
|
107 |
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
108 |
"""
|
109 |
-
def print_distribution(values, name):
|
110 |
-
print(f"\n#### Distribution of {name}:")
|
111 |
-
print(f"min / max: {min(values)}, {max(values)}")
|
112 |
-
print(f"mean / median: {np.mean(values)}, {np.median(values)}")
|
113 |
-
print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
|
114 |
-
|
115 |
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
# Warnings and tokens counts
|
118 |
-
distributions = get_distributions(
|
|
|
|
|
119 |
n_missing_system = distributions["n_missing_system"]
|
120 |
n_missing_user = distributions["n_missing_user"]
|
121 |
n_messages = distributions["n_messages"]
|
122 |
convo_lens = distributions["convo_lens"]
|
123 |
assistant_message_lens = distributions["assistant_message_lens"]
|
124 |
|
125 |
-
|
126 |
-
|
127 |
print_distribution(n_messages, "num_messages_per_example")
|
128 |
print_distribution(convo_lens, "num_total_tokens_per_example")
|
129 |
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
|
130 |
n_too_long = sum(l > 4096 for l in convo_lens)
|
131 |
-
|
132 |
f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning"
|
133 |
)
|
134 |
|
135 |
-
return
|
136 |
|
137 |
|
138 |
def estimate_cost(train_dataset, user_role, model_role):
|
139 |
"""
|
140 |
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
141 |
"""
|
142 |
-
distributions = get_distributions(
|
|
|
|
|
143 |
n_missing_system = distributions["n_missing_system"]
|
144 |
n_missing_user = distributions["n_missing_user"]
|
145 |
n_messages = distributions["n_messages"]
|
146 |
convo_lens = distributions["convo_lens"]
|
147 |
assistant_message_lens = distributions["assistant_message_lens"]
|
148 |
|
149 |
-
|
150 |
-
|
151 |
# Pricing and default n_epochs estimate
|
152 |
MAX_TOKENS_PER_EXAMPLE = 4096
|
153 |
|
@@ -159,10 +168,13 @@ def estimate_cost(train_dataset, user_role, model_role):
|
|
159 |
|
160 |
n_epochs = TARGET_EPOCHS
|
161 |
n_train_examples = len(train_dataset)
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
166 |
|
167 |
n_billing_tokens_in_dataset = sum(
|
168 |
min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
|
@@ -170,5 +182,6 @@ def estimate_cost(train_dataset, user_role, model_role):
|
|
170 |
|
171 |
return {
|
172 |
"Estimated number of tokens in dataset": n_billing_tokens_in_dataset,
|
173 |
-
f"Estimated number of tokens that will be billed (assuming {n_epochs} training epochs)": n_epochs
|
|
|
174 |
}
|
|
|
1 |
+
import logging
|
2 |
from collections import defaultdict
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
import tiktoken
|
6 |
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
logger.setLevel(logging.INFO)
|
9 |
+
|
10 |
|
11 |
def check_format_errors(train_dataset, user_role, model_role):
|
12 |
"""
|
|
|
29 |
if "role" not in message or "content" not in message:
|
30 |
format_errors["message_missing_key"] += 1
|
31 |
|
32 |
+
if any(
|
33 |
+
k not in ("role", "content", "name", "function_call", "weight")
|
34 |
+
for k in message
|
35 |
+
):
|
36 |
format_errors["message_unrecognized_key"] += 1
|
37 |
|
38 |
if message.get("role", None) not in ["system", user_role, model_role]:
|
|
|
48 |
format_errors["example_missing_assistant_message"] += 1
|
49 |
|
50 |
if format_errors:
|
51 |
+
logger.warning("Found errors:")
|
52 |
for k, v in format_errors.items():
|
53 |
+
logger.warning(f"{k}: {v}")
|
54 |
else:
|
55 |
+
logger.info("No errors found")
|
56 |
|
57 |
return format_errors if format_errors else {}
|
58 |
|
59 |
+
|
60 |
def get_distributions(train_dataset, user_role, model_role):
|
61 |
"""
|
62 |
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
|
|
85 |
num_tokens += len(encoding.encode(message["content"]))
|
86 |
return num_tokens
|
87 |
|
|
|
88 |
n_missing_system = 0
|
89 |
n_missing_user = 0
|
90 |
n_messages = []
|
|
|
100 |
n_messages.append(len(messages))
|
101 |
convo_lens.append(num_tokens_from_messages(messages))
|
102 |
assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
|
103 |
+
|
104 |
return {
|
105 |
"n_missing_system": n_missing_system,
|
106 |
"n_missing_user": n_missing_user,
|
107 |
"n_messages": n_messages,
|
108 |
"convo_lens": convo_lens,
|
109 |
+
"assistant_message_lens": assistant_message_lens,
|
110 |
}
|
111 |
|
112 |
|
|
|
114 |
"""
|
115 |
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
116 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
def print_distribution(values, name):
|
119 |
+
logger.info(f"\n#### Distribution of {name}:")
|
120 |
+
logger.info(f"min / max: {min(values)}, {max(values)}")
|
121 |
+
logger.info(f"mean / median: {np.mean(values)}, {np.median(values)}")
|
122 |
+
logger.info(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
|
123 |
|
124 |
# Warnings and tokens counts
|
125 |
+
distributions = get_distributions(
|
126 |
+
train_dataset, user_role=user_role, model_role=model_role
|
127 |
+
)
|
128 |
n_missing_system = distributions["n_missing_system"]
|
129 |
n_missing_user = distributions["n_missing_user"]
|
130 |
n_messages = distributions["n_messages"]
|
131 |
convo_lens = distributions["convo_lens"]
|
132 |
assistant_message_lens = distributions["assistant_message_lens"]
|
133 |
|
134 |
+
logger.info("Num examples missing system message:", n_missing_system)
|
135 |
+
logger.info("Num examples missing user message:", n_missing_user)
|
136 |
print_distribution(n_messages, "num_messages_per_example")
|
137 |
print_distribution(convo_lens, "num_total_tokens_per_example")
|
138 |
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
|
139 |
n_too_long = sum(l > 4096 for l in convo_lens)
|
140 |
+
logger.info(
|
141 |
f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning"
|
142 |
)
|
143 |
|
144 |
+
return
|
145 |
|
146 |
|
147 |
def estimate_cost(train_dataset, user_role, model_role):
|
148 |
"""
|
149 |
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
150 |
"""
|
151 |
+
distributions = get_distributions(
|
152 |
+
train_dataset, user_role=user_role, model_role=model_role
|
153 |
+
)
|
154 |
n_missing_system = distributions["n_missing_system"]
|
155 |
n_missing_user = distributions["n_missing_user"]
|
156 |
n_messages = distributions["n_messages"]
|
157 |
convo_lens = distributions["convo_lens"]
|
158 |
assistant_message_lens = distributions["assistant_message_lens"]
|
159 |
|
|
|
|
|
160 |
# Pricing and default n_epochs estimate
|
161 |
MAX_TOKENS_PER_EXAMPLE = 4096
|
162 |
|
|
|
168 |
|
169 |
n_epochs = TARGET_EPOCHS
|
170 |
n_train_examples = len(train_dataset)
|
171 |
+
try:
|
172 |
+
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
|
173 |
+
n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
|
174 |
+
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
|
175 |
+
n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
|
176 |
+
except:
|
177 |
+
n_epochs = TARGET_EPOCHS
|
178 |
|
179 |
n_billing_tokens_in_dataset = sum(
|
180 |
min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
|
|
|
182 |
|
183 |
return {
|
184 |
"Estimated number of tokens in dataset": n_billing_tokens_in_dataset,
|
185 |
+
f"Estimated number of tokens that will be billed (assuming {n_epochs} training epochs)": n_epochs
|
186 |
+
* n_billing_tokens_in_dataset,
|
187 |
}
|