BAAI
/

ldwang commited on
Commit
a3e49d8
1 Parent(s): 7b1d55e

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +9 -5
predict.py CHANGED
@@ -310,6 +310,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,
310
 
311
  example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
312
 
 
 
 
313
  while(len(history) > 0 and (len(example) < max_token)):
314
  tmp = history.pop()
315
  if tmp[0] == 'ASSISTANT':
@@ -333,7 +336,7 @@ def predict(model, text, tokenizer=None,
333
  sft=True, convo_template = "",
334
  device = "cuda",
335
  model_name="AquilaChat2-7B",
336
- history=[],
337
  **kwargs):
338
 
339
  vocab = tokenizer.get_vocab()
@@ -344,7 +347,7 @@ def predict(model, text, tokenizer=None,
344
  template_map = {"AquilaChat2-7B": "aquila-v1",
345
  "AquilaChat2-34B": "aquila-legacy",
346
  "AquilaChat2-7B-16K": "aquila",
347
- "AquilaChat2-34B-16K": "aquila-v1"}
348
  if not convo_template:
349
  convo_template=template_map.get(model_name, "aquila-chat")
350
 
@@ -435,8 +438,9 @@ def predict(model, text, tokenizer=None,
435
  convert_tokens = convert_tokens[1:]
436
  probs = probs[1:]
437
 
438
- # Update history
439
- history.insert(0, ('ASSISTANT', out))
440
- history.insert(0, ('USER', text))
 
441
 
442
  return out
 
310
 
311
  example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
312
 
313
+ if history is None or not isinstance(history, list):
314
+ history = []
315
+
316
  while(len(history) > 0 and (len(example) < max_token)):
317
  tmp = history.pop()
318
  if tmp[0] == 'ASSISTANT':
 
336
  sft=True, convo_template = "",
337
  device = "cuda",
338
  model_name="AquilaChat2-7B",
339
+ history=None,
340
  **kwargs):
341
 
342
  vocab = tokenizer.get_vocab()
 
347
  template_map = {"AquilaChat2-7B": "aquila-v1",
348
  "AquilaChat2-34B": "aquila-legacy",
349
  "AquilaChat2-7B-16K": "aquila",
350
+ "AquilaChat2-34B-16K": "aquila"}
351
  if not convo_template:
352
  convo_template=template_map.get(model_name, "aquila-chat")
353
 
 
438
  convert_tokens = convert_tokens[1:]
439
  probs = probs[1:]
440
 
441
+ if isinstance(history, list):
442
+ # Update history
443
+ history.insert(0, ('ASSISTANT', out))
444
+ history.insert(0, ('USER', text))
445
 
446
  return out