multimodalart HF staff commited on
Commit
257e1a1
1 Parent(s): e99cabf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -21,7 +21,7 @@ from optimum.quanto import freeze, qfloat8, quantize
21
  #controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
22
  #quantize(controlnet, weights=qfloat8)
23
  #freeze(controlnet)
24
-
25
  transformer = FluxTransformer2DModel.from_pretrained(
26
  "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
27
  )
@@ -34,6 +34,7 @@ pipe = FluxControlNetInpaintingPipeline.from_pretrained(
34
  "black-forest-labs/FLUX.1-dev",
35
  text_encoder_2=None,
36
  transformer=transformer,
 
37
  torch_dtype=torch.bfloat16
38
  )
39
  pipe.text_encoder_2 = text_encoder_2
@@ -42,6 +43,8 @@ repo_name = "ByteDance/Hyper-SD"
42
  ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
43
  pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
44
  pipe.fuse_lora(lora_scale=0.125)
 
 
45
  pipe.to("cuda")
46
  def can_expand(source_width, source_height, target_width, target_height, alignment):
47
  if alignment in ("Left", "Right") and source_width >= target_width:
 
21
  #controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
22
  #quantize(controlnet, weights=qfloat8)
23
  #freeze(controlnet)
24
+ controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
25
  transformer = FluxTransformer2DModel.from_pretrained(
26
  "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
27
  )
 
34
  "black-forest-labs/FLUX.1-dev",
35
  text_encoder_2=None,
36
  transformer=transformer,
37
+ controlnet=controlnet,
38
  torch_dtype=torch.bfloat16
39
  )
40
  pipe.text_encoder_2 = text_encoder_2
 
43
  ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
44
  pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
45
  pipe.fuse_lora(lora_scale=0.125)
46
+ pipe.transformer.to(torch.bfloat16)
47
+ pipe.controlnet.to(torch.bfloat16)
48
  pipe.to("cuda")
49
  def can_expand(source_width, source_height, target_width, target_height, alignment):
50
  if alignment in ("Left", "Right") and source_width >= target_width: