File size: 3,765 Bytes
62977bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import torch
if torch.cuda.is_available():
    from torch.cuda.amp import autocast
from transformers import BertModel, BertTokenizer, BertTokenizerFast

from pyserini.encode import DocumentEncoder, QueryEncoder
from onnxruntime import ExecutionMode, SessionOptions, InferenceSession


class TctColBertDocumentEncoder(DocumentEncoder):
    def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'):
        self.device = device
        self.onnx = False
        if model_name.endswith('onnx'):
            options = SessionOptions()
            self.session = InferenceSession(model_name, options)
            self.onnx = True
            self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name[:-5])
        else:
            self.model = BertModel.from_pretrained(model_name)
            self.model.to(self.device)
            self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name)

    def encode(self, texts, titles=None, fp16=False,  max_length=512, **kwargs):
        if titles is not None:
            texts = [f'[CLS] [D] {title} {text}' for title, text in zip(titles, texts)]
        else:
            texts = ['[CLS] [D] ' + text for text in texts]
        inputs = self.tokenizer(
            texts,
            max_length=max_length,
            padding="longest",
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )
        if self.onnx:
            inputs_onnx = {name: np.atleast_2d(value) for name, value in inputs.items()}
            inputs.to(self.device)
            outputs, _ = self.session.run(None, inputs_onnx)
            outputs = torch.from_numpy(outputs).to(self.device)
            embeddings = self._mean_pooling(outputs[:, 4:, :], inputs['attention_mask'][:, 4:])
        else:
            inputs.to(self.device)
            if fp16:
                with autocast():
                    with torch.no_grad():
                        outputs = self.model(**inputs)
            else:
                outputs = self.model(**inputs)
            embeddings = self._mean_pooling(outputs["last_hidden_state"][:, 4:, :], inputs['attention_mask'][:, 4:])
        return embeddings.detach().cpu().numpy()


class TctColBertQueryEncoder(QueryEncoder):
    def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'):
        self.device = device
        self.model = BertModel.from_pretrained(model_name)
        self.model.to(self.device)
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name)

    def encode(self, query: str, **kwargs):
        max_length = 36  # hardcode for now
        inputs = self.tokenizer(
            '[CLS] [Q] ' + query + '[MASK]' * max_length,
            max_length=max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )
        inputs.to(self.device)
        outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state.detach().cpu().numpy()
        return np.average(embeddings[:, 4:, :], axis=-2).flatten()