Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from functools import partial | |
from diffusers_patch import OMSPipeline | |
def create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path): | |
from diffusers import StableDiffusionXLPipeline, LCMScheduler | |
sd_pipe = StableDiffusionXLPipeline.from_pretrained(sd_pipe_name_or_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False).to('cuda') | |
print('successfully load pipe') | |
sd_scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) | |
sd_pipe.load_lora_weights(lora_name_or_path, variant="fp16") | |
pipe = OMSPipeline.from_pretrained(oms_name_or_path, sd_pipeline = sd_pipe, torch_dtype=torch.float16, variant="fp16", trust_remote_code=True, sd_scheduler=sd_scheduler) | |
pipe.to('cuda') | |
return pipe | |
class GradioDemo: | |
def __init__( | |
self, | |
sd_pipe_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0", | |
oms_name_or_path = 'h1t/oms_b_openclip_xl', | |
lora_name_or_path = 'latent-consistency/lcm-lora-sdxl' | |
): | |
self.pipe = create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path) | |
def _inference( | |
self, | |
prompt = None, | |
oms_prompt = None, | |
oms_guidance_scale = 1.0, | |
num_inference_steps = 4, | |
sd_pipe_guidance_scale = 1.0, | |
seed = 1024, | |
oms_prompt_flag=True, | |
): | |
pipe_kwargs = dict( | |
prompt = prompt, | |
num_inference_steps = num_inference_steps, | |
guidance_scale = sd_pipe_guidance_scale, | |
) | |
generator = torch.Generator(device=self.pipe.device).manual_seed(seed) | |
pipe_kwargs.update(oms_flag=False) | |
print(f'raw kwargs: {pipe_kwargs}') | |
image_raw = self.pipe( | |
**pipe_kwargs, | |
generator=generator | |
)['images'][0] | |
generator = torch.Generator(device=self.pipe.device).manual_seed(seed) | |
pipe_kwargs.update(oms_flag=True, oms_prompt=prompt, oms_guidance_scale=1.0) | |
print(f'w/ oms wo/ cfg (consistent) kwargs: {pipe_kwargs}') | |
image_oms_cp = self.pipe( | |
**pipe_kwargs, | |
generator=generator | |
)['images'][0] | |
if oms_prompt_flag: | |
generator = torch.Generator(device=self.pipe.device).manual_seed(seed) | |
pipe_kwargs.update(oms_prompt=oms_prompt) | |
print(f'w/ oms wo/ cfg (inconsistent) kwargs: {pipe_kwargs}') | |
image_oms_icp = self.pipe( | |
**pipe_kwargs, | |
generator=generator | |
)['images'][0] | |
else: | |
image_oms_icp = None | |
oms_guidance_flag = oms_guidance_scale != 1.0 | |
if oms_guidance_flag: | |
generator = torch.Generator(device=self.pipe.device).manual_seed(seed) | |
pipe_kwargs.update(oms_guidance_scale=oms_guidance_scale) | |
print(f'w/ oms +cfg (inconsistent) kwargs: {pipe_kwargs}') | |
image_oms_cfg = self.pipe( | |
**pipe_kwargs, | |
generator=generator | |
)['images'][0] | |
else: | |
image_oms_cfg = None | |
return image_raw, image_oms_cp, image_oms_icp, image_oms_cfg, gr.update(visible=oms_prompt_flag), gr.update(visible=oms_guidance_flag) | |
def mainloop(self): | |
with gr.Blocks() as demo: | |
gr.Markdown("# One More Step for SDXL w/ LCM-LoRA") | |
with gr.Group() as inputs: | |
prompt = gr.Textbox(label="Prompt", value="a cat against orange ground, studio") | |
with gr.Accordion('OMS Prompt'): | |
oms_prompt_checkbox = gr.Checkbox(info="Inconsistent OMS prompt allows the additional control of low freq info, default is the same as Prompt.", label="Adding OMS Prompt", value=True) | |
oms_prompt = gr.Textbox(label="OMS Prompt", value="a black cat", info='try "a black cat" and "a black room" for diverse control.') | |
with gr.Accordion('OMS Guidance'): | |
oms_cfg_scale_checkbox = gr.Checkbox(info="OMS Guidance will enhance the OMS prompt, specially focus on color and brightness. ", label="Adding OMS Guidance", value=True) | |
oms_guidance_scale = gr.Slider(label="OMS Guidance Scale", minimum=1.0, maximum=5.0, value=2., step=0.1) | |
run_button = gr.Button(value="Generate images") | |
with gr.Accordion("Advanced options", open=False): | |
num_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=4, step=1) | |
sd_guidance_scale = gr.Slider(label="SD Pipe Guidance Scale", minimum=1, maximum=3, value=1.0, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=False, value=1024) | |
with gr.Row(): | |
output_raw = gr.Image(label="SDXL w/ LCM-LoRA ") | |
output_oms_cp = gr.Image(label="w/ OMS (consistent prompt) w/o OMS CFG") | |
output_oms_icp = gr.Image(label="w/ OMS (inconsistent prompt) w/o OMS CFG") | |
output_oms_cfg = gr.Image(label="w/ OMS w/ OMS CFG") | |
oms_prompt_checkbox.input( | |
fn=lambda oms_prompt_flag, prompt, oms_prompt: (oms_prompt if oms_prompt_flag else prompt, gr.update(interactive=oms_prompt_flag)), | |
inputs=[oms_prompt_checkbox, prompt, oms_prompt], | |
outputs=[oms_prompt, oms_prompt] | |
) | |
oms_cfg_scale_checkbox.input( | |
fn=lambda oms_cfg_scale_flag: (1.5 if oms_cfg_scale_flag else 1.0, gr.update(interactive=oms_cfg_scale_flag)), | |
inputs=[oms_cfg_scale_checkbox], | |
outputs=[oms_guidance_scale, oms_guidance_scale] | |
) | |
ips = [prompt, oms_prompt, oms_guidance_scale, num_steps, sd_guidance_scale, seed, oms_prompt_checkbox] | |
run_button.click(fn=self._inference, inputs=ips, outputs=[output_raw, output_oms_cp, output_oms_icp, output_oms_cfg, output_oms_icp, output_oms_cfg]) | |
demo.queue(max_size=20) | |
demo.launch() | |
if __name__ == "__main__": | |
gradio_demo = GradioDemo() | |
gradio_demo.mainloop() | |