yuchenlin commited on
Commit
9b8eb72
1 Parent(s): 77cf82d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -40,23 +40,19 @@ def respond(
40
  )
41
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
42
 
43
- generated_ids = model.generate(
 
 
44
  model_inputs.input_ids,
45
  max_new_tokens = max_tokens,
46
  temperature = temperature,
47
  top_p = top_p,
48
  repetition_penalty=repetition_penalty,
 
49
  )
50
- generated_ids = [
51
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
52
- ]
53
 
54
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
55
- return response
56
 
57
- """
58
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
59
- """
60
  demo = gr.ChatInterface(
61
  respond,
62
  additional_inputs=[
@@ -74,6 +70,5 @@ demo = gr.ChatInterface(
74
  ],
75
  )
76
 
77
-
78
  if __name__ == "__main__":
79
- demo.launch(share=True)
 
40
  )
41
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
42
 
43
+ streamer = gr.utils.StreamingTextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
44
+
45
+ _ = model.generate(
46
  model_inputs.input_ids,
47
  max_new_tokens = max_tokens,
48
  temperature = temperature,
49
  top_p = top_p,
50
  repetition_penalty=repetition_penalty,
51
+ streamer=streamer
52
  )
 
 
 
53
 
54
+ return streamer
 
55
 
 
 
 
56
  demo = gr.ChatInterface(
57
  respond,
58
  additional_inputs=[
 
70
  ],
71
  )
72
 
 
73
  if __name__ == "__main__":
74
+ demo.launch(share=True)