patrickligardes commited on
Commit
5da56f7
β€’
1 Parent(s): ac141df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -1,5 +1,7 @@
1
- import gradio as gr
 
2
  from PIL import Image
 
3
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
4
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
5
  from src.unet_hacked_tryon import UNet2DConditionModel
@@ -15,7 +17,6 @@ from typing import List
15
  import torch
16
  import os
17
  from transformers import AutoTokenizer
18
- import spaces
19
  import numpy as np
20
  from utils_mask import get_mask_location
21
  from torchvision import transforms
@@ -25,6 +26,7 @@ from preprocess.openpose.run_openpose import OpenPose
25
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
  from torchvision.transforms.functional import to_pil_image
27
 
 
28
 
29
  def pil_to_binary_mask(pil_image, threshold=0):
30
  np_image = np.array(pil_image)
@@ -121,10 +123,15 @@ pipe = TryonPipeline.from_pretrained(
121
  )
122
  pipe.unet_encoder = UNet_Encoder
123
 
124
- @spaces.GPU
125
- def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
126
- device = "cuda"
 
127
 
 
 
 
 
128
  openpose_model.preprocessor.body_estimation.model.to(device)
129
  pipe.to(device)
130
  pipe.unet_encoder.to(device)
@@ -150,7 +157,7 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
150
  if is_checked:
151
  keypoints = openpose_model(human_img.resize((384,512)))
152
  model_parse, _ = parsing_model(human_img.resize((384,512)))
153
- mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
154
  mask = mask.resize((768,1024))
155
  else:
156
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
@@ -281,6 +288,7 @@ with image_blocks as demo:
281
  with gr.Row(elem_id="prompt-container"):
282
  with gr.Row():
283
  prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
 
284
  example = gr.Examples(
285
  inputs=garm_img,
286
  examples_per_page=8,
@@ -304,7 +312,7 @@ with image_blocks as demo:
304
 
305
 
306
 
307
- try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
308
 
309
 
310
 
 
1
+ import sys
2
+ sys.path.append('./')
3
  from PIL import Image
4
+ import gradio as gr
5
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
6
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
7
  from src.unet_hacked_tryon import UNet2DConditionModel
 
17
  import torch
18
  import os
19
  from transformers import AutoTokenizer
 
20
  import numpy as np
21
  from utils_mask import get_mask_location
22
  from torchvision import transforms
 
26
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
27
  from torchvision.transforms.functional import to_pil_image
28
 
29
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
30
 
31
  def pil_to_binary_mask(pil_image, threshold=0):
32
  np_image = np.array(pil_image)
 
123
  )
124
  pipe.unet_encoder = UNet_Encoder
125
 
126
+ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed, category):
127
+ category = int(category)
128
+ if category==0:
129
+ category='upper_body'
130
 
131
+ elif category==1:
132
+ category='lower_body'
133
+ else:
134
+ category='dresses'
135
  openpose_model.preprocessor.body_estimation.model.to(device)
136
  pipe.to(device)
137
  pipe.unet_encoder.to(device)
 
157
  if is_checked:
158
  keypoints = openpose_model(human_img.resize((384,512)))
159
  model_parse, _ = parsing_model(human_img.resize((384,512)))
160
+ mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints)
161
  mask = mask.resize((768,1024))
162
  else:
163
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
 
288
  with gr.Row(elem_id="prompt-container"):
289
  with gr.Row():
290
  prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
291
+ category = gr.Textbox(placeholder="0 = upper body, 1 = lower body, 2 = full body", show_label=True)
292
  example = gr.Examples(
293
  inputs=garm_img,
294
  examples_per_page=8,
 
312
 
313
 
314
 
315
+ try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed, category], outputs=[image_out,masked_img], api_name='tryon')
316
 
317
 
318