zyznull commited on
Commit
0119b51
1 Parent(s): 7ee2ea3

Create scripts/gte_embedding.py

Browse files
Files changed (1) hide show
  1. scripts/gte_embedding.py +190 -0
scripts/gte_embedding.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Optional, List, Tuple
3
+ import os
4
+
5
+ import heapq
6
+ import json
7
+ import logging
8
+ import os
9
+ import queue
10
+ import sys
11
+ import time
12
+ from tqdm import tqdm
13
+
14
+ import torch
15
+ from collections import defaultdict
16
+ from torch.utils.data._utils.worker import ManagerWatchdog
17
+ import numpy as np
18
+ import torch.distributed as dist
19
+ from torch import nn, Tensor
20
+ import torch.nn.functional as F
21
+ from transformers import AutoModel, AutoTokenizer
22
+ from transformers.file_utils import ModelOutput
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GTEEmbeddidng(nn.Module):
28
+ def __init__(self,
29
+ model_name: str = None,
30
+ normalized: bool = True,
31
+ pooling_method: str = 'cls',
32
+ use_fp16: bool = True,
33
+ device: str = None
34
+ ):
35
+ super().__init__()
36
+ self.load_model(model_name)
37
+ self.vocab_size = self.model.config.vocab_size
38
+ self.normalized = normalized
39
+ self.pooling_method = pooling_method
40
+ if device:
41
+ self.device = torch.device(device)
42
+ else:
43
+ if torch.cuda.is_available():
44
+ self.device = torch.device("cuda")
45
+ elif torch.backends.mps.is_available():
46
+ self.device = torch.device("mps")
47
+ elif is_torch_npu_available():
48
+ self.device = torch.device("npu")
49
+ else:
50
+ self.device = torch.device("cpu")
51
+ use_fp16 = False
52
+ self.model.to(self.device)
53
+ self.sparse_linear.to(self.device)
54
+ if use_fp16:
55
+ self.model.half()
56
+ self.sparse_linear.half()
57
+
58
+ def load_model(self, model_name):
59
+ if not os.path.exists(model_name):
60
+ cache_folder = os.getenv('HF_HUB_CACHE')
61
+ model_name = snapshot_download(repo_id=model_name,
62
+ cache_dir=cache_folder,
63
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
64
+
65
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
66
+ self.sparse_linear = torch.nn.Linear(in_features=self.model.config.hidden_size, out_features=1)
67
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
68
+ self.model.eval()
69
+ if os.path.exists(os.path.join(model_name, 'sparse_linear.pt')):
70
+ logger.info('loading existing sparse_linear---------')
71
+ self.load_pooler(model_dir=model_name)
72
+ else:
73
+ logger.warring('The parameters of sparse linear is not found')
74
+
75
+ def dense_embedding(self, hidden_state, mask):
76
+ if self.pooling_method == 'cls':
77
+ return hidden_state[:, 0]
78
+ elif self.pooling_method == 'mean':
79
+ s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
80
+ d = mask.sum(axis=1, keepdim=True).float()
81
+ return s / d
82
+
83
+ def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
84
+ token_weights = torch.relu(self.sparse_linear(hidden_state))
85
+ return token_weights
86
+
87
+ def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
88
+ # conver to dict
89
+ result = defaultdict(int)
90
+ unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
91
+ self.tokenizer.unk_token_id])
92
+ # token_weights = np.ceil(token_weights * 100)
93
+ for w, idx in zip(token_weights, input_ids):
94
+ if idx not in unused_tokens and w > 0:
95
+ token = self.tokenizer.decode([int(idx)])
96
+ if w > result[token]:
97
+ result[token] = w
98
+ return result
99
+
100
+ @torch.no_grad()
101
+ def encode(self,
102
+ texts: None,
103
+ dimension: int = None,
104
+ max_length: int = 8192,
105
+ batch_size: int = 16,
106
+ return_dense: bool = True,
107
+ return_sparse: bool = False):
108
+ if dimension is None:
109
+ dimension = self.model.config.hidden_size
110
+ if isinstance(texts, str):
111
+ texts = [texts]
112
+ num_texts = len(texts)
113
+ all_dense_vecs = []
114
+ all_token_weights = []
115
+ for n, i in enumerate(range(0, num_texts, batch_size)):
116
+ batch = texts[i: i + batch_size]
117
+ resulst = self._encode(batch, dimension, max_length, batch_size, return_dense, return_sparse)
118
+ if return_dense:
119
+ all_dense_vecs.append(resulst['dense_embeddings'])
120
+ if return_sparse:
121
+ all_token_weights.extend(resulst['token_weights'])
122
+ all_dense_vecs = torch.cat(all_dense_vecs, dim=0)
123
+ return {
124
+ "dense_embeddings": all_dense_vecs,
125
+ "token_weights": all_token_weights
126
+ }
127
+
128
+ @torch.no_grad()
129
+ def _encode(self,
130
+ texts: Dict[str, Tensor] = None,
131
+ dimension: int = None,
132
+ max_length: int = 1024,
133
+ batch_size: int = 16,
134
+ return_dense: bool = True,
135
+ return_sparse: bool = False):
136
+
137
+ text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
138
+ text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
139
+ last_hidden_state = self.model(**text_input, return_dict=True).last_hidden_state
140
+
141
+ output = {}
142
+ if return_dense:
143
+ dense_vecs = self.dense_embedding(last_hidden_state, text_input['attention_mask'])
144
+ dense_vecs = dense_vecs[:, :dimension]
145
+ if self.normalized:
146
+ dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
147
+ output['dense_embeddings'] = dense_vecs
148
+ if return_sparse:
149
+ token_weights = self.sparse_embedding(last_hidden_state, text_input['input_ids']).squeeze(-1)
150
+ token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
151
+ text_input['input_ids'].cpu().numpy().tolist()))
152
+ output['token_weights'] = token_weights
153
+
154
+ return output
155
+
156
+ def load_pooler(self, model_dir):
157
+ sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')
158
+ self.sparse_linear.load_state_dict(sparse_state_dict)
159
+
160
+ def _compute_sparse_scores(self, embs1, embs2):
161
+ scores = 0
162
+ for token, weight in embs1.items():
163
+ if token in embs2:
164
+ scores += weight * embs2[token]
165
+ return scores
166
+
167
+ def compute_sparse_scores(self, embs1, embs2):
168
+ scores = [self._compute_sparse_scores(emb1, emb2) for emb1, emb2 in zip(embs1, embs2)]
169
+ return np.array(scores)
170
+
171
+ def compute_dense_scores(self, embs1, embs2):
172
+ scores = torch.sum(embs1*embs2, dim=-1).cpu().detach().numpy()
173
+ return scores
174
+
175
+ @torch.no_grad()
176
+ def compute_scores(self,
177
+ text_pairs: List[Tuple[str, str]],
178
+ dimension: int = None,
179
+ max_length: int = 1024,
180
+ batch_size: int = 16,
181
+ dense_weight=1.0,
182
+ sparse_weight=0.1):
183
+ text1_list = [text_pair[0] for text_pair in text_pairs]
184
+ text2_list = [text_pair[1] for text_pair in text_pairs]
185
+ embs1 = self.encode(text1_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
186
+ embs2 = self.encode(text2_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
187
+ scores = self.compute_dense_scores(embs1['dense_embeddings'], embs2['dense_embeddings']) * dense_weight + \
188
+ self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
189
+ scores = scores.tolist()
190
+ return scores