""" This module provides an interface to generate images using the SDXL-Turbo model. The interface allows users to enter a text prompt or an initial image and a prompt. The user will receive a generated image. """ import gradio as gr import spaces import torch import math from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image from PIL import Image # Determine the device and dtype for model computation torch_dtype = torch.float16 # Check for CUDA availability, fallback to MPS or CPU if torch.cuda.is_available(): device = torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = torch.device("mps") torch_dtype = torch.float32 # MPS requires float32 precision else: device = torch.device("cpu") # Log the selected device and dtype print(f"Using device: {device}, dtype: {torch_dtype}") # Load the text-to-image pipeline with appropriate dtype and variant settings if device.type == "cuda": torch_dtype = torch.float16 variant = "fp16" else: torch_dtype = None variant = None # Load the pipeline and move it to the specified device pipeline_text2image = AutoPipelineForText2Image.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch_dtype, variant=variant ).to(device) # Create image-to-image pipeline from the text-to-image pipeline try: pipeline_image2image = AutoPipelineForImage2Image.from_pipe(pipeline_text2image) pipeline_image2image = pipeline_image2image.to(device) except Exception as e: print(f"Error creating image-to-image pipeline: {e}") @spaces.GPU async def generate_image(init_image, prompt, strength, steps, seed=123): """ Generates an image from a text prompt or an initial image and a prompt. Args: init_image (PIL.Image or None): The initial image for image-to-image generation. prompt (str): The text prompt to guide the generation. strength (float): The strength of the image transformation (0.0 to 1.0). steps (int): The number of inference steps for the generation process. seed (int, optional): Random seed for reproducibility. Defaults to 123. Returns: PIL.Image: The generated image, or a blank image if NSFW content is detected. """ try: # Set seed for reproducibility generator = torch.manual_seed(seed) # Image-to-image mode if an initial image is provided if init_image: # Resize image to 512x512 for the pipeline init_image = init_image.resize(size=(512, 512)) # Ensure steps are adjusted for strength steps = max(math.ceil(1 / max(0.1, strength)), steps) # Generate image using image-to-image pipeline result = pipeline_image2image( prompt=prompt, image=init_image, generator=generator, num_inference_steps=steps, guidance_scale=0.0, strength=strength, width=512, height=512, output_type="pil" ) else: # Text-to-image mode if no initial image is provided result = pipeline_text2image( prompt=prompt, generator=generator, num_inference_steps=steps, guidance_scale=0.0, width=512, height=512, output_type="pil" ) # Check for NSFW content in the result nsfw_content_detected = result.get("nsfw_content_detected", [False])[0] if nsfw_content_detected: gr.Warning("NSFW content detected") return Image.new("RGB", (512, 512)) # Return a blank image if NSFW detected # Return the generated image return result.images[0] except Exception as error: print(f"An error occurred during image generation: {error}") return Image.new("RGB", (512, 512)) # Return a blank image in case of an error # Create the Gradio interface with gr.Blocks() as demo: # Initial image state for image-to-image generation init_image_state = gr.State() # Title and description for the interface gr.Markdown( """ # SDXL Turbo Text-to-Image and Image-to-Image Generator """ ) # Row with text prompt and generate button with gr.Row(): prompt = gr.Textbox( label="Text prompt", placeholder="Enter a text prompt", scale=5 ) generate_button = gr.Button("Generate", scale=1) # Row for image input, options, and output image with gr.Row(): with gr.Column(): # Image input field image_input = gr.Image( label="Initial image", sources=["upload", "webcam", "clipboard"], type="pil" ) # Accordion for additional generation options with gr.Accordion("Options", open=False): strength = gr.Slider( label="Strength", value=0.5, # Default value for strength minimum=0.0, maximum=1.0, step=0.001 ) steps = gr.Slider( label="Steps", value=4, # Default number of steps minimum=1, maximum=10, step=1 ) seed = gr.Slider( label="Seed", randomize=True, # Randomize the seed minimum=0, maximum=4294967295, step=1 ) with gr.Column(): # Image output field output_image = gr.Image( label="Generated image", type="filepath" ) # Input and output variables for the generation function inputs = [image_input, prompt, strength, steps, seed] # Click event: bind to generate_image function generate_button.click( fn=generate_image, inputs=inputs, outputs=output_image, show_progress=False ) # Change events: re-trigger image generation on changes to input fields for component in [prompt, steps, seed, strength]: component.change( fn=generate_image, inputs=inputs, outputs=output_image, show_progress=False ) # Image input change: update the init_image_state when the image input changes image_input.change( fn=lambda x: x, inputs=image_input, outputs=init_image_state, show_progress=False, queue=False ) # Launch the Gradio interface demo.launch()