Spaces:
Sleeping
Sleeping
migrating to zero gpu
Browse files- .gitattributes +35 -0
- README.md +1 -3
- app.py +181 -625
- config.py +105 -0
- lora.toml +0 -28
- lora_diffusers.py +0 -478
- requirements.txt +6 -7
- style.css +4 -30
- utils.py +173 -1
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -4,13 +4,11 @@ emoji: 🌍
|
|
4 |
colorFrom: gray
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
license: mit
|
10 |
pinned: false
|
11 |
suggested_hardware: a10g-small
|
12 |
-
duplicated_from: hysts/SD-XL
|
13 |
-
hf_oauth: true
|
14 |
---
|
15 |
|
16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.20.0
|
8 |
app_file: app.py
|
9 |
license: mit
|
10 |
pinned: false
|
11 |
suggested_hardware: a10g-small
|
|
|
|
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,244 +1,71 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
import os
|
6 |
-
import random
|
7 |
import gc
|
8 |
-
import toml
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
-
import utils
|
12 |
import torch
|
13 |
import json
|
14 |
-
import
|
15 |
-
import
|
16 |
-
import
|
17 |
-
|
18 |
-
from
|
19 |
from datetime import datetime
|
20 |
-
from PIL import PngImagePlugin
|
21 |
-
import gradio_user_history as gr_user_history
|
22 |
-
from huggingface_hub import hf_hub_download
|
23 |
-
from safetensors.torch import load_file
|
24 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
25 |
-
from lora_diffusers import LoRANetwork, create_network_from_weights
|
26 |
from diffusers.models import AutoencoderKL
|
27 |
-
from diffusers import
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
DPMSolverSinglestepScheduler,
|
32 |
-
KDPM2DiscreteScheduler,
|
33 |
-
EulerDiscreteScheduler,
|
34 |
-
EulerAncestralDiscreteScheduler,
|
35 |
-
HeunDiscreteScheduler,
|
36 |
-
LMSDiscreteScheduler,
|
37 |
-
DDIMScheduler,
|
38 |
-
DEISMultistepScheduler,
|
39 |
-
UniPCMultistepScheduler,
|
40 |
-
)
|
41 |
|
42 |
DESCRIPTION = "Animagine XL 3.0"
|
43 |
if not torch.cuda.is_available():
|
44 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
|
45 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
46 |
-
MAX_SEED = np.iinfo(np.int32).max
|
47 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
48 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
49 |
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
|
50 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
|
51 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
|
52 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
|
|
53 |
|
54 |
-
MODEL = os.getenv(
|
|
|
|
|
|
|
55 |
|
56 |
torch.backends.cudnn.deterministic = True
|
57 |
torch.backends.cudnn.benchmark = False
|
58 |
|
59 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
60 |
|
61 |
-
|
|
|
62 |
vae = AutoencoderKL.from_pretrained(
|
63 |
"madebyollin/sdxl-vae-fp16-fix",
|
64 |
torch_dtype=torch.float16,
|
65 |
)
|
66 |
-
pipeline =
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
pipe = pipeline(
|
69 |
-
|
70 |
vae=vae,
|
71 |
torch_dtype=torch.float16,
|
72 |
custom_pipeline="lpw_stable_diffusion_xl",
|
73 |
use_safetensors=True,
|
|
|
74 |
use_auth_token=HF_TOKEN,
|
75 |
variant="fp16",
|
76 |
)
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
else:
|
81 |
-
pipe.to(device)
|
82 |
-
if USE_TORCH_COMPILE:
|
83 |
-
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
84 |
-
else:
|
85 |
-
pipe = None
|
86 |
-
|
87 |
-
|
88 |
-
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
89 |
-
if randomize_seed:
|
90 |
-
seed = random.randint(0, MAX_SEED)
|
91 |
-
return seed
|
92 |
-
|
93 |
-
|
94 |
-
def seed_everything(seed):
|
95 |
-
torch.manual_seed(seed)
|
96 |
-
torch.cuda.manual_seed_all(seed)
|
97 |
-
np.random.seed(seed)
|
98 |
-
generator = torch.Generator()
|
99 |
-
generator.manual_seed(seed)
|
100 |
-
return generator
|
101 |
-
|
102 |
-
|
103 |
-
def get_image_path(base_path: str):
|
104 |
-
extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
|
105 |
-
for ext in extensions:
|
106 |
-
image_path = base_path + ext
|
107 |
-
if os.path.exists(image_path):
|
108 |
-
return image_path
|
109 |
-
return None
|
110 |
-
|
111 |
-
|
112 |
-
def update_selection(selected_state: gr.SelectData):
|
113 |
-
lora_repo = sdxl_loras[selected_state.index]["repo"]
|
114 |
-
lora_weight = sdxl_loras[selected_state.index]["multiplier"]
|
115 |
-
updated_selected_info = f"{lora_repo}"
|
116 |
-
|
117 |
-
return (
|
118 |
-
updated_selected_info,
|
119 |
-
selected_state,
|
120 |
-
lora_weight,
|
121 |
-
)
|
122 |
-
|
123 |
-
|
124 |
-
def parse_aspect_ratio(aspect_ratio):
|
125 |
-
if aspect_ratio == "Custom":
|
126 |
-
return None, None
|
127 |
-
width, height = aspect_ratio.split(" x ")
|
128 |
-
return int(width), int(height)
|
129 |
-
|
130 |
-
|
131 |
-
def aspect_ratio_handler(aspect_ratio, custom_width, custom_height):
|
132 |
-
if aspect_ratio == "Custom":
|
133 |
-
return custom_width, custom_height
|
134 |
-
else:
|
135 |
-
width, height = parse_aspect_ratio(aspect_ratio)
|
136 |
-
return width, height
|
137 |
-
|
138 |
-
|
139 |
-
def create_network(text_encoders, unet, state_dict, multiplier, device):
|
140 |
-
network = create_network_from_weights(
|
141 |
-
text_encoders,
|
142 |
-
unet,
|
143 |
-
state_dict,
|
144 |
-
multiplier,
|
145 |
-
)
|
146 |
-
network.load_state_dict(state_dict)
|
147 |
-
network.to(device, dtype=unet.dtype)
|
148 |
-
network.apply_to(multiplier=multiplier)
|
149 |
-
|
150 |
-
return network
|
151 |
-
|
152 |
-
|
153 |
-
def get_scheduler(scheduler_config, name):
|
154 |
-
scheduler_map = {
|
155 |
-
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
156 |
-
scheduler_config, use_karras_sigmas=True
|
157 |
-
),
|
158 |
-
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
|
159 |
-
scheduler_config, use_karras_sigmas=True
|
160 |
-
),
|
161 |
-
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
162 |
-
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
|
163 |
-
),
|
164 |
-
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
|
165 |
-
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
|
166 |
-
scheduler_config
|
167 |
-
),
|
168 |
-
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
|
169 |
-
}
|
170 |
-
return scheduler_map.get(name, lambda: None)()
|
171 |
-
|
172 |
-
|
173 |
-
def free_memory():
|
174 |
-
torch.cuda.empty_cache()
|
175 |
-
gc.collect()
|
176 |
-
|
177 |
-
|
178 |
-
def preprocess_prompt(
|
179 |
-
style_dict,
|
180 |
-
style_name: str,
|
181 |
-
positive: str,
|
182 |
-
negative: str = "",
|
183 |
-
add_style: bool = True,
|
184 |
-
) -> Tuple[str, str]:
|
185 |
-
p, n = style_dict.get(style_name, style_dict["(None)"])
|
186 |
-
|
187 |
-
if add_style and positive.strip():
|
188 |
-
formatted_positive = p.format(prompt=positive)
|
189 |
-
else:
|
190 |
-
formatted_positive = positive
|
191 |
-
|
192 |
-
combined_negative = n + negative
|
193 |
-
return formatted_positive, combined_negative
|
194 |
-
|
195 |
-
|
196 |
-
def common_upscale(samples, width, height, upscale_method):
|
197 |
-
return torch.nn.functional.interpolate(
|
198 |
-
samples, size=(height, width), mode=upscale_method
|
199 |
-
)
|
200 |
-
|
201 |
-
|
202 |
-
def upscale(samples, upscale_method, scale_by):
|
203 |
-
width = round(samples.shape[3] * scale_by)
|
204 |
-
height = round(samples.shape[2] * scale_by)
|
205 |
-
s = common_upscale(samples, width, height, upscale_method)
|
206 |
-
return s
|
207 |
-
|
208 |
|
209 |
-
def load_and_convert_thumbnail(model_path: str):
|
210 |
-
with safetensors.safe_open(model_path, framework="pt") as f:
|
211 |
-
metadata = f.metadata()
|
212 |
-
if "modelspec.thumbnail" in metadata:
|
213 |
-
base64_data = metadata["modelspec.thumbnail"]
|
214 |
-
prefix, encoded = base64_data.split(",", 1)
|
215 |
-
image_data = base64.b64decode(encoded)
|
216 |
-
image = PIL.Image.open(BytesIO(image_data))
|
217 |
-
return image
|
218 |
-
return None
|
219 |
-
|
220 |
-
def load_wildcard_files(wildcard_dir):
|
221 |
-
wildcard_files = {}
|
222 |
-
for file in os.listdir(wildcard_dir):
|
223 |
-
if file.endswith(".txt"):
|
224 |
-
key = f"__{file.split('.')[0]}__" # Create a key like __character__
|
225 |
-
wildcard_files[key] = os.path.join(wildcard_dir, file)
|
226 |
-
return wildcard_files
|
227 |
-
|
228 |
-
def get_random_line_from_file(file_path):
|
229 |
-
with open(file_path, 'r') as file:
|
230 |
-
lines = file.readlines()
|
231 |
-
if not lines:
|
232 |
-
return ""
|
233 |
-
return random.choice(lines).strip()
|
234 |
-
|
235 |
-
def add_wildcard(prompt, wildcard_files):
|
236 |
-
for key, file_path in wildcard_files.items():
|
237 |
-
if key in prompt:
|
238 |
-
wildcard_line = get_random_line_from_file(file_path)
|
239 |
-
prompt = prompt.replace(key, wildcard_line)
|
240 |
-
return prompt
|
241 |
|
|
|
242 |
def generate(
|
243 |
prompt: str,
|
244 |
negative_prompt: str = "",
|
@@ -247,90 +74,40 @@ def generate(
|
|
247 |
custom_height: int = 1024,
|
248 |
guidance_scale: float = 7.0,
|
249 |
num_inference_steps: int = 28,
|
250 |
-
use_lora: bool = False,
|
251 |
-
lora_weight: float = 1.0,
|
252 |
-
selected_state: str = "",
|
253 |
sampler: str = "Euler a",
|
254 |
aspect_ratio_selector: str = "896 x 1152",
|
255 |
style_selector: str = "(None)",
|
256 |
quality_selector: str = "Standard",
|
257 |
use_upscaler: bool = False,
|
258 |
-
upscaler_strength: float = 0.
|
259 |
upscale_by: float = 1.5,
|
260 |
add_quality_tags: bool = True,
|
261 |
-
profile: gr.OAuthProfile | None = None,
|
262 |
progress=gr.Progress(track_tqdm=True),
|
263 |
-
) ->
|
264 |
-
generator = seed_everything(seed)
|
265 |
|
266 |
-
|
267 |
-
network_state = {"current_lora": None, "multiplier": None}
|
268 |
-
|
269 |
-
width, height = aspect_ratio_handler(
|
270 |
aspect_ratio_selector,
|
271 |
custom_width,
|
272 |
custom_height,
|
273 |
)
|
274 |
|
275 |
-
prompt = add_wildcard(prompt, wildcard_files)
|
276 |
|
277 |
-
|
278 |
-
prompt, negative_prompt = preprocess_prompt(
|
279 |
quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
|
280 |
)
|
281 |
-
prompt, negative_prompt = preprocess_prompt(
|
282 |
styles, style_selector, prompt, negative_prompt
|
283 |
)
|
284 |
|
285 |
-
|
286 |
-
width = width - (width % 8)
|
287 |
-
if height % 8 != 0:
|
288 |
-
height = height - (height % 8)
|
289 |
-
|
290 |
-
if use_lora:
|
291 |
-
if not selected_state:
|
292 |
-
raise Exception("You must Select a LoRA")
|
293 |
-
repo_name = sdxl_loras[selected_state.index]["repo"]
|
294 |
-
full_path_lora = saved_names[selected_state.index]
|
295 |
-
weight_name = sdxl_loras[selected_state.index]["weights"]
|
296 |
-
|
297 |
-
lora_sd = load_file(full_path_lora)
|
298 |
-
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
299 |
-
|
300 |
-
if network_state["current_lora"] != repo_name:
|
301 |
-
network = create_network(
|
302 |
-
text_encoders,
|
303 |
-
pipe.unet,
|
304 |
-
lora_sd,
|
305 |
-
lora_weight,
|
306 |
-
device,
|
307 |
-
)
|
308 |
-
network_state["current_lora"] = repo_name
|
309 |
-
network_state["multiplier"] = lora_weight
|
310 |
-
elif network_state["multiplier"] != lora_weight:
|
311 |
-
network = create_network(
|
312 |
-
text_encoders,
|
313 |
-
pipe.unet,
|
314 |
-
lora_sd,
|
315 |
-
lora_weight,
|
316 |
-
device,
|
317 |
-
)
|
318 |
-
network_state["multiplier"] = lora_weight
|
319 |
-
else:
|
320 |
-
if network:
|
321 |
-
network.unapply_to()
|
322 |
-
network = None
|
323 |
-
network_state = {
|
324 |
-
"current_lora": None,
|
325 |
-
"multiplier": None,
|
326 |
-
}
|
327 |
|
328 |
backup_scheduler = pipe.scheduler
|
329 |
-
pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
|
330 |
|
331 |
if use_upscaler:
|
332 |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
333 |
-
|
334 |
metadata = {
|
335 |
"prompt": prompt,
|
336 |
"negative_prompt": negative_prompt,
|
@@ -344,11 +121,6 @@ def generate(
|
|
344 |
"quality_tags": quality_selector,
|
345 |
}
|
346 |
|
347 |
-
if use_lora:
|
348 |
-
metadata["use_lora"] = {"selected_lora": repo_name, "multiplier": lora_weight}
|
349 |
-
else:
|
350 |
-
metadata["use_lora"] = None
|
351 |
-
|
352 |
if use_upscaler:
|
353 |
new_width = int(width * upscale_by)
|
354 |
new_height = int(height * upscale_by)
|
@@ -360,8 +132,7 @@ def generate(
|
|
360 |
}
|
361 |
else:
|
362 |
metadata["use_upscaler"] = None
|
363 |
-
|
364 |
-
print(json.dumps(metadata, indent=4))
|
365 |
|
366 |
try:
|
367 |
if use_upscaler:
|
@@ -375,8 +146,8 @@ def generate(
|
|
375 |
generator=generator,
|
376 |
output_type="latent",
|
377 |
).images
|
378 |
-
upscaled_latents = upscale(latents, "nearest-exact", upscale_by)
|
379 |
-
|
380 |
prompt=prompt,
|
381 |
negative_prompt=negative_prompt,
|
382 |
image=upscaled_latents,
|
@@ -385,9 +156,9 @@ def generate(
|
|
385 |
strength=upscaler_strength,
|
386 |
generator=generator,
|
387 |
output_type="pil",
|
388 |
-
).images
|
389 |
else:
|
390 |
-
|
391 |
prompt=prompt,
|
392 |
negative_prompt=negative_prompt,
|
393 |
width=width,
|
@@ -396,194 +167,38 @@ def generate(
|
|
396 |
num_inference_steps=num_inference_steps,
|
397 |
generator=generator,
|
398 |
output_type="pil",
|
399 |
-
).images
|
400 |
-
if network:
|
401 |
-
network.unapply_to()
|
402 |
-
network = None
|
403 |
-
if profile is not None:
|
404 |
-
gr_user_history.save_image(
|
405 |
-
label=prompt,
|
406 |
-
image=image,
|
407 |
-
profile=profile,
|
408 |
-
metadata=metadata,
|
409 |
-
)
|
410 |
-
if image and IS_COLAB:
|
411 |
-
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
412 |
-
output_directory = "./outputs"
|
413 |
-
os.makedirs(output_directory, exist_ok=True)
|
414 |
-
filename = f"image_{current_time}.png"
|
415 |
-
filepath = os.path.join(output_directory, filename)
|
416 |
-
|
417 |
-
# Convert metadata to a string and save as a text chunk in the PNG
|
418 |
-
metadata_str = json.dumps(metadata)
|
419 |
-
info = PngImagePlugin.PngInfo()
|
420 |
-
info.add_text("metadata", metadata_str)
|
421 |
-
image.save(filepath, "PNG", pnginfo=info)
|
422 |
-
print(f"Image saved as {filepath} with metadata")
|
423 |
|
424 |
-
|
|
|
|
|
|
|
425 |
|
|
|
426 |
except Exception as e:
|
427 |
-
|
428 |
raise
|
429 |
finally:
|
430 |
-
if network:
|
431 |
-
network.unapply_to()
|
432 |
-
network = None
|
433 |
-
if use_lora:
|
434 |
-
del lora_sd, text_encoders
|
435 |
if use_upscaler:
|
436 |
del upscaler_pipe
|
437 |
pipe.scheduler = backup_scheduler
|
438 |
-
free_memory()
|
439 |
-
|
440 |
-
|
441 |
-
examples = [
|
442 |
-
"1girl, arima kana, oshi no ko, solo, idol, idol clothes, one eye closed, red shirt, black skirt, black headwear, gloves, stage light, singing, open mouth, crowd, smile, pointing at viewer",
|
443 |
-
"1girl, c.c., code geass, white shirt, long sleeves, turtleneck, sitting, looking at viewer, eating, pizza, plate, fork, knife, table, chair, table, restaurant, cinematic angle, cinematic lighting",
|
444 |
-
"1girl, sakurauchi riko, \(love live\), queen hat, noble coat, red coat, noble shirt, sitting, crossed legs, gentle smile, parted lips, throne, cinematic angle",
|
445 |
-
"1girl, amiya \(arknights\), arknights, dirty face, outstretched hand, close-up, cinematic angle, foreshortening, dark, dark background",
|
446 |
-
"A boy and a girl, Emiya Shirou and Artoria Pendragon from fate series, having their breakfast in the dining room. Emiya Shirou wears white t-shirt and jacket. Artoria Pendragon wears white dress with blue neck ribbon. Rice, soup, and minced meats are served on the table. They look at each other while smiling happily",
|
447 |
-
]
|
448 |
-
|
449 |
-
quality_prompt_list = [
|
450 |
-
{
|
451 |
-
"name": "(None)",
|
452 |
-
"prompt": "{prompt}",
|
453 |
-
"negative_prompt": "nsfw, lowres, ",
|
454 |
-
},
|
455 |
-
{
|
456 |
-
"name": "Standard",
|
457 |
-
"prompt": "{prompt}, masterpiece, best quality",
|
458 |
-
"negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, ",
|
459 |
-
},
|
460 |
-
{
|
461 |
-
"name": "Light",
|
462 |
-
"prompt": "{prompt}, (masterpiece), best quality, perfect face",
|
463 |
-
"negative_prompt": "nsfw, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn, ",
|
464 |
-
},
|
465 |
-
{
|
466 |
-
"name": "Heavy",
|
467 |
-
"prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), illustration, disheveled hair, perfect composition, moist skin, intricate details, earrings",
|
468 |
-
"negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, ",
|
469 |
-
},
|
470 |
-
]
|
471 |
-
|
472 |
-
sampler_list = [
|
473 |
-
"DPM++ 2M Karras",
|
474 |
-
"DPM++ SDE Karras",
|
475 |
-
"DPM++ 2M SDE Karras",
|
476 |
-
"Euler",
|
477 |
-
"Euler a",
|
478 |
-
"DDIM",
|
479 |
-
]
|
480 |
-
|
481 |
-
aspect_ratios = [
|
482 |
-
"1024 x 1024",
|
483 |
-
"1152 x 896",
|
484 |
-
"896 x 1152",
|
485 |
-
"1216 x 832",
|
486 |
-
"832 x 1216",
|
487 |
-
"1344 x 768",
|
488 |
-
"768 x 1344",
|
489 |
-
"1536 x 640",
|
490 |
-
"640 x 1536",
|
491 |
-
"Custom",
|
492 |
-
]
|
493 |
-
|
494 |
-
style_list = [
|
495 |
-
{
|
496 |
-
"name": "(None)",
|
497 |
-
"prompt": "{prompt}",
|
498 |
-
"negative_prompt": "",
|
499 |
-
},
|
500 |
-
{
|
501 |
-
"name": "Cinematic",
|
502 |
-
"prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
503 |
-
"negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
504 |
-
},
|
505 |
-
{
|
506 |
-
"name": "Photographic",
|
507 |
-
"prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
508 |
-
"negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
509 |
-
},
|
510 |
-
{
|
511 |
-
"name": "Anime",
|
512 |
-
"prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
|
513 |
-
"negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
|
514 |
-
},
|
515 |
-
{
|
516 |
-
"name": "Manga",
|
517 |
-
"prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
518 |
-
"negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
519 |
-
},
|
520 |
-
{
|
521 |
-
"name": "Digital Art",
|
522 |
-
"prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
523 |
-
"negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
|
524 |
-
},
|
525 |
-
{
|
526 |
-
"name": "Pixel art",
|
527 |
-
"prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
528 |
-
"negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
529 |
-
},
|
530 |
-
{
|
531 |
-
"name": "Fantasy art",
|
532 |
-
"prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
533 |
-
"negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
534 |
-
},
|
535 |
-
{
|
536 |
-
"name": "Neonpunk",
|
537 |
-
"prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
538 |
-
"negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
539 |
-
},
|
540 |
-
{
|
541 |
-
"name": "3D Model",
|
542 |
-
"prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
543 |
-
"negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
|
544 |
-
},
|
545 |
-
]
|
546 |
-
|
547 |
-
thumbnail_cache = {}
|
548 |
|
549 |
-
with open("lora.toml", "r") as file:
|
550 |
-
data = toml.load(file)
|
551 |
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
if model_path not in thumbnail_cache:
|
559 |
-
thumbnail_image = load_and_convert_thumbnail(model_path)
|
560 |
-
thumbnail_cache[model_path] = thumbnail_image
|
561 |
-
else:
|
562 |
-
thumbnail_image = thumbnail_cache[model_path]
|
563 |
-
|
564 |
-
sdxl_loras.append(
|
565 |
-
{
|
566 |
-
"image": thumbnail_image, # Storing the PIL image object
|
567 |
-
"title": item["title"],
|
568 |
-
"repo": item["repo"],
|
569 |
-
"weights": item["weights"],
|
570 |
-
"multiplier": item.get("multiplier", "1.0"),
|
571 |
-
}
|
572 |
-
)
|
573 |
|
574 |
-
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
575 |
quality_prompt = {
|
576 |
-
k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list
|
577 |
}
|
578 |
|
579 |
-
|
580 |
-
# hf_hub_download(item["repo"], item["weights"], token=HF_TOKEN)
|
581 |
-
# for item in sdxl_loras
|
582 |
-
# ]
|
583 |
-
|
584 |
-
wildcard_files = load_wildcard_files("wildcard")
|
585 |
|
586 |
-
with gr.Blocks(css="style.css"
|
587 |
title = gr.HTML(
|
588 |
f"""<h1><span>{DESCRIPTION}</span></h1>""",
|
589 |
elem_id="title",
|
@@ -592,187 +207,131 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
592 |
f"""Gradio demo for [cagliostrolab/animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0)""",
|
593 |
elem_id="subtitle",
|
594 |
)
|
595 |
-
gr.Markdown(
|
596 |
-
f"""Prompting is a bit different in this iteration, we train the model like this:
|
597 |
-
```
|
598 |
-
1girl/1boy, character name, from what series, everything else in any order.
|
599 |
-
```
|
600 |
-
Prompting Tips
|
601 |
-
```
|
602 |
-
1. Quality Tags: `masterpiece, best quality, high quality, normal quality, worst quality, low quality`
|
603 |
-
2. Year Tags: `oldest, early, mid, late, newest`
|
604 |
-
3. Rating tags: `rating: general, rating: sensitive, rating: questionable, rating: explicit, nsfw`
|
605 |
-
4. Escape character: `character name \(series\)`
|
606 |
-
5. Recommended settings: `Euler a, cfg 5-7, 25-28 steps`
|
607 |
-
6. It's recommended to use the exact danbooru tags for more accurate result
|
608 |
-
7. To use character wildcard, add this syntax to the prompt `__character__`.
|
609 |
-
```
|
610 |
-
""",
|
611 |
-
elem_id="subtitle",
|
612 |
-
)
|
613 |
gr.DuplicateButton(
|
614 |
value="Duplicate Space for private use",
|
615 |
elem_id="duplicate-button",
|
616 |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
617 |
)
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
label="Width",
|
702 |
-
minimum=MIN_IMAGE_SIZE,
|
703 |
-
maximum=MAX_IMAGE_SIZE,
|
704 |
-
step=8,
|
705 |
-
value=1024,
|
706 |
-
)
|
707 |
-
custom_height = gr.Slider(
|
708 |
-
label="Height",
|
709 |
-
minimum=MIN_IMAGE_SIZE,
|
710 |
-
maximum=MAX_IMAGE_SIZE,
|
711 |
-
step=8,
|
712 |
-
value=1024,
|
713 |
-
)
|
714 |
-
with gr.Group():
|
715 |
-
sampler = gr.Dropdown(
|
716 |
-
label="Sampler",
|
717 |
-
choices=sampler_list,
|
718 |
-
interactive=True,
|
719 |
-
value="Euler a",
|
720 |
-
)
|
721 |
-
with gr.Group():
|
722 |
-
seed = gr.Slider(
|
723 |
-
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
|
724 |
-
)
|
725 |
-
|
726 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
727 |
-
with gr.Group():
|
728 |
-
with gr.Row():
|
729 |
-
guidance_scale = gr.Slider(
|
730 |
-
label="Guidance scale",
|
731 |
-
minimum=1,
|
732 |
-
maximum=12,
|
733 |
-
step=0.1,
|
734 |
-
value=7.0,
|
735 |
-
)
|
736 |
-
num_inference_steps = gr.Slider(
|
737 |
-
label="Number of inference steps",
|
738 |
-
minimum=1,
|
739 |
-
maximum=50,
|
740 |
-
step=1,
|
741 |
-
value=28,
|
742 |
-
)
|
743 |
-
|
744 |
-
with gr.Tab("Past Generation"):
|
745 |
-
gr_user_history.render()
|
746 |
-
with gr.Column(scale=3):
|
747 |
-
with gr.Blocks():
|
748 |
-
run_button = gr.Button("Generate", variant="primary")
|
749 |
-
result = gr.Image(label="Result", show_label=False)
|
750 |
-
with gr.Accordion(label="Generation Parameters", open=False):
|
751 |
-
gr_metadata = gr.JSON(label="Metadata", show_label=False)
|
752 |
-
gr.Examples(
|
753 |
-
examples=examples,
|
754 |
-
inputs=prompt,
|
755 |
-
outputs=[result, gr_metadata],
|
756 |
-
fn=generate,
|
757 |
-
cache_examples=CACHE_EXAMPLES,
|
758 |
)
|
759 |
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
)
|
777 |
use_upscaler.change(
|
778 |
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
|
@@ -797,9 +356,6 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
797 |
custom_height,
|
798 |
guidance_scale,
|
799 |
num_inference_steps,
|
800 |
-
use_lora,
|
801 |
-
lora_weight,
|
802 |
-
selected_state,
|
803 |
sampler,
|
804 |
aspect_ratio_selector,
|
805 |
style_selector,
|
@@ -807,11 +363,11 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
807 |
use_upscaler,
|
808 |
upscaler_strength,
|
809 |
upscale_by,
|
810 |
-
add_quality_tags
|
811 |
]
|
812 |
|
813 |
prompt.submit(
|
814 |
-
fn=randomize_seed_fn,
|
815 |
inputs=[seed, randomize_seed],
|
816 |
outputs=seed,
|
817 |
queue=False,
|
@@ -823,7 +379,7 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
823 |
api_name="run",
|
824 |
)
|
825 |
negative_prompt.submit(
|
826 |
-
fn=randomize_seed_fn,
|
827 |
inputs=[seed, randomize_seed],
|
828 |
outputs=seed,
|
829 |
queue=False,
|
@@ -835,7 +391,7 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
835 |
api_name=False,
|
836 |
)
|
837 |
run_button.click(
|
838 |
-
fn=randomize_seed_fn,
|
839 |
inputs=[seed, randomize_seed],
|
840 |
outputs=seed,
|
841 |
queue=False,
|
@@ -846,4 +402,4 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
846 |
outputs=[result, gr_metadata],
|
847 |
api_name=False,
|
848 |
)
|
849 |
-
demo.queue(max_size=
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import gc
|
|
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
|
|
5 |
import torch
|
6 |
import json
|
7 |
+
import spaces
|
8 |
+
import config
|
9 |
+
import utils
|
10 |
+
import logging
|
11 |
+
from PIL import Image, PngImagePlugin
|
12 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from diffusers.models import AutoencoderKL
|
14 |
+
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
15 |
+
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
DESCRIPTION = "Animagine XL 3.0"
|
20 |
if not torch.cuda.is_available():
|
21 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
|
22 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
|
|
23 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
24 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
25 |
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
|
26 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
|
27 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
|
28 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
29 |
+
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
|
30 |
|
31 |
+
MODEL = os.getenv(
|
32 |
+
"MODEL",
|
33 |
+
"https://huggingface.co/cagliostrolab/animagine-xl-3.0/blob/main/animagine-xl-3.0.safetensors",
|
34 |
+
)
|
35 |
|
36 |
torch.backends.cudnn.deterministic = True
|
37 |
torch.backends.cudnn.benchmark = False
|
38 |
|
39 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
40 |
|
41 |
+
|
42 |
+
def load_pipeline(model_name):
|
43 |
vae = AutoencoderKL.from_pretrained(
|
44 |
"madebyollin/sdxl-vae-fp16-fix",
|
45 |
torch_dtype=torch.float16,
|
46 |
)
|
47 |
+
pipeline = (
|
48 |
+
StableDiffusionXLPipeline.from_single_file
|
49 |
+
if MODEL.endswith(".safetensors")
|
50 |
+
else StableDiffusionXLPipeline.from_pretrained
|
51 |
+
)
|
52 |
+
|
53 |
pipe = pipeline(
|
54 |
+
model_name,
|
55 |
vae=vae,
|
56 |
torch_dtype=torch.float16,
|
57 |
custom_pipeline="lpw_stable_diffusion_xl",
|
58 |
use_safetensors=True,
|
59 |
+
add_watermarker=False,
|
60 |
use_auth_token=HF_TOKEN,
|
61 |
variant="fp16",
|
62 |
)
|
63 |
|
64 |
+
pipe.to(device)
|
65 |
+
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
@spaces.GPU
|
69 |
def generate(
|
70 |
prompt: str,
|
71 |
negative_prompt: str = "",
|
|
|
74 |
custom_height: int = 1024,
|
75 |
guidance_scale: float = 7.0,
|
76 |
num_inference_steps: int = 28,
|
|
|
|
|
|
|
77 |
sampler: str = "Euler a",
|
78 |
aspect_ratio_selector: str = "896 x 1152",
|
79 |
style_selector: str = "(None)",
|
80 |
quality_selector: str = "Standard",
|
81 |
use_upscaler: bool = False,
|
82 |
+
upscaler_strength: float = 0.55,
|
83 |
upscale_by: float = 1.5,
|
84 |
add_quality_tags: bool = True,
|
|
|
85 |
progress=gr.Progress(track_tqdm=True),
|
86 |
+
) -> Image:
|
87 |
+
generator = utils.seed_everything(seed)
|
88 |
|
89 |
+
width, height = utils.aspect_ratio_handler(
|
|
|
|
|
|
|
90 |
aspect_ratio_selector,
|
91 |
custom_width,
|
92 |
custom_height,
|
93 |
)
|
94 |
|
95 |
+
prompt = utils.add_wildcard(prompt, wildcard_files)
|
96 |
|
97 |
+
prompt, negative_prompt = utils.preprocess_prompt(
|
|
|
98 |
quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
|
99 |
)
|
100 |
+
prompt, negative_prompt = utils.preprocess_prompt(
|
101 |
styles, style_selector, prompt, negative_prompt
|
102 |
)
|
103 |
|
104 |
+
width, height = utils.preprocess_image_dimensions(width, height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
backup_scheduler = pipe.scheduler
|
107 |
+
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
|
108 |
|
109 |
if use_upscaler:
|
110 |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
|
|
111 |
metadata = {
|
112 |
"prompt": prompt,
|
113 |
"negative_prompt": negative_prompt,
|
|
|
121 |
"quality_tags": quality_selector,
|
122 |
}
|
123 |
|
|
|
|
|
|
|
|
|
|
|
124 |
if use_upscaler:
|
125 |
new_width = int(width * upscale_by)
|
126 |
new_height = int(height * upscale_by)
|
|
|
132 |
}
|
133 |
else:
|
134 |
metadata["use_upscaler"] = None
|
135 |
+
logger.info(json.dumps(metadata, indent=4))
|
|
|
136 |
|
137 |
try:
|
138 |
if use_upscaler:
|
|
|
146 |
generator=generator,
|
147 |
output_type="latent",
|
148 |
).images
|
149 |
+
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
|
150 |
+
images = upscaler_pipe(
|
151 |
prompt=prompt,
|
152 |
negative_prompt=negative_prompt,
|
153 |
image=upscaled_latents,
|
|
|
156 |
strength=upscaler_strength,
|
157 |
generator=generator,
|
158 |
output_type="pil",
|
159 |
+
).images
|
160 |
else:
|
161 |
+
images = pipe(
|
162 |
prompt=prompt,
|
163 |
negative_prompt=negative_prompt,
|
164 |
width=width,
|
|
|
167 |
num_inference_steps=num_inference_steps,
|
168 |
generator=generator,
|
169 |
output_type="pil",
|
170 |
+
).images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
+
if images and IS_COLAB:
|
173 |
+
for image in images:
|
174 |
+
filepath = utils.save_image(image, metadata, OUTPUT_DIR)
|
175 |
+
logger.info(f"Image saved as {filepath} with metadata")
|
176 |
|
177 |
+
return images, metadata
|
178 |
except Exception as e:
|
179 |
+
logger.exception(f"An error occurred: {e}")
|
180 |
raise
|
181 |
finally:
|
|
|
|
|
|
|
|
|
|
|
182 |
if use_upscaler:
|
183 |
del upscaler_pipe
|
184 |
pipe.scheduler = backup_scheduler
|
185 |
+
utils.free_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
|
|
|
|
187 |
|
188 |
+
if torch.cuda.is_available():
|
189 |
+
pipe = load_pipeline(MODEL)
|
190 |
+
logger.info("Loaded on Device!")
|
191 |
+
else:
|
192 |
+
pipe = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.style_list}
|
195 |
quality_prompt = {
|
196 |
+
k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.quality_prompt_list
|
197 |
}
|
198 |
|
199 |
+
wildcard_files = utils.load_wildcard_files("wildcard")
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
+
with gr.Blocks(css="style.css") as demo:
|
202 |
title = gr.HTML(
|
203 |
f"""<h1><span>{DESCRIPTION}</span></h1>""",
|
204 |
elem_id="title",
|
|
|
207 |
f"""Gradio demo for [cagliostrolab/animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0)""",
|
208 |
elem_id="subtitle",
|
209 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
gr.DuplicateButton(
|
211 |
value="Duplicate Space for private use",
|
212 |
elem_id="duplicate-button",
|
213 |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
214 |
)
|
215 |
+
with gr.Group():
|
216 |
+
with gr.Row():
|
217 |
+
prompt = gr.Text(
|
218 |
+
label="Prompt",
|
219 |
+
show_label=False,
|
220 |
+
max_lines=5,
|
221 |
+
placeholder="Enter your prompt",
|
222 |
+
container=False,
|
223 |
+
)
|
224 |
+
run_button = gr.Button(
|
225 |
+
"Generate",
|
226 |
+
variant="primary",
|
227 |
+
scale=0
|
228 |
+
)
|
229 |
+
result = gr.Gallery(
|
230 |
+
label="Result",
|
231 |
+
columns=1,
|
232 |
+
preview=True,
|
233 |
+
show_label=False
|
234 |
+
)
|
235 |
+
with gr.Accordion(label="Advanced Settings", open=False):
|
236 |
+
negative_prompt = gr.Text(
|
237 |
+
label="Negative Prompt",
|
238 |
+
max_lines=5,
|
239 |
+
placeholder="Enter a negative prompt",
|
240 |
+
)
|
241 |
+
with gr.Row():
|
242 |
+
add_quality_tags = gr.Checkbox(
|
243 |
+
label="Add Quality Tags",
|
244 |
+
value=True
|
245 |
+
)
|
246 |
+
quality_selector = gr.Dropdown(
|
247 |
+
label="Quality Tags Presets",
|
248 |
+
interactive=True,
|
249 |
+
choices=list(quality_prompt.keys()),
|
250 |
+
value="Standard",
|
251 |
+
)
|
252 |
+
style_selector = gr.Radio(
|
253 |
+
label="Style Preset",
|
254 |
+
container=True,
|
255 |
+
interactive=True,
|
256 |
+
choices=list(styles.keys()),
|
257 |
+
value="(None)",
|
258 |
+
)
|
259 |
+
aspect_ratio_selector = gr.Radio(
|
260 |
+
label="Aspect Ratio",
|
261 |
+
choices=config.aspect_ratios,
|
262 |
+
value="896 x 1152",
|
263 |
+
container=True,
|
264 |
+
)
|
265 |
+
with gr.Group(visible=False) as custom_resolution:
|
266 |
+
with gr.Row():
|
267 |
+
custom_width = gr.Slider(
|
268 |
+
label="Width",
|
269 |
+
minimum=MIN_IMAGE_SIZE,
|
270 |
+
maximum=MAX_IMAGE_SIZE,
|
271 |
+
step=8,
|
272 |
+
value=1024,
|
273 |
+
)
|
274 |
+
custom_height = gr.Slider(
|
275 |
+
label="Height",
|
276 |
+
minimum=MIN_IMAGE_SIZE,
|
277 |
+
maximum=MAX_IMAGE_SIZE,
|
278 |
+
step=8,
|
279 |
+
value=1024,
|
280 |
+
)
|
281 |
+
use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
|
282 |
+
with gr.Row() as upscaler_row:
|
283 |
+
upscaler_strength = gr.Slider(
|
284 |
+
label="Strength",
|
285 |
+
minimum=0,
|
286 |
+
maximum=1,
|
287 |
+
step=0.05,
|
288 |
+
value=0.55,
|
289 |
+
visible=False,
|
290 |
+
)
|
291 |
+
upscale_by = gr.Slider(
|
292 |
+
label="Upscale by",
|
293 |
+
minimum=1,
|
294 |
+
maximum=1.5,
|
295 |
+
step=0.1,
|
296 |
+
value=1.5,
|
297 |
+
visible=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
)
|
299 |
|
300 |
+
sampler = gr.Dropdown(
|
301 |
+
label="Sampler",
|
302 |
+
choices=config.sampler_list,
|
303 |
+
interactive=True,
|
304 |
+
value="Euler a",
|
305 |
+
)
|
306 |
+
with gr.Row():
|
307 |
+
seed = gr.Slider(
|
308 |
+
label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
|
309 |
+
)
|
310 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
311 |
+
with gr.Group():
|
312 |
+
with gr.Row():
|
313 |
+
guidance_scale = gr.Slider(
|
314 |
+
label="Guidance scale",
|
315 |
+
minimum=1,
|
316 |
+
maximum=12,
|
317 |
+
step=0.1,
|
318 |
+
value=7.0,
|
319 |
+
)
|
320 |
+
num_inference_steps = gr.Slider(
|
321 |
+
label="Number of inference steps",
|
322 |
+
minimum=1,
|
323 |
+
maximum=50,
|
324 |
+
step=1,
|
325 |
+
value=28,
|
326 |
+
)
|
327 |
+
with gr.Accordion(label="Generation Parameters", open=False):
|
328 |
+
gr_metadata = gr.JSON(label="Metadata", show_label=False)
|
329 |
+
gr.Examples(
|
330 |
+
examples=config.examples,
|
331 |
+
inputs=prompt,
|
332 |
+
outputs=[result, gr_metadata],
|
333 |
+
fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
|
334 |
+
cache_examples=CACHE_EXAMPLES,
|
335 |
)
|
336 |
use_upscaler.change(
|
337 |
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
|
|
|
356 |
custom_height,
|
357 |
guidance_scale,
|
358 |
num_inference_steps,
|
|
|
|
|
|
|
359 |
sampler,
|
360 |
aspect_ratio_selector,
|
361 |
style_selector,
|
|
|
363 |
use_upscaler,
|
364 |
upscaler_strength,
|
365 |
upscale_by,
|
366 |
+
add_quality_tags,
|
367 |
]
|
368 |
|
369 |
prompt.submit(
|
370 |
+
fn=utils.randomize_seed_fn,
|
371 |
inputs=[seed, randomize_seed],
|
372 |
outputs=seed,
|
373 |
queue=False,
|
|
|
379 |
api_name="run",
|
380 |
)
|
381 |
negative_prompt.submit(
|
382 |
+
fn=utils.randomize_seed_fn,
|
383 |
inputs=[seed, randomize_seed],
|
384 |
outputs=seed,
|
385 |
queue=False,
|
|
|
391 |
api_name=False,
|
392 |
)
|
393 |
run_button.click(
|
394 |
+
fn=utils.randomize_seed_fn,
|
395 |
inputs=[seed, randomize_seed],
|
396 |
outputs=seed,
|
397 |
queue=False,
|
|
|
402 |
outputs=[result, gr_metadata],
|
403 |
api_name=False,
|
404 |
)
|
405 |
+
demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
|
config.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
examples = [
|
2 |
+
"1girl, arima kana, oshi no ko, solo, idol, idol clothes, one eye closed, red shirt, black skirt, black headwear, gloves, stage light, singing, open mouth, crowd, smile, pointing at viewer",
|
3 |
+
"1girl, c.c., code geass, white shirt, long sleeves, turtleneck, sitting, looking at viewer, eating, pizza, plate, fork, knife, table, chair, table, restaurant, cinematic angle, cinematic lighting",
|
4 |
+
"1girl, sakurauchi riko, \(love live\), queen hat, noble coat, red coat, noble shirt, sitting, crossed legs, gentle smile, parted lips, throne, cinematic angle",
|
5 |
+
"1girl, amiya \(arknights\), arknights, dirty face, outstretched hand, close-up, cinematic angle, foreshortening, dark, dark background",
|
6 |
+
"A boy and a girl, Emiya Shirou and Artoria Pendragon from fate series, having their breakfast in the dining room. Emiya Shirou wears white t-shirt and jacket. Artoria Pendragon wears white dress with blue neck ribbon. Rice, soup, and minced meats are served on the table. They look at each other while smiling happily",
|
7 |
+
]
|
8 |
+
|
9 |
+
quality_prompt_list = [
|
10 |
+
{
|
11 |
+
"name": "(None)",
|
12 |
+
"prompt": "{prompt}",
|
13 |
+
"negative_prompt": "nsfw, lowres, ",
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "Standard",
|
17 |
+
"prompt": "{prompt}, masterpiece, best quality",
|
18 |
+
"negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, ",
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"name": "Light",
|
22 |
+
"prompt": "{prompt}, (masterpiece), best quality, perfect face",
|
23 |
+
"negative_prompt": "nsfw, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn, ",
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "Heavy",
|
27 |
+
"prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), illustration, disheveled hair, perfect composition, moist skin, intricate details, earrings",
|
28 |
+
"negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, ",
|
29 |
+
},
|
30 |
+
]
|
31 |
+
|
32 |
+
sampler_list = [
|
33 |
+
"DPM++ 2M Karras",
|
34 |
+
"DPM++ SDE Karras",
|
35 |
+
"DPM++ 2M SDE Karras",
|
36 |
+
"Euler",
|
37 |
+
"Euler a",
|
38 |
+
"DDIM",
|
39 |
+
]
|
40 |
+
|
41 |
+
aspect_ratios = [
|
42 |
+
"1024 x 1024",
|
43 |
+
"1152 x 896",
|
44 |
+
"896 x 1152",
|
45 |
+
"1216 x 832",
|
46 |
+
"832 x 1216",
|
47 |
+
"1344 x 768",
|
48 |
+
"768 x 1344",
|
49 |
+
"1536 x 640",
|
50 |
+
"640 x 1536",
|
51 |
+
"Custom",
|
52 |
+
]
|
53 |
+
|
54 |
+
style_list = [
|
55 |
+
{
|
56 |
+
"name": "(None)",
|
57 |
+
"prompt": "{prompt}",
|
58 |
+
"negative_prompt": "",
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"name": "Cinematic",
|
62 |
+
"prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
63 |
+
"negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"name": "Photographic",
|
67 |
+
"prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
68 |
+
"negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"name": "Anime",
|
72 |
+
"prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
|
73 |
+
"negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"name": "Manga",
|
77 |
+
"prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
78 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"name": "Digital Art",
|
82 |
+
"prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
83 |
+
"negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"name": "Pixel art",
|
87 |
+
"prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
88 |
+
"negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"name": "Fantasy art",
|
92 |
+
"prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
93 |
+
"negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"name": "Neonpunk",
|
97 |
+
"prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
98 |
+
"negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"name": "3D Model",
|
102 |
+
"prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
103 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
|
104 |
+
},
|
105 |
+
]
|
lora.toml
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
[[data]]
|
2 |
-
title = "Style Enhancer XL"
|
3 |
-
repo = "Linaqruf/style-enhancer-xl-lora"
|
4 |
-
weights = "style-enhancer-xl.safetensors"
|
5 |
-
multiplier = 0.6
|
6 |
-
[[data]]
|
7 |
-
title = "Anime Detailer XL"
|
8 |
-
repo = "Linaqruf/anime-detailer-xl-lora"
|
9 |
-
weights = "anime-detailer-xl.safetensors"
|
10 |
-
multiplier = 2.0
|
11 |
-
|
12 |
-
[[data]]
|
13 |
-
title = "Sketch Style XL"
|
14 |
-
repo = "Linaqruf/sketch-style-xl-lora"
|
15 |
-
weights = "sketch-style-xl.safetensors"
|
16 |
-
multiplier = 0.6
|
17 |
-
|
18 |
-
[[data]]
|
19 |
-
title = "Pastel Style XL 2.0"
|
20 |
-
repo = "Linaqruf/pastel-style-xl-lora"
|
21 |
-
weights = "pastel-style-xl-v2.safetensors"
|
22 |
-
multiplier = 0.6
|
23 |
-
|
24 |
-
[[data]]
|
25 |
-
title = "Anime Nouveau XL"
|
26 |
-
repo = "Linaqruf/anime-nouveau-xl-lora"
|
27 |
-
weights = "anime-nouveau-xl.safetensors"
|
28 |
-
multiplier = 0.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lora_diffusers.py
DELETED
@@ -1,478 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
LoRA module for Diffusers
|
3 |
-
==========================
|
4 |
-
|
5 |
-
This file works independently and is designed to operate with Diffusers.
|
6 |
-
|
7 |
-
Credits
|
8 |
-
-------
|
9 |
-
- Modified from: https://github.com/vladmandic/automatic/blob/master/modules/lora_diffusers.py
|
10 |
-
- Originally from: https://github.com/kohya-ss/sd-scripts/blob/sdxl/networks/lora_diffusers.py
|
11 |
-
"""
|
12 |
-
|
13 |
-
import bisect
|
14 |
-
import math
|
15 |
-
import random
|
16 |
-
from typing import Any, Dict, List, Mapping, Optional, Union
|
17 |
-
from diffusers import UNet2DConditionModel
|
18 |
-
import numpy as np
|
19 |
-
from tqdm import tqdm
|
20 |
-
from transformers import CLIPTextModel
|
21 |
-
import torch
|
22 |
-
|
23 |
-
|
24 |
-
def make_unet_conversion_map() -> Dict[str, str]:
|
25 |
-
unet_conversion_map_layer = []
|
26 |
-
|
27 |
-
for i in range(3): # num_blocks is 3 in sdxl
|
28 |
-
# loop over downblocks/upblocks
|
29 |
-
for j in range(2):
|
30 |
-
# loop over resnets/attentions for downblocks
|
31 |
-
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
32 |
-
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
33 |
-
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
34 |
-
|
35 |
-
if i < 3:
|
36 |
-
# no attention layers in down_blocks.3
|
37 |
-
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
38 |
-
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
39 |
-
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
40 |
-
|
41 |
-
for j in range(3):
|
42 |
-
# loop over resnets/attentions for upblocks
|
43 |
-
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
44 |
-
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
45 |
-
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
46 |
-
|
47 |
-
# if i > 0: commentout for sdxl
|
48 |
-
# no attention layers in up_blocks.0
|
49 |
-
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
50 |
-
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
51 |
-
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
52 |
-
|
53 |
-
if i < 3:
|
54 |
-
# no downsample in down_blocks.3
|
55 |
-
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
56 |
-
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
57 |
-
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
58 |
-
|
59 |
-
# no upsample in up_blocks.3
|
60 |
-
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
61 |
-
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
62 |
-
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
63 |
-
|
64 |
-
hf_mid_atn_prefix = "mid_block.attentions.0."
|
65 |
-
sd_mid_atn_prefix = "middle_block.1."
|
66 |
-
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
67 |
-
|
68 |
-
for j in range(2):
|
69 |
-
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
70 |
-
sd_mid_res_prefix = f"middle_block.{2*j}."
|
71 |
-
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
72 |
-
|
73 |
-
unet_conversion_map_resnet = [
|
74 |
-
# (stable-diffusion, HF Diffusers)
|
75 |
-
("in_layers.0.", "norm1."),
|
76 |
-
("in_layers.2.", "conv1."),
|
77 |
-
("out_layers.0.", "norm2."),
|
78 |
-
("out_layers.3.", "conv2."),
|
79 |
-
("emb_layers.1.", "time_emb_proj."),
|
80 |
-
("skip_connection.", "conv_shortcut."),
|
81 |
-
]
|
82 |
-
|
83 |
-
unet_conversion_map = []
|
84 |
-
for sd, hf in unet_conversion_map_layer:
|
85 |
-
if "resnets" in hf:
|
86 |
-
for sd_res, hf_res in unet_conversion_map_resnet:
|
87 |
-
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
88 |
-
else:
|
89 |
-
unet_conversion_map.append((sd, hf))
|
90 |
-
|
91 |
-
for j in range(2):
|
92 |
-
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
93 |
-
sd_time_embed_prefix = f"time_embed.{j*2}."
|
94 |
-
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
95 |
-
|
96 |
-
for j in range(2):
|
97 |
-
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
98 |
-
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
99 |
-
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
100 |
-
|
101 |
-
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
102 |
-
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
103 |
-
unet_conversion_map.append(("out.2.", "conv_out."))
|
104 |
-
|
105 |
-
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
106 |
-
return sd_hf_conversion_map
|
107 |
-
|
108 |
-
|
109 |
-
UNET_CONVERSION_MAP = make_unet_conversion_map()
|
110 |
-
|
111 |
-
|
112 |
-
class LoRAModule(torch.nn.Module):
|
113 |
-
"""
|
114 |
-
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
115 |
-
"""
|
116 |
-
|
117 |
-
def __init__(
|
118 |
-
self,
|
119 |
-
lora_name,
|
120 |
-
org_module: torch.nn.Module,
|
121 |
-
multiplier=1.0,
|
122 |
-
lora_dim=4,
|
123 |
-
alpha=1,
|
124 |
-
):
|
125 |
-
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
126 |
-
super().__init__()
|
127 |
-
self.lora_name = lora_name
|
128 |
-
|
129 |
-
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
130 |
-
in_dim = org_module.in_channels
|
131 |
-
out_dim = org_module.out_channels
|
132 |
-
else:
|
133 |
-
in_dim = org_module.in_features
|
134 |
-
out_dim = org_module.out_features
|
135 |
-
|
136 |
-
self.lora_dim = lora_dim
|
137 |
-
|
138 |
-
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
139 |
-
kernel_size = org_module.kernel_size
|
140 |
-
stride = org_module.stride
|
141 |
-
padding = org_module.padding
|
142 |
-
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
143 |
-
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
144 |
-
else:
|
145 |
-
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
146 |
-
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
147 |
-
|
148 |
-
if type(alpha) == torch.Tensor:
|
149 |
-
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
150 |
-
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
151 |
-
self.scale = alpha / self.lora_dim
|
152 |
-
self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
|
153 |
-
|
154 |
-
# same as microsoft's
|
155 |
-
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
156 |
-
torch.nn.init.zeros_(self.lora_up.weight)
|
157 |
-
|
158 |
-
self.multiplier = multiplier
|
159 |
-
self.org_module = [org_module]
|
160 |
-
self.enabled = True
|
161 |
-
self.network: LoRANetwork = None
|
162 |
-
self.org_forward = None
|
163 |
-
|
164 |
-
# override org_module's forward method
|
165 |
-
def apply_to(self, multiplier=None):
|
166 |
-
if multiplier is not None:
|
167 |
-
self.multiplier = multiplier
|
168 |
-
if self.org_forward is None:
|
169 |
-
self.org_forward = self.org_module[0].forward
|
170 |
-
self.org_module[0].forward = self.forward
|
171 |
-
|
172 |
-
# restore org_module's forward method
|
173 |
-
def unapply_to(self):
|
174 |
-
if self.org_forward is not None:
|
175 |
-
self.org_module[0].forward = self.org_forward
|
176 |
-
|
177 |
-
# forward with lora
|
178 |
-
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
|
179 |
-
def forward(self, x, scale=1.0):
|
180 |
-
if not self.enabled:
|
181 |
-
return self.org_forward(x)
|
182 |
-
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
183 |
-
|
184 |
-
def set_network(self, network):
|
185 |
-
self.network = network
|
186 |
-
|
187 |
-
# merge lora weight to org weight
|
188 |
-
def merge_to(self, multiplier=1.0):
|
189 |
-
# get lora weight
|
190 |
-
lora_weight = self.get_weight(multiplier)
|
191 |
-
|
192 |
-
# get org weight
|
193 |
-
org_sd = self.org_module[0].state_dict()
|
194 |
-
org_weight = org_sd["weight"]
|
195 |
-
weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
196 |
-
|
197 |
-
# set weight to org_module
|
198 |
-
org_sd["weight"] = weight
|
199 |
-
self.org_module[0].load_state_dict(org_sd)
|
200 |
-
|
201 |
-
# restore org weight from lora weight
|
202 |
-
def restore_from(self, multiplier=1.0):
|
203 |
-
# get lora weight
|
204 |
-
lora_weight = self.get_weight(multiplier)
|
205 |
-
|
206 |
-
# get org weight
|
207 |
-
org_sd = self.org_module[0].state_dict()
|
208 |
-
org_weight = org_sd["weight"]
|
209 |
-
weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
210 |
-
|
211 |
-
# set weight to org_module
|
212 |
-
org_sd["weight"] = weight
|
213 |
-
self.org_module[0].load_state_dict(org_sd)
|
214 |
-
|
215 |
-
# return lora weight
|
216 |
-
def get_weight(self, multiplier=None):
|
217 |
-
if multiplier is None:
|
218 |
-
multiplier = self.multiplier
|
219 |
-
|
220 |
-
# get up/down weight from module
|
221 |
-
up_weight = self.lora_up.weight.to(torch.float)
|
222 |
-
down_weight = self.lora_down.weight.to(torch.float)
|
223 |
-
|
224 |
-
# pre-calculated weight
|
225 |
-
if len(down_weight.size()) == 2:
|
226 |
-
# linear
|
227 |
-
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
228 |
-
elif down_weight.size()[2:4] == (1, 1):
|
229 |
-
# conv2d 1x1
|
230 |
-
weight = (
|
231 |
-
self.multiplier
|
232 |
-
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
233 |
-
* self.scale
|
234 |
-
)
|
235 |
-
else:
|
236 |
-
# conv2d 3x3
|
237 |
-
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
238 |
-
weight = self.multiplier * conved * self.scale
|
239 |
-
|
240 |
-
return weight
|
241 |
-
|
242 |
-
|
243 |
-
# Create network from weights for inference, weights are not loaded here
|
244 |
-
def create_network_from_weights(
|
245 |
-
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
|
246 |
-
):
|
247 |
-
# get dim/alpha mapping
|
248 |
-
modules_dim = {}
|
249 |
-
modules_alpha = {}
|
250 |
-
for key, value in weights_sd.items():
|
251 |
-
if "." not in key:
|
252 |
-
continue
|
253 |
-
|
254 |
-
lora_name = key.split(".")[0]
|
255 |
-
if "alpha" in key:
|
256 |
-
modules_alpha[lora_name] = value
|
257 |
-
elif "lora_down" in key:
|
258 |
-
dim = value.size()[0]
|
259 |
-
modules_dim[lora_name] = dim
|
260 |
-
# print(lora_name, value.size(), dim)
|
261 |
-
|
262 |
-
# support old LoRA without alpha
|
263 |
-
for key in modules_dim.keys():
|
264 |
-
if key not in modules_alpha:
|
265 |
-
modules_alpha[key] = modules_dim[key]
|
266 |
-
|
267 |
-
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
268 |
-
|
269 |
-
|
270 |
-
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
271 |
-
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
272 |
-
unet = pipe.unet
|
273 |
-
|
274 |
-
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
275 |
-
lora_network.load_state_dict(weights_sd)
|
276 |
-
lora_network.merge_to(multiplier=multiplier)
|
277 |
-
|
278 |
-
|
279 |
-
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
280 |
-
class LoRANetwork(torch.nn.Module):
|
281 |
-
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
282 |
-
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
283 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
284 |
-
LORA_PREFIX_UNET = "lora_unet"
|
285 |
-
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
286 |
-
|
287 |
-
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
288 |
-
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
289 |
-
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
290 |
-
|
291 |
-
def __init__(
|
292 |
-
self,
|
293 |
-
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
294 |
-
unet: UNet2DConditionModel,
|
295 |
-
multiplier: float = 1.0,
|
296 |
-
modules_dim: Optional[Dict[str, int]] = None,
|
297 |
-
modules_alpha: Optional[Dict[str, int]] = None,
|
298 |
-
varbose: Optional[bool] = False,
|
299 |
-
) -> None:
|
300 |
-
super().__init__()
|
301 |
-
self.multiplier = multiplier
|
302 |
-
|
303 |
-
print(f"create LoRA network from weights")
|
304 |
-
|
305 |
-
# convert SDXL Stability AI's U-Net modules to Diffusers
|
306 |
-
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
307 |
-
if converted:
|
308 |
-
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
309 |
-
|
310 |
-
# create module instances
|
311 |
-
def create_modules(
|
312 |
-
is_unet: bool,
|
313 |
-
text_encoder_idx: Optional[int], # None, 1, 2
|
314 |
-
root_module: torch.nn.Module,
|
315 |
-
target_replace_modules: List[torch.nn.Module],
|
316 |
-
) -> List[LoRAModule]:
|
317 |
-
prefix = (
|
318 |
-
self.LORA_PREFIX_UNET
|
319 |
-
if is_unet
|
320 |
-
else (
|
321 |
-
self.LORA_PREFIX_TEXT_ENCODER
|
322 |
-
if text_encoder_idx is None
|
323 |
-
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
324 |
-
)
|
325 |
-
)
|
326 |
-
loras = []
|
327 |
-
skipped = []
|
328 |
-
for name, module in root_module.named_modules():
|
329 |
-
if module.__class__.__name__ in target_replace_modules:
|
330 |
-
for child_name, child_module in module.named_modules():
|
331 |
-
is_linear = (
|
332 |
-
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
333 |
-
)
|
334 |
-
is_conv2d = (
|
335 |
-
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
336 |
-
)
|
337 |
-
|
338 |
-
if is_linear or is_conv2d:
|
339 |
-
lora_name = prefix + "." + name + "." + child_name
|
340 |
-
lora_name = lora_name.replace(".", "_")
|
341 |
-
|
342 |
-
if lora_name not in modules_dim:
|
343 |
-
# print(f"skipped {lora_name} (not found in modules_dim)")
|
344 |
-
skipped.append(lora_name)
|
345 |
-
continue
|
346 |
-
|
347 |
-
dim = modules_dim[lora_name]
|
348 |
-
alpha = modules_alpha[lora_name]
|
349 |
-
lora = LoRAModule(
|
350 |
-
lora_name,
|
351 |
-
child_module,
|
352 |
-
self.multiplier,
|
353 |
-
dim,
|
354 |
-
alpha,
|
355 |
-
)
|
356 |
-
loras.append(lora)
|
357 |
-
return loras, skipped
|
358 |
-
|
359 |
-
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
360 |
-
|
361 |
-
# create LoRA for text encoder
|
362 |
-
# 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
|
363 |
-
self.text_encoder_loras: List[LoRAModule] = []
|
364 |
-
skipped_te = []
|
365 |
-
for i, text_encoder in enumerate(text_encoders):
|
366 |
-
if len(text_encoders) > 1:
|
367 |
-
index = i + 1
|
368 |
-
else:
|
369 |
-
index = None
|
370 |
-
|
371 |
-
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
372 |
-
self.text_encoder_loras.extend(text_encoder_loras)
|
373 |
-
skipped_te += skipped
|
374 |
-
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
375 |
-
if len(skipped_te) > 0:
|
376 |
-
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
377 |
-
|
378 |
-
# extend U-Net target modules to include Conv2d 3x3
|
379 |
-
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
380 |
-
|
381 |
-
self.unet_loras: List[LoRAModule]
|
382 |
-
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
383 |
-
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
384 |
-
if len(skipped_un) > 0:
|
385 |
-
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
386 |
-
|
387 |
-
# assertion
|
388 |
-
names = set()
|
389 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
390 |
-
names.add(lora.lora_name)
|
391 |
-
for lora_name in modules_dim.keys():
|
392 |
-
assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
|
393 |
-
|
394 |
-
# make to work load_state_dict
|
395 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
396 |
-
self.add_module(lora.lora_name, lora)
|
397 |
-
|
398 |
-
# SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
|
399 |
-
def convert_unet_modules(self, modules_dim, modules_alpha):
|
400 |
-
converted_count = 0
|
401 |
-
not_converted_count = 0
|
402 |
-
|
403 |
-
map_keys = list(UNET_CONVERSION_MAP.keys())
|
404 |
-
map_keys.sort()
|
405 |
-
|
406 |
-
for key in list(modules_dim.keys()):
|
407 |
-
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
408 |
-
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
409 |
-
position = bisect.bisect_right(map_keys, search_key)
|
410 |
-
map_key = map_keys[position - 1]
|
411 |
-
if search_key.startswith(map_key):
|
412 |
-
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
413 |
-
modules_dim[new_key] = modules_dim[key]
|
414 |
-
modules_alpha[new_key] = modules_alpha[key]
|
415 |
-
del modules_dim[key]
|
416 |
-
del modules_alpha[key]
|
417 |
-
converted_count += 1
|
418 |
-
else:
|
419 |
-
not_converted_count += 1
|
420 |
-
assert (
|
421 |
-
converted_count == 0 or not_converted_count == 0
|
422 |
-
), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
|
423 |
-
return converted_count
|
424 |
-
|
425 |
-
def set_multiplier(self, multiplier):
|
426 |
-
self.multiplier = multiplier
|
427 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
428 |
-
lora.multiplier = self.multiplier
|
429 |
-
|
430 |
-
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
431 |
-
if apply_text_encoder:
|
432 |
-
print("enable LoRA for text encoder")
|
433 |
-
for lora in self.text_encoder_loras:
|
434 |
-
lora.apply_to(multiplier)
|
435 |
-
if apply_unet:
|
436 |
-
print("enable LoRA for U-Net")
|
437 |
-
for lora in self.unet_loras:
|
438 |
-
lora.apply_to(multiplier)
|
439 |
-
|
440 |
-
def unapply_to(self):
|
441 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
442 |
-
lora.unapply_to()
|
443 |
-
|
444 |
-
def merge_to(self, multiplier=1.0):
|
445 |
-
print("merge LoRA weights to original weights")
|
446 |
-
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
447 |
-
lora.merge_to(multiplier)
|
448 |
-
print(f"weights are merged")
|
449 |
-
|
450 |
-
def restore_from(self, multiplier=1.0):
|
451 |
-
print("restore LoRA weights from original weights")
|
452 |
-
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
453 |
-
lora.restore_from(multiplier)
|
454 |
-
print(f"weights are restored")
|
455 |
-
|
456 |
-
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
457 |
-
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
458 |
-
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
459 |
-
map_keys.sort()
|
460 |
-
for key in list(state_dict.keys()):
|
461 |
-
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
462 |
-
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
463 |
-
position = bisect.bisect_right(map_keys, search_key)
|
464 |
-
map_key = map_keys[position - 1]
|
465 |
-
if search_key.startswith(map_key):
|
466 |
-
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
467 |
-
state_dict[new_key] = state_dict[key]
|
468 |
-
del state_dict[key]
|
469 |
-
|
470 |
-
# in case of V2, some weights have different shape, so we need to convert them
|
471 |
-
# because V2 LoRA is based on U-Net created by use_linear_projection=False
|
472 |
-
my_state_dict = self.state_dict()
|
473 |
-
for key in state_dict.keys():
|
474 |
-
if state_dict[key].size() != my_state_dict[key].size():
|
475 |
-
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
476 |
-
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
477 |
-
|
478 |
-
return super().load_state_dict(state_dict, strict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
-
accelerate==0.
|
2 |
-
diffusers==0.
|
3 |
-
gradio==4.
|
4 |
invisible-watermark==0.2.0
|
5 |
-
Pillow==10.
|
|
|
6 |
torch==2.0.1
|
7 |
-
transformers==4.
|
8 |
-
toml==0.10.2
|
9 |
omegaconf==2.3.0
|
10 |
timm==0.9.10
|
11 |
-
git+https://huggingface.co/spaces/Wauplin/gradio-user-history
|
|
|
1 |
+
accelerate==0.27.2
|
2 |
+
diffusers==0.26.3
|
3 |
+
gradio==4.20.0
|
4 |
invisible-watermark==0.2.0
|
5 |
+
Pillow==10.2.0
|
6 |
+
spaces==0.24.0
|
7 |
torch==2.0.1
|
8 |
+
transformers==4.38.1
|
|
|
9 |
omegaconf==2.3.0
|
10 |
timm==0.9.10
|
|
style.css
CHANGED
@@ -1,11 +1,6 @@
|
|
1 |
h1 {
|
2 |
text-align: center;
|
3 |
-
|
4 |
-
}
|
5 |
-
|
6 |
-
h2 {
|
7 |
-
text-align: center;
|
8 |
-
font-size: 10vw; /* relative to the viewport width */
|
9 |
}
|
10 |
|
11 |
#duplicate-button {
|
@@ -15,24 +10,12 @@ h2 {
|
|
15 |
border-radius: 100vh;
|
16 |
}
|
17 |
|
18 |
-
|
19 |
-
max-width:
|
20 |
margin: auto;
|
21 |
padding-top: 1.5rem;
|
22 |
}
|
23 |
|
24 |
-
/* You can also use media queries to adjust your style for different screen sizes */
|
25 |
-
@media (max-width: 600px) {
|
26 |
-
#component-0 {
|
27 |
-
max-width: 90%;
|
28 |
-
padding-top: 1rem;
|
29 |
-
}
|
30 |
-
}
|
31 |
-
|
32 |
-
#gallery .grid-wrap{
|
33 |
-
min-height: 25%;
|
34 |
-
}
|
35 |
-
|
36 |
#title-container {
|
37 |
display: flex;
|
38 |
justify-content: center;
|
@@ -43,18 +26,9 @@ h2 {
|
|
43 |
#title {
|
44 |
font-size: 3em;
|
45 |
text-align: center;
|
46 |
-
color: #333;
|
47 |
-
font-family: 'Helvetica Neue', sans-serif;
|
48 |
-
text-transform: uppercase;
|
49 |
background: transparent;
|
50 |
}
|
51 |
|
52 |
-
#title span {
|
53 |
-
background: -webkit-linear-gradient(45deg, #4EACEF, #28b485);
|
54 |
-
-webkit-background-clip: text;
|
55 |
-
-webkit-text-fill-color: transparent;
|
56 |
-
}
|
57 |
-
|
58 |
#subtitle {
|
59 |
text-align: center;
|
60 |
-
}
|
|
|
1 |
h1 {
|
2 |
text-align: center;
|
3 |
+
display: block;
|
|
|
|
|
|
|
|
|
|
|
4 |
}
|
5 |
|
6 |
#duplicate-button {
|
|
|
10 |
border-radius: 100vh;
|
11 |
}
|
12 |
|
13 |
+
.gradio-container {
|
14 |
+
max-width: 730px !important;
|
15 |
margin: auto;
|
16 |
padding-top: 1.5rem;
|
17 |
}
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
#title-container {
|
20 |
display: flex;
|
21 |
justify-content: center;
|
|
|
26 |
#title {
|
27 |
font-size: 3em;
|
28 |
text-align: center;
|
|
|
|
|
|
|
29 |
background: transparent;
|
30 |
}
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
#subtitle {
|
33 |
text-align: center;
|
34 |
+
}
|
utils.py
CHANGED
@@ -1,7 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def is_google_colab():
|
2 |
try:
|
3 |
import google.colab
|
4 |
-
|
5 |
return True
|
6 |
except:
|
7 |
return False
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
from PIL import Image, PngImagePlugin
|
8 |
+
from datetime import datetime
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Callable, Dict, Optional, Tuple
|
11 |
+
from diffusers import (
|
12 |
+
DDIMScheduler,
|
13 |
+
DPMSolverMultistepScheduler,
|
14 |
+
DPMSolverSinglestepScheduler,
|
15 |
+
EulerAncestralDiscreteScheduler,
|
16 |
+
EulerDiscreteScheduler,
|
17 |
+
)
|
18 |
+
|
19 |
+
MAX_SEED = np.iinfo(np.int32).max
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class StyleConfig:
|
24 |
+
prompt: str
|
25 |
+
negative_prompt: str
|
26 |
+
|
27 |
+
|
28 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
29 |
+
if randomize_seed:
|
30 |
+
seed = random.randint(0, MAX_SEED)
|
31 |
+
return seed
|
32 |
+
|
33 |
+
|
34 |
+
def seed_everything(seed: int) -> torch.Generator:
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
np.random.seed(seed)
|
38 |
+
generator = torch.Generator()
|
39 |
+
generator.manual_seed(seed)
|
40 |
+
return generator
|
41 |
+
|
42 |
+
|
43 |
+
def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
|
44 |
+
if aspect_ratio == "Custom":
|
45 |
+
return None
|
46 |
+
width, height = aspect_ratio.split(" x ")
|
47 |
+
return int(width), int(height)
|
48 |
+
|
49 |
+
|
50 |
+
def aspect_ratio_handler(
|
51 |
+
aspect_ratio: str, custom_width: int, custom_height: int
|
52 |
+
) -> Tuple[int, int]:
|
53 |
+
if aspect_ratio == "Custom":
|
54 |
+
return custom_width, custom_height
|
55 |
+
else:
|
56 |
+
width, height = parse_aspect_ratio(aspect_ratio)
|
57 |
+
return width, height
|
58 |
+
|
59 |
+
|
60 |
+
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
|
61 |
+
scheduler_factory_map = {
|
62 |
+
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
63 |
+
scheduler_config, use_karras_sigmas=True
|
64 |
+
),
|
65 |
+
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
|
66 |
+
scheduler_config, use_karras_sigmas=True
|
67 |
+
),
|
68 |
+
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
69 |
+
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
|
70 |
+
),
|
71 |
+
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
|
72 |
+
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
|
73 |
+
scheduler_config
|
74 |
+
),
|
75 |
+
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
|
76 |
+
}
|
77 |
+
return scheduler_factory_map.get(name, lambda: None)()
|
78 |
+
|
79 |
+
|
80 |
+
def free_memory() -> None:
|
81 |
+
torch.cuda.empty_cache()
|
82 |
+
gc.collect()
|
83 |
+
|
84 |
+
|
85 |
+
def preprocess_prompt(
|
86 |
+
style_dict,
|
87 |
+
style_name: str,
|
88 |
+
positive: str,
|
89 |
+
negative: str = "",
|
90 |
+
add_style: bool = True,
|
91 |
+
) -> Tuple[str, str]:
|
92 |
+
p, n = style_dict.get(style_name, style_dict["(None)"])
|
93 |
+
|
94 |
+
if add_style and positive.strip():
|
95 |
+
formatted_positive = p.format(prompt=positive)
|
96 |
+
else:
|
97 |
+
formatted_positive = positive
|
98 |
+
|
99 |
+
combined_negative = n
|
100 |
+
if negative.strip():
|
101 |
+
if combined_negative:
|
102 |
+
combined_negative += ", " + negative
|
103 |
+
else:
|
104 |
+
combined_negative = negative
|
105 |
+
|
106 |
+
return formatted_positive, combined_negative
|
107 |
+
|
108 |
+
|
109 |
+
def common_upscale(
|
110 |
+
samples: torch.Tensor,
|
111 |
+
width: int,
|
112 |
+
height: int,
|
113 |
+
upscale_method: str,
|
114 |
+
) -> torch.Tensor:
|
115 |
+
return torch.nn.functional.interpolate(
|
116 |
+
samples, size=(height, width), mode=upscale_method
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
def upscale(
|
121 |
+
samples: torch.Tensor, upscale_method: str, scale_by: float
|
122 |
+
) -> torch.Tensor:
|
123 |
+
width = round(samples.shape[3] * scale_by)
|
124 |
+
height = round(samples.shape[2] * scale_by)
|
125 |
+
return common_upscale(samples, width, height, upscale_method)
|
126 |
+
|
127 |
+
|
128 |
+
def load_wildcard_files(wildcard_dir: str) -> Dict[str, str]:
|
129 |
+
wildcard_files = {}
|
130 |
+
for file in os.listdir(wildcard_dir):
|
131 |
+
if file.endswith(".txt"):
|
132 |
+
key = f"__{file.split('.')[0]}__" # Create a key like __character__
|
133 |
+
wildcard_files[key] = os.path.join(wildcard_dir, file)
|
134 |
+
return wildcard_files
|
135 |
+
|
136 |
+
|
137 |
+
def get_random_line_from_file(file_path: str) -> str:
|
138 |
+
with open(file_path, "r") as file:
|
139 |
+
lines = file.readlines()
|
140 |
+
if not lines:
|
141 |
+
return ""
|
142 |
+
return random.choice(lines).strip()
|
143 |
+
|
144 |
+
|
145 |
+
def add_wildcard(prompt: str, wildcard_files: Dict[str, str]) -> str:
|
146 |
+
for key, file_path in wildcard_files.items():
|
147 |
+
if key in prompt:
|
148 |
+
wildcard_line = get_random_line_from_file(file_path)
|
149 |
+
prompt = prompt.replace(key, wildcard_line)
|
150 |
+
return prompt
|
151 |
+
|
152 |
+
|
153 |
+
def preprocess_image_dimensions(width, height):
|
154 |
+
if width % 8 != 0:
|
155 |
+
width = width - (width % 8)
|
156 |
+
if height % 8 != 0:
|
157 |
+
height = height - (height % 8)
|
158 |
+
return width, height
|
159 |
+
|
160 |
+
|
161 |
+
def save_image(image, metadata, output_dir):
|
162 |
+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
163 |
+
os.makedirs(output_dir, exist_ok=True)
|
164 |
+
filename = f"image_{current_time}.png"
|
165 |
+
filepath = os.path.join(output_dir, filename)
|
166 |
+
|
167 |
+
metadata_str = json.dumps(metadata)
|
168 |
+
info = PngImagePlugin.PngInfo()
|
169 |
+
info.add_text("metadata", metadata_str)
|
170 |
+
image.save(filepath, "PNG", pnginfo=info)
|
171 |
+
return filepath
|
172 |
+
|
173 |
+
|
174 |
def is_google_colab():
|
175 |
try:
|
176 |
import google.colab
|
|
|
177 |
return True
|
178 |
except:
|
179 |
return False
|