Spaces:
Sleeping
Sleeping
patrickligardes
commited on
Commit
β’
c7cdb88
1
Parent(s):
5da56f7
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
import sys
|
2 |
-
sys.path.append('./')
|
3 |
-
from PIL import Image
|
4 |
import gradio as gr
|
|
|
5 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
6 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
7 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
@@ -17,6 +15,7 @@ from typing import List
|
|
17 |
import torch
|
18 |
import os
|
19 |
from transformers import AutoTokenizer
|
|
|
20 |
import numpy as np
|
21 |
from utils_mask import get_mask_location
|
22 |
from torchvision import transforms
|
@@ -26,7 +25,6 @@ from preprocess.openpose.run_openpose import OpenPose
|
|
26 |
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
|
27 |
from torchvision.transforms.functional import to_pil_image
|
28 |
|
29 |
-
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
30 |
|
31 |
def pil_to_binary_mask(pil_image, threshold=0):
|
32 |
np_image = np.array(pil_image)
|
@@ -123,15 +121,10 @@ pipe = TryonPipeline.from_pretrained(
|
|
123 |
)
|
124 |
pipe.unet_encoder = UNet_Encoder
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
category='upper_body'
|
130 |
|
131 |
-
elif category==1:
|
132 |
-
category='lower_body'
|
133 |
-
else:
|
134 |
-
category='dresses'
|
135 |
openpose_model.preprocessor.body_estimation.model.to(device)
|
136 |
pipe.to(device)
|
137 |
pipe.unet_encoder.to(device)
|
@@ -157,7 +150,7 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
|
|
157 |
if is_checked:
|
158 |
keypoints = openpose_model(human_img.resize((384,512)))
|
159 |
model_parse, _ = parsing_model(human_img.resize((384,512)))
|
160 |
-
mask, mask_gray = get_mask_location('hd',
|
161 |
mask = mask.resize((768,1024))
|
162 |
else:
|
163 |
mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
|
@@ -288,7 +281,6 @@ with image_blocks as demo:
|
|
288 |
with gr.Row(elem_id="prompt-container"):
|
289 |
with gr.Row():
|
290 |
prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
|
291 |
-
category = gr.Textbox(placeholder="0 = upper body, 1 = lower body, 2 = full body", show_label=True)
|
292 |
example = gr.Examples(
|
293 |
inputs=garm_img,
|
294 |
examples_per_page=8,
|
@@ -312,7 +304,7 @@ with image_blocks as demo:
|
|
312 |
|
313 |
|
314 |
|
315 |
-
try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed
|
316 |
|
317 |
|
318 |
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
4 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
5 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
|
|
15 |
import torch
|
16 |
import os
|
17 |
from transformers import AutoTokenizer
|
18 |
+
import spaces
|
19 |
import numpy as np
|
20 |
from utils_mask import get_mask_location
|
21 |
from torchvision import transforms
|
|
|
25 |
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
|
26 |
from torchvision.transforms.functional import to_pil_image
|
27 |
|
|
|
28 |
|
29 |
def pil_to_binary_mask(pil_image, threshold=0):
|
30 |
np_image = np.array(pil_image)
|
|
|
121 |
)
|
122 |
pipe.unet_encoder = UNet_Encoder
|
123 |
|
124 |
+
@spaces.GPU
|
125 |
+
def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
|
126 |
+
device = "cuda"
|
|
|
127 |
|
|
|
|
|
|
|
|
|
128 |
openpose_model.preprocessor.body_estimation.model.to(device)
|
129 |
pipe.to(device)
|
130 |
pipe.unet_encoder.to(device)
|
|
|
150 |
if is_checked:
|
151 |
keypoints = openpose_model(human_img.resize((384,512)))
|
152 |
model_parse, _ = parsing_model(human_img.resize((384,512)))
|
153 |
+
mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
|
154 |
mask = mask.resize((768,1024))
|
155 |
else:
|
156 |
mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
|
|
|
281 |
with gr.Row(elem_id="prompt-container"):
|
282 |
with gr.Row():
|
283 |
prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
|
|
|
284 |
example = gr.Examples(
|
285 |
inputs=garm_img,
|
286 |
examples_per_page=8,
|
|
|
304 |
|
305 |
|
306 |
|
307 |
+
try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
|
308 |
|
309 |
|
310 |
|