Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Miaoran000
commited on
Commit
•
3cf286c
1
Parent(s):
2aa9a75
minor fix
Browse files- src/backend/model_operations.py +30 -13
src/backend/model_operations.py
CHANGED
@@ -13,7 +13,7 @@ from sentence_transformers import CrossEncoder
|
|
13 |
import litellm
|
14 |
# from litellm import completion
|
15 |
from tqdm import tqdm
|
16 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
|
17 |
# from accelerate import PartialState
|
18 |
# from accelerate.inference import prepare_pippy
|
19 |
import torch
|
@@ -272,9 +272,8 @@ class SummaryGenerator:
|
|
272 |
# Using HF API or download checkpoints
|
273 |
elif self.local_model is None:
|
274 |
try: # try use HuggingFace API
|
275 |
-
|
276 |
response = litellm.completion(
|
277 |
-
model='command-r-plus' if 'command' in self.
|
278 |
messages=[{"role": "system", "content": system_prompt},
|
279 |
{"role": "user", "content": user_prompt}],
|
280 |
temperature=0.0,
|
@@ -286,7 +285,7 @@ class SummaryGenerator:
|
|
286 |
except: # fail to call api. run it locally.
|
287 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
|
288 |
print("Tokenizer loaded")
|
289 |
-
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
|
290 |
print("Local model loaded")
|
291 |
|
292 |
# Using local model
|
@@ -294,15 +293,33 @@ class SummaryGenerator:
|
|
294 |
messages=[
|
295 |
{"role": "system", "content": system_prompt}, # gemma-1.1 does not accept system role
|
296 |
{"role": "user", "content": user_prompt}
|
297 |
-
]
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
return result
|
307 |
|
308 |
def _compute_avg_length(self):
|
|
|
13 |
import litellm
|
14 |
# from litellm import completion
|
15 |
from tqdm import tqdm
|
16 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig, pipeline
|
17 |
# from accelerate import PartialState
|
18 |
# from accelerate.inference import prepare_pippy
|
19 |
import torch
|
|
|
272 |
# Using HF API or download checkpoints
|
273 |
elif self.local_model is None:
|
274 |
try: # try use HuggingFace API
|
|
|
275 |
response = litellm.completion(
|
276 |
+
model='command-r-plus' if 'command' in self.model_id else self.model_id,
|
277 |
messages=[{"role": "system", "content": system_prompt},
|
278 |
{"role": "user", "content": user_prompt}],
|
279 |
temperature=0.0,
|
|
|
285 |
except: # fail to call api. run it locally.
|
286 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
|
287 |
print("Tokenizer loaded")
|
288 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto", cache_dir='/home/paperspace/cache')
|
289 |
print("Local model loaded")
|
290 |
|
291 |
# Using local model
|
|
|
293 |
messages=[
|
294 |
{"role": "system", "content": system_prompt}, # gemma-1.1 does not accept system role
|
295 |
{"role": "user", "content": user_prompt}
|
296 |
+
]
|
297 |
+
try: # some models support pipeline
|
298 |
+
pipe = pipeline(
|
299 |
+
"text-generation",
|
300 |
+
model=self.local_model,
|
301 |
+
tokenizer=self.tokenizer,
|
302 |
+
)
|
303 |
+
|
304 |
+
generation_args = {
|
305 |
+
"max_new_tokens": 250,
|
306 |
+
"return_full_text": False,
|
307 |
+
"temperature": 0.0,
|
308 |
+
"do_sample": False,
|
309 |
+
}
|
310 |
+
|
311 |
+
output = pipe(messages, **generation_args)
|
312 |
+
result = output[0]['generated_text']
|
313 |
+
print(result)
|
314 |
+
except:
|
315 |
+
prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
|
316 |
+
print(prompt)
|
317 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").to('cuda')
|
318 |
+
with torch.no_grad():
|
319 |
+
outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id)
|
320 |
+
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
321 |
+
result = result.replace(prompt[0], '')
|
322 |
+
print(result)
|
323 |
return result
|
324 |
|
325 |
def _compute_avg_length(self):
|