gte-multilingual-base / scripts /gte_embedding.py
zyznull's picture
Create scripts/gte_embedding.py
0119b51 verified
raw
history blame
7.83 kB
import logging
from typing import Dict, Optional, List, Tuple
import os
import heapq
import json
import logging
import os
import queue
import sys
import time
from tqdm import tqdm
import torch
from collections import defaultdict
from torch.utils.data._utils.worker import ManagerWatchdog
import numpy as np
import torch.distributed as dist
from torch import nn, Tensor
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
class GTEEmbeddidng(nn.Module):
def __init__(self,
model_name: str = None,
normalized: bool = True,
pooling_method: str = 'cls',
use_fp16: bool = True,
device: str = None
):
super().__init__()
self.load_model(model_name)
self.vocab_size = self.model.config.vocab_size
self.normalized = normalized
self.pooling_method = pooling_method
if device:
self.device = torch.device(device)
else:
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
self.model.to(self.device)
self.sparse_linear.to(self.device)
if use_fp16:
self.model.half()
self.sparse_linear.half()
def load_model(self, model_name):
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.sparse_linear = torch.nn.Linear(in_features=self.model.config.hidden_size, out_features=1)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model.eval()
if os.path.exists(os.path.join(model_name, 'sparse_linear.pt')):
logger.info('loading existing sparse_linear---------')
self.load_pooler(model_dir=model_name)
else:
logger.warring('The parameters of sparse linear is not found')
def dense_embedding(self, hidden_state, mask):
if self.pooling_method == 'cls':
return hidden_state[:, 0]
elif self.pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
token_weights = torch.relu(self.sparse_linear(hidden_state))
return token_weights
def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
# conver to dict
result = defaultdict(int)
unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
self.tokenizer.unk_token_id])
# token_weights = np.ceil(token_weights * 100)
for w, idx in zip(token_weights, input_ids):
if idx not in unused_tokens and w > 0:
token = self.tokenizer.decode([int(idx)])
if w > result[token]:
result[token] = w
return result
@torch.no_grad()
def encode(self,
texts: None,
dimension: int = None,
max_length: int = 8192,
batch_size: int = 16,
return_dense: bool = True,
return_sparse: bool = False):
if dimension is None:
dimension = self.model.config.hidden_size
if isinstance(texts, str):
texts = [texts]
num_texts = len(texts)
all_dense_vecs = []
all_token_weights = []
for n, i in enumerate(range(0, num_texts, batch_size)):
batch = texts[i: i + batch_size]
resulst = self._encode(batch, dimension, max_length, batch_size, return_dense, return_sparse)
if return_dense:
all_dense_vecs.append(resulst['dense_embeddings'])
if return_sparse:
all_token_weights.extend(resulst['token_weights'])
all_dense_vecs = torch.cat(all_dense_vecs, dim=0)
return {
"dense_embeddings": all_dense_vecs,
"token_weights": all_token_weights
}
@torch.no_grad()
def _encode(self,
texts: Dict[str, Tensor] = None,
dimension: int = None,
max_length: int = 1024,
batch_size: int = 16,
return_dense: bool = True,
return_sparse: bool = False):
text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
last_hidden_state = self.model(**text_input, return_dict=True).last_hidden_state
output = {}
if return_dense:
dense_vecs = self.dense_embedding(last_hidden_state, text_input['attention_mask'])
dense_vecs = dense_vecs[:, :dimension]
if self.normalized:
dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
output['dense_embeddings'] = dense_vecs
if return_sparse:
token_weights = self.sparse_embedding(last_hidden_state, text_input['input_ids']).squeeze(-1)
token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
text_input['input_ids'].cpu().numpy().tolist()))
output['token_weights'] = token_weights
return output
def load_pooler(self, model_dir):
sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')
self.sparse_linear.load_state_dict(sparse_state_dict)
def _compute_sparse_scores(self, embs1, embs2):
scores = 0
for token, weight in embs1.items():
if token in embs2:
scores += weight * embs2[token]
return scores
def compute_sparse_scores(self, embs1, embs2):
scores = [self._compute_sparse_scores(emb1, emb2) for emb1, emb2 in zip(embs1, embs2)]
return np.array(scores)
def compute_dense_scores(self, embs1, embs2):
scores = torch.sum(embs1*embs2, dim=-1).cpu().detach().numpy()
return scores
@torch.no_grad()
def compute_scores(self,
text_pairs: List[Tuple[str, str]],
dimension: int = None,
max_length: int = 1024,
batch_size: int = 16,
dense_weight=1.0,
sparse_weight=0.1):
text1_list = [text_pair[0] for text_pair in text_pairs]
text2_list = [text_pair[1] for text_pair in text_pairs]
embs1 = self.encode(text1_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
embs2 = self.encode(text2_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
scores = self.compute_dense_scores(embs1['dense_embeddings'], embs2['dense_embeddings']) * dense_weight + \
self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
scores = scores.tolist()
return scores