ACMC
Bugfix
a8dfddd
raw
history blame
20.8 kB
import datetime
import json
import logging
import os
import re
import datasets
import dateutil.parser
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# %%
def group_messages(messages_iterable, minutes_threshold):
"""
Groups messages in a conversation. If the difference between two consecutive messages is less than `minutes_threshold` minutes, they are grouped together.
"""
groups = []
current_group = []
try:
first_message = next(messages_iterable)
current_group.append(first_message)
except StopIteration:
logger.exception("No messages in the conversation")
return []
for message in messages_iterable:
assert len(current_group) > 0 # We should never have an empty group
if (
message["timestamp"] - current_group[-1]["timestamp"]
< minutes_threshold * 60
):
current_group.append(message)
else:
groups.append(current_group)
current_group = [message]
groups.append(current_group)
return groups
def printable_conversation(conversation):
return "\n".join(
[f"{message['contact_name']}: {message['message']}" for message in conversation]
)
import contextualSpellCheck
# %%
# Use spacy to spell check the messages
import spacy
from spellchecker import SpellChecker
spell = SpellChecker()
# nlp = spacy.load("es_core_news_sm")
nlp = spacy.load("en_core_web_sm")
def spell_check_conversation(conversation):
for i, message in enumerate(conversation["conversations"]):
# Use SpaCy to get the words
words = spell.split_words(message["message"])
logger.info(f"Words: {words}")
corrected_message = []
for word in words:
correction = spell.correction(word)
if (correction != None) and (correction != word):
logger.info(f"Spell check: {word} -> {correction}")
corrected_message.append(correction)
else:
corrected_message.append(word)
logger.info(f"Corrected message: {corrected_message}")
joined_message = " ".join(corrected_message)
conversation["conversations"][i]["message"] = joined_message
return conversation
def spell_check_conversation_spacy(conversation):
nlp.add_pipe(
"contextual spellchecker",
config={
"model_name": "bert-base-multilingual-uncased",
"max_edit_dist": 2,
},
)
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
for i, doc in enumerate(docs):
if doc._.performed_spellCheck:
logger.info(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
return conversation
def remove_whatsapp_annotations(conversation):
"""
Removes the following annotations from the messages:
- <This message was edited>
"""
for message in conversation["conversations"]:
message["message"] = re.sub(
r"<This message was edited>", "", message["message"]
)
return conversation
# %%
"""
Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages.
For example, if we have a conversation like this:
A: Hi
A: How are you?
B: Hi
B: I'm fine, thanks
A: I'm fine too
We'll reorder it to:
A: Hi
B: Hi
A: How are you?
B: I'm fine, thanks
A: I'm fine too
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
"""
import torch
from transformers import AutoModelForNextSentencePrediction, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
if torch.cuda.is_available():
model.cuda()
def swap_messages_if_needed(message1, message2):
# If the messages have the same contact, we don't swap them
if message1["contact_name"] == message2["contact_name"]:
return message1, message2
# The timestamp must have a difference of less than 2 minutes. First, convert to datetime
datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"])
datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"])
if (datetime2 - datetime1).total_seconds() > 2 * 60:
return message1, message2
# If one of the messages has less than 3 words, we don't swap them
if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3:
return message1, message2
# We'll use the first message as the first sentence, and the second message as the second sentence
inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt")
reverse_inputs = tokenizer(
message2["message"], message1["message"], return_tensors="pt"
)
# Join them in a single batch
joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0)
if torch.cuda.is_available():
joined_inputs = joined_inputs.cuda()
with torch.no_grad():
outputs = model(input_ids=joined_inputs)
# The output is a tuple with the logits for each class (next sentence or not)
# We'll take the first one (next sentence)
logits = outputs[0]
# Apply softmax
logits = torch.softmax(logits, dim=1)
# We have two probabilities: the probability of 1 -> 2, and the probability of 2 -> 1
# We'll take the difference
swap = logits[0, 0] - logits[1, 0] < -0.2
if swap:
# Swap the messages
logger.info(
f"Swapping messages: {message1['message']} <-> {message2['message']}"
)
return message2, message1
else:
# logger.info(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
return message1, message2
def swap_messages_if_needed_in_conversation(conversation):
# We'll use the first message as the first sentence, and the second message as the second sentence
if len(conversation) <= 2:
return conversation
new_conversation = [
conversation[0],
conversation[1],
] # We'll always keep the first message in the same position
for i in range(2, len(conversation)):
message1 = new_conversation[-1]
message2 = conversation[i]
message1, message2 = swap_messages_if_needed(message1, message2)
new_conversation[-1] = message1
new_conversation.append(message2)
# logger.info(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
# logger.info(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
return new_conversation
test_conversation = [
{"message": "Hola!", "contact_name": "A", "timestamp": 1},
{
"message": "Está todo bien, gracias por preguntar!",
"contact_name": "B",
"timestamp": 2,
},
{
"message": "Hola, qué tal estás? Espero que vaya todo bien por España.",
"contact_name": "A",
"timestamp": 3,
},
]
# logger.info(swap_messages_if_needed_in_conversation(test_conversation))
# %%
# Now, we'll train an mT5 model to generate the next message in a conversation
import os
# %%
def process_chat_file(
file,
do_spelling_correction,
whatsapp_name,
datetime_dayfirst,
message_line_format,
minutes_threshold,
min_messages_per_conversation,
do_reordering=False,
):
"""
Process a chat file and return a dataset with the conversations.
"""
exp = re.compile(
# r"(?P<msg_datetime>.+?) - (?P<contact_name>.+): (?P<message>.+)"
# r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+): (?P<message>.+)"
message_line_format
)
def process_line(examples):
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
messages = []
contact_names = []
timestamps = []
for line_text in examples["text"]:
try:
groups = exp.match(line_text).groupdict()
# First, get the elements. If something fails here, it will raise an exception before actually adding the element to the list, so we'll be sure that the three lists contain the same # of elements.
timestamp = dateutil.parser.parse(
groups["msg_datetime"], dayfirst=datetime_dayfirst
).timestamp()
message = groups["message"]
contact_name = groups["contact_name"]
messages.append(message)
contact_names.append(contact_name)
timestamps.append(timestamp)
except Exception as e:
logger.exception(f"Error while processing line {line_text}")
return {
"message": messages,
"contact_name": contact_names,
"timestamp": timestamps,
}
try:
ds = datasets.load_dataset("text", data_files=[file])["train"]
except Exception as e:
logger.exception(f"Error while loading file {file}")
raise Exception(f"Error while loading file {file}") from e
# try:
# ds = ds.filter(
# # Has to begin by date, time, contact name, and contain at least a ':' symbol
# lambda x: re.match(
# r"^\d{1,2}/\d{1,2}/\d{1,4},\s\d{2}:\d{2}\s-\s.+:", x["text"]
# )
# )
# except Exception as e:
# logger.exception(f"Error filtering the lines in file {file} so they match the expected format")
# raise Exception(f"Error filtering the lines in file {file} so they match the expected format") from e
try:
ds = ds.map(process_line, remove_columns=["text"], batched=True, batch_size=10)
except Exception as e:
logger.exception(
f"Error mapping the lines in file {file} to the expected format"
)
raise Exception(
f"Error mapping the lines in file {file} to the expected format"
) from e
# Check that the WhatsApp name is in at least one of the messages. If it's not, raise an exception
set_of_contact_names = ds.unique("contact_name")
if whatsapp_name not in set_of_contact_names:
raise Exception(
f"Your WhatsApp name ({whatsapp_name}) is not in the messages of at least one uploaded file. Please check that you wrote your name correctly. These were the participants found: {set_of_contact_names}"
)
# # Also check that the number of contact names is == 2 (i.e. we don't have group chats)
# if len(set_of_contact_names) > 2:
# raise Exception(
# f"There were more than 2 participants in at least one uploaded file. Please check that you're not using group chats. These were the participants found: {set_of_contact_names}"
# )
try:
# Filter out messages that just say '<Media omitted>'
ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
except Exception as e:
logger.exception(
f"Error filtering out messages that say '<Media omitted>' in file {file}"
)
raise Exception(
f"Error filtering out messages that say '<Media omitted>' in file {file}"
) from e
try:
groups = group_messages(iter(ds), minutes_threshold=minutes_threshold)
# Generate the dataset
conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
except Exception as e:
logger.exception(f"Error grouping the messages in file {file}")
raise Exception(f"Error grouping the messages in file {file}") from e
try:
# Filter out conversations with less than 5 messages
conversations_ds = conversations_ds.filter(
lambda x: len(x["conversations"]) >= min_messages_per_conversation
)
except Exception as e:
logger.exception(
f"Error filtering out conversations with less than {min_messages_per_conversation} messages in file {file}"
)
raise Exception(
f"Error filtering out conversations with less than {min_messages_per_conversation} messages in file {file}"
) from e
try:
conversations_ds_without_whatsapp_annotations = conversations_ds.map(
remove_whatsapp_annotations,
num_proc=os.cpu_count() - 1,
)
except Exception as e:
logger.exception(f"Error removing WhatsApp annotations in file {file}")
raise Exception(f"Error removing WhatsApp annotations in file {file}") from e
if do_spelling_correction:
try:
spell_checked_conversations_ds = (
conversations_ds_without_whatsapp_annotations.map(
spell_check_conversation
)
)
except Exception as e:
logger.exception(f"Error spell checking the conversations in file {file}")
raise Exception(
f"Error spell checking the conversations in file {file}"
) from e
else:
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
if do_reordering:
try:
reordered_conversations_ds = spell_checked_conversations_ds.map(
swap_messages_if_needed_in_conversation
)
except Exception as e:
logger.exception(f"Error reordering the conversations in file {file}")
raise Exception(f"Error reordering the conversations in file {file}") from e
else:
reordered_conversations_ds = spell_checked_conversations_ds
# For the contact_name, rewrite everything that is not 'my_whatsapp_name' to 'Other'
def rewrite_contact_name(conversation):
for message in conversation["conversations"]:
if message["contact_name"] != whatsapp_name:
message["contact_name"] = "Other"
return conversation
try:
changed_contact_name_ds = reordered_conversations_ds.map(
rewrite_contact_name
) # , num_proc=os.cpu_count() - 1)
except Exception as e:
logger.exception(f"Error changing your other contact's names in file {file}")
raise Exception(
f"Error changing your other contact's names in file {file}"
) from e
try:
# Filter out conversations with only one contact
changed_contact_name_ds = changed_contact_name_ds.filter(
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
)
except Exception as e:
logger.exception(
f"Error filtering out conversations with only one contact in file {file}"
)
raise Exception(
f"Error filtering out conversations with only one contact in file {file}"
) from e
return changed_contact_name_ds
def transform_conversations_dataset_into_training_examples(
conversations_ds,
system_prompt,
user_role,
model_role,
whatsapp_name,
minutes_threshold,
min_messages_per_conversation,
split_conversation_threshold,
max_characters_per_message,
):
"""
Takes in a dataset with conversations and returns a dataset with training examples.
The input dataset contains a single column (conversations), with each row being a list of messages with this format:
```
[{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ]
```
Each row will be converted to fit the format of the training examples.
The training examples have the following format:
```
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
```
"""
def process_examples(examples):
processed_examples = []
for conversation in examples["conversations"]:
messages = [{"role": "system", "content": [system_prompt]}]
counter = 0
for msg in conversation:
converted_role = (
model_role if msg["contact_name"] == whatsapp_name else user_role
)
if (
counter > split_conversation_threshold
and converted_role == user_role
):
processed_examples.append(
{
"messages": [
{
"role": m["role"],
"content": json.dumps(
m["content"], ensure_ascii=False
),
}
for m in messages
]
}
)
messages = [{"role": "system", "content": [system_prompt]}]
counter = 0
if converted_role == messages[-1]["role"]:
messages[-1]["content"] += [msg["message"]]
else:
messages.append(
{"role": converted_role, "content": [msg["message"]]}
)
counter += 1
if len(messages) >= min_messages_per_conversation:
processed_examples.append(
{
"messages": [
{
"role": m["role"],
"content": json.dumps(m["content"], ensure_ascii=False),
}
for m in messages
]
}
)
else:
logger.warning(
f"Discarding conversation because the length is not at least {min_messages_per_conversation}: {messages}"
)
if len(processed_examples) == 0:
logger.warning(
f"Discarding all conversations because none of them have at least {min_messages_per_conversation} messages"
)
return {}
# Before returning, flatten the list of dictionaries into a dictionary of lists
flattened_examples = {}
for key in processed_examples[0].keys():
flattened_examples[key] = [d[key] for d in processed_examples]
return flattened_examples
try:
processed_examples = conversations_ds.map(
process_examples,
remove_columns=["conversations"],
# num_proc=os.cpu_count() - 1,
batched=True,
)
except Exception as e:
logger.exception(
"Error transforming the conversations dataset into training examples"
)
raise Exception(
"Error transforming the conversations dataset into training examples"
) from e
try:
examples_filtered_by_length = processed_examples.filter(
lambda x: all(
[len(m["content"]) < max_characters_per_message for m in x["messages"]]
)
)
except Exception as e:
logger.exception(
"Error filtering out examples with messages longer than the maximum allowed"
)
raise Exception(
"Error filtering out examples with messages longer than the maximum allowed"
) from e
return examples_filtered_by_length