Recherche sémantique avec FAISS
Dans section 5, nous avons créé un jeu de données de problèmes et de commentaires GitHub à partir du dépôt 🤗 Datasets. Dans cette section, nous utilisons ces informations pour créer un moteur de recherche qui peut nous aider à trouver des réponses à nos questions les plus urgentes sur la bibliothèque !
Utilisation des enchâssements pour la recherche sémantique
Comme nous l’avons vu dans le chapitre 1, les modèles de langage basés sur les transformers représentent chaque token dans une étendue de texte sous la forme d’un enchâssement. Il s’avère que l’on peut regrouper les enchâssements individuels pour créer une représentation vectorielle pour des phrases entières, des paragraphes ou (dans certains cas) des documents. Ces enchâssements peuvent ensuite être utilisés pour trouver des documents similaires dans le corpus en calculant la similarité du produit scalaire (ou une autre métrique de similarité) entre chaque enchâssement et en renvoyant les documents avec le plus grand chevauchement.
Dans cette section, nous utilisons les enchâssements pour développer un moteur de recherche sémantique. Ces moteurs de recherche offrent plusieurs avantages par rapport aux approches conventionnelles basées sur la correspondance des mots-clés dans une requête avec les documents.
Chargement et préparation du jeu de données
La première chose que nous devons faire est de télécharger notre jeu de données de problèmes GitHub. Utilisons la bibliothèque 🤗 Hub pour résoudre l’URL où notre fichier est stocké sur le Hub d’Hugging Face :
from huggingface_hub import hf_hub_url
data_files = hf_hub_url(
repo_id="lewtun/github-issues",
filename="datasets-issues-with-comments.jsonl",
repo_type="dataset",
)
Avec l’URL stocké dans data_files
, nous pouvons ensuite charger le jeu de données distant en utilisant la méthode introduite dans section 2 :
from datasets import load_dataset
issues_dataset = load_dataset("json", data_files=data_files, split="train")
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 2855
})
Ici, nous avons spécifié l’échantillon train
par défaut dans load_dataset()
, de sorte que cela renvoie un Dataset
au lieu d’un DatasetDict
. La première chose à faire est de filtrer les pull requests car celles-ci ont tendance à être rarement utilisées pour répondre aux requêtes des utilisateurs et introduiront du bruit dans notre moteur de recherche. Comme cela devrait être familier maintenant, nous pouvons utiliser la fonction Dataset.filter()
pour exclure ces lignes de notre jeu de données. Pendant que nous y sommes, filtrons également les lignes sans commentaires, car celles-ci ne fournissent aucune réponse aux requêtes des utilisateurs :
issues_dataset = issues_dataset.filter(
lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 771
})
Nous pouvons voir qu’il y a beaucoup de colonnes dans notre jeu de données, dont la plupart n’ont pas besoin de construire notre moteur de recherche. Du point de vue de la recherche, les colonnes les plus informatives sont title
, body
et comments
, tandis que html_url
nous fournit un lien vers le problème source. Utilisons la fonction Dataset.remove_columns()
pour supprimer le reste :
columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 771
})
Pour créer nos enchâssements, nous ajoutons à chaque commentaire le titre et le corps du problème, car ces champs contiennent des informations contextuelles utiles. Étant donné que notre colonne comments
est actuellement une liste de commentaires pour chaque problème, nous devons « éclater » la colonne afin que chaque ligne se compose d’un tuple (html_url, title, body, comment)
. Dans Pandas, nous pouvons le faire avec la fonction DataFrame.explode()
, qui crée une nouvelle ligne pour chaque élément dans une colonne de type liste, tout en répliquant toutes les autres valeurs de colonne. Pour voir cela en action, passons d’abord au format DataFrame
de Pandas :
issues_dataset.set_format("pandas")
df = issues_dataset[:]
Si nous inspectons la première ligne de ce DataFrame
, nous pouvons voir qu’il y a quatre commentaires associés à ce problème :
df["comments"][0].tolist()
['the bug code locate in :\r\n if data_args.task_name is not None:\r\n # Downloading and loading a dataset from the hub.\r\n datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
'cannot connect,even by Web browser,please check that there is some problems。',
'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']
Lorsque nous décomposons df
, nous nous attendons à obtenir une ligne pour chacun de ces commentaires. Vérifions si c’est le cas :
comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url | title | comments | body | |
---|---|---|---|---|
0 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | the bug code locate in :\r\n if data_args.task_name is not None... | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
1 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com... | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
2 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | cannot connect,even by Web browser,please check that there is some problems。 | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
3 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem... | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
Génial, nous pouvons voir que les lignes ont été répliquées, avec la colonne comments
contenant les commentaires individuels ! Maintenant que nous en avons fini avec Pandas, nous pouvons rapidement revenir à un Dataset
en chargeant le DataFrame
en mémoire :
from datasets import Dataset
comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 2842
})
D’accord, cela nous a donné quelques milliers de commentaires avec lesquels travailler !
✏️ Essayez ! Voyez si vous pouvez utiliser Dataset.map()
pour exploser la colonne comments
de issues_dataset
sans recourir à l’utilisation de Pandas. C’est un peu délicat. La section « Batch mapping » de la documentation 🤗 Datasets peut être utile pour cette tâche.
Maintenant que nous avons un commentaire par ligne, créons une nouvelle colonne comments_length
contenant le nombre de mots par commentaire :
comments_dataset = comments_dataset.map(
lambda x: {"comment_length": len(x["comments"].split())}
)
Nous pouvons utiliser cette nouvelle colonne pour filtrer les commentaires courts incluant généralement des éléments tels que « cc @lewtun » ou « Merci ! » qui ne sont pas pertinents pour notre moteur de recherche. Il n’y a pas de nombre précis à sélectionner pour le filtre mais 15 mots semblent être un bon début :
comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
num_rows: 2098
})
Après avoir un peu nettoyé notre jeu de données, concaténons le titre, la description et les commentaires du problème dans une nouvelle colonne text
. Comme d’habitude, nous allons écrire une fonction simple que nous pouvons passer à Dataset.map()
:
def concatenate_text(examples):
return {
"text": examples["title"]
+ " \n "
+ examples["body"]
+ " \n "
+ examples["comments"]
}
comments_dataset = comments_dataset.map(concatenate_text)
Nous sommes enfin prêts à créer des enchâssements ! Jetons un coup d’œil.
Création d’enchâssements pour les textes
Nous avons vu dans chapitre 2 que nous pouvons obtenir des enchâssements de tokens en utilisant la classe AutoModel
. Tout ce que nous avons à faire est de choisir un checkpoint approprié à partir duquel charger le modèle. Heureusement, il existe une bibliothèque appelée sentence-transformers
dédiée à la création d’enchâssements. Comme décrit dans la documentation de la bibliothèque, notre cas d’utilisation est un exemple de recherche sémantique asymétrique. En effet, nous avons une requête courte dont nous aimerions trouver la réponse dans un document plus long, par exemple un commentaire à un problème. Le tableau de présentation des modèles de la documentation indique que le checkpoint multi-qa-mpnet-base-dot-v1
a les meilleures performances pour la recherche sémantique. Utilisons donc le pour notre application. Nous allons également charger le tokenizer en utilisant le même checkpoint :
from transformers import AutoTokenizer, AutoModel
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
Pour accélérer le processus, il est utile de placer le modèle et les entrées sur un périphérique GPU, alors faisons-le maintenant :
import torch
device = torch.device("cuda")
model.to(device)
Comme nous l’avons mentionné précédemment, nous aimerions représenter chaque entrée dans notre corpus de problèmes GitHub comme un vecteur unique. Nous devons donc regrouper ou faire la moyenne de nos enchâssements de tokens d’une manière ou d’une autre. Une approche populaire consiste à effectuer un regroupement CLS sur les sorties de notre modèle, où nous collectons simplement le dernier état caché pour le token spécial [CLS]
. La fonction suivante fait ça pour nous :
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
Ensuite, nous allons créer une fonction utile qui va tokeniser une liste de documents, placer les tenseurs dans le GPU, les donner au modèle et enfin appliquer le regroupement CLS aux sorties :
def get_embeddings(text_list):
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
Nous pouvons tester le fonctionnement de la fonction en lui donnant la première entrée textuelle de notre corpus et en inspectant la forme de sortie :
embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])
Super ! Nous avons converti la première entrée de notre corpus en un vecteur à 768 dimensions. Nous pouvons utiliser Dataset.map()
pour appliquer notre fonction get_embeddings()
à chaque ligne de notre corpus. Créons donc une nouvelle colonne embeddings
comme suit :
embeddings_dataset = comments_dataset.map(
lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)
Notez que nous avons converti les enchâssements en tableaux NumPy. C’est parce que 🤗 Datasets nécessite ce format lorsque nous essayons de les indexer avec FAISS, ce que nous ferons ensuite.
Utilisation de FAISS pour une recherche de similarité efficace
Maintenant que nous avons un jeu de données d’incorporations, nous avons besoin d’un moyen de les rechercher. Pour ce faire, nous utiliserons une structure de données spéciale dans 🤗 Datasets appelée FAISS index. FAISS (abréviation de Facebook AI Similarity Search) est une bibliothèque qui fournit des algorithmes efficaces pour rechercher et regrouper rapidement des vecteurs d’intégration.
L’idée de base derrière FAISS est de créer une structure de données spéciale appelée un index qui permet de trouver quels plongements sont similaires à un plongement d’entrée. Créer un index FAISS dans 🤗 Datasets est simple — nous utilisons la fonction Dataset.add_faiss_index()
et spécifions quelle colonne de notre jeu de données nous aimerions indexer :
embeddings_dataset.add_faiss_index(column="embeddings")
Nous pouvons maintenant effectuer des requêtes sur cet index en effectuant une recherche des voisins les plus proches avec la fonction Dataset.get_nearest_examples()
. Testons cela en enchâssant d’abord une question comme suit :
question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])
Tout comme avec les documents, nous avons maintenant un vecteur de 768 dimensions représentant la requête. Nous pouvons le comparer à l’ensemble du corpus pour trouver les enchâssements les plus similaires :
scores, samples = embeddings_dataset.get_nearest_examples(
"embeddings", question_embedding, k=5
)
La fonction Dataset.get_nearest_examples()
renvoie un tuple de scores qui classent le chevauchement entre la requête et le document, et un jeu correspondant d’échantillons (ici, les 5 meilleures correspondances). Collectons-les dans un pandas.DataFrame
afin de pouvoir les trier facilement :
import pandas as pd
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
Nous pouvons maintenant parcourir les premières lignes pour voir dans quelle mesure notre requête correspond aux commentaires disponibles :
for _, row in samples_df.iterrows():
print(f"COMMENT: {row.comments}")
print(f"SCORE: {row.scores}")
print(f"TITLE: {row.title}")
print(f"URL: {row.html_url}")
print("=" * 50)
print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.
@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`
We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.
Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)
I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.
----------
> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.
----------
About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.
SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`
HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""
Pas mal ! Notre deuxième résultat semble correspondre à la requête.
✏️ Essayez ! Créez votre propre requête et voyez si vous pouvez trouver une réponse dans les documents récupérés. Vous devrez peut-être augmenter le paramètre k
dans Dataset.get_nearest_examples()
pour élargir la recherche.