abrar0503 commited on
Commit
727fd6d
1 Parent(s): 3f7e564

Upload 9 files

Browse files
README.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: sentence-transformers
5
+ tags:
6
+ - sentence-transformers
7
+ - feature-extraction
8
+ - sentence-similarity
9
+ - transformers
10
+ datasets:
11
+ - s2orc
12
+ - flax-sentence-embeddings/stackexchange_xml
13
+ - ms_marco
14
+ - gooaq
15
+ - yahoo_answers_topics
16
+ - code_search_net
17
+ - search_qa
18
+ - eli5
19
+ - snli
20
+ - multi_nli
21
+ - wikihow
22
+ - natural_questions
23
+ - trivia_qa
24
+ - embedding-data/sentence-compression
25
+ - embedding-data/flickr30k-captions
26
+ - embedding-data/altlex
27
+ - embedding-data/simple-wiki
28
+ - embedding-data/QQP
29
+ - embedding-data/SPECTER
30
+ - embedding-data/PAQ_pairs
31
+ - embedding-data/WikiAnswers
32
+ pipeline_tag: sentence-similarity
33
+ ---
34
+
35
+
36
+ # all-MiniLM-L6-v2
37
+ This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.
38
+
39
+ ## Usage (Sentence-Transformers)
40
+ Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
41
+
42
+ ```
43
+ pip install -U sentence-transformers
44
+ ```
45
+
46
+ Then you can use the model like this:
47
+ ```python
48
+ from sentence_transformers import SentenceTransformer
49
+ sentences = ["This is an example sentence", "Each sentence is converted"]
50
+
51
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
52
+ embeddings = model.encode(sentences)
53
+ print(embeddings)
54
+ ```
55
+
56
+ ## Usage (HuggingFace Transformers)
57
+ Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
58
+
59
+ ```python
60
+ from transformers import AutoTokenizer, AutoModel
61
+ import torch
62
+ import torch.nn.functional as F
63
+
64
+ #Mean Pooling - Take attention mask into account for correct averaging
65
+ def mean_pooling(model_output, attention_mask):
66
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
67
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
68
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
69
+
70
+
71
+ # Sentences we want sentence embeddings for
72
+ sentences = ['This is an example sentence', 'Each sentence is converted']
73
+
74
+ # Load model from HuggingFace Hub
75
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
76
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
77
+
78
+ # Tokenize sentences
79
+ encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
80
+
81
+ # Compute token embeddings
82
+ with torch.no_grad():
83
+ model_output = model(**encoded_input)
84
+
85
+ # Perform pooling
86
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
87
+
88
+ # Normalize embeddings
89
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
90
+
91
+ print("Sentence embeddings:")
92
+ print(sentence_embeddings)
93
+ ```
94
+
95
+ ## Evaluation Results
96
+
97
+ For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/all-MiniLM-L6-v2)
98
+
99
+ ------
100
+
101
+ ## Background
102
+
103
+ The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
104
+ contrastive learning objective. We used the pretrained [`nreimers/MiniLM-L6-H384-uncased`](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) model and fine-tuned in on a
105
+ 1B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
106
+
107
+ We developed this model during the
108
+ [Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
109
+ organized by Hugging Face. We developed this model as part of the project:
110
+ [Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
111
+
112
+ ## Intended uses
113
+
114
+ Our model is intended to be used as a sentence and short paragraph encoder. Given an input text, it outputs a vector which captures
115
+ the semantic information. The sentence vector may be used for information retrieval, clustering or sentence similarity tasks.
116
+
117
+ By default, input text longer than 256 word pieces is truncated.
118
+
119
+
120
+ ## Training procedure
121
+
122
+ ### Pre-training
123
+
124
+ We use the pretrained [`nreimers/MiniLM-L6-H384-uncased`](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) model. Please refer to the model card for more detailed information about the pre-training procedure.
125
+
126
+ ### Fine-tuning
127
+
128
+ We fine-tune the model using a contrastive objective. Formally, we compute the cosine similarity from each possible sentence pairs from the batch.
129
+ We then apply the cross entropy loss by comparing with true pairs.
130
+
131
+ #### Hyper parameters
132
+
133
+ We trained our model on a TPU v3-8. We train the model during 100k steps using a batch size of 1024 (128 per TPU core).
134
+ We use a learning rate warm up of 500. The sequence length was limited to 128 tokens. We used the AdamW optimizer with
135
+ a 2e-5 learning rate. The full training script is accessible in this current repository: `train_script.py`.
136
+
137
+ #### Training data
138
+
139
+ We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is above 1 billion sentences.
140
+ We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
141
+
142
+
143
+ | Dataset | Paper | Number of training tuples |
144
+ |--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
145
+ | [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
146
+ | [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts) | [paper](https://aclanthology.org/2020.acl-main.447/) | 116,288,806 |
147
+ | [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
148
+ | [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
149
+ | [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
150
+ | [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
151
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs | - | 25,316,456 |
152
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title+Body, Answer) pairs | - | 21,396,559 |
153
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Answer) pairs | - | 21,396,559 |
154
+ | [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
155
+ | [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
156
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
157
+ | [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,151,414 |
158
+ | [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
159
+ | [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
160
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
161
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
162
+ | [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
163
+ | [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
164
+ | [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
165
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles) | | 304,525 |
166
+ | AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
167
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (bodies) | | 250,519 |
168
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles+bodies) | | 250,460 |
169
+ | [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
170
+ | [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
171
+ | [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
172
+ | [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
173
+ | [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
174
+ | [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
175
+ | [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
176
+ | [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
177
+ | **Total** | | **1,170,060,424** |
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nreimers/MiniLM-L6-H384-uncased",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 384,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1536,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 6,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "transformers_version": "4.8.2",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 30522
24
+ }
data_config.json ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, normalize=True):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ self.model = AutoModel.from_pretrained(model_name)
45
+ self.normalize = normalize
46
+ self.tokenizer = tokenizer
47
+
48
+ def forward(self, **kwargs):
49
+ model_output = self.model(**kwargs)
50
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
51
+ if self.normalize:
52
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53
+
54
+ return embeddings
55
+
56
+ def mean_pooling(self, model_output, attention_mask):
57
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
58
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
59
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
60
+
61
+ def save_pretrained(self, output_path):
62
+ if xm.is_master_ordinal():
63
+ self.tokenizer.save_pretrained(output_path)
64
+ self.model.config.save_pretrained(output_path)
65
+
66
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
67
+
68
+
69
+
70
+
71
+ def train_function(index, args, queue):
72
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
73
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer)
74
+
75
+
76
+ ### Train Loop
77
+ device = xm.xla_device()
78
+ model = model.to(device)
79
+
80
+ # Instantiate optimizer
81
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
82
+
83
+ lr_scheduler = get_linear_schedule_with_warmup(
84
+ optimizer=optimizer,
85
+ num_warmup_steps=500,
86
+ num_training_steps=args.steps,
87
+ )
88
+
89
+ # Now we train the model
90
+ cross_entropy_loss = nn.CrossEntropyLoss()
91
+ max_grad_norm = 1
92
+
93
+ model.train()
94
+
95
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
96
+ #### Get the batch data
97
+ batch = queue.get()
98
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
99
+
100
+
101
+ if len(batch[0]) == 2: #(anchor, positive)
102
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
103
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
104
+
105
+ ### Compute embeddings
106
+ embeddings_a = model(**text1.to(device))
107
+ embeddings_b = model(**text2.to(device))
108
+
109
+ ### Gather all embedings
110
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
111
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
112
+
113
+ ### Compute similarity scores 512 x 512
114
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
115
+
116
+ ### Compute cross-entropy loss
117
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
118
+
119
+ ## Symmetric loss as in CLIP
120
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
121
+
122
+ else: #(anchor, positive, negative)
123
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
124
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
125
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
126
+
127
+ embeddings_a = model(**text1.to(device))
128
+ embeddings_b1 = model(**text2.to(device))
129
+ embeddings_b2 = model(**text3.to(device))
130
+
131
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
132
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
133
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
134
+
135
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
136
+
137
+ ### Compute similarity scores 512 x 1024
138
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
139
+
140
+ ### Compute cross-entropy loss
141
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
142
+
143
+ ## One-way loss
144
+ loss = cross_entropy_loss(scores, labels)
145
+
146
+
147
+ # Backward pass
148
+ optimizer.zero_grad()
149
+ loss.backward()
150
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
151
+
152
+ xm.optimizer_step(optimizer, barrier=True)
153
+ lr_scheduler.step()
154
+
155
+
156
+ #Save model
157
+ if (global_step+1) % args.save_steps == 0:
158
+ output_path = os.path.join(args.output, str(global_step+1))
159
+ xm.master_print("save model: "+output_path)
160
+ model.save_pretrained(output_path)
161
+
162
+
163
+ output_path = os.path.join(args.output, "final")
164
+ xm.master_print("save model final: "+ output_path)
165
+ model.save_pretrained(output_path)
166
+
167
+
168
+ def produce_data(args, queue, filepaths, dataset_indices):
169
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
170
+ size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
171
+ num_same_dataset = int(size_per_dataset / args.batch_size)
172
+ print("producer", "global_batch_size", global_batch_size)
173
+ print("producer", "size_per_dataset", size_per_dataset)
174
+ print("producer", "num_same_dataset", num_same_dataset)
175
+
176
+ datasets = []
177
+ for filepath in filepaths:
178
+ if "reddit_" in filepath: #Special dataset class for Reddit files
179
+ data_obj = RedditDataset(filepath)
180
+ else:
181
+ data_obj = Dataset(filepath)
182
+ datasets.append(iter(data_obj))
183
+
184
+ # Store if dataset is in a 2 col or 3 col format
185
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
186
+
187
+ while True:
188
+ texts_in_batch = set()
189
+ batch_format = None #2 vs 3 col format for this batch
190
+
191
+ #Add data from several sub datasets
192
+ for _ in range(args.datasets_per_batch):
193
+ valid_dataset = False #Check that datasets have the same 2/3 col format
194
+ while not valid_dataset:
195
+ data_idx = random.choice(dataset_indices)
196
+ if batch_format is None:
197
+ batch_format = num_cols[data_idx]
198
+ valid_dataset = True
199
+ else: #Check that this dataset has the same format
200
+ valid_dataset = (batch_format == num_cols[data_idx])
201
+
202
+ #Get data from this dataset
203
+ dataset = datasets[data_idx]
204
+ for _ in range(num_same_dataset):
205
+ for _ in range(args.nprocs):
206
+ batch_device = [] #A batch for one device
207
+ while len(batch_device) < args.batch_size:
208
+ sample = next(dataset)
209
+ in_batch = False
210
+ for text in sample:
211
+ if text in texts_in_batch:
212
+ in_batch = True
213
+ break
214
+
215
+ if not in_batch:
216
+ for text in sample:
217
+ texts_in_batch.add(text)
218
+ batch_device.append(sample)
219
+
220
+ queue.put(batch_device)
221
+
222
+
223
+ class RedditDataset:
224
+ """
225
+ A class that handles the reddit data files
226
+ """
227
+ def __init__(self, filepath):
228
+ self.filepath = filepath
229
+
230
+ def __iter__(self):
231
+ while True:
232
+ with gzip.open(self.filepath, "rt") as fIn:
233
+ for line in fIn:
234
+ data = json.loads(line)
235
+
236
+ if "response" in data and "context" in data:
237
+ yield [data["response"], data["context"]]
238
+
239
+ class Dataset:
240
+ """
241
+ A class that handles one dataset
242
+ """
243
+ def __init__(self, filepath):
244
+ self.filepath = filepath
245
+
246
+ def __iter__(self):
247
+ max_dataset_size = 10*1000*1000 #Cache small datasets in memory
248
+ dataset = []
249
+ data_format = None
250
+
251
+ while dataset is None or len(dataset) == 0:
252
+ with gzip.open(self.filepath, "rt") as fIn:
253
+ for line in fIn:
254
+ data = json.loads(line)
255
+ if isinstance(data, dict):
256
+ data = data['texts']
257
+
258
+ if data_format is None:
259
+ data_format = len(data)
260
+
261
+ #Ensure that all entries are of the same 2/3 col format
262
+ assert len(data) == data_format
263
+
264
+ if dataset is not None:
265
+ dataset.append(data)
266
+ if len(dataset) >= max_dataset_size:
267
+ dataset = None
268
+
269
+ yield data
270
+
271
+ # Data loaded. Now stream to the queue
272
+ # Shuffle for each epoch
273
+ while True:
274
+ random.shuffle(dataset)
275
+ for data in dataset:
276
+ yield data
277
+
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser()
282
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
283
+ parser.add_argument('--steps', type=int, default=2000)
284
+ parser.add_argument('--save_steps', type=int, default=10000)
285
+ parser.add_argument('--batch_size', type=int, default=64)
286
+ parser.add_argument('--max_length', type=int, default=128)
287
+ parser.add_argument('--nprocs', type=int, default=8)
288
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
289
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
290
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
291
+ parser.add_argument('data_config', help="A data_config.json file")
292
+ parser.add_argument('output')
293
+ args = parser.parse_args()
294
+
295
+ # Ensure global batch size is divisble by data_sample_size
296
+ assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
297
+
298
+ logging.info("Output: "+args.output)
299
+ if os.path.exists(args.output):
300
+ print("Output folder already exists.")
301
+ input("Continue?")
302
+
303
+ # Write train script to output path
304
+ os.makedirs(args.output, exist_ok=True)
305
+
306
+ data_config_path = os.path.join(args.output, 'data_config.json')
307
+ copyfile(args.data_config, data_config_path)
308
+
309
+ train_script_path = os.path.join(args.output, 'train_script.py')
310
+ copyfile(__file__, train_script_path)
311
+ with open(train_script_path, 'a') as fOut:
312
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
313
+
314
+
315
+
316
+ #Load data config
317
+ with open(args.data_config) as fIn:
318
+ data_config = json.load(fIn)
319
+
320
+ queue = mp.Queue(maxsize=100*args.nprocs)
321
+
322
+ filepaths = []
323
+ dataset_indices = []
324
+ for idx, data in enumerate(data_config):
325
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
326
+ dataset_indices.extend([idx]*data['weight'])
327
+
328
+ # Start producer
329
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
330
+ p.start()
331
+
332
+ # Run training
333
+ print("Start processes:", args.nprocs)
334
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
335
+ print("Training done")
336
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
337
+ print("With 'pkill python' you can kill all remaining python processes")
338
+ p.kill()
339
+ exit()
340
+
341
+
342
+
343
+ # Script was called via:
344
+ #python train_many_data_files_v2.py --steps 1000000 --batch_size 128 --model nreimers/MiniLM-L6-H384-uncased train_data_configs/all_datasets_v4.json output/all_datasets_v4_MiniLM-L6-H384-uncased-batch128
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 256,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "name_or_path": "nreimers/MiniLM-L6-H384-uncased", "do_basic_tokenize": true, "never_split": null, "tokenizer_class": "BertTokenizer", "model_max_length": 512}
train_script.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, normalize=True):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ self.model = AutoModel.from_pretrained(model_name)
45
+ self.normalize = normalize
46
+ self.tokenizer = tokenizer
47
+
48
+ def forward(self, **kwargs):
49
+ model_output = self.model(**kwargs)
50
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
51
+ if self.normalize:
52
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53
+
54
+ return embeddings
55
+
56
+ def mean_pooling(self, model_output, attention_mask):
57
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
58
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
59
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
60
+
61
+ def save_pretrained(self, output_path):
62
+ if xm.is_master_ordinal():
63
+ self.tokenizer.save_pretrained(output_path)
64
+ self.model.config.save_pretrained(output_path)
65
+
66
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
67
+
68
+
69
+
70
+
71
+ def train_function(index, args, queue):
72
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
73
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer)
74
+
75
+
76
+ ### Train Loop
77
+ device = xm.xla_device()
78
+ model = model.to(device)
79
+
80
+ # Instantiate optimizer
81
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
82
+
83
+ lr_scheduler = get_linear_schedule_with_warmup(
84
+ optimizer=optimizer,
85
+ num_warmup_steps=500,
86
+ num_training_steps=args.steps,
87
+ )
88
+
89
+ # Now we train the model
90
+ cross_entropy_loss = nn.CrossEntropyLoss()
91
+ max_grad_norm = 1
92
+
93
+ model.train()
94
+
95
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
96
+ #### Get the batch data
97
+ batch = queue.get()
98
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
99
+
100
+
101
+ if len(batch[0]) == 2: #(anchor, positive)
102
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
103
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
104
+
105
+ ### Compute embeddings
106
+ embeddings_a = model(**text1.to(device))
107
+ embeddings_b = model(**text2.to(device))
108
+
109
+ ### Gather all embedings
110
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
111
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
112
+
113
+ ### Compute similarity scores 512 x 512
114
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
115
+
116
+ ### Compute cross-entropy loss
117
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
118
+
119
+ ## Symmetric loss as in CLIP
120
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
121
+
122
+ else: #(anchor, positive, negative)
123
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
124
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
125
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
126
+
127
+ embeddings_a = model(**text1.to(device))
128
+ embeddings_b1 = model(**text2.to(device))
129
+ embeddings_b2 = model(**text3.to(device))
130
+
131
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
132
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
133
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
134
+
135
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
136
+
137
+ ### Compute similarity scores 512 x 1024
138
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
139
+
140
+ ### Compute cross-entropy loss
141
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
142
+
143
+ ## One-way loss
144
+ loss = cross_entropy_loss(scores, labels)
145
+
146
+
147
+ # Backward pass
148
+ optimizer.zero_grad()
149
+ loss.backward()
150
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
151
+
152
+ xm.optimizer_step(optimizer, barrier=True)
153
+ lr_scheduler.step()
154
+
155
+
156
+ #Save model
157
+ if (global_step+1) % args.save_steps == 0:
158
+ output_path = os.path.join(args.output, str(global_step+1))
159
+ xm.master_print("save model: "+output_path)
160
+ model.save_pretrained(output_path)
161
+
162
+
163
+ output_path = os.path.join(args.output, "final")
164
+ xm.master_print("save model final: "+ output_path)
165
+ model.save_pretrained(output_path)
166
+
167
+
168
+ def produce_data(args, queue, filepaths, dataset_indices):
169
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
170
+ size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
171
+ num_same_dataset = int(size_per_dataset / args.batch_size)
172
+ print("producer", "global_batch_size", global_batch_size)
173
+ print("producer", "size_per_dataset", size_per_dataset)
174
+ print("producer", "num_same_dataset", num_same_dataset)
175
+
176
+ datasets = []
177
+ for filepath in filepaths:
178
+ if "reddit_" in filepath: #Special dataset class for Reddit files
179
+ data_obj = RedditDataset(filepath)
180
+ else:
181
+ data_obj = Dataset(filepath)
182
+ datasets.append(iter(data_obj))
183
+
184
+ # Store if dataset is in a 2 col or 3 col format
185
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
186
+
187
+ while True:
188
+ texts_in_batch = set()
189
+ batch_format = None #2 vs 3 col format for this batch
190
+
191
+ #Add data from several sub datasets
192
+ for _ in range(args.datasets_per_batch):
193
+ valid_dataset = False #Check that datasets have the same 2/3 col format
194
+ while not valid_dataset:
195
+ data_idx = random.choice(dataset_indices)
196
+ if batch_format is None:
197
+ batch_format = num_cols[data_idx]
198
+ valid_dataset = True
199
+ else: #Check that this dataset has the same format
200
+ valid_dataset = (batch_format == num_cols[data_idx])
201
+
202
+ #Get data from this dataset
203
+ dataset = datasets[data_idx]
204
+ for _ in range(num_same_dataset):
205
+ for _ in range(args.nprocs):
206
+ batch_device = [] #A batch for one device
207
+ while len(batch_device) < args.batch_size:
208
+ sample = next(dataset)
209
+ in_batch = False
210
+ for text in sample:
211
+ if text in texts_in_batch:
212
+ in_batch = True
213
+ break
214
+
215
+ if not in_batch:
216
+ for text in sample:
217
+ texts_in_batch.add(text)
218
+ batch_device.append(sample)
219
+
220
+ queue.put(batch_device)
221
+
222
+
223
+ class RedditDataset:
224
+ """
225
+ A class that handles the reddit data files
226
+ """
227
+ def __init__(self, filepath):
228
+ self.filepath = filepath
229
+
230
+ def __iter__(self):
231
+ while True:
232
+ with gzip.open(self.filepath, "rt") as fIn:
233
+ for line in fIn:
234
+ data = json.loads(line)
235
+
236
+ if "response" in data and "context" in data:
237
+ yield [data["response"], data["context"]]
238
+
239
+ class Dataset:
240
+ """
241
+ A class that handles one dataset
242
+ """
243
+ def __init__(self, filepath):
244
+ self.filepath = filepath
245
+
246
+ def __iter__(self):
247
+ max_dataset_size = 10*1000*1000 #Cache small datasets in memory
248
+ dataset = []
249
+ data_format = None
250
+
251
+ while dataset is None or len(dataset) == 0:
252
+ with gzip.open(self.filepath, "rt") as fIn:
253
+ for line in fIn:
254
+ data = json.loads(line)
255
+ if isinstance(data, dict):
256
+ data = data['texts']
257
+
258
+ if data_format is None:
259
+ data_format = len(data)
260
+
261
+ #Ensure that all entries are of the same 2/3 col format
262
+ assert len(data) == data_format
263
+
264
+ if dataset is not None:
265
+ dataset.append(data)
266
+ if len(dataset) >= max_dataset_size:
267
+ dataset = None
268
+
269
+ yield data
270
+
271
+ # Data loaded. Now stream to the queue
272
+ # Shuffle for each epoch
273
+ while True:
274
+ random.shuffle(dataset)
275
+ for data in dataset:
276
+ yield data
277
+
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser()
282
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
283
+ parser.add_argument('--steps', type=int, default=2000)
284
+ parser.add_argument('--save_steps', type=int, default=10000)
285
+ parser.add_argument('--batch_size', type=int, default=64)
286
+ parser.add_argument('--max_length', type=int, default=128)
287
+ parser.add_argument('--nprocs', type=int, default=8)
288
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
289
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
290
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
291
+ parser.add_argument('data_config', help="A data_config.json file")
292
+ parser.add_argument('output')
293
+ args = parser.parse_args()
294
+
295
+ # Ensure global batch size is divisble by data_sample_size
296
+ assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
297
+
298
+ logging.info("Output: "+args.output)
299
+ if os.path.exists(args.output):
300
+ print("Output folder already exists.")
301
+ input("Continue?")
302
+
303
+ # Write train script to output path
304
+ os.makedirs(args.output, exist_ok=True)
305
+
306
+ data_config_path = os.path.join(args.output, 'data_config.json')
307
+ copyfile(args.data_config, data_config_path)
308
+
309
+ train_script_path = os.path.join(args.output, 'train_script.py')
310
+ copyfile(__file__, train_script_path)
311
+ with open(train_script_path, 'a') as fOut:
312
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
313
+
314
+
315
+
316
+ #Load data config
317
+ with open(args.data_config) as fIn:
318
+ data_config = json.load(fIn)
319
+
320
+ queue = mp.Queue(maxsize=100*args.nprocs)
321
+
322
+ filepaths = []
323
+ dataset_indices = []
324
+ for idx, data in enumerate(data_config):
325
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
326
+ dataset_indices.extend([idx]*data['weight'])
327
+
328
+ # Start producer
329
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
330
+ p.start()
331
+
332
+ # Run training
333
+ print("Start processes:", args.nprocs)
334
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
335
+ print("Training done")
336
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
337
+ print("With 'pkill python' you can kill all remaining python processes")
338
+ p.kill()
339
+ exit()
340
+
341
+
342
+
343
+ # Script was called via:
344
+ #python train_many_data_files_v2.py --steps 1000000 --batch_size 128 --model nreimers/MiniLM-L6-H384-uncased train_data_configs/all_datasets_v4.json output/all_datasets_v4_MiniLM-L6-H384-uncased-batch128
vocab.txt ADDED
The diff for this file is too large to render. See raw diff