File size: 2,323 Bytes
47c0876
 
 
70f429b
 
 
 
 
dc74ffc
5f8fbb5
70f429b
 
 
 
 
 
7504cba
70f429b
 
 
7504cba
70f429b
 
 
d32f7fb
b7e8c63
70f429b
 
7504cba
70f429b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff4e33d
70f429b
 
 
d32f7fb
ff4e33d
 
70f429b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
---
license: cc-by-nc-sa-4.0
---

Our best attempt at reproducing [RankT5 Enc-Softmax](https://arxiv.org/pdf/2210.10634.pdf), with a few important differences:

1. We use a SPLADE first stage for the negatives vs GTR on the paper
2. We train using Pytorch vs Flaxx on the paper
3. ~~We use the original t5-3b vs Flan T5-3b on the paper~~ -> Actually the paper also uses t5-3b
4. The head is not exactly the same, here we add Linear->LayerNorm->Linear and actually make a mistake by not including a nonlinearity. The original paper uses just a dense layer. Fixing this should improve our performance because we have more layers without actually using them correctly

This leads to what seems to be a slightly worse performance (42.8 vs 43.? on the paper) and seems slightly worse on BEIR as well. 

To use this model, first clone the huggingface repo

```
git clone https://huggingface.co/naver/trecdl22-crossencoder-rankT53b-repro

```

And then we suggest loading it like follows:

```
import torch
from transformers import T5EncoderModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput

class T5EncoderRerank(torch.nn.Module):
    def __init__(self, model_type_or_dir):
        super().__init__()
        self.model = T5EncoderModel.from_pretrained(model_type_or_dir)
        self.config = self.model.config
        self.first_transform = torch.nn.Linear(self.config.d_model, self.config.d_model)
        self.layer_norm = torch.nn.LayerNorm(self.config.d_model, eps=1e-12)
        self.linear = torch.nn.Linear(self.config.d_model,1)

    def forward(self, **kwargs):
        result = self.model(**kwargs).last_hidden_state[:,0,:]
        first_transformed = self.first_transform(result)
        layer_normed = self.layer_norm(first_transformed)
        logits = self.linear(layer_normed)
        return SequenceClassifierOutput(
            logits=logits
        )


original_model="t5-3b"
path_checkpoint="trecdl22-crossencoder-rankT53b-repro/pytorch_model.bin"

print("Loading")
model = T5EncoderRerank(original_model)
model.load_state_dict(torch.load(path_checkpoint,map_location=torch.device("cpu")))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(original_model)
print("loaded")

```