YC-Chen commited on
Commit
9901656
1 Parent(s): 8a6cc8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -4
app.py CHANGED
@@ -1,19 +1,346 @@
1
-
 
 
 
2
 
3
  import gradio as gr
 
 
 
4
 
 
 
 
 
 
5
 
6
 
7
  DESCRIPTION = """
8
- # Demo: Breeze-7B-Instruct-v0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- Breeze-7B-Instruct already updated to v1.0. Please check the new demo [here](https://huggingface.co/spaces/MediaTek-Research/Demo-MR-Breeze-7B).
11
  """
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  with gr.Blocks() as demo:
15
  gr.Markdown(DESCRIPTION)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- demo.launch()
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import time
5
 
6
  import gradio as gr
7
+ from transformers import AutoTokenizer
8
+ import psycopg2
9
+
10
 
11
+ import socket
12
+ hostname=socket.gethostname()
13
+ IPAddr=socket.gethostbyname(hostname)
14
+ print("Your Computer Name is:" + hostname)
15
+ print("Your Computer IP Address is:" + IPAddr)
16
 
17
 
18
  DESCRIPTION = """
19
+ # MediaTek Research Breeze-7B
20
+
21
+ MediaTek Research Breeze-7B (hereinafter referred to as Breeze-7B) is a language model family that builds on top of [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1), specifically intended for Traditional Chinese use.
22
+
23
+ [Breeze-7B-Base](https://huggingface.co/MediaTek-Research/Breeze-7B-Base-v1_0) is the base model for the Breeze-7B series.
24
+ It is suitable for use if you have substantial fine-tuning data to tune it for your specific use case.
25
+
26
+ [Breeze-7B-Instruct](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-v1_0) derives from the base model Breeze-7B-Base, making the resulting model amenable to be used as-is for commonly seen tasks.
27
+
28
+
29
+ The current release version of Breeze-7B is v1.0.
30
+
31
+ *A project by the members (in alphabetical order): Chan-Jan Hsu 許湛然, Chang-Le Liu 劉昶樂, Feng-Ting Liao 廖峰挺, Po-Chun Hsu 許博竣, Yi-Chang Chen 陳宜昌, and the supervisor Da-Shan Shiu 許大山.*
32
+
33
+ **免責聲明: MediaTek Research Breeze-7B 並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。**
34
+ """
35
+
36
+ LICENSE = """
37
 
 
38
  """
39
 
40
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan."
41
+
42
+ API_URL = os.environ.get("API_URL")
43
+ TOKEN = os.environ.get("TOKEN")
44
+
45
+ HEADERS = {
46
+ "Authorization": f"Bearer {TOKEN}",
47
+ "Content-Type": "application/json",
48
+ "accept": "application/json"
49
+ }
50
+
51
+
52
+ MAX_SEC = 30
53
+ MAX_INPUT_LENGTH = 5000
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained("MediaTek-Research/Breeze-7B-Instruct-v0_1")
56
+
57
+ def insert_to_db(prompt, response, temperature, top_p):
58
+ try:
59
+ #Establishing the connection
60
+ conn = psycopg2.connect(
61
+ database=os.environ.get("DB"), user=os.environ.get("USER"), password=os.environ.get("DB_PASS"), host=os.environ.get("DB_HOST"), port= '5432'
62
+ )
63
+ #Setting auto commit false
64
+ conn.autocommit = True
65
+
66
+ #Creating a cursor object using the cursor() method
67
+ cursor = conn.cursor()
68
+
69
+ # Preparing SQL queries to INSERT a record into the database.
70
+ cursor.execute(f"INSERT INTO breezedata(prompt, response, temperature, top_p) VALUES ('{prompt}', '{response}', {temperature}, {top_p})")
71
+
72
+ # Commit your changes in the database
73
+ conn.commit()
74
+
75
+ # Closing the connection
76
+ conn.close()
77
+ except:
78
+ pass
79
+
80
+
81
+ def refusal_condition(query):
82
+ # 不要再問這些問題啦!
83
+
84
+ query_remove_space = query.replace(' ', '').lower()
85
+ is_including_tw = False
86
+ for x in ['台灣', '台湾', 'taiwan', 'tw', '中華民國', '中华民国']:
87
+ if x in query_remove_space:
88
+ is_including_tw = True
89
+ is_including_cn = False
90
+ for x in ['中國', '中国', 'cn', 'china', '大陸', '內地', '大陆', '内地', '中華人民共和國', '中华人民共和国']:
91
+ if x in query_remove_space:
92
+ is_including_cn = True
93
+ if is_including_tw and is_including_cn:
94
+ return True
95
+
96
+ for x in ['一個中國', '兩岸', '一中原則', '一中政策', '一个中国', '两岸', '一中原则']:
97
+ if x in query_remove_space:
98
+ return True
99
+
100
+ return False
101
 
