|
import gradio as gr
|
|
import torch
|
|
import spaces
|
|
|
|
from pathlib import Path
|
|
import gc
|
|
import subprocess
|
|
from PIL import Image
|
|
|
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
|
subprocess.run('pip cache purge', shell=True)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
models = [
|
|
"camenduru/FLUX.1-dev-diffusers",
|
|
"black-forest-labs/FLUX.1-schnell",
|
|
"sayakpaul/FLUX.1-merged",
|
|
"John6666/hyper-flux1-dev-fp8-flux",
|
|
"John6666/flux1-dev-minus-v1-fp8-flux",
|
|
"John6666/blue-pencil-flux1-v001-fp8-flux",
|
|
"John6666/copycat-flux-test-fp8-v11-fp8-flux",
|
|
"John6666/nepotism-fuxdevschnell-v3aio-fp8-flux",
|
|
"John6666/niji-style-flux-devfp8-fp8-flux",
|
|
"John6666/niji56-style-v3-fp8-flux",
|
|
"John6666/lyh-dalle-anime-v12dalle-fp8-flux",
|
|
"John6666/xe-hentai-flux-01-fp8-flux",
|
|
"John6666/fluxunchained-artfulnsfw-fut516xfp8e4m3fnv11-fp8-flux",
|
|
"John6666/fastflux-unchained-t5f16-fp8-flux",
|
|
"John6666/iniverse-mix-xl-sfwnsfw-fluxdfp16nsfwv11-fp8-flux",
|
|
"John6666/nsfw-master-flux-lora-merged-with-flux1-dev-fp16-v10-fp8-flux",
|
|
"John6666/the-araminta-flux1a1-fp8-flux",
|
|
"John6666/acorn-is-spinning-flux-v11-fp8-flux",
|
|
"John6666/real-horny-v2-v2unet-fp8-flux",
|
|
"John6666/centerfold-flux-v20fp8e5m2-fp8-flux",
|
|
"John6666/jib-mix-flux-v208stephyper-fp8-flux",
|
|
"John6666/fluxescore-dev-v10fp16-fp8-flux",
|
|
|
|
]
|
|
|
|
|
|
num_loras = 3
|
|
num_cns = 2
|
|
control_images = [None] * num_cns
|
|
control_modes = [-1] * num_cns
|
|
control_scales = [0] * num_cns
|
|
|
|
|
|
def is_repo_name(s):
|
|
import re
|
|
return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
|
|
|
|
|
|
def is_repo_exists(repo_id):
|
|
from huggingface_hub import HfApi
|
|
api = HfApi()
|
|
try:
|
|
if api.repo_exists(repo_id=repo_id): return True
|
|
else: return False
|
|
except Exception as e:
|
|
print(f"Error: Failed to connect {repo_id}. ")
|
|
print(e)
|
|
return True
|
|
|
|
|
|
def clear_cache():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
def deselect_lora():
|
|
selected_index = None
|
|
new_placeholder = "Type a prompt"
|
|
updated_text = ""
|
|
width = 1024
|
|
height = 1024
|
|
return (
|
|
gr.update(placeholder=new_placeholder),
|
|
updated_text,
|
|
selected_index,
|
|
width,
|
|
height,
|
|
)
|
|
|
|
|
|
def get_repo_safetensors(repo_id: str):
|
|
from huggingface_hub import HfApi
|
|
api = HfApi()
|
|
try:
|
|
if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
|
|
files = api.list_repo_files(repo_id=repo_id)
|
|
except Exception as e:
|
|
print(f"Error: Failed to get {repo_id}'s info.")
|
|
print(e)
|
|
return gr.update(choices=[])
|
|
files = [f for f in files if f.endswith(".safetensors")]
|
|
if len(files) == 0: return gr.update(value="", choices=[])
|
|
else: return gr.update(value=files[0], choices=files)
|
|
|
|
|
|
def expand2square(pil_img: Image.Image, background_color: tuple=(0, 0, 0)):
|
|
width, height = pil_img.size
|
|
if width == height:
|
|
return pil_img
|
|
elif width > height:
|
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
result.paste(pil_img, (0, (width - height) // 2))
|
|
return result
|
|
else:
|
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
result.paste(pil_img, ((height - width) // 2, 0))
|
|
return result
|
|
|
|
|
|
|
|
def resize_image(image, target_width, target_height, crop=True):
|
|
from image_datasets.canny_dataset import c_crop
|
|
if crop:
|
|
image = c_crop(image)
|
|
original_width, original_height = image.size
|
|
|
|
|
|
scale = max(target_width / original_width, target_height / original_height)
|
|
resized_width = int(scale * original_width)
|
|
resized_height = int(scale * original_height)
|
|
|
|
image = image.resize((resized_width, resized_height), Image.LANCZOS)
|
|
|
|
|
|
left = (resized_width - target_width) // 2
|
|
top = (resized_height - target_height) // 2
|
|
image = image.crop((left, top, left + target_width, top + target_height))
|
|
else:
|
|
image = image.resize((target_width, target_height), Image.LANCZOS)
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
controlnet_union_modes = {
|
|
"None": -1,
|
|
|
|
"canny": 0,
|
|
"mlsd": 0,
|
|
"tile": 1,
|
|
"depth_midas": 2,
|
|
"blur": 3,
|
|
"openpose": 4,
|
|
"gray": 5,
|
|
"low_quality": 6,
|
|
}
|
|
|
|
|
|
|
|
def get_control_params():
|
|
from diffusers.utils import load_image
|
|
modes = []
|
|
images = []
|
|
scales = []
|
|
for i, mode in enumerate(control_modes):
|
|
if mode == -1 or control_images[i] is None: continue
|
|
modes.append(control_modes[i])
|
|
images.append(load_image(control_images[i]))
|
|
scales.append(control_scales[i])
|
|
return modes, images, scales
|
|
|
|
|
|
from preprocessor import Preprocessor
|
|
def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int,
|
|
preprocess_resolution: int):
|
|
if control_mode == "None": return image
|
|
image_resolution = max(width, height)
|
|
image_before = resize_image(expand2square(image.convert("RGB")), image_resolution, image_resolution, False)
|
|
|
|
print("start to generate control image")
|
|
preprocessor = Preprocessor()
|
|
if control_mode == "depth_midas":
|
|
preprocessor.load("Midas")
|
|
control_image = preprocessor(
|
|
image=image_before,
|
|
image_resolution=image_resolution,
|
|
detect_resolution=preprocess_resolution,
|
|
)
|
|
if control_mode == "openpose":
|
|
preprocessor.load("Openpose")
|
|
control_image = preprocessor(
|
|
image=image_before,
|
|
hand_and_face=True,
|
|
image_resolution=image_resolution,
|
|
detect_resolution=preprocess_resolution,
|
|
)
|
|
if control_mode == "canny":
|
|
preprocessor.load("Canny")
|
|
control_image = preprocessor(
|
|
image=image_before,
|
|
image_resolution=image_resolution,
|
|
detect_resolution=preprocess_resolution,
|
|
)
|
|
|
|
if control_mode == "mlsd":
|
|
preprocessor.load("MLSD")
|
|
control_image = preprocessor(
|
|
image=image_before,
|
|
image_resolution=image_resolution,
|
|
detect_resolution=preprocess_resolution,
|
|
)
|
|
|
|
if control_mode == "scribble_hed":
|
|
preprocessor.load("HED")
|
|
control_image = preprocessor(
|
|
image=image_before,
|
|
image_resolution=image_resolution,
|
|
detect_resolution=preprocess_resolution,
|
|
)
|
|
|
|
if control_mode == "low_quality" or control_mode == "gray" or control_mode == "blur" or control_mode == "tile":
|
|
control_image = image_before
|
|
image_width = 768
|
|
image_height = 768
|
|
else:
|
|
|
|
image_width, image_height = control_image.size
|
|
|
|
image_after = resize_image(control_image, width, height, False)
|
|
ref_width, ref_height = image.size
|
|
print(f"generate control image success: {ref_width}x{ref_height} => {image_width}x{image_height}")
|
|
return image_after
|
|
|
|
|
|
def get_control_union_mode():
|
|
return list(controlnet_union_modes.keys())
|
|
|
|
|
|
def set_control_union_mode(i: int, mode: str, scale: str):
|
|
global control_modes
|
|
global control_scales
|
|
control_modes[i] = controlnet_union_modes.get(mode, 0)
|
|
control_scales[i] = scale
|
|
if mode != "None": return True
|
|
else: return gr.update(visible=True)
|
|
|
|
|
|
def set_control_union_image(i: int, mode: str, image: Image.Image | None, height: int, width: int, preprocess_resolution: int):
|
|
global control_images
|
|
if image is None: return None
|
|
control_images[i] = preprocess_image(image, mode, height, width, preprocess_resolution)
|
|
return control_images[i]
|
|
|
|
|
|
def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
|
|
lorajson[i]["name"] = str(name) if name != "None" else ""
|
|
lorajson[i]["scale"] = float(scale)
|
|
lorajson[i]["filename"] = str(filename)
|
|
lorajson[i]["trigger"] = str(trigger)
|
|
return lorajson
|
|
|
|
|
|
def is_valid_lora(lorajson: list[dict]):
|
|
valid = False
|
|
for d in lorajson:
|
|
if "name" in d.keys() and d["name"] and d["name"] != "None": valid = True
|
|
return valid
|
|
|
|
|
|
def get_trigger_word(lorajson: list[dict]):
|
|
trigger = ""
|
|
for d in lorajson:
|
|
if "name" in d.keys() and d["name"] and d["name"] != "None" and d["trigger"]:
|
|
trigger += ", " + d["trigger"]
|
|
return trigger
|
|
|
|
|
|
|
|
|
|
def fuse_loras(pipe, lorajson: list[dict]):
|
|
if not lorajson or not isinstance(lorajson, list): return
|
|
a_list = []
|
|
w_list = []
|
|
for d in lorajson:
|
|
if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue
|
|
k = d["name"]
|
|
if is_repo_name(k) and is_repo_exists(k):
|
|
a_name = Path(k).stem
|
|
pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name)
|
|
elif not Path(k).exists():
|
|
print(f"LoRA not found: {k}")
|
|
continue
|
|
else:
|
|
w_name = Path(k).name
|
|
a_name = Path(k).stem
|
|
pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name)
|
|
a_list.append(a_name)
|
|
w_list.append(d["scale"])
|
|
if not a_list: return
|
|
pipe.set_adapters(a_list, adapter_weights=w_list)
|
|
pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
|
|
|
|
|
|
|
|
def description_ui():
|
|
gr.Markdown(
|
|
"""
|
|
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
|
|
[jiuface/FLUX.1-dev-Controlnet-Union](https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union),
|
|
[DamarJati/FLUX.1-DEV-Canny](https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny),
|
|
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
|
|
"""
|
|
)
|
|
|
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
|
def load_prompt_enhancer():
|
|
try:
|
|
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
|
|
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
|
|
except Exception as e:
|
|
print(e)
|
|
enhancer_flux = None
|
|
return enhancer_flux
|
|
|
|
|
|
enhancer_flux = load_prompt_enhancer()
|
|
|
|
|
|
@spaces.GPU(duration=30)
|
|
def enhance_prompt(input_prompt):
|
|
result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
|
|
enhanced_text = result[0]['generated_text']
|
|
return enhanced_text
|
|
|
|
|
|
load_prompt_enhancer.zerogpu = True
|
|
fuse_loras.zerogpu = True
|
|
preprocess_image.zerogpu = True
|
|
get_control_params.zerogpu = True
|
|
|