|
import torch |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
import numpy as np |
|
import scipy |
|
|
|
from pyserini.encode import QueryEncoder |
|
|
|
|
|
class SlimQueryEncoder(QueryEncoder): |
|
def __init__(self, model_name_or_path, tokenizer_name=None, fusion_weight=.99, device='cpu'): |
|
self.device = device |
|
self.fusion_weight = fusion_weight |
|
self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) |
|
self.model.to(self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) |
|
self.reverse_vocab = {v: k for k, v in self.tokenizer.vocab.items()} |
|
|
|
def encode(self, text, max_length=256, topk=20, return_sparse=False, **kwargs): |
|
inputs = self.tokenizer( |
|
[text], |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=max_length, |
|
add_special_tokens=True, |
|
) |
|
outputs = self.model(**inputs, return_dict=True) |
|
attention_mask = inputs["attention_mask"][:, 1:] |
|
logits = outputs.logits[:, 1:, :] |
|
|
|
full_router_repr = torch.log(1 + torch.relu(logits)) * attention_mask.unsqueeze(-1) |
|
expert_weights, expert_ids = torch.topk(full_router_repr, dim=2, k=topk) |
|
min_expert_weight = torch.min(expert_weights, -1, True)[0] |
|
sparse_expert_weights = torch.where(full_router_repr >= min_expert_weight, full_router_repr, 0) |
|
return self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0] |
|
|
|
def _output_to_weight_dicts(self, batch_expert_weights, batch_expert_ids, batch_sparse_expert_weights, batch_attention, return_sparse): |
|
to_return = [] |
|
for batch_id, sparse_expert_weights in enumerate(batch_sparse_expert_weights): |
|
tok_vector = scipy.sparse.csr_matrix(sparse_expert_weights.detach().numpy()) |
|
upper_vector, lower_vector = {}, {} |
|
max_term, max_weight = None, 0 |
|
for position, (expert_topk_ids, expert_topk_weights, attention_score) in enumerate(zip(batch_expert_ids[batch_id], |
|
batch_expert_weights[batch_id], |
|
batch_attention[batch_id])): |
|
if attention_score > 0: |
|
for expert_id, expert_weight in zip(expert_topk_ids, expert_topk_weights): |
|
if expert_weight > 0: |
|
term, weight = self.reverse_vocab[expert_id.item()], expert_weight.item() |
|
upper_vector[term] = upper_vector.get(term, 0) + weight |
|
if weight > max_weight: |
|
max_term, max_weight = term, weight |
|
if max_term is not None: |
|
lower_vector[term] = lower_vector.get(term, 0) + weight |
|
fusion_vector = {} |
|
for term, weight in upper_vector.items(): |
|
fusion_vector[term] = self.fusion_weight * weight + (1 - self.fusion_weight) * lower_vector.get(term, 0) |
|
if return_sparse: |
|
to_return.append((fusion_vector, tok_vector)) |
|
else: |
|
to_return.append(fusion_vector) |
|
return to_return |