import spaces import argparse import os import time from os import path from safetensors.torch import load_file from huggingface_hub import hf_hub_download cache_path = path.join(path.dirname(path.abspath(__file__)), "models") os.environ["TRANSFORMERS_CACHE"] = cache_path os.environ["HF_HUB_CACHE"] = cache_path os.environ["HF_HOME"] = cache_path import gradio as gr import torch from diffusers import FluxPipeline torch.backends.cuda.matmul.allow_tf32 = True class timer: def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() print(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() print(f"{self.method} took {str(round(end - self.start, 2))}s") if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.load_lora_weights(hf_hub_download("RED-AIGC/TDD", "TDD-FLUX.1-dev-lora-beta.safetensors")) pipe.fuse_lora(lora_scale=0.125) pipe.to("cuda") css = """ h1 { text-align: center; display:block; } .gradio-container { max-width: 70.5rem !important; } """ with gr.Blocks(css=css) as demo: gr.Markdown( """ # FLUX.1-dev(beta) distilled by ✨Target-Driven Distillation✨ Compared to Hyper-FLUX, the beta version of TDD has its parameters reduced by half(600MB), resulting in more realistic details. Due to limitations in machine resources, there are still many imperfections in the beta version. The official version is still being optimized and is expected to be released after the National Day holiday. Besides, TDD is also available for distilling video generation models. This space presents TDD-distilled [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). [**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD) The codes of this space are built on [Hyper-FLUX](https://huggingface.co/spaces/ByteDance/Hyper-FLUX-8Steps-LoRA) and we acknowledge their contribution. """ ) with gr.Row(): with gr.Column(scale=3): with gr.Group(): prompt = gr.Textbox( label="Prompt", value="portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", lines=3 ) with gr.Accordion("Advanced Settings", open=False): with gr.Group(): with gr.Row(): height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024) width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024) with gr.Row(): steps = gr.Slider(label="Inference Steps", minimum=4, maximum=10, step=1, value=8) scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=3.5, step=0.1, value=2.0) seed = gr.Number(label="Seed", value=3413, precision=0) generate_btn = gr.Button("Generate Image", variant="primary", scale=1) with gr.Column(scale=4): output = gr.Image(label="Your Generated Image") gr.Markdown( """

How to Use

  1. Enter a detailed description of the image you want to create.
  2. Adjust advanced settings if desired (tap to expand).
  3. Tap "Generate Image" and wait for your creation!

Tip: Be specific in your description for best results!

""" ) @spaces.GPU def process_image(height, width, steps, scales, prompt, seed): global pipe with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"): return pipe( prompt=[prompt], generator=torch.Generator().manual_seed(int(seed)), num_inference_steps=int(steps), guidance_scale=float(scales), height=int(height), width=int(width), max_sequence_length=256 ).images[0] generate_btn.click( process_image, inputs=[height, width, steps, scales, prompt, seed], outputs=output ) if __name__ == "__main__": demo.launch()