ACMC
initial commit
7e73556
raw
history blame
12.8 kB
import datasets
import datetime
import os
import json
import re
exp = re.compile(
r"(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+), (?P<hour>\d+):(?P<minute>\d+) - (?P<contact_name>.+): (?P<message>.+)"
)
def process_line(example):
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
try:
groups = exp.match(example["text"]).groupdict()
timestamp = datetime.datetime(
int(groups["year"]),
int(groups["month"]),
int(groups["day"]),
int(groups["hour"]),
int(groups["minute"]),
).timestamp()
return {
"message": groups["message"],
"contact_name": groups["contact_name"],
"timestamp": timestamp,
}
except Exception as e:
print(e)
print(example["text"])
raise e
# %%
# Now, create message groups ('conversations')
# The idea is to group messages that are close in time
# We'll use a 240 minute threshold
MINUTES_THRESHOLD = 240
def group_messages(messages_iterable):
groups = []
current_group = [next(messages_iterable)]
for message in messages_iterable:
assert len(current_group) > 0 # We should never have an empty group
if (
message["timestamp"] - current_group[-1]["timestamp"]
< MINUTES_THRESHOLD * 60
):
current_group.append(message)
else:
groups.append(current_group)
current_group = [message]
groups.append(current_group)
return groups
def printable_conversation(conversation):
return "\n".join(
[f"{message['contact_name']}: {message['message']}" for message in conversation]
)
# %%
# Use spacy to spell check the messages
import spacy
import contextualSpellCheck
from spellchecker import SpellChecker
spell = SpellChecker()
#nlp = spacy.load("es_core_news_sm")
nlp = spacy.load("en_core_web_sm")
def spell_check_conversation(conversation):
for i, message in enumerate(conversation["conversations"]):
# Use SpaCy to get the words
words = spell.split_words(message["message"])
print(f"Words: {words}")
corrected_message = []
for word in words:
correction = spell.correction(word)
if (correction != None) and (correction != word):
print(f"Spell check: {word} -> {correction}")
corrected_message.append(correction)
else:
corrected_message.append(word)
print(f"Corrected message: {corrected_message}")
joined_message = " ".join(corrected_message)
conversation["conversations"][i]["message"] = joined_message
return conversation
def spell_check_conversation_spacy(conversation):
nlp.add_pipe(
"contextual spellchecker",
config={
"model_name": "bert-base-multilingual-uncased",
"max_edit_dist": 2,
},
)
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
for i, doc in enumerate(docs):
if doc._.performed_spellCheck:
print(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
return conversation
def remove_whatapp_annotations(conversation):
"""
Removes the following annotations from the messages:
- <This message was edited>
"""
for message in conversation["conversations"]:
message["message"] = re.sub(
r"<This message was edited>", "", message["message"]
)
return conversation
# %%
"""
Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages.
For example, if we have a conversation like this:
A: Hi
A: How are you?
B: Hi
B: I'm fine, thanks
A: I'm fine too
We'll reorder it to:
A: Hi
B: Hi
A: How are you?
B: I'm fine, thanks
A: I'm fine too
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
"""
from transformers import AutoTokenizer, AutoModelForNextSentencePrediction
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
if torch.cuda.is_available():
model.cuda()
def swap_messages_if_needed(message1, message2):
# If the messages have the same contact, we don't swap them
if message1["contact_name"] == message2["contact_name"]:
return message1, message2
# The timestamp must have a difference of less than 2 minutes. First, convert to datetime
datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"])
datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"])
if (datetime2 - datetime1).total_seconds() > 2 * 60:
return message1, message2
# If one of the messages has less than 3 words, we don't swap them
if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3:
return message1, message2
# We'll use the first message as the first sentence, and the second message as the second sentence
inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt")
reverse_inputs = tokenizer(
message2["message"], message1["message"], return_tensors="pt"
)
# Join them in a single batch
joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0)
if torch.cuda.is_available():
joined_inputs = joined_inputs.cuda()
with torch.no_grad():
outputs = model(input_ids=joined_inputs)
# The output is a tuple with the logits for each class (next sentence or not)
# We'll take the first one (next sentence)
logits = outputs[0]
# Apply softmax
logits = torch.softmax(logits, dim=1)
# We have two probabilities: the probability of 1 -> 2, and the probability of 2 -> 1
# We'll take the difference
swap = logits[0, 0] - logits[1, 0] < -0.2
if swap:
# Swap the messages
print(f"YES Swapping messages: {message1['message']} <-> {message2['message']}")
return message2, message1
else:
# print(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
return message1, message2
def swap_messages_if_needed_in_conversation(conversation):
# We'll use the first message as the first sentence, and the second message as the second sentence
if len(conversation) <= 2:
return conversation
new_conversation = [
conversation[0],
conversation[1],
] # We'll always keep the first message in the same position
for i in range(2, len(conversation)):
message1 = new_conversation[-1]
message2 = conversation[i]
message1, message2 = swap_messages_if_needed(message1, message2)
new_conversation[-1] = message1
new_conversation.append(message2)
# print(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
# print(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
return new_conversation
test_conversation = [
{"message": "Hola!", "contact_name": "A", "timestamp": 1},
{
"message": "Está todo bien, gracias por preguntar!",
"contact_name": "B",
"timestamp": 2,
},
{
"message": "Hola, qué tal estás? Espero que vaya todo bien por España.",
"contact_name": "A",
"timestamp": 3,
},
]
# print(swap_messages_if_needed_in_conversation(test_conversation))
# %%
# Now, we'll train an mT5 model to generate the next message in a conversation
import os
# For the contact_name, rewrite everything that is not 'Aldi' to 'Other'
def rewrite_contact_name(conversation):
for message in conversation["conversations"]:
if message["contact_name"] != "Aldi":
message["contact_name"] = "Other"
return conversation
# %%
def process_chat_file(file, do_spelling_correction, do_reordering=False):
"""
Process a chat file and return a dataset with the conversations.
"""
ds = (
datasets.load_dataset("text", data_files=[file])["train"]
.filter(
# Has to begin by date, time, contact name, and contain at least a ':' symbol
lambda x: re.match(
r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"]
)
)
.map(process_line, remove_columns=["text"])
)
# Filter out messages that just say '<Media omitted>'
ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
groups = group_messages(iter(ds))
# Generate the dataset
conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
# Filter out conversations with less than 10 messages
conversations_ds = conversations_ds.filter(lambda x: len(x["conversations"]) >= 10)
conversations_ds_without_whatsapp_annotations = conversations_ds.map(
remove_whatapp_annotations,
num_proc=os.cpu_count() - 1,
)
if do_spelling_correction:
spell_checked_conversations_ds = (
conversations_ds_without_whatsapp_annotations.map(spell_check_conversation)
)
else:
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
if do_reordering:
reordered_conversations_ds = spell_checked_conversations_ds.map(
swap_messages_if_needed_in_conversation
)
else:
reordered_conversations_ds = spell_checked_conversations_ds
changed_contact_name_ds = reordered_conversations_ds.map(
rewrite_contact_name
) # , num_proc=os.cpu_count() - 1)
# Filter out conversations with only one contact
changed_contact_name_ds = changed_contact_name_ds.filter(
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
)
return changed_contact_name_ds
def transform_conversations_dataset_into_training_examples(
conversations_ds, system_prompt
):
"""
Takes in a dataset with conversations and returns a dataset with training examples.
The input dataset contains a single column (conversations), with each row being a list of messages with this format:
```
[{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ]
```
Each row will be converted to fit the format of the training examples.
The training examples have the following format:
```
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
```
"""
def process_one_example(example):
messages = [{"role": "system", "content": [system_prompt]}]
for msg in example["conversations"]:
converted_role = "assistant" if msg["contact_name"] == "Aldi" else "user"
if converted_role == messages[-1]["role"]:
messages[-1]["content"] += [msg["message"]]
else:
messages.append({"role": converted_role, "content": [msg["message"]]})
return {
"messages": [
{
"role": m["role"],
"content": json.dumps(m["content"], ensure_ascii=False),
}
for m in messages
]
}
return conversations_ds.map(
process_one_example,
remove_columns=["conversations"],
num_proc=os.cpu_count() - 1,
)