Using ```past_key_values``` for context retention
I am running phi3-mini-4k-instruct-fp16 to create a chatbot. To provide the llm with context of previous conversation I used the pretrained_tokenizer.apply_chat_template(input)
to create the input_ids
, where input
is an array of all previous messages, with past_key_values
set to empty on every new llm.run(...)
. This enabled the model to retain information about the previous messages.
However, I was wondering if I could get the same effect by including only the most recent prompt from the user in the input_ids
and past_key_values
pairs from the previous response.
The first call llm.run(...)
, where the past_key_values
initially is empty, works fine.
On the second call, where the dimensions of tensors in the feed is:
input_ids dims: [1, 8]
position_ids dims: [1, 8]
attention_mask dims: [1, 44]
past_key_values dims: [1, 32, 36, 96]
This error is thrown
Error: [WebGPU] Kernel "[Expand] /model/attn_mask_reformat/input_ids_subgraph/Expand" failed. Error: Expand requires shape to be broadcastable to input
at Object._OrtRun (:8888/node_modules/onnxruntime-web/lib/wasm/binding/ort-wasm-simd.jsep.js:9:401)
at zd (:8888/node_modules/onnxruntime-web/lib/wasm/wasm-core-impl.ts:562:19)
at fi.run (:8888/node_modules/onnxruntime-web/lib/wasm/session-handler-inference.ts:109:21)
at e.run (:8888/node_modules/common/lib/inference-session-impl.ts:110:21)
I believe the error is caused by shape incompatibility. However, I can't understand which tensor is at fault? Can I even use past_key_values
for more efficient context retention?
Here is my llm.run
:
async run(input: any, callback?: (output: string) => void){
const prompt = this.tokenizer.apply_chat_template([input[input.length - 1]], { tokenize: false});
const tokens = await this.tokenizer.encode(prompt);
const input_ids = new Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]);
const output_tokens = [...input_ids.data];
const last_token = 0n;
this.feed['input_ids'] = input_ids;
let seqlen = output_tokens.length;
this.feed['position_ids'] = new Tensor('int64', BigInt64Array.from({ length: seqlen }, (_, i) => BigInt(this.feed['past_key_values.0.key'].dims[2] + i)), [1, seqlen]);
const extra = this.feed['past_key_values.0.key'].dims[2];
this.feed['attention_mask'] = new Tensor(BigInt64Array.from({ length: seqlen + extra }, (_, i) => 1n), [1, seqlen + extra])
while (last_token !== this.eos && last_token != 32007n && seqlen < this.max_seq_len){
seqlen = output_tokens.length;
const outputs = await runSession(this.feed, this.onnxSession);
last_token = BigInt(argmax(outputs.logits));
output_tokens.push(last_token);
if (callback) {
const endIndex = last_token === this.eos ? -1 : output_tokens.length;
const newString = this.tokenizer.decode(output_tokens.slice(tokens.length, endIndex).map(t => Number(t)));
callback(newString);
}
update_kv_cache(this.feed, outputs);
this.feed['input_ids'] = new Tensor('int64', BigInt64Array.from([last_token]), [1, 1]);
this.feed['position_ids'] = new Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]);
this.feed['attention_mask'] = new Tensor(BigInt64Array.from({ length: seqlen + 1 }, () => 1n), [1, seqlen + 1])
}
const output = this.tokenizer.decode(output_tokens.slice(tokens.length, -1).map(t => Number(t)));
return output;
}
[email protected]
@xenova
/[email protected]
onnx model, tokenizer config from microsoft/Phi-3-mini-4k-instruct-onnx-web