|
import datetime |
|
import json |
|
import logging |
|
import os |
|
import re |
|
import datasets |
|
import dateutil.parser |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
|
|
|
MINUTES_THRESHOLD = 180 |
|
MIN_MESSAGES_THRESHOLD = 5 |
|
|
|
|
|
def group_messages(messages_iterable): |
|
groups = [] |
|
current_group = [next(messages_iterable)] |
|
for message in messages_iterable: |
|
assert len(current_group) > 0 |
|
if ( |
|
message["timestamp"] - current_group[-1]["timestamp"] |
|
< MINUTES_THRESHOLD * 60 |
|
): |
|
current_group.append(message) |
|
else: |
|
groups.append(current_group) |
|
current_group = [message] |
|
groups.append(current_group) |
|
return groups |
|
|
|
|
|
def printable_conversation(conversation): |
|
return "\n".join( |
|
[f"{message['contact_name']}: {message['message']}" for message in conversation] |
|
) |
|
|
|
|
|
import contextualSpellCheck |
|
|
|
|
|
|
|
import spacy |
|
from spellchecker import SpellChecker |
|
|
|
spell = SpellChecker() |
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
def spell_check_conversation(conversation): |
|
for i, message in enumerate(conversation["conversations"]): |
|
|
|
words = spell.split_words(message["message"]) |
|
logger.info(f"Words: {words}") |
|
corrected_message = [] |
|
for word in words: |
|
correction = spell.correction(word) |
|
if (correction != None) and (correction != word): |
|
logger.info(f"Spell check: {word} -> {correction}") |
|
corrected_message.append(correction) |
|
else: |
|
corrected_message.append(word) |
|
|
|
logger.info(f"Corrected message: {corrected_message}") |
|
joined_message = " ".join(corrected_message) |
|
conversation["conversations"][i]["message"] = joined_message |
|
|
|
return conversation |
|
|
|
|
|
def spell_check_conversation_spacy(conversation): |
|
|
|
nlp.add_pipe( |
|
"contextual spellchecker", |
|
config={ |
|
"model_name": "bert-base-multilingual-uncased", |
|
"max_edit_dist": 2, |
|
}, |
|
) |
|
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]])) |
|
for i, doc in enumerate(docs): |
|
if doc._.performed_spellCheck: |
|
logger.info(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}") |
|
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck |
|
|
|
return conversation |
|
|
|
|
|
def remove_whatsapp_annotations(conversation): |
|
""" |
|
Removes the following annotations from the messages: |
|
- <This message was edited> |
|
""" |
|
for message in conversation["conversations"]: |
|
message["message"] = re.sub( |
|
r"<This message was edited>", "", message["message"] |
|
) |
|
return conversation |
|
|
|
|
|
|
|
""" |
|
Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages. |
|
For example, if we have a conversation like this: |
|
A: Hi |
|
A: How are you? |
|
B: Hi |
|
B: I'm fine, thanks |
|
A: I'm fine too |
|
We'll reorder it to: |
|
A: Hi |
|
B: Hi |
|
A: How are you? |
|
B: I'm fine, thanks |
|
A: I'm fine too |
|
|
|
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages. |
|
""" |
|
|
|
import torch |
|
from transformers import AutoModelForNextSentencePrediction, AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased") |
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
|
|
|
|
def swap_messages_if_needed(message1, message2): |
|
|
|
if message1["contact_name"] == message2["contact_name"]: |
|
return message1, message2 |
|
|
|
datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"]) |
|
datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"]) |
|
if (datetime2 - datetime1).total_seconds() > 2 * 60: |
|
return message1, message2 |
|
|
|
if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3: |
|
return message1, message2 |
|
|
|
inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt") |
|
reverse_inputs = tokenizer( |
|
message2["message"], message1["message"], return_tensors="pt" |
|
) |
|
|
|
joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0) |
|
if torch.cuda.is_available(): |
|
joined_inputs = joined_inputs.cuda() |
|
with torch.no_grad(): |
|
outputs = model(input_ids=joined_inputs) |
|
|
|
|
|
logits = outputs[0] |
|
|
|
logits = torch.softmax(logits, dim=1) |
|
|
|
|
|
swap = logits[0, 0] - logits[1, 0] < -0.2 |
|
if swap: |
|
|
|
logger.info( |
|
f"Swapping messages: {message1['message']} <-> {message2['message']}" |
|
) |
|
return message2, message1 |
|
else: |
|
|
|
return message1, message2 |
|
|
|
|
|
def swap_messages_if_needed_in_conversation(conversation): |
|
|
|
if len(conversation) <= 2: |
|
return conversation |
|
new_conversation = [ |
|
conversation[0], |
|
conversation[1], |
|
] |
|
for i in range(2, len(conversation)): |
|
message1 = new_conversation[-1] |
|
message2 = conversation[i] |
|
message1, message2 = swap_messages_if_needed(message1, message2) |
|
new_conversation[-1] = message1 |
|
new_conversation.append(message2) |
|
|
|
|
|
|
|
return new_conversation |
|
|
|
|
|
test_conversation = [ |
|
{"message": "Hola!", "contact_name": "A", "timestamp": 1}, |
|
{ |
|
"message": "Está todo bien, gracias por preguntar!", |
|
"contact_name": "B", |
|
"timestamp": 2, |
|
}, |
|
{ |
|
"message": "Hola, qué tal estás? Espero que vaya todo bien por España.", |
|
"contact_name": "A", |
|
"timestamp": 3, |
|
}, |
|
] |
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayfirst, message_line_format, do_reordering=False): |
|
""" |
|
Process a chat file and return a dataset with the conversations. |
|
""" |
|
exp = re.compile( |
|
|
|
|
|
message_line_format |
|
) |
|
|
|
def process_line(example): |
|
|
|
try: |
|
groups = exp.match(example["text"]).groupdict() |
|
timestamp = dateutil.parser.parse(groups['msg_datetime'], dayfirst=datetime_dayfirst).timestamp() |
|
return { |
|
"message": groups["message"], |
|
"contact_name": groups["contact_name"], |
|
"timestamp": timestamp, |
|
} |
|
except Exception as e: |
|
logger.exception(example["text"]) |
|
raise e |
|
|
|
try: |
|
ds = datasets.load_dataset("text", data_files=[file])["train"] |
|
except Exception as e: |
|
logger.exception(f"Error while loading file {file}") |
|
raise Exception(f"Error while loading file {file}") from e |
|
try: |
|
ds = ds.filter( |
|
|
|
lambda x: re.match( |
|
r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"] |
|
) |
|
) |
|
except Exception as e: |
|
logger.exception(f"Error filtering the lines in file {file} so they match the expected format") |
|
raise Exception(f"Error filtering the lines in file {file} so they match the expected format") from e |
|
try: |
|
ds = ds.map(process_line, remove_columns=["text"]) |
|
except Exception as e: |
|
logger.exception(f"Error mapping the lines in file {file} to the expected format") |
|
raise Exception(f"Error mapping the lines in file {file} to the expected format") from e |
|
|
|
try: |
|
|
|
ds = ds.filter(lambda x: x["message"] != "<Media omitted>") |
|
except Exception as e: |
|
logger.exception(f"Error filtering out messages that say '<Media omitted>' in file {file}") |
|
raise Exception(f"Error filtering out messages that say '<Media omitted>' in file {file}") from e |
|
|
|
try: |
|
groups = group_messages(iter(ds)) |
|
|
|
conversations_ds = datasets.Dataset.from_dict({"conversations": groups}) |
|
except Exception as e: |
|
logger.exception(f"Error grouping the messages in file {file}") |
|
raise Exception(f"Error grouping the messages in file {file}") from e |
|
|
|
try: |
|
|
|
conversations_ds = conversations_ds.filter( |
|
lambda x: len(x["conversations"]) >= MIN_MESSAGES_THRESHOLD |
|
) |
|
except Exception as e: |
|
logger.exception(f"Error filtering out conversations with less than {MIN_MESSAGES_THRESHOLD} messages in file {file}") |
|
raise Exception(f"Error filtering out conversations with less than {MIN_MESSAGES_THRESHOLD} messages in file {file}") from e |
|
|
|
try: |
|
conversations_ds_without_whatsapp_annotations = conversations_ds.map( |
|
remove_whatsapp_annotations, |
|
num_proc=os.cpu_count() - 1, |
|
) |
|
except Exception as e: |
|
logger.exception(f"Error removing WhatsApp annotations in file {file}") |
|
raise Exception(f"Error removing WhatsApp annotations in file {file}") from e |
|
|
|
if do_spelling_correction: |
|
try: |
|
spell_checked_conversations_ds = ( |
|
conversations_ds_without_whatsapp_annotations.map(spell_check_conversation) |
|
) |
|
except Exception as e: |
|
logger.exception(f"Error spell checking the conversations in file {file}") |
|
raise Exception(f"Error spell checking the conversations in file {file}") from e |
|
else: |
|
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations |
|
|
|
if do_reordering: |
|
try: |
|
reordered_conversations_ds = spell_checked_conversations_ds.map( |
|
swap_messages_if_needed_in_conversation |
|
) |
|
except Exception as e: |
|
logger.exception(f"Error reordering the conversations in file {file}") |
|
raise Exception(f"Error reordering the conversations in file {file}") from e |
|
else: |
|
reordered_conversations_ds = spell_checked_conversations_ds |
|
|
|
|
|
def rewrite_contact_name(conversation): |
|
for message in conversation["conversations"]: |
|
if message["contact_name"] != whatsapp_name: |
|
message["contact_name"] = "Other" |
|
return conversation |
|
|
|
try: |
|
changed_contact_name_ds = reordered_conversations_ds.map( |
|
rewrite_contact_name |
|
) |
|
except Exception as e: |
|
logger.exception(f"Error changing your other contact's names in file {file}") |
|
raise Exception(f"Error changing your other contact's names in file {file}") from e |
|
|
|
try: |
|
|
|
changed_contact_name_ds = changed_contact_name_ds.filter( |
|
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1 |
|
) |
|
except Exception as e: |
|
logger.exception(f"Error filtering out conversations with only one contact in file {file}") |
|
raise Exception(f"Error filtering out conversations with only one contact in file {file}") from e |
|
|
|
return changed_contact_name_ds |
|
|
|
|
|
SPLIT_CONVERSATION_THRESHOLD = 40 |
|
MAX_CHARACTERS_PER_MESSAGE = 10000 |
|
|
|
|
|
def transform_conversations_dataset_into_training_examples( |
|
conversations_ds, system_prompt, user_role, model_role, whatsapp_name |
|
): |
|
""" |
|
Takes in a dataset with conversations and returns a dataset with training examples. |
|
|
|
The input dataset contains a single column (conversations), with each row being a list of messages with this format: |
|
``` |
|
[{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ] |
|
``` |
|
|
|
Each row will be converted to fit the format of the training examples. |
|
|
|
The training examples have the following format: |
|
``` |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} |
|
``` |
|
""" |
|
|
|
def process_examples(examples): |
|
processed_examples = [] |
|
for conversation in examples["conversations"]: |
|
messages = [{"role": "system", "content": [system_prompt]}] |
|
counter = 0 |
|
for msg in conversation: |
|
converted_role = ( |
|
model_role if msg["contact_name"] == whatsapp_name else user_role |
|
) |
|
if ( |
|
counter > SPLIT_CONVERSATION_THRESHOLD |
|
and converted_role == user_role |
|
): |
|
processed_examples.append( |
|
{ |
|
"messages": [ |
|
{ |
|
"role": m["role"], |
|
"content": json.dumps( |
|
m["content"], ensure_ascii=False |
|
), |
|
} |
|
for m in messages |
|
] |
|
} |
|
) |
|
messages = [{"role": "system", "content": [system_prompt]}] |
|
counter = 0 |
|
if converted_role == messages[-1]["role"]: |
|
messages[-1]["content"] += [msg["message"]] |
|
else: |
|
messages.append( |
|
{"role": converted_role, "content": [msg["message"]]} |
|
) |
|
counter += 1 |
|
if len(messages) >= MIN_MESSAGES_THRESHOLD: |
|
processed_examples.append( |
|
{ |
|
"messages": [ |
|
{ |
|
"role": m["role"], |
|
"content": json.dumps(m["content"], ensure_ascii=False), |
|
} |
|
for m in messages |
|
] |
|
} |
|
) |
|
else: |
|
logger.warning( |
|
f"Discarding conversation because the length is not at least {MIN_MESSAGES_THRESHOLD}: {messages}" |
|
) |
|
|
|
flattened_examples = {} |
|
for key in processed_examples[0].keys(): |
|
flattened_examples[key] = [d[key] for d in processed_examples] |
|
return flattened_examples |
|
|
|
try: |
|
processed_examples = conversations_ds.map( |
|
process_examples, |
|
remove_columns=["conversations"], |
|
|
|
batched=True, |
|
) |
|
except Exception as e: |
|
logger.exception("Error transforming the conversations dataset into training examples") |
|
raise Exception("Error transforming the conversations dataset into training examples") from e |
|
|
|
try: |
|
examples_filtered_by_length = processed_examples.filter( |
|
lambda x: all( |
|
[len(m["content"]) < MAX_CHARACTERS_PER_MESSAGE for m in x["messages"]] |
|
) |
|
) |
|
except Exception as e: |
|
logger.exception("Error filtering out examples with messages longer than the maximum allowed") |
|
raise Exception("Error filtering out examples with messages longer than the maximum allowed") from e |
|
|
|
return examples_filtered_by_length |
|
|