Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import torch | |
import numpy as np | |
import random | |
from diffusers import StableDiffusion3Pipeline, AutoencoderKL, SD3Transformer2DModel, FlowMatchEulerDiscreteScheduler | |
import spaces | |
from PIL import Image | |
import requests | |
import transformers | |
from translatepy import Translator | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
translator = Translator() | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# Constants | |
model = "stabilityai/stable-diffusion-3-medium" | |
repo= "stabilityai/stable-diffusion-3-medium-diffusers" | |
MAX_SEED = np.iinfo(np.int32).max | |
CSS = """ | |
.gradio-container { | |
max-width: 690px !important; | |
} | |
footer { | |
visibility: hidden; | |
} | |
""" | |
JS = """function () { | |
gradioURL = window.location.href | |
if (!gradioURL.endsWith('?__theme=dark')) { | |
window.location.replace(gradioURL + '?__theme=dark'); | |
} | |
}""" | |
vae = AutoencoderKL.from_pretrained( | |
repo, | |
subfolder="vae", | |
torch_dtype=torch.float16, | |
) | |
transformer = SD3Transformer2DModel.from_pretrained( | |
repo, | |
subfolder="transformer", | |
torch_dtype=torch.float16, | |
) | |
text_encoder_3 = T5EncoderModel.from_pretrained( | |
repo, | |
subfolder="text_encoder_3", | |
torch_dtype=torch.float16, | |
) | |
# Ensure model and scheduler are initialized in GPU-enabled function | |
if torch.cuda.is_available(): | |
pipe = StableDiffusion3Pipeline.from_pretrained(repo, vae=vae, transformer=transformer, text_encoder_3=text_encoder_3, torch_dtype=torch.float16).to("cuda") | |
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
# Function | |
def generate_image( | |
prompt, | |
negative="low quality", | |
width=1024, | |
height=1024, | |
scale=1.5, | |
steps=28, | |
clip=3): | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
prompt = str(translator.translate(prompt, 'English')) | |
print(f'prompt:{prompt}') | |
image = pipe( | |
prompt, | |
negative_prompt=negative, | |
width=width, | |
height=height, | |
guidance_scale=scale, | |
num_inference_steps=steps, | |
clip_skip=clip, | |
generator = generator, | |
) | |
return image.images[0] | |
examples = [ | |
"a cat eating a piece of cheese", | |
"a ROBOT riding a BLUE horse on Mars, photorealistic", | |
"Ironman VS Hulk, ultrarealistic", | |
"a CUTE robot artist painting on an easel", | |
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", | |
"An alien holding sign board contain word 'Flash', futuristic, neonpunk", | |
"Kids going to school, Anime style" | |
] | |
# Gradio Interface | |
with gr.Blocks(css=CSS, js=JS, theme="soft") as demo: | |
gr.HTML("<h1><center>SD3M🦄</center></h1>") | |
gr.HTML("<p><center><a href='https://huggingface.co/stabilityai/stable-diffusion-3-medium'>sd3m</a> text-to-image generation</center><br><center>Multi-Languages. Adding default prompts to enhance.</center></p>") | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label='Enter Your Prompt', value="best quality, HD, aesthetic", scale=6) | |
submit = gr.Button(scale=1, variant='primary') | |
img = gr.Image(label='SD3M Generated Image') | |
with gr.Accordion("Advanced Options", open=False): | |
with gr.Row(): | |
negative = gr.Textbox(label="Negative prompt", value="low quality") | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=512, | |
maximum=1280, | |
step=8, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=512, | |
maximum=1280, | |
step=8, | |
value=1024, | |
) | |
with gr.Row(): | |
scale = gr.Slider( | |
label="Guidance", | |
minimum=3.5, | |
maximum=7, | |
step=0.1, | |
value=5, | |
) | |
steps = gr.Slider( | |
label="Steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
clip = gr.Slider( | |
label="Clip Skip", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=3, | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=prompt, | |
outputs=img, | |
fn=generate_image, | |
cache_examples="lazy", | |
) | |
prompt.submit(fn=generate_image, | |
inputs=[prompt, negative, width, height, scale, steps, clip], | |
outputs=img, | |
) | |
submit.click(fn=generate_image, | |
inputs=[prompt, negative, width, height, scale, steps, clip], | |
outputs=img, | |
) | |
demo.queue().launch() |