Spaces:
Runtime error
Runtime error
File size: 6,407 Bytes
d0a7bb3 1920e68 d0a7bb3 1920e68 d0a7bb3 1920e68 d0a7bb3 1920e68 d0a7bb3 1920e68 d0a7bb3 1920e68 1783fac 1920e68 1783fac 1920e68 d0a7bb3 1920e68 7902b52 1920e68 d0a7bb3 1920e68 d0a7bb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import gradio as gr
from functools import partial
from diffusers_patch import OMSPipeline
def create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path):
from diffusers import StableDiffusionXLPipeline, LCMScheduler
sd_pipe = StableDiffusionXLPipeline.from_pretrained(sd_pipe_name_or_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False).to('cuda')
print('successfully load pipe')
sd_scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.load_lora_weights(lora_name_or_path, variant="fp16")
pipe = OMSPipeline.from_pretrained(oms_name_or_path, sd_pipeline = sd_pipe, torch_dtype=torch.float16, variant="fp16", trust_remote_code=True, sd_scheduler=sd_scheduler)
pipe.to('cuda')
return pipe
class GradioDemo:
def __init__(
self,
sd_pipe_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0",
oms_name_or_path = 'h1t/oms_b_openclip_xl',
lora_name_or_path = 'latent-consistency/lcm-lora-sdxl'
):
self.pipe = create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path)
def _inference(
self,
prompt = None,
oms_prompt = None,
oms_guidance_scale = 1.0,
num_inference_steps = 4,
sd_pipe_guidance_scale = 1.0,
seed = 1024,
oms_prompt_flag=True,
):
pipe_kwargs = dict(
prompt = prompt,
num_inference_steps = num_inference_steps,
guidance_scale = sd_pipe_guidance_scale,
)
generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
pipe_kwargs.update(oms_flag=False)
print(f'raw kwargs: {pipe_kwargs}')
image_raw = self.pipe(
**pipe_kwargs,
generator=generator
)['images'][0]
generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
pipe_kwargs.update(oms_flag=True, oms_prompt=prompt, oms_guidance_scale=1.0)
print(f'w/ oms wo/ cfg (consistent) kwargs: {pipe_kwargs}')
image_oms_cp = self.pipe(
**pipe_kwargs,
generator=generator
)['images'][0]
if oms_prompt_flag:
generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
pipe_kwargs.update(oms_prompt=oms_prompt)
print(f'w/ oms wo/ cfg (inconsistent) kwargs: {pipe_kwargs}')
image_oms_icp = self.pipe(
**pipe_kwargs,
generator=generator
)['images'][0]
else:
image_oms_icp = None
oms_guidance_flag = oms_guidance_scale != 1.0
if oms_guidance_flag:
generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
pipe_kwargs.update(oms_guidance_scale=oms_guidance_scale)
print(f'w/ oms +cfg (inconsistent) kwargs: {pipe_kwargs}')
image_oms_cfg = self.pipe(
**pipe_kwargs,
generator=generator
)['images'][0]
else:
image_oms_cfg = None
return image_raw, image_oms_cp, image_oms_icp, image_oms_cfg, gr.update(visible=oms_prompt_flag), gr.update(visible=oms_guidance_flag)
def mainloop(self):
with gr.Blocks() as demo:
gr.Markdown("# One More Step for SDXL w/ LCM-LoRA")
with gr.Group() as inputs:
prompt = gr.Textbox(label="Prompt", value="a cat against orange ground, studio")
with gr.Accordion('OMS Prompt'):
oms_prompt_checkbox = gr.Checkbox(info="Inconsistent OMS prompt allows the additional control of low freq info, default is the same as Prompt. \n Tips:When there is a conflict between the OMS prompt and the base prompt in describing the same object, the model will respect the base prompt.", label="Adding OMS Prompt", value=True)
oms_prompt = gr.Textbox(label="OMS Prompt", value="a black cat", info='try "a black cat" and "a black room" for diverse control.')
with gr.Accordion('OMS Guidance'):
oms_cfg_scale_checkbox = gr.Checkbox(info="OMS Guidance will enhance the OMS prompt, specially focus on color and brightness.", label="Adding OMS Guidance", value=True)
oms_guidance_scale = gr.Slider(label="OMS Guidance Scale", minimum=1.0, maximum=5.0, value=2., step=0.1)
run_button = gr.Button(value="Generate images")
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=4, step=1)
sd_guidance_scale = gr.Slider(label="SD Pipe Guidance Scale", minimum=1, maximum=3, value=1.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=False, value=1024)
with gr.Row():
output_raw = gr.Image(label="SDXL w/ LCM-LoRA ")
output_oms_cp = gr.Image(label="w/ OMS (consistent) w/o OMS CFG")
output_oms_icp = gr.Image(label="w/ OMS (inconsistent) w/o OMS CFG")
output_oms_cfg = gr.Image(label="w/ OMS w/ OMS CFG")
oms_prompt_checkbox.input(
fn=lambda oms_prompt_flag, prompt, oms_prompt: (oms_prompt if oms_prompt_flag else prompt, gr.update(interactive=oms_prompt_flag)),
inputs=[oms_prompt_checkbox, prompt, oms_prompt],
outputs=[oms_prompt, oms_prompt]
)
oms_cfg_scale_checkbox.input(
fn=lambda oms_cfg_scale_flag: (1.5 if oms_cfg_scale_flag else 1.0, gr.update(interactive=oms_cfg_scale_flag)),
inputs=[oms_cfg_scale_checkbox],
outputs=[oms_guidance_scale, oms_guidance_scale]
)
ips = [prompt, oms_prompt, oms_guidance_scale, num_steps, sd_guidance_scale, seed, oms_prompt_checkbox]
run_button.click(fn=self._inference, inputs=ips, outputs=[output_raw, output_oms_cp, output_oms_icp, output_oms_cfg, output_oms_icp, output_oms_cfg])
demo.queue(max_size=20)
demo.launch()
if __name__ == "__main__":
gradio_demo = GradioDemo()
gradio_demo.mainloop()
|