jon-tow commited on
Commit
3708810
1 Parent(s): 277f540

feat(app): add actual init demo app

Browse files
Files changed (2) hide show
  1. app.py +82 -36
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,22 +1,28 @@
1
- """
2
- Model by @duyphung for @carperai
3
- Dumb Simple Gradio by @jon-tow
4
- """
5
  from string import Template
 
6
 
7
  import torch
8
  import gradio as gr
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
 
12
- tokenizer = AutoTokenizer.from_pretrained("CarperAI/vicuna-13b-fine-tuned-rlhf")
13
- model = AutoModelForCausalLM.from_pretrained(
14
  "CarperAI/vicuna-13b-fine-tuned-rlhf",
15
- torch_dtype=torch.bfloat16,
 
 
 
 
 
 
 
 
16
  )
17
  model.cuda()
18
  max_context_length = model.config.max_position_embeddings
19
- max_new_tokens = 256
20
 
21
 
22
  prompt_template = Template("""\
@@ -26,44 +32,70 @@ prompt_template = Template("""\
26
 
27
 
28
  def bot(history):
 
29
  history = history or []
30
-
31
  # Hack to inject prompt formatting into the history
32
  prompt_history = []
33
  for human, bot in history:
 
 
 
34
  prompt_history.append(
35
  prompt_template.substitute(
36
  human=human, bot=bot if bot is not None else "")
37
  )
38
 
39
- prompt = "\n\n".join(prompt_history)
40
- prompt = prompt.rstrip()
41
- inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
42
- # Use only the most recent context up to the maximum context length with room left over
 
43
  # for the max new tokens
44
- inputs = {k: v[:, -max_context_length + max_new_tokens:] for k, v in inputs.items()}
45
- inputs_length = inputs['input_ids'].shape[1]
 
 
 
 
 
 
 
46
 
47
  # Generate the response
48
- tokens = model.generate(
49
- **inputs,
50
- # Only allow the model to generate up to 512 tokens
51
  max_new_tokens=max_new_tokens,
52
- num_return_sequences=1,
53
  do_sample=True,
 
 
54
  temperature=1.0,
55
- top_p=1.0,
56
  )
57
- # Strip the initial prompt
58
- tokens = tokens[:, inputs_length:]
59
-
60
- # Process response
61
- response = tokenizer.decode(tokens[0], skip_special_tokens=True)
62
- response = response.split("###")[0].strip()
63
 
64
- # Add the response to the history
65
- history[-1][1] = response
66
- return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  def user(user_message, history):
@@ -71,14 +103,28 @@ def user(user_message, history):
71
 
72
 
73
  with gr.Blocks() as demo:
74
- gr.Markdown("""Vicuna-13B RLHF Chatbot""")
 
 
 
75
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=512)
76
- msg = gr.Textbox()
77
- clear = gr.Button("Clear")
78
  state = gr.State([])
79
-
80
- msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
 
 
 
 
 
 
 
 
 
 
81
  bot, chatbot, chatbot)
82
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
83
 
 
84
  demo.launch(share=True)
 
1
+ import os
 
 
 
2
  from string import Template
3
+ from threading import Thread
4
 
5
  import torch
6
  import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
 
9
 
10
+ auth_token = os.environ.get("HUGGINGFACE_TOKEN")
11
+ tokenizer = AutoTokenizer.from_pretrained(
12
  "CarperAI/vicuna-13b-fine-tuned-rlhf",
13
+ use_auth_token=auth_token if auth_token else True,
14
+ )
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ "CarperAI/vicuna-13b-fine-tuned-rlhf-fp16",
17
+ torch_dtype=torch.float16,
18
+ device_map="auto",
19
+ offload_folder="./offload",
20
+ low_cpu_mem_usage=True, # Not required for demo but leave for now
21
+ use_auth_token=auth_token if auth_token else True,
22
  )
23
  model.cuda()
24
  max_context_length = model.config.max_position_embeddings
25
+ max_new_tokens = 500
26
 
