Kohaku-Blueleaf commited on
Commit
a4db55a
1 Parent(s): 5140369

first commit

Browse files
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from time import time_ns
3
+
4
+ import gradio as gr
5
+ import torch
6
+ import requests as rq
7
+ from llama_cpp import Llama, LLAMA_SPLIT_MODE_NONE
8
+ from transformers import LlamaForCausalLM, LlamaTokenizer
9
+
10
+ from kgen.generate import tag_gen
11
+ from kgen.metainfo import SPECIAL, TARGET
12
+
13
+
14
+ MODEL_PATH = "KBlueLeaf/DanTagGen"
15
+
16
+
17
+ @torch.no_grad()
18
+ def get_result(
19
+ text_model: LlamaForCausalLM,
20
+ tokenizer: LlamaTokenizer,
21
+ rating: str = "",
22
+ artist: str = "",
23
+ characters: str = "",
24
+ copyrights: str = "",
25
+ target: str = "long",
26
+ special_tags: list[str] = ["1girl"],
27
+ general: str = "",
28
+ aspect_ratio: float = 0.0,
29
+ blacklist: str = "",
30
+ escape_bracket: bool = False,
31
+ temperature: float = 1.35,
32
+ ):
33
+ start = time_ns()
34
+ print("=" * 50, "\n")
35
+ # Use LLM to predict possible summary
36
+ # This prompt allow model itself to make request longer based on what it learned
37
+ # Which will be better for preference sim and pref-sum contrastive scorer
38
+ prompt = f"""
39
+ rating: {rating or '<|empty|>'}
40
+ artist: {artist.strip() or '<|empty|>'}
41
+ characters: {characters.strip() or '<|empty|>'}
42
+ copyrights: {copyrights.strip() or '<|empty|>'}
43
+ aspect ratio: {f"{aspect_ratio:.1f}" or '<|empty|>'}
44
+ target: {'<|' + target + '|>' if target else '<|long|>'}
45
+ general: {", ".join(special_tags)}, {general.strip().strip(",")}<|input_end|>
46
+ """.strip()
47
+
48
+ artist = artist.strip().strip(",").replace("_", " ")
49
+ characters = characters.strip().strip(",").replace("_", " ")
50
+ copyrights = copyrights.strip().strip(",").replace("_", " ")
51
+ special_tags = [tag.strip().replace("_", " ") for tag in special_tags]
52
+ general = general.strip().strip(",")
53
+ black_list = set(
54
+ [tag.strip().replace("_", " ") for tag in blacklist.strip().split(",")]
55
+ )
56
+
57
+ prompt_tags = special_tags + general.strip().strip(",").split(",")
58
+ len_target = TARGET[target]
59
+ llm_gen = ""
60
+
61
+ for llm_gen, extra_tokens in tag_gen(
62
+ text_model,
63
+ tokenizer,
64
+ prompt,
65
+ prompt_tags,
66
+ len_target,
67
+ black_list,
68
+ temperature=temperature,
69
+ top_p=0.95,
70
+ top_k=100,
71
+ max_new_tokens=256,
72
+ max_retry=5,
73
+ ):
74
+ yield "", llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s"
75
+ print()
76
+ print("-" * 50)
77
+
78
+ general = f"{general.strip().strip(',')}, {','.join(extra_tokens)}"
79
+ tags = general.strip().split(",")
80
+ tags = [tag.strip() for tag in tags if tag.strip()]
81
+ special = special_tags + [tag for tag in tags if tag in SPECIAL]
82
+ tags = [tag for tag in tags if tag not in special]
83
+
84
+ final_prompt = ", ".join(special)
85
+ if characters:
86
+ final_prompt += f", \n\n{characters}"
87
+ if copyrights:
88
+ final_prompt += ", "
89
+ if not characters:
90
+ final_prompt += "\n\n"
91
+ final_prompt += copyrights
92
+ if artist:
93
+ final_prompt += f", \n\n{artist}"
94
+ final_prompt += f""", \n\n{', '.join(tags)},
95
+
96
+ masterpiece, newest, absurdres, {rating}"""
97
+
98
+ print(final_prompt)
99
+ print("=" * 50)
100
+
101
+ if escape_bracket:
102
+ final_prompt = (
103
+ final_prompt.replace("[", "\\[")
104
+ .replace("]", "\\]")
105
+ .replace("(", "\\(")
106
+ .replace(")", "\\)")
107
+ )
108
+
109
+ yield final_prompt, llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s | Total general tags: {len(special+tags)}"
110
+
111
+
112
+ if __name__ == "__main__":
113
+ tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
114
+ if not os.path.isfile("./model.gguf"):
115
+ data = rq.get("https://huggingface.co/KBlueLeaf/DanTagGen/resolve/main/ggml-model-Q6_K.gguf").content
116
+ with open("./model.gguf", "wb") as f:
117
+ f.write(data)
118
+ text_model = Llama(
119
+ "./model.gguf",
120
+ n_ctx=384,
121
+ verbose=False,
122
+ )
123
+
124
+ def wrapper(
125
+ rating: str,
126
+ artist: str,
127
+ characters: str,
128
+ copyrights: str,
129
+ target: str,
130
+ special_tags: list[str],
131
+ general: str,
132
+ width: float,
133
+ height: float,
134
+ blacklist: str,
135
+ escape_bracket: bool,
136
+ temperature: float = 1.35,
137
+ ):
138
+ yield from get_result(
139
+ text_model,
140
+ tokenizer,
141
+ rating,
142
+ artist,
143
+ characters,
144
+ copyrights,
145
+ target,
146
+ special_tags,
147
+ general,
148
+ width / height,
149
+ blacklist,
150
+ escape_bracket,
151
+ temperature,
152
+ )
153
+
154
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
155
+ with gr.Row():
156
+ with gr.Column(scale=4):
157
+ with gr.Row():
158
+ with gr.Column(scale=2):
159
+ rating = gr.Radio(
160
+ ["safe", "sensitive", "nsfw", "nsfw, explicit"],
161
+ label="Rating",
162
+ )
163
+ special_tags = gr.Dropdown(
164
+ SPECIAL,
165
+ value=["1girl"],
166
+ label="Special tags",
167
+ multiselect=True,
168
+ )
169
+ characters = gr.Textbox(label="Characters")
170
+ copyrights = gr.Textbox(label="Copyrights(Series)")
171
+ artist = gr.Textbox(label="Artist")
172
+ target = gr.Radio(
173
+ ["very_short", "short", "long", "very_long"],
174
+ label="Target length",
175
+ )
176
+ with gr.Column(scale=2):
177
+ general = gr.TextArea(label="Input your general tags")
178
+ black_list = gr.TextArea(
179
+ label="tag Black list (seperated by comma)"
180
+ )
181
+ with gr.Row():
182
+ width = gr.Slider(
183
+ value=1024,
184
+ minimum=256,
185
+ maximum=4096,
186
+ step=32,
187
+ label="Width",
188
+ )
189
+ height = gr.Slider(
190
+ value=1024,
191
+ minimum=256,
192
+ maximum=4096,
193
+ step=32,
194
+ label="Height",
195
+ )
196
+ with gr.Row():
197
+ temperature = gr.Slider(
198
+ value=1.35,
199
+ minimum=0.1,
200
+ maximum=2,
201
+ step=0.05,
202
+ label="Temperature",
203
+ )
204
+ escape_bracket = gr.Checkbox(
205
+ value=False,
206
+ label="Escape bracket",
207
+ )
208
+ submit = gr.Button("Submit")
209
+ with gr.Column(scale=3):
210
+ formated_result = gr.TextArea(
211
+ label="Final output", lines=14, show_copy_button=True
212
+ )
213
+ llm_result = gr.TextArea(label="LLM output", lines=10)
214
+ cost_time = gr.Markdown()
215
+ submit.click(
216
+ wrapper,
217
+ inputs=[
218
+ rating,
219
+ artist,
220
+ characters,
221
+ copyrights,
222
+ target,
223
+ special_tags,
224
+ general,
225
+ width,
226
+ height,
227
+ black_list,
228
+ temperature,
229
+ escape_bracket,
230
+ ],
231
+ outputs=[
232
+ formated_result,
233
+ llm_result,
234
+ cost_time,
235
+ ],
236
+ show_progress=True,
237
+ )
238
+
239
+ demo.launch()
kgen/__init__.py ADDED
File without changes
kgen/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (146 Bytes). View file
 
