|
import datasets |
|
import datetime |
|
import os |
|
import json |
|
|
|
import re |
|
|
|
exp = re.compile( |
|
r"(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+), (?P<hour>\d+):(?P<minute>\d+) - (?P<contact_name>.+): (?P<message>.+)" |
|
) |
|
|
|
|
|
def process_line(example): |
|
|
|
try: |
|
groups = exp.match(example["text"]).groupdict() |
|
timestamp = datetime.datetime( |
|
int(groups["year"]), |
|
int(groups["month"]), |
|
int(groups["day"]), |
|
int(groups["hour"]), |
|
int(groups["minute"]), |
|
).timestamp() |
|
return { |
|
"message": groups["message"], |
|
"contact_name": groups["contact_name"], |
|
"timestamp": timestamp, |
|
} |
|
except Exception as e: |
|
print(e) |
|
print(example["text"]) |
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
MINUTES_THRESHOLD = 240 |
|
|
|
|
|
def group_messages(messages_iterable): |
|
groups = [] |
|
current_group = [next(messages_iterable)] |
|
for message in messages_iterable: |
|
assert len(current_group) > 0 |
|
if ( |
|
message["timestamp"] - current_group[-1]["timestamp"] |
|
< MINUTES_THRESHOLD * 60 |
|
): |
|
current_group.append(message) |
|
else: |
|
groups.append(current_group) |
|
current_group = [message] |
|
groups.append(current_group) |
|
return groups |
|
|
|
|
|
def printable_conversation(conversation): |
|
return "\n".join( |
|
[f"{message['contact_name']}: {message['message']}" for message in conversation] |
|
) |
|
|
|
|
|
|
|
|
|
import spacy |
|
import contextualSpellCheck |
|
from spellchecker import SpellChecker |
|
spell = SpellChecker() |
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
def spell_check_conversation(conversation): |
|
for i, message in enumerate(conversation["conversations"]): |
|
|
|
words = spell.split_words(message["message"]) |
|
print(f"Words: {words}") |
|
corrected_message = [] |
|
for word in words: |
|
correction = spell.correction(word) |
|
if (correction != None) and (correction != word): |
|
print(f"Spell check: {word} -> {correction}") |
|
corrected_message.append(correction) |
|
else: |
|
corrected_message.append(word) |
|
|
|
print(f"Corrected message: {corrected_message}") |
|
joined_message = " ".join(corrected_message) |
|
conversation["conversations"][i]["message"] = joined_message |
|
|
|
return conversation |
|
|
|
|
|
def spell_check_conversation_spacy(conversation): |
|
|
|
nlp.add_pipe( |
|
"contextual spellchecker", |
|
config={ |
|
"model_name": "bert-base-multilingual-uncased", |
|
"max_edit_dist": 2, |
|
}, |
|
) |
|
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]])) |
|
for i, doc in enumerate(docs): |
|
if doc._.performed_spellCheck: |
|
print(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}") |
|
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck |
|
|
|
return conversation |
|
|
|
|
|
def remove_whatapp_annotations(conversation): |
|
""" |
|
Removes the following annotations from the messages: |
|
- <This message was edited> |
|
""" |
|
for message in conversation["conversations"]: |
|
message["message"] = re.sub( |
|
r"<This message was edited>", "", message["message"] |
|
) |
|
return conversation |
|
|
|
|
|
|
|
""" |
|
Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages. |
|
For example, if we have a conversation like this: |
|
A: Hi |
|
A: How are you? |
|
B: Hi |
|
B: I'm fine, thanks |
|
A: I'm fine too |
|
We'll reorder it to: |
|
A: Hi |
|
B: Hi |
|
A: How are you? |
|
B: I'm fine, thanks |
|
A: I'm fine too |
|
|
|
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages. |
|
""" |
|
|
|
from transformers import AutoTokenizer, AutoModelForNextSentencePrediction |
|
import torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased") |
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
|
|
|
|
def swap_messages_if_needed(message1, message2): |
|
|
|
if message1["contact_name"] == message2["contact_name"]: |
|
return message1, message2 |
|
|
|
datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"]) |
|
datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"]) |
|
if (datetime2 - datetime1).total_seconds() > 2 * 60: |
|
return message1, message2 |
|
|
|
if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3: |
|
return message1, message2 |
|
|
|
inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt") |
|
reverse_inputs = tokenizer( |
|
message2["message"], message1["message"], return_tensors="pt" |
|
) |
|
|
|
joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0) |
|
if torch.cuda.is_available(): |
|
joined_inputs = joined_inputs.cuda() |
|
with torch.no_grad(): |
|
outputs = model(input_ids=joined_inputs) |
|
|
|
|
|
logits = outputs[0] |
|
|
|
logits = torch.softmax(logits, dim=1) |
|
|
|
|
|
swap = logits[0, 0] - logits[1, 0] < -0.2 |
|
if swap: |
|
|
|
print(f"YES Swapping messages: {message1['message']} <-> {message2['message']}") |
|
return message2, message1 |
|
else: |
|
|
|
return message1, message2 |
|
|
|
|
|
def swap_messages_if_needed_in_conversation(conversation): |
|
|
|
if len(conversation) <= 2: |
|
return conversation |
|
new_conversation = [ |
|
conversation[0], |
|
conversation[1], |
|
] |
|
for i in range(2, len(conversation)): |
|
message1 = new_conversation[-1] |
|
message2 = conversation[i] |
|
message1, message2 = swap_messages_if_needed(message1, message2) |
|
new_conversation[-1] = message1 |
|
new_conversation.append(message2) |
|
|
|
|
|
|
|
return new_conversation |
|
|
|
|
|
test_conversation = [ |
|
{"message": "Hola!", "contact_name": "A", "timestamp": 1}, |
|
{ |
|
"message": "Está todo bien, gracias por preguntar!", |
|
"contact_name": "B", |
|
"timestamp": 2, |
|
}, |
|
{ |
|
"message": "Hola, qué tal estás? Espero que vaya todo bien por España.", |
|
"contact_name": "A", |
|
"timestamp": 3, |
|
}, |
|
] |
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
def rewrite_contact_name(conversation): |
|
for message in conversation["conversations"]: |
|
if message["contact_name"] != "Aldi": |
|
message["contact_name"] = "Other" |
|
return conversation |
|
|
|
|
|
|
|
def process_chat_file(file, do_spelling_correction, do_reordering=False): |
|
""" |
|
Process a chat file and return a dataset with the conversations. |
|
""" |
|
ds = ( |
|
datasets.load_dataset("text", data_files=[file])["train"] |
|
.filter( |
|
|
|
lambda x: re.match( |
|
r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"] |
|
) |
|
) |
|
.map(process_line, remove_columns=["text"]) |
|
) |
|
|
|
|
|
ds = ds.filter(lambda x: x["message"] != "<Media omitted>") |
|
|
|
groups = group_messages(iter(ds)) |
|
|
|
conversations_ds = datasets.Dataset.from_dict({"conversations": groups}) |
|
|
|
|
|
conversations_ds = conversations_ds.filter(lambda x: len(x["conversations"]) >= 10) |
|
|
|
conversations_ds_without_whatsapp_annotations = conversations_ds.map( |
|
remove_whatapp_annotations, |
|
num_proc=os.cpu_count() - 1, |
|
) |
|
|
|
if do_spelling_correction: |
|
spell_checked_conversations_ds = ( |
|
conversations_ds_without_whatsapp_annotations.map(spell_check_conversation) |
|
) |
|
else: |
|
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations |
|
|
|
if do_reordering: |
|
reordered_conversations_ds = spell_checked_conversations_ds.map( |
|
swap_messages_if_needed_in_conversation |
|
) |
|
else: |
|
reordered_conversations_ds = spell_checked_conversations_ds |
|
|
|
changed_contact_name_ds = reordered_conversations_ds.map( |
|
rewrite_contact_name |
|
) |
|
|
|
|
|
changed_contact_name_ds = changed_contact_name_ds.filter( |
|
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1 |
|
) |
|
|
|
return changed_contact_name_ds |
|
|
|
|
|
def transform_conversations_dataset_into_training_examples( |
|
conversations_ds, system_prompt |
|
): |
|
""" |
|
Takes in a dataset with conversations and returns a dataset with training examples. |
|
|
|
The input dataset contains a single column (conversations), with each row being a list of messages with this format: |
|
``` |
|
[{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ] |
|
``` |
|
|
|
Each row will be converted to fit the format of the training examples. |
|
|
|
The training examples have the following format: |
|
``` |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} |
|
``` |
|
""" |
|
|
|
def process_one_example(example): |
|
messages = [{"role": "system", "content": [system_prompt]}] |
|
for msg in example["conversations"]: |
|
converted_role = "assistant" if msg["contact_name"] == "Aldi" else "user" |
|
if converted_role == messages[-1]["role"]: |
|
messages[-1]["content"] += [msg["message"]] |
|
else: |
|
messages.append({"role": converted_role, "content": [msg["message"]]}) |
|
return { |
|
"messages": [ |
|
{ |
|
"role": m["role"], |
|
"content": json.dumps(m["content"], ensure_ascii=False), |
|
} |
|
for m in messages |
|
] |
|
} |
|
|
|
return conversations_ds.map( |
|
process_one_example, |
|
remove_columns=["conversations"], |
|
num_proc=os.cpu_count() - 1, |
|
) |
|
|