102
  with gr.Blocks() as demo:
103
  gr.Markdown(DESCRIPTION)
104
 
105
+ system_prompt = gr.Textbox(label='System prompt',
106
+ value=DEFAULT_SYSTEM_PROMPT,
107
+ lines=1)
108
+
109
+ with gr.Accordion(label='Advanced options', open=False):
110
+ max_new_tokens = gr.Slider(
111
+ label='Max new tokens',
112
+ minimum=32,
113
+ maximum=1024,
114
+ step=1,
115
+ value=512,
116
+ )
117
+ temperature = gr.Slider(
118
+ label='Temperature',
119
+ minimum=0.01,
120
+ maximum=1.0,
121
+ step=0.01,
122
+ value=0.01,
123
+ )
124
+ top_p = gr.Slider(
125
+ label='Top-p (nucleus sampling)',
126
+ minimum=0.01,
127
+ maximum=0.99,
128
+ step=0.01,
129
+ value=0.01,
130
+ )
131
+ repetition_penalty = gr.Slider(
132
+ label='Repetition Penalty',
133
+ minimum=0.1,
134
+ maximum=2,
135
+ step=0.01,
136
+ value=1.1,
137
+ )
138
+
139
+ chatbot = gr.Chatbot()
140
+ with gr.Row():
141
+ msg = gr.Textbox(
142
+ container=False,
143
+ show_label=False,
144
+ placeholder='Type a message...',
145
+ scale=10,
146
+ lines=6
147
+ )
148
+ submit_button = gr.Button('Submit',
149
+ variant='primary',
150
+ scale=1,
151
+ min_width=0)
152
+
153
+ with gr.Row():
154
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
155
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
156
+ clear = gr.Button('🗑️ Clear', variant='secondary')
157
+
158
+ saved_input = gr.State()
159
+
160
+
161
+ def user(user_message, history):
162
+ return "", history + [[user_message, None]]
163
+
164
+
165
+ def connect_server(data):
166
+ for _ in range(3):
167
+ s = requests.Session()
168
+ r = s.post(API_URL, headers=HEADERS, json=data, stream=True, timeout=30)
169
+ time.sleep(1)
170
+ if r.status_code == 200:
171
+ return r
172
+ return None
173
+
174
+
175
+ def stream_response_from_server(r):
176
+ # start_time = time.time()
177
+ keep_streaming = True
178
+ for line in r.iter_lines():
179
+ # if time.time() - start_time > MAX_SEC:
180
+ # keep_streaming = False
181
+ # break
182
+
183
+ if line and keep_streaming:
184
+ if r.status_code != 200:
185
+ continue
186
+ json_response = json.loads(line)
187
+
188
+ if "fragment" not in json_response["result"]:
189
+ keep_streaming = False
190
+ break
191
+
192
+ delta = json_response["result"]["fragment"]["data"]["text"]
193
+ yield delta
194
+
195
+ # start_time = time.time()
196
+
197
 
