Upload 4 files
Browse files- README.md +80 -0
- config.json +35 -0
- modeling_vmsst.py +19 -0
- pytorch_model.bin +3 -0
README.md
CHANGED
@@ -1,3 +1,83 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
pipeline_tag: sentence-similarity
|
4 |
+
tags:
|
5 |
+
- cross-lingual
|
6 |
+
- multilingual
|
7 |
+
- question-answering
|
8 |
+
- retrieval
|
9 |
+
- sentence-similarity
|
10 |
+
- variational
|
11 |
---
|
12 |
+
|
13 |
+
# VMSST
|
14 |
+
|
15 |
+
Published as a long paper at ACL 2023.
|
16 |
+
|
17 |
+
Contrastive learning has been successfully used for retrieval of semantically aligned sentences, but it often requires large batch sizes and carefully engineered heuristics to work well. In this paper, we instead propose a generative model for learning multilingual text embeddings which can be used to retrieve or score sentence pairs. Our model operates on parallel data in N languages and, through an approximation we introduce, efficiently encourages source separation in this multilingual setting, separating semantic information that is shared between translations from stylistic or language-specific variation. We show careful large-scale comparisons between contrastive and generation-based approaches for learning multilingual text embeddings, a comparison that has not been done to the best of our knowledge despite the popularity of these approaches. We evaluate this method on a suite of tasks including semantic similarity, bitext mining, and cross-lingual question retrieval––the last of which we introduce in this paper. Overall, our Variational Multilingual Source-Separation Transformer (VMSST) model outperforms both a strong contrastive and generative baseline on these tasks.
|
18 |
+
|
19 |
+
## Checkpoints
|
20 |
+
|
21 |
+
T5X (Jax): https://storage.googleapis.com/gresearch/vmsst/vmsst-large-2048-t5x.zip
|
22 |
+
|
23 |
+
PyTorch: https://storage.googleapis.com/gresearch/vmsst/vmsst-large-2048-pytorch.zip
|
24 |
+
|
25 |
+
## Usage
|
26 |
+
|
27 |
+
### Installation
|
28 |
+
|
29 |
+
1. Clone the following repository from Google Research.
|
30 |
+
|
31 |
+
```
|
32 |
+
git clone -b master --single-branch https://github.com/google-research/google-research.git
|
33 |
+
```
|
34 |
+
|
35 |
+
2. Make sure `google-research` is the current directory:
|
36 |
+
|
37 |
+
```
|
38 |
+
cd google-research/vmsst
|
39 |
+
```
|
40 |
+
|
41 |
+
3. Create and activate a new virtualenv:
|
42 |
+
|
43 |
+
```
|
44 |
+
python -m venv vmsst
|
45 |
+
source vmsst/bin/activate
|
46 |
+
```
|
47 |
+
|
48 |
+
4. This repository is tested on Python 3.10+. Install required packages:
|
49 |
+
|
50 |
+
```
|
51 |
+
pip install -r requirements.txt
|
52 |
+
```
|
53 |
+
|
54 |
+
### Test
|
55 |
+
|
56 |
+
To test that the checkpoint and installation are working as intended, run:
|
57 |
+
|
58 |
+
bash run.sh
|
59 |
+
|
60 |
+
The expected cosine similarity scores for the three sentences pairs are:
|
61 |
+
|
62 |
+
0.2573888301849365, 0.1563197821378708, and 0.28531330823898315.
|
63 |
+
|
64 |
+
### Inference
|
65 |
+
|
66 |
+
To embed a list of sentences:
|
67 |
+
|
68 |
+
python score_sentence_pairs.py --sentence_pair_file test_data/test_sentence_pairs.tsv
|
69 |
+
|
70 |
+
To score a list of sentence pairs:
|
71 |
+
|
72 |
+
python embed_sentences.py --sentence_file test_data/test_sentences.txt
|
73 |
+
|
74 |
+
## Citation
|
75 |
+
|
76 |
+
If you use our code or models your work please cite:
|
77 |
+
|
78 |
+
@article{wieting2022beyond,
|
79 |
+
title={Beyond Contrastive Learning: A Variational Generative Model for Multilingual Retrieval},
|
80 |
+
author={Wieting, John and Clark, Jonathan H and Cohen, William W and Neubig, Graham and Berg-Kirkpatrick, Taylor},
|
81 |
+
journal={arXiv preprint arXiv:2212.10726},
|
82 |
+
year={2022}
|
83 |
+
}
|
config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "jwieting/vmsst",
|
3 |
+
"architectures": [
|
4 |
+
"MT5EncoderWithProjection"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoModel": "jwieting/vmsst--modeling_vmsst.MT5EncoderWithProjection"
|
8 |
+
},
|
9 |
+
"d_ff": 2816,
|
10 |
+
"d_kv": 64,
|
11 |
+
"d_model": 1024,
|
12 |
+
"decoder_start_token_id": 0,
|
13 |
+
"dense_act_fn": "gelu_new",
|
14 |
+
"dropout_rate": 0.1,
|
15 |
+
"eos_token_id": 1,
|
16 |
+
"feed_forward_proj": "gated-gelu",
|
17 |
+
"initializer_factor": 1.0,
|
18 |
+
"is_encoder_decoder": true,
|
19 |
+
"is_gated_act": true,
|
20 |
+
"layer_norm_epsilon": 1e-06,
|
21 |
+
"model_type": "mt5",
|
22 |
+
"num_decoder_layers": 24,
|
23 |
+
"num_heads": 16,
|
24 |
+
"num_layers": 24,
|
25 |
+
"output_past": true,
|
26 |
+
"pad_token_id": 0,
|
27 |
+
"relative_attention_max_distance": 128,
|
28 |
+
"relative_attention_num_buckets": 32,
|
29 |
+
"tie_word_embeddings": false,
|
30 |
+
"tokenizer_class": "T5Tokenizer",
|
31 |
+
"torch_dtype": "float32",
|
32 |
+
"transformers_version": "4.30.2",
|
33 |
+
"use_cache": true,
|
34 |
+
"vocab_size": 250112
|
35 |
+
}
|
modeling_vmsst.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tqdm
|
3 |
+
from torch import nn
|
4 |
+
from transformers import MT5EncoderModel, MT5PreTrainedModel
|
5 |
+
|
6 |
+
class MT5EncoderWithProjection(MT5PreTrainedModel):
|
7 |
+
def __init__(self, config):
|
8 |
+
super().__init__(config)
|
9 |
+
self.config = config
|
10 |
+
self.mt5_encoder = MT5EncoderModel(config)
|
11 |
+
self.projection = nn.Linear(config.d_model, config.d_model, bias=False)
|
12 |
+
self.post_init()
|
13 |
+
|
14 |
+
def forward(self, **input_args):
|
15 |
+
hidden_states = self.mt5_encoder(**input_args).last_hidden_state
|
16 |
+
mask = input_args['attention_mask']
|
17 |
+
batch_embeddings = torch.sum(hidden_states * mask[:, :, None], dim=1) / torch.sum(mask, dim=1)[:, None]
|
18 |
+
batch_embeddings = self.projection(batch_embeddings)
|
19 |
+
return batch_embeddings
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed70bb240affcc90a945b5905dc643778806ecf9e3c1ff6542de24fa70056228
|
3 |
+
size 2262056637
|