BAAI
/

AquilaChat-7B / chat_test_NBCE.py
shunxing1234's picture
Upload 2 files
811643a
raw
history blame
4.78 kB
#! -*- coding: utf-8 -*-
# Naive Bayes-based Context Extension (NBCE)
# 使用朴素贝叶斯增加LLM的Context处理长度
# 链接:https://kexue.fm/archives/9617
# Torch 2.0 测试通过
import json
import torch
from transformers import AutoTokenizer
from transformers import AquilaForCausalLM
from transformers import TopPLogitsWarper, LogitsProcessorList
import pdb
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.unk_token
# 加载Aquila模型
model = AquilaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
device = torch.device('cuda')
model.to(device)
# 加载示例Context
from cyg_conversation import default_conversation
conv = default_conversation.copy()
contexts = json.load(open('code_text_2.json'))
question = "请解释这段程序的功能:"
batch = []
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
batch.append(conv.get_prompt())
# 拼接context和question
for ci,context in enumerate(contexts):
conv1 = default_conversation.copy()
conv1.append_message(conv.roles[0], context+question)
conv1.append_message(conv.roles[1], None)
batch.append(conv1.get_prompt())
print('Context长度分布:', [len(text) for text in batch])
print('Context总长度:', sum([len(text) for text in batch]))
# Top-P截断
processors = LogitsProcessorList()
processors.append(TopPLogitsWarper(0.95))
# Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
@torch.inference_mode()
def generate(max_tokens):
"""Naive Bayes-based Context Extension 演示代码
"""
inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
print('input_ids', input_ids.shape)
past_key_values = None
n = input_ids.shape[0]
for i in range(max_tokens):
# 模型输出
outputs = model(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
use_cache=True,
past_key_values=past_key_values
)
past_key_values = outputs.past_key_values
# ===== 核心代码开始 =====
beta, eta = 0.25, 0.1
logits = outputs.logits[:, -1]
logits = logits - logits.logsumexp(dim=-1, keepdims=True)
logits = processors(input_ids, logits)
entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
if i > 0:
entropy[k] -= eta
k = entropy[1:].argmin() + 1
logits_max = logits[k]
logits_uncond = logits[0]
logits_merged = (1 + beta) * logits_max - beta * logits_uncond
logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
# ===== 核心代码结束 =====
# 构建分布,采样
# tau = 1是标准的随机采样,tau->0则是贪心搜索
# 简单起见,这里没有实现topk、topp截断
tau = 0.01
probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
if next_tokens[0] == tokenizer.eos_token_id:
break
ret = tokenizer.batch_decode(next_tokens)
print(ret[0], flush=True, end='')
# prepare for next iteration
input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1)
if __name__ == '__main__':
generate(1000)
"""
========= 输出结果参考 =========
1.菲律宾国家电网公司,中国占股多少?
答:中国国家电网公司持有菲律宾国家电网公司40%的股份。
2.领英计划裁员多少人?
答:领英计划裁员716人。
3.吉利德收购Pharmasset的价格是多少?
答:吉利德收购Pharmasset的价格为110亿美元。
4.丙肝神药Sovaldi在哪一年上市?
答:丙肝神药Sovaldi于2013年上市。
5.中亚峰会将在哪里举行?由谁主持?
答:中亚峰会将在陕西省西安市举行,由国家主席习近平主持。
6.哪个演员由于侮辱人民军队而被立案调查?
答:李昊石因在表演中存在侮辱人民军队的言论而被立案调查。
7.哪个项目宣称“能过坦克”的水上道路?
答:湖北恩施宣称的“能过坦克”水上道路。
8.如果你是默沙东的CEO,你的首要任务是什么?
答:如果我是默沙东的CEO,我的首要任务是如何让基本盘更加坚固,并通过药物联用获得更好的增长。
"""