saffr0n commited on
Commit
a610a17
1 Parent(s): ca32be5

update from gr.ChatInterface to gr.Chatbot to enable chat_history to be actually passed through

Browse files
Files changed (1) hide show
  1. app.py +95 -52
app.py CHANGED
@@ -86,61 +86,104 @@ def generate(
86
  outputs.append(text)
87
  yield "".join(outputs)
88
 
89
-
90
- chat_interface = gr.ChatInterface(
91
- fn=generate,
92
- additional_inputs=[
93
- gr.Textbox(label="System prompt", lines=6),
94
- gr.Slider(
95
- label="Max new tokens",
96
- minimum=1,
97
- maximum=MAX_MAX_NEW_TOKENS,
98
- step=1,
99
- value=DEFAULT_MAX_NEW_TOKENS,
100
- ),
101
- gr.Slider(
102
- label="Temperature",
103
- minimum=0.1,
104
- maximum=4.0,
105
- step=0.1,
106
- value=0.6,
107
- ),
108
- gr.Slider(
109
- label="Top-p (nucleus sampling)",
110
- minimum=0.05,
111
- maximum=1.0,
112
- step=0.05,
113
- value=0.9,
114
- ),
115
- gr.Slider(
116
- label="Top-k",
117
- minimum=1,
118
- maximum=1000,
119
- step=1,
120
- value=50,
121
- ),
122
- gr.Slider(
123
- label="Repetition penalty",
124
- minimum=1.0,
125
- maximum=2.0,
126
- step=0.05,
127
- value=1.2,
128
- ),
129
- ],
130
- stop_btn=None,
131
- examples=[
132
- ["Ninawezaje kupata usingizi haraka?"],
133
- ["Bosi wangu anadhibiti sana, nifanye nini?"],
134
- ["Je, ni vipindi gani muhimu katika historia vya kujua kuvihusu?"],
135
- ["Ni kazi gani nzuri ikiwa ninataka kupata pesa lakini pia kufurahiya?"],
136
- ["Nivae nini kwenye harusi?"],
137
- ],
138
- )
 
 
 
 
 
 
 
139
 
140
  with gr.Blocks(css="style.css") as demo:
141
  gr.Markdown(DESCRIPTION)
142
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
143
- chat_interface.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  gr.Markdown(LICENSE)
145
 
146
  if __name__ == "__main__":
 
86
  outputs.append(text)
87
  yield "".join(outputs)
88
 
89
+ examples = [
90
+ ["Ninawezaje kupata usingizi haraka?"],
91
+ ["Bosi wangu anadhibiti sana, nifanye nini?"],
92
+ ["Je, ni vipindi gani muhimu katika historia vya kujua kuvihusu?"],
93
+ ["Ni kazi gani nzuri ikiwa ninataka kupata pesa lakini pia kufurahiya?"],
94
+ ["Nivae nini kwenye harusi?"],
95
+ ]
96
+
97
+ # chat_interface = gr.ChatInterface(
98
+ # fn=generate,
99
+ # additional_inputs=[
100
+ # gr.Textbox(label="System prompt", lines=6),
101
+ # gr.Slider(
102
+ # label="Max new tokens",
103
+ # minimum=1,
104
+ # maximum=MAX_MAX_NEW_TOKENS,
105
+ # step=1,
106
+ # value=DEFAULT_MAX_NEW_TOKENS,
107
+ # ),
108
+ # gr.Slider(
109
+ # label="Temperature",
110
+ # minimum=0.1,
111
+ # maximum=4.0,
112
+ # step=0.1,
113
+ # value=0.6,
114
+ # ),
115
+ # gr.Slider(
116
+ # label="Top-p (nucleus sampling)",
117
+ # minimum=0.05,
118
+ # maximum=1.0,
119
+ # step=0.05,
120
+ # value=0.9,
121
+ # ),
122
+ # gr.Slider(
123
+ # label="Top-k",
124
+ # minimum=1,
125
+ # maximum=1000,
126
+ # step=1,
127
+ # value=50,
128
+ # ),
129
+ # gr.Slider(
130
+ # label="Repetition penalty",
131
+ # minimum=1.0,
132
+ # maximum=2.0,
133
+ # step=0.05,
134
+ # value=1.2,
135
+ # ),
136
+ # ],
137
+ # stop_btn=None,
138
+ # examples=[
139
+ # ["Ninawezaje kupata usingizi haraka?"],
140
+ # ["Bosi wangu anadhibiti sana, nifanye nini?"],
141
+ # ["Je, ni vipindi gani muhimu katika historia vya kujua kuvihusu?"],
142
+ # ["Ni kazi gani nzuri ikiwa ninataka kupata pesa lakini pia kufurahiya?"],
143
+ # ["Nivae nini kwenye harusi?"],
144
+ # ],
145
+ # )
146
 
147
  with gr.Blocks(css="style.css") as demo:
148
  gr.Markdown(DESCRIPTION)
149
+ chatbot = gr.Chatbot()
150
+ msg = gr.Textbox(label="Enter your message")
151
+ submit_btn = gr.Button("Submit")
152
+ clear = gr.Button("Clear")
153
+
154
+ def user(user_message, history):
155
+ return "", history + [[user_message, None]]
156
+
157
+ def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
158
+ user_message = history[-1][0]
159
+ chat_history = [(msg[0], msg[1]) for msg in history[:-1]]
160
+ bot_message = ""
161
+ for response in generate(user_message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
162
+ bot_message = response
163
+ history[-1][1] = bot_message
164
+ yield history
165
+
166
+ gr.Examples(examples=examples, inputs=[msg], label="Examples")
167
+
168
+ with gr.Accordion("Advanced Options", open=False):
169
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
170
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
171
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
172
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
173
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
174
+
175
+ submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
176
+ bot,
177
+ [chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
178
+ chatbot,
179
+ )
180
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
181
+ bot,
182
+ [chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
183
+ chatbot,
184
+ )
185
+ clear.click(lambda: None, None, chatbot, queue=False)
186
+
187
  gr.Markdown(LICENSE)
188
 
189
  if __name__ == "__main__":