djstrong commited on
Commit
358ae35
1 Parent(s): 626f638
Files changed (1) hide show
  1. app.py +33 -28
app.py CHANGED
@@ -74,9 +74,37 @@ model = AutoModelForCausalLM.from_pretrained(
74
  )
75
 
76
  @spaces.GPU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
78
  repetition_penalty=float(repetition_penalty)
79
- print('LLL', message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p)
80
  # Format history with a given chat template
81
  if CHAT_TEMPLATE == "ChatML":
82
  stop_tokens = ["<|endoftext|>", "<|im_end|>"]
@@ -103,33 +131,10 @@ def predict(message, history, system_prompt, temperature, max_new_tokens, top_k,
103
  raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
104
  print(instruction)
105
 
106
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
107
- enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
108
- input_ids, attention_mask = enc.input_ids, enc.attention_mask
109
-
110
- if input_ids.shape[1] > CONTEXT_LENGTH:
111
- input_ids = input_ids[:, -CONTEXT_LENGTH:]
112
-
113
- generate_kwargs = dict(
114
- {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
115
- streamer=streamer,
116
- do_sample=True if temperature else False,
117
- temperature=temperature,
118
- max_new_tokens=max_new_tokens,
119
- top_k=top_k,
120
- repetition_penalty=repetition_penalty,
121
- top_p=top_p
122
- )
123
- t = Thread(target=model.generate, kwargs=generate_kwargs)
124
- t.start()
125
- outputs = []
126
- for new_token in streamer:
127
- outputs.append(new_token)
128
- if new_token in stop_tokens:
129
- break
130
- yield "".join(outputs)
131
 
132
- send_discord(instruction, "".join(outputs))
133
 
134
 
135
  hfapi = HfApi()
@@ -145,7 +150,7 @@ def predict(message, history, system_prompt, temperature, max_new_tokens, top_k,
145
  'repetition_penalty':repetition_penalty,
146
  'top_p':top_p,
147
  'instruction':instruction,
148
- 'output':"".join(outputs),
149
  'precision': 'auto '+str(model.dtype),
150
  }
151
  hfapi.upload_file(
 
74
  )
75
 
76
  @spaces.GPU()
77
+ def generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
78
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
79
+ enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
80
+ input_ids, attention_mask = enc.input_ids, enc.attention_mask
81
+
82
+ if input_ids.shape[1] > CONTEXT_LENGTH:
83
+ input_ids = input_ids[:, -CONTEXT_LENGTH:]
84
+
85
+ generate_kwargs = dict(
86
+ {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
87
+ streamer=streamer,
88
+ do_sample=True if temperature else False,
89
+ temperature=temperature,
90
+ max_new_tokens=max_new_tokens,
91
+ top_k=top_k,
92
+ repetition_penalty=repetition_penalty,
93
+ top_p=top_p
94
+ )
95
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
96
+ t.start()
97
+ outputs = []
98
+ for new_token in streamer:
99
+ outputs.append(new_token)
100
+ if new_token in stop_tokens:
101
+ break
102
+ yield "".join(outputs)
103
+
104
+
105
  def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
106
  repetition_penalty=float(repetition_penalty)
107
+ print('LLL', [message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p])
108
  # Format history with a given chat template
109
  if CHAT_TEMPLATE == "ChatML":
110
  stop_tokens = ["<|endoftext|>", "<|im_end|>"]
 
131
  raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
132
  print(instruction)
133
 
134
+ for output_text in generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
135
+ yield output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ send_discord(instruction, output_text)
138
 
139
 
140
  hfapi = HfApi()
 
150
  'repetition_penalty':repetition_penalty,
151
  'top_p':top_p,
152
  'instruction':instruction,
153
+ 'output':output_text,
154
  'precision': 'auto '+str(model.dtype),
155
  }
156
  hfapi.upload_file(