kgen/__pycache__/generate.cpython-311.pyc ADDED
Binary file (4.93 kB). View file
 
kgen/__pycache__/metainfo.cpython-311.pyc ADDED
Binary file (483 Bytes). View file
 
kgen/formatter.py ADDED
File without changes
kgen/generate.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ from random import shuffle
3
+
4
+ import torch
5
+ from llama_cpp import Llama
6
+ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase
7
+
8
+
9
+ def generate(
10
+ model: PreTrainedModel | Llama,
11
+ tokenizer: PreTrainedTokenizerBase,
12
+ prompt="",
13
+ temperature=0.5,
14
+ top_p=0.95,
15
+ top_k=45,
16
+ repetition_penalty=1.17,
17
+ max_new_tokens=128,
18
+ autocast_gen=lambda: torch.autocast("cpu", enabled=False),
19
+ **kwargs,
20
+ ):
21
+ if isinstance(model, Llama):
22
+ result = model.create_completion(
23
+ prompt,
24
+ temperature=temperature,
25
+ top_p=top_p,
26
+ top_k=top_k,
27
+ max_tokens=max_new_tokens,
28
+ repeat_penalty=repetition_penalty or 1,
29
+ )
30
+ return prompt + result["choices"][0]["text"]
31
+
32
+ torch.cuda.empty_cache()
33
+ inputs = tokenizer(prompt, return_tensors="pt")
34
+ input_ids = inputs["input_ids"].to(next(model.parameters()).device)
35
+ generation_config = GenerationConfig(
36
+ temperature=temperature,
37
+ top_p=top_p,
38
+ top_k=top_k,
39
+ repetition_penalty=repetition_penalty,
40
+ do_sample=True,
41
+ **kwargs,
42
+ )
43
+ with torch.no_grad(), autocast_gen():
44
+ generation_output = model.generate(
45
+ input_ids=input_ids,
46
+ generation_config=generation_config,
47
+ return_dict_in_generate=True,
48
+ output_scores=True,
49
+ max_new_tokens=max_new_tokens,
50
+ )
51
+ s = generation_output.sequences[0]
52
+ output = tokenizer.decode(s)
53
+
54
+ torch.cuda.empty_cache()
55
+ return output
56
+
57
+
58
+ def tag_gen(
59
+ text_model,
60
+ tokenizer,
61
+ prompt,
62
+ prompt_tags,
63
+ len_target,
64
+ black_list,
65
+ temperature=0.5,
66
+ top_p=0.95,
67
+ top_k=100,
68
+ max_new_tokens=256,
69
+ max_retry=5,
70
+ ):
71
+ prev_len = 0
72
+ retry = max_retry
73
+ llm_gen = ""
74
+
75
+ while True:
76
+ llm_gen = generate(
77
+ model=text_model,
78
+ tokenizer=tokenizer,
79
+ prompt=prompt,
80
+ temperature=temperature,
81
+ top_p=top_p,
82
+ top_k=top_k,
83
+ repetition_penalty=None,
84
+ max_new_tokens=max_new_tokens,
85
+ stream_output=False,
86
+ autocast_gen=nullcontext,
87
+ prompt_lookup_num_tokens=10,
88
+ pad_token_id=tokenizer.eos_token_id,
89
+ eos_token_id=tokenizer.eos_token_id,
90
+ )
91
+ llm_gen = llm_gen.replace("</s>", "").replace("<s>", "")
92
+ extra = llm_gen.split("<|input_end|>")[-1].strip().strip(",")
93
+ extra_tokens = list(
94
+ set(
95
+ [
96
+ tok.strip()
97
+ for tok in extra.split(",")
98
+ if tok.strip() not in black_list
99
+ ]
100
+ )
101
+ )
102
+ llm_gen = llm_gen.replace(extra, ", ".join(extra_tokens))
103
+
104
+ yield llm_gen, extra_tokens
105
+
106
+ if len(prompt_tags) + len(extra_tokens) < len_target:
107
+ if len(extra_tokens) == prev_len and prev_len > 0:
108
+ if retry < 0:
109
+ break
110
+ retry -= 1
111
+ shuffle(extra_tokens)
112
+ retry = max_retry
113
+ prev_len = len(extra_tokens)
114
+ prompt = llm_gen.strip().replace(" <|", " <|")
115
+ else:
116
+ break
117
+ yield llm_gen, extra_tokens
kgen/metainfo.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SPECIAL = [
2
+ "1girl",
3
+ "2girls",
4
+ "3girls",
5
+ "4girls",
6
+ "5girls",
7
+ "6+girls",
8
+ "multiple_girls",
9
+ "1boy",
10
+ "2boys",
11
+ "3boys",
12
+ "4boys",
13
+ "5boys",
14
+ "6+boys",
15
+ "multiple_boys",
16
+ "male_focus",
17
+ "1other",
18
+ "2others",
19
+ "3others",
20
+ "4others",
21
+ "5others",
22
+ "6+others",
23
+ "multiple_others",
24
+ ]
25
+ TARGET = {
26
+ "very_short": 10,
27
+ "short": 20,
28
+ "long": 40,
29
+ "very_long": 60,
30
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ llama-cpp-python
3
+ gradio
4
+ requests