198
+ def bot(history, max_new_tokens, temperature, top_p, system_prompt, repetition_penalty):
199
+ chat_data = []
200
+ system_prompt = system_prompt.strip()
201
+ if system_prompt:
202
+ chat_data.append({"role": "system", "content": system_prompt})
203
+ for user_msg, assistant_msg in history:
204
+ chat_data.append({"role": "user", "content": user_msg if user_msg is not None else ''})
205
+ chat_data.append({"role": "assistant", "content": assistant_msg if assistant_msg is not None else ''})
206
+
207
+ message = tokenizer.apply_chat_template(chat_data, tokenize=False)
208
+ message = message[3:] # remove SOT token
209
+
210
+ if len(message) > MAX_INPUT_LENGTH:
211
+ raise Exception()
212
+
213
+ response = '[ERROR]'
214
+ if refusal_condition(history[-1][0]):
215
+ history = [['[安全拒答啟動]', '[安全拒答啟動] 請清除再開啟對話']]
216
+ response = '[REFUSAL]'
217
+ yield history
218
+ else:
219
+ data = {
220
+ "model_type": "breeze-7b-instruct-v10",
221
+ "prompt": str(message),
222
+ "parameters": {
223
+ "temperature": float(temperature),
224
+ "top_p": float(top_p),
225
+ "max_new_tokens": int(max_new_tokens),
226
+ "repetition_penalty": float(repetition_penalty),
227
+
228
+ "num_beams":1, # w/o beam search
229
+ "typical_p":0.99,
230
+ "top_k":61952, # w/o top_k
231
+ "do_sample": True,
232
+ "min_length":1,
233
+ }
234
+ }
235
+
236
+ r = connect_server(data)
237
+ if r is not None:
238
+ for delta in stream_response_from_server(r):
239
+ if history[-1][1] is None:
240
+ history[-1][1] = ''
241
+ history[-1][1] += delta
242
+ yield history
243
+
244
+ if history[-1][1].endswith('</s>'):
245
+ history[-1][1] = history[-1][1][:-4]
246
+ yield history
247
+
248
+ response = history[-1][1]
249
+
250
+ if refusal_condition(history[-1][1]):
251
+ history[-1][1] = history[-1][1] + '\n\n**[免責聲明: Breeze-7B-Instruct 和 Breeze-7B-Instruct-64k 並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。]**'
252
+ yield history
253
+ else:
254
+ del history[-1]
255
+ yield history
256
+
257
+ print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(message), response=repr(history[-1][1])))
258
+ insert_to_db(message, response, float(temperature), float(top_p))
259
+
260
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
261
+ fn=bot,
262
+ inputs=[
263
+ chatbot,
264
+ max_new_tokens,
265
+ temperature,
266
+ top_p,
267
+ system_prompt,
268
+ repetition_penalty,
269
+ ],
270
+ outputs=chatbot
271
+ )
272
+ submit_button.click(
273
+ user, [msg, chatbot], [msg, chatbot], queue=False
274
+ ).then(
275
+ fn=bot,
276
+ inputs=[
277
+ chatbot,
278
+ max_new_tokens,
279
+ temperature,
280
+ top_p,
281
+ system_prompt,
282
+ repetition_penalty,
283
+ ],
284
+ outputs=chatbot
285
+ )
286
+
287
+
288
+ def delete_prev_fn(
289
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
290
+ try:
291
+ message, _ = history.pop()
292
+ except IndexError:
293
+ message = ''
294
+ return history, message or ''
295
+
296
+
297
+ def display_input(message: str,
298
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
299
+ history.append((message, ''))
300
+ return history
301
+
302
+ retry_button.click(
303
+ fn=delete_prev_fn,
304
+ inputs=chatbot,
305
+ outputs=[chatbot, saved_input],
306
+ api_name=False,
307
+ queue=False,
308
+ ).then(
309
+ fn=display_input,
310
+ inputs=[saved_input, chatbot],
311
+ outputs=chatbot,
312
+ api_name=False,
313
+ queue=False,
314
+ ).then(
315
+ fn=bot,
316
+ inputs=[
317
+ chatbot,
318
+ max_new_tokens,
319
+ temperature,
320
+ top_p,
321
+ system_prompt,
322
+ repetition_penalty,
323
+ ],
324
+ outputs=chatbot,
325
+ )
326
+
327
+ undo_button.click(
328
+ fn=delete_prev_fn,
329
+ inputs=chatbot,
330
+ outputs=[chatbot, saved_input],
331
+ api_name=False,
332
+ queue=False,
333
+ ).then(
334
+ fn=lambda x: x,
335
+ inputs=[saved_input],
336
+ outputs=msg,
337
+ api_name=False,
338
+ queue=False,
339
+ )
340
+
341
+ clear.click(lambda: None, None, chatbot, queue=False)
342
+
343
+ gr.Markdown(LICENSE)
344
 
345
+ demo.queue(concurrency_count=2, max_size=128)
346
+ demo.launch()