vilarin commited on
Commit
36e78de
1 Parent(s): 45b2636

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -48,16 +48,15 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
48
 
49
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
50
 
51
- streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
52
 
53
  generate_kwargs = dict(
54
  input_ids=input_ids,
55
  streamer=streamer,
56
- max_new_tokens=max_new_tokens,
57
  do_sample=True,
58
  temperature=temperature,
59
  repetition_penalty=1.2,
60
- eos_token_id=model.config.eos_token_id,
61
  )
62
 
63
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -65,7 +64,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
65
 
66
  buffer = ""
67
  for new_text in streamer:
68
- buffer += new_text
69
  yield buffer
70
 
71
 
 
48
 
49
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
50
 
51
+ streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
52
 
53
  generate_kwargs = dict(
54
  input_ids=input_ids,
55
  streamer=streamer,
56
+ max_length=max_new_tokens,
57
  do_sample=True,
58
  temperature=temperature,
59
  repetition_penalty=1.2,
 
60
  )
61
 
62
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
64
 
65
  buffer = ""
66
  for new_text in streamer:
67
+ buffer[-1][1] += new_text
68
  yield buffer
69
 
70