Spaces:
Running
Running
import time | |
import torch | |
from typing import Callable | |
from pathlib import Path | |
from dartrs.v2 import ( | |
V2Model, | |
MixtralModel, | |
MistralModel, | |
compose_prompt, | |
LengthTag, | |
AspectRatioTag, | |
RatingTag, | |
IdentityTag, | |
) | |
from dartrs.dartrs import DartTokenizer | |
from dartrs.utils import get_generation_config | |
import gradio as gr | |
from gradio.components import Component | |
try: | |
from output import UpsamplingOutput | |
except: | |
from .output import UpsamplingOutput | |
V2_ALL_MODELS = { | |
"dart-v2-moe-sft": { | |
"repo": "p1atdev/dart-v2-moe-sft", | |
"type": "sft", | |
"class": MixtralModel, | |
}, | |
"dart-v2-sft": { | |
"repo": "p1atdev/dart-v2-sft", | |
"type": "sft", | |
"class": MistralModel, | |
}, | |
} | |
def prepare_models(model_config: dict): | |
model_name = model_config["repo"] | |
tokenizer = DartTokenizer.from_pretrained(model_name) | |
model = model_config["class"].from_pretrained(model_name) | |
return { | |
"tokenizer": tokenizer, | |
"model": model, | |
} | |
def normalize_tags(tokenizer: DartTokenizer, tags: str): | |
"""Just remove unk tokens.""" | |
return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"]) | |
def generate_tags( | |
model: V2Model, | |
tokenizer: DartTokenizer, | |
prompt: str, | |
ban_token_ids: list[int], | |
): | |
output = model.generate( | |
get_generation_config( | |
prompt, | |
tokenizer=tokenizer, | |
temperature=1, | |
top_p=0.9, | |
top_k=100, | |
max_new_tokens=256, | |
ban_token_ids=ban_token_ids, | |
), | |
) | |
return output | |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5): | |
return ( | |
[f"1{noun}"] | |
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)] | |
+ [f"{maximum+1}+{noun}s"] | |
) | |
PEOPLE_TAGS = ( | |
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"] | |
) | |
def gen_prompt_text(output: UpsamplingOutput): | |
# separate people tags (e.g. 1girl) | |
people_tags = [] | |
other_general_tags = [] | |
for tag in output.general_tags.split(","): | |
tag = tag.strip() | |
if tag in PEOPLE_TAGS: | |
people_tags.append(tag) | |
else: | |
other_general_tags.append(tag) | |
return ", ".join( | |
[ | |
part.strip() | |
for part in [ | |
*people_tags, | |
output.character_tags, | |
output.copyright_tags, | |
*other_general_tags, | |
output.upsampled_tags, | |
output.rating_tag, | |
] | |
if part.strip() != "" | |
] | |
) | |
def elapsed_time_format(elapsed_time: float) -> str: | |
return f"Elapsed: {elapsed_time:.2f} seconds" | |
def parse_upsampling_output( | |
upsampler: Callable[..., UpsamplingOutput], | |
): | |
def _parse_upsampling_output(*args) -> tuple[str, str, dict]: | |
output = upsampler(*args) | |
return ( | |
gen_prompt_text(output), | |
elapsed_time_format(output.elapsed_time), | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
) | |
return _parse_upsampling_output | |
class V2UI: | |
model_name: str | None = None | |
model: V2Model | |
tokenizer: DartTokenizer | |
input_components: list[Component] = [] | |
generate_btn: gr.Button | |
def on_generate( | |
self, | |
model_name: str, | |
copyright_tags: str, | |
character_tags: str, | |
general_tags: str, | |
rating_tag: RatingTag, | |
aspect_ratio_tag: AspectRatioTag, | |
length_tag: LengthTag, | |
identity_tag: IdentityTag, | |
ban_tags: str, | |
*args, | |
) -> UpsamplingOutput: | |
if self.model_name is None or self.model_name != model_name: | |
models = prepare_models(V2_ALL_MODELS[model_name]) | |
self.model = models["model"] | |
self.tokenizer = models["tokenizer"] | |
self.model_name = model_name | |
# normalize tags | |
# copyright_tags = normalize_tags(self.tokenizer, copyright_tags) | |
# character_tags = normalize_tags(self.tokenizer, character_tags) | |
# general_tags = normalize_tags(self.tokenizer, general_tags) | |
ban_token_ids = self.tokenizer.encode(ban_tags.strip()) | |
prompt = compose_prompt( | |
prompt=general_tags, | |
copyright=copyright_tags, | |
character=character_tags, | |
rating=rating_tag, | |
aspect_ratio=aspect_ratio_tag, | |
length=length_tag, | |
identity=identity_tag, | |
) | |
start = time.time() | |
upsampled_tags = generate_tags( | |
self.model, | |
self.tokenizer, | |
prompt, | |
ban_token_ids, | |
) | |
elapsed_time = time.time() - start | |
return UpsamplingOutput( | |
upsampled_tags=upsampled_tags, | |
copyright_tags=copyright_tags, | |
character_tags=character_tags, | |
general_tags=general_tags, | |
rating_tag=rating_tag, | |
aspect_ratio_tag=aspect_ratio_tag, | |
length_tag=length_tag, | |
identity_tag=identity_tag, | |
elapsed_time=elapsed_time, | |
) | |
def parse_upsampling_output_simple(upsampler: UpsamplingOutput): | |
return gen_prompt_text(upsampler) | |
v2 = V2UI() | |
def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "", | |
general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square", | |
length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"): | |
raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags, | |
rating, aspect_ratio, length, identity, ban_tags)) | |
return raw_prompt | |
def load_dict_from_csv(filename): | |
dict = {} | |
if not Path(filename).exists(): | |
if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename)) | |
else: return dict | |
try: | |
with open(filename, 'r', encoding="utf-8") as f: | |
lines = f.readlines() | |
except Exception: | |
print(f"Failed to open dictionary file: {filename}") | |
return dict | |
for line in lines: | |
parts = line.strip().split(',') | |
dict[parts[0]] = parts[1] | |
return dict | |
anime_series_dict = load_dict_from_csv('character_series_dict.csv') | |
def select_random_character(series: str, character: str): | |
from random import seed, randrange | |
seed() | |
character_list = list(anime_series_dict.keys()) | |
character = character_list[randrange(len(character_list) - 1)] | |
series = anime_series_dict.get(character.split(",")[0].strip(), "") | |
return series, character | |
def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw", | |
aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax", | |
ban_tags: str = "censored", model: str = "dart-v2-moe-sft"): | |
if copyright == "" and character == "": | |
copyright, character = select_random_character("", "") | |
raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating, | |
aspect_ratio, length, identity, ban_tags) | |
return raw_prompt, copyright, character |