File size: 3,749 Bytes
34b628f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
from diffusers.utils import load_image
from controlnet_flux import FluxControlNetModel
from transformer_flux import FluxTransformer2DModel
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
from PIL import Image, ImageDraw

# Load models
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
)
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    controlnet=controlnet,
    transformer=transformer,
    torch_dtype=torch.bfloat16
).to("cuda")
pipe.transformer.to(torch.bfloat16)
pipe.controlnet.to(torch.bfloat16)

def prepare_image_and_mask(image, width, height, overlap_percentage):
    # Resize the input image to fit within the target size
    image.thumbnail((width, height), Image.LANCZOS)
    
    # Create a new white background image of the target size
    background = Image.new('RGB', (width, height), (255, 255, 255))
    
    # Paste the resized image onto the background
    offset = ((width - image.width) // 2, (height - image.height) // 2)
    background.paste(image, offset)
    
    # Create a mask
    mask = Image.new('L', (width, height), 255)
    draw = ImageDraw.Draw(mask)
    
    # Calculate the overlap area
    overlap_x = int(image.width * overlap_percentage / 100)
    overlap_y = int(image.height * overlap_percentage / 100)
    
    # Draw the mask (black area is where we want to inpaint)
    draw.rectangle([
        (offset[0] + overlap_x, offset[1] + overlap_y),
        (offset[0] + image.width - overlap_x, offset[1] + image.height - overlap_y)
    ], fill=0)
    
    return background, mask

def inpaint(image, prompt, width, height, overlap_percentage, num_inference_steps, guidance_scale):
    # Prepare image and mask
    image, mask = prepare_image_and_mask(image, width, height, overlap_percentage)
    
    # Set up generator for reproducibility
    generator = torch.Generator(device="cuda").manual_seed(42)
    
    # Run inpainting
    result = pipe(
        prompt=prompt,
        height=height,
        width=width,
        control_image=image,
        control_mask=mask,
        num_inference_steps=num_inference_steps,
        generator=generator,
        controlnet_conditioning_scale=0.9,
        guidance_scale=guidance_scale,
        negative_prompt="",
        true_guidance_scale=guidance_scale
    ).images[0]
    
    return result

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# FLUX Outpainting Demo")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            prompt_input = gr.Textbox(label="Prompt")
            width_slider = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=768)
            height_slider = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=768)
            overlap_slider = gr.Slider(label="Overlap Percentage", minimum=0, maximum=50, step=1, value=10)
            steps_slider = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
            guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=3.5)
            run_button = gr.Button("Generate")
        with gr.Column():
            output_image = gr.Image(label="Output Image")
    
    run_button.click(
        fn=inpaint,
        inputs=[input_image, prompt_input, width_slider, height_slider, overlap_slider, steps_slider, guidance_slider],
        outputs=output_image
    )

demo.launch()