multimodalart HF staff commited on
Commit
7ac68e0
1 Parent(s): b1b84dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from diffusers.utils import load_image
4
  from controlnet_flux import FluxControlNetModel
5
  from transformer_flux import FluxTransformer2DModel
@@ -14,10 +15,12 @@ controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Co
14
  transformer = FluxTransformer2DModel.from_pretrained(
15
  "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
16
  )
 
17
  pipe = FluxControlNetInpaintingPipeline.from_pretrained(
18
  "black-forest-labs/FLUX.1-dev",
19
  controlnet=controlnet,
20
  transformer=transformer,
 
21
  torch_dtype=torch.bfloat16
22
  )
23
  repo_name = "ByteDance/Hyper-SD"
@@ -26,7 +29,7 @@ pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
26
  pipe.fuse_lora(lora_scale=0.125)
27
  pipe.transformer.to(torch.bfloat16)
28
  pipe.controlnet.to(torch.bfloat16)
29
-
30
  def can_expand(source_width, source_height, target_width, target_height, alignment):
31
  if alignment in ("Left", "Right") and source_width >= target_width:
32
  return False
@@ -133,7 +136,6 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
133
 
134
  @spaces.GPU
135
  def inpaint(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, progress=gr.Progress(track_tqdm=True)):
136
- pipe.enable_model_cpu_offload()
137
 
138
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
139
 
@@ -158,9 +160,12 @@ def inpaint(image, width, height, overlap_percentage, num_inference_steps, resiz
158
  controlnet_conditioning_scale=0.9,
159
  guidance_scale=3.5,
160
  negative_prompt="",
161
- true_guidance_scale=3.5
 
162
  ).images[0]
163
-
 
 
164
  result = result.convert("RGBA")
165
  cnet_image.paste(result, (0, 0), mask)
166
 
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import AutoencoderKL
4
  from diffusers.utils import load_image
5
  from controlnet_flux import FluxControlNetModel
6
  from transformer_flux import FluxTransformer2DModel
 
15
  transformer = FluxTransformer2DModel.from_pretrained(
16
  "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
17
  )
18
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to("cuda")
19
  pipe = FluxControlNetInpaintingPipeline.from_pretrained(
20
  "black-forest-labs/FLUX.1-dev",
21
  controlnet=controlnet,
22
  transformer=transformer,
23
+ vae=vae,
24
  torch_dtype=torch.bfloat16
25
  )
26
  repo_name = "ByteDance/Hyper-SD"
 
29
  pipe.fuse_lora(lora_scale=0.125)
30
  pipe.transformer.to(torch.bfloat16)
31
  pipe.controlnet.to(torch.bfloat16)
32
+ pipe.to("cuda")
33
  def can_expand(source_width, source_height, target_width, target_height, alignment):
34
  if alignment in ("Left", "Right") and source_width >= target_width:
35
  return False
 
136
 
137
  @spaces.GPU
138
  def inpaint(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, progress=gr.Progress(track_tqdm=True)):
 
139
 
140
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
141
 
 
160
  controlnet_conditioning_scale=0.9,
161
  guidance_scale=3.5,
162
  negative_prompt="",
163
+ true_guidance_scale=3.5,
164
+ output_type="latent"
165
  ).images[0]
166
+ pipe.to("cpu")
167
+ vae.to("cuda")
168
+ result = vae.decode(latent_image).sample
169
  result = result.convert("RGBA")
170
  cnet_image.paste(result, (0, 0), mask)
171