File size: 4,780 Bytes
811643a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
#! -*- coding: utf-8 -*-
# Naive Bayes-based Context Extension (NBCE)
# 使用朴素贝叶斯增加LLM的Context处理长度
# 链接:https://kexue.fm/archives/9617
# Torch 2.0 测试通过
import json
import torch
from transformers import AutoTokenizer
from transformers import AquilaForCausalLM
from transformers import TopPLogitsWarper, LogitsProcessorList
import pdb
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.unk_token
# 加载Aquila模型
model = AquilaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
device = torch.device('cuda')
model.to(device)
# 加载示例Context
from cyg_conversation import default_conversation
conv = default_conversation.copy()
contexts = json.load(open('code_text_2.json'))
question = "请解释这段程序的功能:"
batch = []
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
batch.append(conv.get_prompt())
# 拼接context和question
for ci,context in enumerate(contexts):
conv1 = default_conversation.copy()
conv1.append_message(conv.roles[0], context+question)
conv1.append_message(conv.roles[1], None)
batch.append(conv1.get_prompt())
print('Context长度分布:', [len(text) for text in batch])
print('Context总长度:', sum([len(text) for text in batch]))
# Top-P截断
processors = LogitsProcessorList()
processors.append(TopPLogitsWarper(0.95))
# Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
@torch.inference_mode()
def generate(max_tokens):
"""Naive Bayes-based Context Extension 演示代码
"""
inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
print('input_ids', input_ids.shape)
past_key_values = None
n = input_ids.shape[0]
for i in range(max_tokens):
# 模型输出
outputs = model(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
use_cache=True,
past_key_values=past_key_values
)
past_key_values = outputs.past_key_values
# ===== 核心代码开始 =====
beta, eta = 0.25, 0.1
logits = outputs.logits[:, -1]
logits = logits - logits.logsumexp(dim=-1, keepdims=True)
logits = processors(input_ids, logits)
entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
if i > 0:
entropy[k] -= eta
k = entropy[1:].argmin() + 1
logits_max = logits[k]
logits_uncond = logits[0]
logits_merged = (1 + beta) * logits_max - beta * logits_uncond
logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
# ===== 核心代码结束 =====
# 构建分布,采样
# tau = 1是标准的随机采样,tau->0则是贪心搜索
# 简单起见,这里没有实现topk、topp截断
tau = 0.01
probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
if next_tokens[0] == tokenizer.eos_token_id:
break
ret = tokenizer.batch_decode(next_tokens)
print(ret[0], flush=True, end='')
# prepare for next iteration
input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1)
if __name__ == '__main__':
generate(1000)
"""
========= 输出结果参考 =========
1.菲律宾国家电网公司,中国占股多少?
答:中国国家电网公司持有菲律宾国家电网公司40%的股份。
2.领英计划裁员多少人?
答:领英计划裁员716人。
3.吉利德收购Pharmasset的价格是多少?
答:吉利德收购Pharmasset的价格为110亿美元。
4.丙肝神药Sovaldi在哪一年上市?
答:丙肝神药Sovaldi于2013年上市。
5.中亚峰会将在哪里举行?由谁主持?
答:中亚峰会将在陕西省西安市举行,由国家主席习近平主持。
6.哪个演员由于侮辱人民军队而被立案调查?
答:李昊石因在表演中存在侮辱人民军队的言论而被立案调查。
7.哪个项目宣称“能过坦克”的水上道路?
答:湖北恩施宣称的“能过坦克”水上道路。
8.如果你是默沙东的CEO,你的首要任务是什么?
答:如果我是默沙东的CEO,我的首要任务是如何让基本盘更加坚固,并通过药物联用获得更好的增长。
"""
|