yuchenlin commited on
Commit
e55bd08
1 Parent(s): 131a07a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -29
app.py CHANGED
@@ -1,22 +1,42 @@
1
- @spaces.GPU
2
- def generate(
3
- message: str,
4
- chat_history: list[tuple[str, str]],
5
- system_prompt: str,
6
- max_new_tokens: int = 1024,
7
- temperature: float = 0.6,
8
- top_p: float = 0.9,
9
- top_k: int = 50,
10
- repetition_penalty: float = 1.2,
11
- ) -> Iterator[str]:
12
- conversation = []
13
- if system_prompt:
14
- conversation.append({"role": "system", "content": system_prompt})
15
- for user, assistant in chat_history:
16
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
17
- conversation.append({"role": "user", "content": message})
18
-
19
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
21
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
22
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -24,20 +44,40 @@ def generate(
24
 
25
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
26
  generate_kwargs = dict(
27
- {"input_ids": input_ids},
28
  streamer=streamer,
29
- max_new_tokens=max_new_tokens,
30
  do_sample=True,
31
  top_p=top_p,
32
- top_k=top_k,
33
  temperature=temperature,
34
- num_beams=1,
35
  repetition_penalty=repetition_penalty,
36
  )
37
- t = Thread(target=model.generate, kwargs=generate_kwargs)
38
- t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- outputs = []
41
- for text in streamer:
42
- outputs.append(text)
43
- yield "".join(outputs)
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ import spaces
4
+ from threading import Thread
5
+
6
+ # Load model and tokenizer
7
+ model_name = "Magpie-Align/MagpieLM-4B-Chat-v0.1"
8
+
9
+ device = "cuda" # the device to load the model onto
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype="auto"
14
+ )
15
+ model.to(device)
16
+
17
+ MAX_INPUT_TOKEN_LENGTH = 4096 # You may need to adjust this value
18
+
19
+ @spaces.GPU(enable_queue=True)
20
+ def respond(
21
+ message,
22
+ history: list[tuple[str, str]],
23
+ system_message,
24
+ max_tokens=2048,
25
+ temperature=0.6,
26
+ top_p=0.9,
27
+ repetition_penalty=1.0,
28
+ ):
29
+ messages = [{"role": "system", "content": system_message}]
30
+
31
+ for val in history:
32
+ if val[0]:
33
+ messages.append({"role": "user", "content": val[0]})
34
+ if val[1]:
35
+ messages.append({"role": "assistant", "content": val[1]})
36
+
37
+ messages.append({"role": "user", "content": message})
38
+
39
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
40
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
41
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
42
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
44
 
45
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
46
  generate_kwargs = dict(
47
+ input_ids=input_ids,
48
  streamer=streamer,
49
+ max_new_tokens=max_tokens,
50
  do_sample=True,
51
  top_p=top_p,
 
52
  temperature=temperature,
 
53
  repetition_penalty=repetition_penalty,
54
  )
55
+
56
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
57
+ thread.start()
58
+
59
+ def stream():
60
+ for text in streamer:
61
+ yield text
62
+
63
+ return stream()
64
+
65
+ demo = gr.ChatInterface(
66
+ respond,
67
+ additional_inputs=[
68
+ gr.Textbox(value="You are Magpie, a friendly Chatbot.", label="System message"),
69
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
70
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
71
+ gr.Slider(
72
+ minimum=0.1,
73
+ maximum=1.0,
74
+ value=0.9,
75
+ step=0.05,
76
+ label="Top-p (nucleus sampling)",
77
+ ),
78
+ gr.Slider(minimum=0.5, maximum=1.5, value=1.0, step=0.1, label="Repetition Penalty"),
79
+ ],
80
+ )
81
 
82
+ if __name__ == "__main__":
83
+ demo.launch(share=True)