Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,391 Bytes
a4db55a 0d5be9b a4db55a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from contextlib import nullcontext
from random import shuffle
import torch
from llama_cpp import Llama
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase
def generate(
model: PreTrainedModel | Llama,
tokenizer: PreTrainedTokenizerBase,
prompt="",
temperature=0.5,
top_p=0.95,
top_k=45,
repetition_penalty=1.17,
max_new_tokens=128,
autocast_gen=lambda: torch.autocast("cpu", enabled=False),
**kwargs,
):
if isinstance(model, Llama):
result = model.create_completion(
prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_new_tokens,
repeat_penalty=repetition_penalty or 1,
)
return prompt + result["choices"][0]["text"]
torch.cuda.empty_cache()
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(next(model.parameters()).device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
**kwargs,
)
with torch.no_grad(), autocast_gen():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
torch.cuda.empty_cache()
return output
def tag_gen(
text_model,
tokenizer,
prompt,
prompt_tags,
len_target,
black_list,
temperature=0.5,
top_p=0.95,
top_k=100,
max_new_tokens=256,
max_retry=5,
):
prev_len = 0
retry = max_retry
llm_gen = ""
while True:
llm_gen = generate(
model=text_model,
tokenizer=tokenizer,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=None,
max_new_tokens=max_new_tokens,
stream_output=False,
autocast_gen=lambda: (
torch.autocast("cuda") if torch.cuda.is_available() else nullcontext()
),
prompt_lookup_num_tokens=10,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
llm_gen = llm_gen.replace("</s>", "").replace("<s>", "")
extra = llm_gen.split("<|input_end|>")[-1].strip().strip(",")
extra_tokens = list(
set(
[
tok.strip()
for tok in extra.split(",")
if tok.strip() not in black_list
]
)
)
llm_gen = llm_gen.replace(extra, ", ".join(extra_tokens))
yield llm_gen, extra_tokens
if len(prompt_tags) + len(extra_tokens) < len_target:
if len(extra_tokens) == prev_len and prev_len > 0:
if retry < 0:
break
retry -= 1
shuffle(extra_tokens)
retry = max_retry
prev_len = len(extra_tokens)
prompt = llm_gen.strip().replace(" <|", " <|")
else:
break
yield llm_gen, extra_tokens
|