|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import torch |
|
from transformers import AutoTokenizer |
|
from transformers import AquilaForCausalLM |
|
from transformers import TopPLogitsWarper, LogitsProcessorList |
|
import pdb |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
tokenizer.padding_side = 'left' |
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
|
|
|
model = AquilaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16) |
|
device = torch.device('cuda') |
|
model.to(device) |
|
|
|
from cyg_conversation import default_conversation |
|
|
|
conv = default_conversation.copy() |
|
contexts = json.load(open('code_text_2.json')) |
|
|
|
question = "请解释这段程序的功能:" |
|
batch = [] |
|
conv.append_message(conv.roles[0], question) |
|
conv.append_message(conv.roles[1], None) |
|
batch.append(conv.get_prompt()) |
|
|
|
for ci,context in enumerate(contexts): |
|
conv1 = default_conversation.copy() |
|
conv1.append_message(conv.roles[0], context+question) |
|
conv1.append_message(conv.roles[1], None) |
|
batch.append(conv1.get_prompt()) |
|
print('Context长度分布:', [len(text) for text in batch]) |
|
print('Context总长度:', sum([len(text) for text in batch])) |
|
|
|
|
|
processors = LogitsProcessorList() |
|
processors.append(TopPLogitsWarper(0.95)) |
|
|
|
|
|
@torch.inference_mode() |
|
def generate(max_tokens): |
|
"""Naive Bayes-based Context Extension 演示代码 |
|
""" |
|
inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device) |
|
input_ids = inputs.input_ids |
|
attention_mask = inputs.attention_mask |
|
|
|
print('input_ids', input_ids.shape) |
|
past_key_values = None |
|
n = input_ids.shape[0] |
|
|
|
for i in range(max_tokens): |
|
|
|
outputs = model(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
use_cache=True, |
|
past_key_values=past_key_values |
|
) |
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
beta, eta = 0.25, 0.1 |
|
logits = outputs.logits[:, -1] |
|
logits = logits - logits.logsumexp(dim=-1, keepdims=True) |
|
logits = processors(input_ids, logits) |
|
entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1) |
|
if i > 0: |
|
entropy[k] -= eta |
|
k = entropy[1:].argmin() + 1 |
|
logits_max = logits[k] |
|
logits_uncond = logits[0] |
|
logits_merged = (1 + beta) * logits_max - beta * logits_uncond |
|
logits = torch.where(logits_uncond > -100, logits_merged, logits_max) |
|
|
|
|
|
|
|
|
|
|
|
tau = 0.01 |
|
probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1) |
|
next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1) |
|
if next_tokens[0] == tokenizer.eos_token_id: |
|
break |
|
|
|
ret = tokenizer.batch_decode(next_tokens) |
|
print(ret[0], flush=True, end='') |
|
|
|
|
|
input_ids = next_tokens.unsqueeze(-1).tile(n, 1) |
|
attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1) |
|
|
|
|
|
if __name__ == '__main__': |
|
generate(1000) |
|
|
|
|
|
""" |
|
========= 输出结果参考 ========= |
|
|
|
1.菲律宾国家电网公司,中国占股多少? |
|
答:中国国家电网公司持有菲律宾国家电网公司40%的股份。 |
|
|
|
2.领英计划裁员多少人? |
|
答:领英计划裁员716人。 |
|
|
|
3.吉利德收购Pharmasset的价格是多少? |
|
答:吉利德收购Pharmasset的价格为110亿美元。 |
|
|
|
4.丙肝神药Sovaldi在哪一年上市? |
|
答:丙肝神药Sovaldi于2013年上市。 |
|
|
|
5.中亚峰会将在哪里举行?由谁主持? |
|
答:中亚峰会将在陕西省西安市举行,由国家主席习近平主持。 |
|
|
|
6.哪个演员由于侮辱人民军队而被立案调查? |
|
答:李昊石因在表演中存在侮辱人民军队的言论而被立案调查。 |
|
|
|
7.哪个项目宣称“能过坦克”的水上道路? |
|
答:湖北恩施宣称的“能过坦克”水上道路。 |
|
|
|
8.如果你是默沙东的CEO,你的首要任务是什么? |
|
答:如果我是默沙东的CEO,我的首要任务是如何让基本盘更加坚固,并通过药物联用获得更好的增长。 |
|
""" |
|
|