27
 
28
  prompt_template = Template("""\
 
32
 
33
 
34
  def bot(history):
35
+ # print(f"History:\n`{history}`")
36
  history = history or []
 
37
  # Hack to inject prompt formatting into the history
38
  prompt_history = []
39
  for human, bot in history:
40
+ if bot is not None:
41
+ bot = bot.replace("<br>", "\n")
42
+ bot = bot.rstrip()
43
  prompt_history.append(
44
  prompt_template.substitute(
45
  human=human, bot=bot if bot is not None else "")
46
  )
47
 
48
+ messages = "\n\n".join(prompt_history)
49
+ messages = messages.rstrip()
50
+ # print(f"Messages:\n{messages}")
51
+
52
+ # Use only the most recent context up to the maximum context length with room left over
53
  # for the max new tokens
54
+ inputs = tokenizer(messages, return_tensors='pt').to('cuda')
55
+ inputs = {k: v[:, -max_context_length + max_new_tokens:]
56
+ for k, v in inputs.items()}
57
+ if inputs.get("token_type_ids", None) is not None:
58
+ inputs.pop("token_type_ids")
59
+ # print(f"Inputs: {inputs}")
60
+ streamer = TextIteratorStreamer(
61
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
62
+ )
63
 
64
  # Generate the response
65
+ generate_kwargs = dict(
66
+ inputs,
67
+ streamer=streamer,
68
  max_new_tokens=max_new_tokens,
 
69
  do_sample=True,
70
+ top_p=0.95,
71
+ top_k=1000,
72
  temperature=1.0,
73
+ num_beams=1,
74
  )
 
 
 
 
 
 
75
 
76
+ # print(f"Generating with kwargs: {generate_kwargs}")
77
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
78
+ thread.start()
79
+
80
+ partial_text = ""
81
+ for new_text in streamer:
82
+ # Process out the prompt separator. NOTE: we should tune with special tokens for this
83
+ new_text = new_text.replace("<br>", "\n")
84
+ # print(f"New text: `{new_text}`")
85
+ if "###" in new_text:
86
+ new_text = new_text.split("###")[0]
87
+ partial_text += new_text.strip()
88
+ history[-1][1] = partial_text
89
+ break
90
+ else:
91
+ # Filter empty trailing whitespaces
92
+ if new_text.isspace():
93
+ new_text = new_text.strip()
94
+ partial_text += new_text
95
+ history[-1][1] = partial_text
96
+ yield history
97
+
98
+ return partial_text
99
 
100
 
101
  def user(user_message, history):
 
103
 
104
 
105
  with gr.Blocks() as demo:
106
+ gr.Markdown("Chat-RLHF by CarperAI")
107
+ gr.HTML("<a href='https://huggingface.co/CarperAI/vicuna-13b-fine-tuned-rlhf'><code>CarperAI/vicuna-13b-fine-tuned-rlhf</a>")
108
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/CarperAI/chat-rlhf?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
109
+
110
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=512)
 
 
111
  state = gr.State([])
112
+ with gr.Row():
113
+ with gr.Column():
114
+ msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box",
115
+ show_label=False).style(container=False)
116
+ with gr.Column():
117
+ with gr.Row():
118
+ submit = gr.Button("Submit")
119
+ stop = gr.Button("Stop")
120
+ clear = gr.Button("Clear")
121
+ submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then(
122
+ bot, chatbot, chatbot)
123
+ submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then(
124
  bot, chatbot, chatbot)
125
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[
126
+ submit_event, submit_click_event], queue=False)
127
+ clear.click(lambda: None, None, chatbot, queue=True)
128
 
129
+ demo.queue(max_size=32, concurrency_count=2)
130
  demo.launch(share=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
 
1
  torch
2
- transformers @ git+https://github.com/huggingface/transformers@c612628045822f909020f7eb6784c79700813eda
 
1
+ accelerate
2
  torch
3
+ transformers>=4.28.0,<4.29.0