from __future__ import annotations import math import random import gradio as gr import torch from PIL import Image, ImageOps from diffusers import StableDiffusionSAGPipeline help_text = """ """ examples = [ [ ' ', 50, "Fix Seed", 8367, 3.0, 1.0, ], [ ' ', 50, "Fix Seed", 65911, 3.0, 1.0, ], [ ' ', 50, "Fix Seed", 98184, 3.0, 1.0, ], [ ' ', 50, "Fix Seed", 33784, 3.0, 1.0, ], [ ' ', 50, "Fix Seed", 74545, 3.0, 1.0, ], [ ' ', 50, "Fix Seed", 8393, 3.0, 1.0, ], [ '.', 50, "Fix Seed", 24865, 3.0, 1.0, ], [ 'A poster', 50, "Fix Seed", 37956, 3.0, 1.0, ], [ 'A high-quality living room', 50, "Fix Seed", 78710, 3.0, 1.0, ], [ 'A Scottish Fold playing with a ball', 50, "Fix Seed", 11511, 3.0, 1.0, ], ] model_id = "runwayml/stable-diffusion-v1-5" def main(): pipe = StableDiffusionSAGPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda') def generate( prompt: str, steps: int, randomize_seed: bool, seed: int, cfg_scale: float, sag_scale: float, ): seed = random.randint(0, 100000) if randomize_seed else seed generator = torch.manual_seed(seed) ori_image = pipe(prompt, generator=generator, num_inference_steps=steps, guidance_scale=cfg_scale, sag_scale=0.0).images[0] generator = torch.manual_seed(seed) sag_image = pipe(prompt, generator=generator, num_inference_steps=steps, guidance_scale=cfg_scale, sag_scale=sag_scale).images[0] return [ori_image, sag_image, seed] def reset(): return [50, "Randomize Seed", 90061, 3.0, 1.0, None, None] with gr.Blocks() as demo: gr.HTML("""
Condition-agnostic diffusion guidance using the internal self-attention by Susung Hong et al. This space uses StableDiffusionSAGPipeline in Diffusers.
SAG also produces fine unconditional results. Just leave the prompt blank for the unconditional sampling of Stable Diffusion.
""") with gr.Row(): with gr.Column(scale=5): prompt = gr.Textbox(lines=1, label="Enter your prompt", interactive=True) with gr.Column(scale=1, min_width=60): generate_button = gr.Button("Generate") with gr.Column(scale=1, min_width=60): reset_button = gr.Button("Reset") with gr.Row(): steps = gr.Number(value=50, precision=0, label="Steps", interactive=True) randomize_seed = gr.Radio( ["Fix Seed", "Randomize Seed"], label="Seed Type", value="Fix Seed", type="index", show_label=False, interactive=True, ) seed = gr.Number(value=90061, precision=0, label="Seed", interactive=True) with gr.Row(): cfg_scale = gr.Slider( label="Text Guidance Scale", minimum=0, maximum=10, value=3.0, step=0.1 ) sag_scale = gr.Slider( label="Self-Attention Guidance Scale", minimum=0, maximum=1.0, value=1.0, step=0.05 ) with gr.Row(): ori_image = gr.Image(label="CFG", type="pil", interactive=False) sag_image = gr.Image(label="SAG + CFG", type="pil", interactive=False) ori_image.style(height=512, width=512) sag_image.style(height=512, width=512) ex = gr.Examples( examples=examples, fn=generate, inputs=[ prompt, steps, randomize_seed, seed, cfg_scale, sag_scale, ], outputs=[ori_image, sag_image, seed], cache_examples=False, ) gr.Markdown(help_text) generate_button.click( fn=generate, inputs=[ prompt, steps, randomize_seed, seed, cfg_scale, sag_scale, ], outputs=[ori_image, sag_image, seed], ) reset_button.click( fn=reset, inputs=[], outputs=[steps, randomize_seed, seed, cfg_scale, sag_scale, ori_image, sag_image], ) demo.queue(concurrency_count=1) demo.launch(share=False) if __name__ == "__main__": main()