File size: 6,326 Bytes
af8c4b6
 
 
0119b51
 
af8c4b6
0119b51
af8c4b6
 
 
 
0119b51
 
af8c4b6
0119b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8c4b6
0119b51
af8c4b6
 
 
 
 
0119b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8c4b6
0119b51
 
 
 
 
 
 
 
af8c4b6
0119b51
 
 
af8c4b6
0119b51
 
 
 
af8c4b6
0119b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8c4b6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# coding=utf-8
# Copyright 2024 The GTE Team Authors and Alibaba Group.
# Licensed under the Apache License, Version 2.0 (the "License");

from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformers.utils import is_torch_npu_available


class GTEEmbeddidng(torch.nn.Module):
    def __init__(self,
                 model_name: str = None,
                 normalized: bool = True,
                 use_fp16: bool = True,
                 device: str = None
                ):
        super().__init__()
        self.normalized = normalized
        if device:
            self.device = torch.device(device)
        else:
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            elif torch.backends.mps.is_available():
                self.device = torch.device("mps")
            elif is_torch_npu_available():
                self.device = torch.device("npu")
            else:
                self.device = torch.device("cpu")
                use_fp16 = False
        self.use_fp16 = use_fp16
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForTokenClassification.from_pretrained(
            model_name, trust_remote_code=True, torch_dtype=torch.float16 if self.use_fp16 else None
        )
        self.vocab_size = self.model.config.vocab_size
        self.model.to(self.device)

    def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
        # conver to dict
        result = defaultdict(int)
        unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
                             self.tokenizer.unk_token_id])
        # token_weights = np.ceil(token_weights * 100)
        for w, idx in zip(token_weights, input_ids):
            if idx not in unused_tokens and w > 0:
                token = self.tokenizer.decode([int(idx)])
                if w > result[token]:
                    result[token] = w
        return result

    @torch.no_grad()
    def encode(self,
               texts: None,
               dimension: int = None,
               max_length: int = 8192,
               batch_size: int = 16,
               return_dense: bool = True,
               return_sparse: bool = False):
        if dimension is None:
            dimension = self.model.config.hidden_size
        if isinstance(texts, str):
            texts = [texts]
        num_texts = len(texts)
        all_dense_vecs = []
        all_token_weights = []
        for n, i in enumerate(range(0, num_texts, batch_size)):
            batch = texts[i: i + batch_size]
            resulst = self._encode(batch, dimension, max_length, batch_size, return_dense, return_sparse)
            if return_dense:
                all_dense_vecs.append(resulst['dense_embeddings'])
            if return_sparse:
                all_token_weights.extend(resulst['token_weights'])
        all_dense_vecs = torch.cat(all_dense_vecs, dim=0)
        return {
            "dense_embeddings": all_dense_vecs,
            "token_weights": all_token_weights 
        }

    @torch.no_grad()
    def _encode(self,
                texts: Dict[str, torch.Tensor] = None,
                dimension: int = None,
                max_length: int = 1024,
                batch_size: int = 16,
                return_dense: bool = True,
                return_sparse: bool = False):

        text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
        text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
        model_out = self.model(**text_input, return_dict=True)

        output = {}
        if return_dense:
            dense_vecs = model_out.last_hidden_state[:, 0, :dimension]
            if self.normalized:
                dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
            output['dense_embeddings'] = dense_vecs
        if return_sparse:
            token_weights = torch.relu(model_out.logits).squeeze(-1)
            token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
                                                    text_input['input_ids'].cpu().numpy().tolist()))
            output['token_weights'] = token_weights

        return output

    def _compute_sparse_scores(self, embs1, embs2):
        scores = 0
        for token, weight in embs1.items():
            if token in embs2:
                scores += weight * embs2[token]
        return scores

    def compute_sparse_scores(self, embs1, embs2):
        scores = [self._compute_sparse_scores(emb1, emb2) for emb1, emb2 in zip(embs1, embs2)]
        return np.array(scores)

    def compute_dense_scores(self, embs1, embs2):
        scores = torch.sum(embs1*embs2, dim=-1).cpu().detach().numpy()
        return scores

    @torch.no_grad()
    def compute_scores(self, 
        text_pairs: List[Tuple[str, str]], 
        dimension: int = None,
        max_length: int = 1024,
        batch_size: int = 16,
        dense_weight=1.0,
        sparse_weight=0.1):
        text1_list = [text_pair[0] for text_pair in text_pairs]
        text2_list = [text_pair[1] for text_pair in text_pairs]
        embs1 = self.encode(text1_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
        embs2 = self.encode(text2_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
        scores = self.compute_dense_scores(embs1['dense_embeddings'], embs2['dense_embeddings']) * dense_weight + \
            self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
        scores = scores.tolist()
        return scores


if __name__ == '__main__':
    gte = GTEEmbeddidng('Alibaba-NLP/gte-multilingual-base')
    docs =  [
        "黑龙江离俄罗斯很近",
        "哈尔滨是中国黑龙江省的省会,位于中国东北",
        "you are the hero"
    ]
    print('docs', docs)
    embs = gte.encode(docs, return_dense=True,return_sparse=True)
    print('dense vecs', embs['dense_embeddings'])
    print('sparse vecs', embs['token_weights'])