ML-Motivators commited on
Commit
183cb58
1 Parent(s): 93168f6

uploads files

Browse files
Files changed (6) hide show
  1. environment.yaml +32 -0
  2. get-pip.py +0 -0
  3. inference.py +425 -0
  4. inference.sh +34 -0
  5. inference_dc.py +578 -0
  6. vitonhd_test_tagged.json +0 -0
environment.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: idm
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - python=3.10.0=h12debd9_5
8
+ - pytorch=2.0.1=py3.10_cuda11.8_cudnn8.7.0_0
9
+ - pytorch-cuda=11.8=h7e8668a_5
10
+ - torchaudio=2.0.2=py310_cu118
11
+ - torchtriton=2.0.0=py310
12
+ - torchvision=0.15.2=py310_cu118
13
+ - pip=23.3.1=py310h06a4308_0
14
+
15
+ - pip:
16
+ - accelerate==0.25.0
17
+ - torchmetrics==1.2.1
18
+ - tqdm==4.66.1
19
+ - transformers==4.36.2
20
+ - diffusers==0.25.0
21
+ - einops==0.7.0
22
+ - bitsandbytes==0.39.0
23
+ - scipy==1.11.1
24
+ - opencv-python
25
+ - gradio==4.24.0
26
+ - fvcore
27
+ - cloudpickle
28
+ - omegaconf
29
+ - pycocotools
30
+ - basicsr
31
+ - av
32
+ - onnxruntime==1.16.2
get-pip.py ADDED
The diff for this file is too large to render. See raw diff
 
inference.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
15
+ from ip_adapter.ip_adapter import Resampler
16
+
17
+ import argparse
18
+ import logging
19
+ import os
20
+ import torch.utils.data as data
21
+ import torchvision
22
+ import json
23
+ import accelerate
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image
27
+ import torch.nn.functional as F
28
+ import transformers
29
+ from accelerate import Accelerator
30
+ from accelerate.logging import get_logger
31
+ from accelerate.utils import ProjectConfiguration, set_seed
32
+ from packaging import version
33
+ from torchvision import transforms
34
+ import diffusers
35
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline
36
+ from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer
37
+
38
+ from diffusers.utils.import_utils import is_xformers_available
39
+
40
+ from src.unet_hacked_tryon import UNet2DConditionModel
41
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
42
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
43
+
44
+
45
+
46
+ logger = get_logger(__name__, log_level="INFO")
47
+
48
+
49
+
50
+ def parse_args():
51
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
52
+ parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,)
53
+ parser.add_argument("--width",type=int,default=768,)
54
+ parser.add_argument("--height",type=int,default=1024,)
55
+ parser.add_argument("--num_inference_steps",type=int,default=30,)
56
+ parser.add_argument("--output_dir",type=str,default="result",)
57
+ parser.add_argument("--unpaired",action="store_true",)
58
+ parser.add_argument("--data_dir",type=str,default="/home/omnious/workspace/yisol/Dataset/zalando")
59
+ parser.add_argument("--seed", type=int, default=42,)
60
+ parser.add_argument("--test_batch_size", type=int, default=2,)
61
+ parser.add_argument("--guidance_scale",type=float,default=2.0,)
62
+ parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],)
63
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
64
+ args = parser.parse_args()
65
+
66
+
67
+ return args
68
+
69
+ def pil_to_tensor(images):
70
+ images = np.array(images).astype(np.float32) / 255.0
71
+ images = torch.from_numpy(images.transpose(2, 0, 1))
72
+ return images
73
+
74
+
75
+ class VitonHDTestDataset(data.Dataset):
76
+ def __init__(
77
+ self,
78
+ dataroot_path: str,
79
+ phase: Literal["train", "test"],
80
+ order: Literal["paired", "unpaired"] = "paired",
81
+ size: Tuple[int, int] = (512, 384),
82
+ ):
83
+ super(VitonHDTestDataset, self).__init__()
84
+ self.dataroot = dataroot_path
85
+ self.phase = phase
86
+ self.height = size[0]
87
+ self.width = size[1]
88
+ self.size = size
89
+ self.transform = transforms.Compose(
90
+ [
91
+ transforms.ToTensor(),
92
+ transforms.Normalize([0.5], [0.5]),
93
+ ]
94
+ )
95
+ self.toTensor = transforms.ToTensor()
96
+
97
+ with open(
98
+ os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r"
99
+ ) as file1:
100
+ data1 = json.load(file1)
101
+
102
+ annotation_list = [
103
+ "sleeveLength",
104
+ "neckLine",
105
+ "item",
106
+ ]
107
+
108
+ self.annotation_pair = {}
109
+ for k, v in data1.items():
110
+ for elem in v:
111
+ annotation_str = ""
112
+ for template in annotation_list:
113
+ for tag in elem["tag_info"]:
114
+ if (
115
+ tag["tag_name"] == template
116
+ and tag["tag_category"] is not None
117
+ ):
118
+ annotation_str += tag["tag_category"]
119
+ annotation_str += " "
120
+ self.annotation_pair[elem["file_name"]] = annotation_str
121
+
122
+ self.order = order
123
+ self.toTensor = transforms.ToTensor()
124
+
125
+ im_names = []
126
+ c_names = []
127
+ dataroot_names = []
128
+
129
+
130
+ if phase == "train":
131
+ filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
132
+ else:
133
+ filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
134
+
135
+ with open(filename, "r") as f:
136
+ for line in f.readlines():
137
+ if phase == "train":
138
+ im_name, _ = line.strip().split()
139
+ c_name = im_name
140
+ else:
141
+ if order == "paired":
142
+ im_name, _ = line.strip().split()
143
+ c_name = im_name
144
+ else:
145
+ im_name, c_name = line.strip().split()
146
+
147
+ im_names.append(im_name)
148
+ c_names.append(c_name)
149
+ dataroot_names.append(dataroot_path)
150
+
151
+ self.im_names = im_names
152
+ self.c_names = c_names
153
+ self.dataroot_names = dataroot_names
154
+ self.clip_processor = CLIPImageProcessor()
155
+ def __getitem__(self, index):
156
+ c_name = self.c_names[index]
157
+ im_name = self.im_names[index]
158
+ if c_name in self.annotation_pair:
159
+ cloth_annotation = self.annotation_pair[c_name]
160
+ else:
161
+ cloth_annotation = "shirts"
162
+ cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name))
163
+
164
+ im_pil_big = Image.open(
165
+ os.path.join(self.dataroot, self.phase, "image", im_name)
166
+ ).resize((self.width,self.height))
167
+ image = self.transform(im_pil_big)
168
+
169
+ mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height))
170
+ mask = self.toTensor(mask)
171
+ mask = mask[:1]
172
+ mask = 1-mask
173
+ im_mask = image * mask
174
+
175
+ pose_img = Image.open(
176
+ os.path.join(self.dataroot, self.phase, "image-densepose", im_name)
177
+ )
178
+ pose_img = self.transform(pose_img) # [-1,1]
179
+
180
+ result = {}
181
+ result["c_name"] = c_name
182
+ result["im_name"] = im_name
183
+ result["image"] = image
184
+ result["cloth_pure"] = self.transform(cloth)
185
+ result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values
186
+ result["inpaint_mask"] =1-mask
187
+ result["im_mask"] = im_mask
188
+ result["caption_cloth"] = "a photo of " + cloth_annotation
189
+ result["caption"] = "model is wearing a " + cloth_annotation
190
+ result["pose_img"] = pose_img
191
+
192
+ return result
193
+
194
+ def __len__(self):
195
+ # model images + cloth image
196
+ return len(self.im_names)
197
+
198
+
199
+
200
+
201
+ def main():
202
+ args = parse_args()
203
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir)
204
+ accelerator = Accelerator(
205
+ mixed_precision=args.mixed_precision,
206
+ project_config=accelerator_project_config,
207
+ )
208
+ if accelerator.is_local_main_process:
209
+ transformers.utils.logging.set_verbosity_warning()
210
+ diffusers.utils.logging.set_verbosity_info()
211
+ else:
212
+ transformers.utils.logging.set_verbosity_error()
213
+ diffusers.utils.logging.set_verbosity_error()
214
+ # If passed along, set the training seed now.
215
+ if args.seed is not None:
216
+ set_seed(args.seed)
217
+
218
+ # Handle the repository creation
219
+ if accelerator.is_main_process:
220
+ if args.output_dir is not None:
221
+ os.makedirs(args.output_dir, exist_ok=True)
222
+
223
+ weight_dtype = torch.float16
224
+ # if accelerator.mixed_precision == "fp16":
225
+ # weight_dtype = torch.float16
226
+ # args.mixed_precision = accelerator.mixed_precision
227
+ # elif accelerator.mixed_precision == "bf16":
228
+ # weight_dtype = torch.bfloat16
229
+ # args.mixed_precision = accelerator.mixed_precision
230
+
231
+ # Load scheduler, tokenizer and models.
232
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
233
+ vae = AutoencoderKL.from_pretrained(
234
+ args.pretrained_model_name_or_path,
235
+ subfolder="vae",
236
+ torch_dtype=torch.float16,
237
+ )
238
+ unet = UNet2DConditionModel.from_pretrained(
239
+ args.pretrained_model_name_or_path,
240
+ subfolder="unet",
241
+ torch_dtype=torch.float16,
242
+ )
243
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
244
+ args.pretrained_model_name_or_path,
245
+ subfolder="image_encoder",
246
+ torch_dtype=torch.float16,
247
+ )
248
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
249
+ args.pretrained_model_name_or_path,
250
+ subfolder="unet_encoder",
251
+ torch_dtype=torch.float16,
252
+ )
253
+ text_encoder_one = CLIPTextModel.from_pretrained(
254
+ args.pretrained_model_name_or_path,
255
+ subfolder="text_encoder",
256
+ torch_dtype=torch.float16,
257
+ )
258
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
259
+ args.pretrained_model_name_or_path,
260
+ subfolder="text_encoder_2",
261
+ torch_dtype=torch.float16,
262
+ )
263
+ tokenizer_one = AutoTokenizer.from_pretrained(
264
+ args.pretrained_model_name_or_path,
265
+ subfolder="tokenizer",
266
+ revision=None,
267
+ use_fast=False,
268
+ )
269
+ tokenizer_two = AutoTokenizer.from_pretrained(
270
+ args.pretrained_model_name_or_path,
271
+ subfolder="tokenizer_2",
272
+ revision=None,
273
+ use_fast=False,
274
+ )
275
+
276
+
277
+ # Freeze vae and text_encoder and set unet to trainable
278
+ unet.requires_grad_(False)
279
+ vae.requires_grad_(False)
280
+ image_encoder.requires_grad_(False)
281
+ UNet_Encoder.requires_grad_(False)
282
+ text_encoder_one.requires_grad_(False)
283
+ text_encoder_two.requires_grad_(False)
284
+ UNet_Encoder.to(accelerator.device, weight_dtype)
285
+ unet.eval()
286
+ UNet_Encoder.eval()
287
+
288
+
289
+
290
+ if args.enable_xformers_memory_efficient_attention:
291
+ if is_xformers_available():
292
+ import xformers
293
+
294
+ xformers_version = version.parse(xformers.__version__)
295
+ if xformers_version == version.parse("0.0.16"):
296
+ logger.warn(
297
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
298
+ )
299
+ unet.enable_xformers_memory_efficient_attention()
300
+ else:
301
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
302
+
303
+ test_dataset = VitonHDTestDataset(
304
+ dataroot_path=args.data_dir,
305
+ phase="test",
306
+ order="unpaired" if args.unpaired else "paired",
307
+ size=(args.height, args.width),
308
+ )
309
+ test_dataloader = torch.utils.data.DataLoader(
310
+ test_dataset,
311
+ shuffle=False,
312
+ batch_size=args.test_batch_size,
313
+ num_workers=4,
314
+ )
315
+
316
+ pipe = TryonPipeline.from_pretrained(
317
+ args.pretrained_model_name_or_path,
318
+ unet=unet,
319
+ vae=vae,
320
+ feature_extractor= CLIPImageProcessor(),
321
+ text_encoder = text_encoder_one,
322
+ text_encoder_2 = text_encoder_two,
323
+ tokenizer = tokenizer_one,
324
+ tokenizer_2 = tokenizer_two,
325
+ scheduler = noise_scheduler,
326
+ image_encoder=image_encoder,
327
+ torch_dtype=torch.float16,
328
+ ).to(accelerator.device)
329
+ pipe.unet_encoder = UNet_Encoder
330
+
331
+ # pipe.enable_sequential_cpu_offload()
332
+ # pipe.enable_model_cpu_offload()
333
+ # pipe.enable_vae_slicing()
334
+
335
+
336
+
337
+ with torch.no_grad():
338
+ # Extract the images
339
+ with torch.cuda.amp.autocast():
340
+ with torch.no_grad():
341
+ for sample in test_dataloader:
342
+ img_emb_list = []
343
+ for i in range(sample['cloth'].shape[0]):
344
+ img_emb_list.append(sample['cloth'][i])
345
+
346
+ prompt = sample["caption"]
347
+
348
+ num_prompts = sample['cloth'].shape[0]
349
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
350
+
351
+ if not isinstance(prompt, List):
352
+ prompt = [prompt] * num_prompts
353
+ if not isinstance(negative_prompt, List):
354
+ negative_prompt = [negative_prompt] * num_prompts
355
+
356
+ image_embeds = torch.cat(img_emb_list,dim=0)
357
+
358
+ with torch.inference_mode():
359
+ (
360
+ prompt_embeds,
361
+ negative_prompt_embeds,
362
+ pooled_prompt_embeds,
363
+ negative_pooled_prompt_embeds,
364
+ ) = pipe.encode_prompt(
365
+ prompt,
366
+ num_images_per_prompt=1,
367
+ do_classifier_free_guidance=True,
368
+ negative_prompt=negative_prompt,
369
+ )
370
+
371
+
372
+ prompt = sample["caption_cloth"]
373
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
374
+
375
+ if not isinstance(prompt, List):
376
+ prompt = [prompt] * num_prompts
377
+ if not isinstance(negative_prompt, List):
378
+ negative_prompt = [negative_prompt] * num_prompts
379
+
380
+
381
+ with torch.inference_mode():
382
+ (
383
+ prompt_embeds_c,
384
+ _,
385
+ _,
386
+ _,
387
+ ) = pipe.encode_prompt(
388
+ prompt,
389
+ num_images_per_prompt=1,
390
+ do_classifier_free_guidance=False,
391
+ negative_prompt=negative_prompt,
392
+ )
393
+
394
+
395
+
396
+ generator = torch.Generator(pipe.device).manual_seed(args.seed) if args.seed is not None else None
397
+ images = pipe(
398
+ prompt_embeds=prompt_embeds,
399
+ negative_prompt_embeds=negative_prompt_embeds,
400
+ pooled_prompt_embeds=pooled_prompt_embeds,
401
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
402
+ num_inference_steps=args.num_inference_steps,
403
+ generator=generator,
404
+ strength = 1.0,
405
+ pose_img = sample['pose_img'],
406
+ text_embeds_cloth=prompt_embeds_c,
407
+ cloth = sample["cloth_pure"].to(accelerator.device),
408
+ mask_image=sample['inpaint_mask'],
409
+ image=(sample['image']+1.0)/2.0,
410
+ height=args.height,
411
+ width=args.width,
412
+ guidance_scale=args.guidance_scale,
413
+ ip_adapter_image = image_embeds,
414
+ )[0]
415
+
416
+
417
+ for i in range(len(images)):
418
+ x_sample = pil_to_tensor(images[i])
419
+ torchvision.utils.save_image(x_sample,os.path.join(args.output_dir,sample['im_name'][i]))
420
+
421
+
422
+
423
+
424
+ if __name__ == "__main__":
425
+ main()
inference.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #VITON-HD
2
+ ##paired setting
3
+ accelerate launch inference.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
4
+ --width 768 --height 1024 --num_inference_steps 30 \
5
+ --output_dir "result" --data_dir "/home/omnious/workspace/yisol/Dataset/zalando" \
6
+ --seed 42 --test_batch_size 2 --guidance_scale 2.0
7
+
8
+
9
+ ##unpaired setting
10
+ accelerate launch inference.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
11
+ --width 768 --height 1024 --num_inference_steps 30 \
12
+ --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/Dataset/zalando" \
13
+ --seed 42 --test_batch_size 2 --guidance_scale 2.0
14
+
15
+
16
+
17
+ #DressCode
18
+ ##upper_body
19
+ accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
20
+ --width 768 --height 1024 --num_inference_steps 30 \
21
+ --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \
22
+ --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "upper_body"
23
+
24
+ ##lower_body
25
+ accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
26
+ --width 768 --height 1024 --num_inference_steps 30 \
27
+ --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \
28
+ --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "lower_body"
29
+
30
+ ##dresses
31
+ accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
32
+ --width 768 --height 1024 --num_inference_steps 30 \
33
+ --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \
34
+ --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "dresses"
inference_dc.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
15
+ from ip_adapter.ip_adapter import Resampler
16
+
17
+ import argparse
18
+ import logging
19
+ import os
20
+ import torch.utils.data as data
21
+ import torchvision
22
+ import json
23
+ import accelerate
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image, ImageDraw
27
+ import torch.nn.functional as F
28
+ import transformers
29
+ from accelerate import Accelerator
30
+ from accelerate.logging import get_logger
31
+ from accelerate.utils import ProjectConfiguration, set_seed
32
+ from packaging import version
33
+ from torchvision import transforms
34
+ import diffusers
35
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline
36
+ from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer
37
+ import cv2
38
+ from diffusers.utils.import_utils import is_xformers_available
39
+ from numpy.linalg import lstsq
40
+
41
+ from src.unet_hacked_tryon import UNet2DConditionModel
42
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
43
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
44
+
45
+
46
+
47
+ logger = get_logger(__name__, log_level="INFO")
48
+
49
+ label_map={
50
+ "background": 0,
51
+ "hat": 1,
52
+ "hair": 2,
53
+ "sunglasses": 3,
54
+ "upper_clothes": 4,
55
+ "skirt": 5,
56
+ "pants": 6,
57
+ "dress": 7,
58
+ "belt": 8,
59
+ "left_shoe": 9,
60
+ "right_shoe": 10,
61
+ "head": 11,
62
+ "left_leg": 12,
63
+ "right_leg": 13,
64
+ "left_arm": 14,
65
+ "right_arm": 15,
66
+ "bag": 16,
67
+ "scarf": 17,
68
+ }
69
+
70
+ def parse_args():
71
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
72
+ parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,)
73
+ parser.add_argument("--width",type=int,default=768,)
74
+ parser.add_argument("--height",type=int,default=1024,)
75
+ parser.add_argument("--num_inference_steps",type=int,default=30,)
76
+ parser.add_argument("--output_dir",type=str,default="result",)
77
+ parser.add_argument("--category",type=str,default="upper_body",choices=["upper_body", "lower_body", "dresses"])
78
+ parser.add_argument("--unpaired",action="store_true",)
79
+ parser.add_argument("--data_dir",type=str,default="/home/omnious/workspace/yisol/Dataset/zalando")
80
+ parser.add_argument("--seed", type=int, default=42,)
81
+ parser.add_argument("--test_batch_size", type=int, default=2,)
82
+ parser.add_argument("--guidance_scale",type=float,default=2.0,)
83
+ parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],)
84
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
85
+ args = parser.parse_args()
86
+
87
+
88
+ return args
89
+
90
+ def pil_to_tensor(images):
91
+ images = np.array(images).astype(np.float32) / 255.0
92
+ images = torch.from_numpy(images.transpose(2, 0, 1))
93
+ return images
94
+
95
+
96
+ class DresscodeTestDataset(data.Dataset):
97
+ def __init__(
98
+ self,
99
+ dataroot_path: str,
100
+ phase: Literal["train", "test"],
101
+ order: Literal["paired", "unpaired"] = "paired",
102
+ category = "upper_body",
103
+ size: Tuple[int, int] = (512, 384),
104
+ ):
105
+ super(DresscodeTestDataset, self).__init__()
106
+ self.dataroot = os.path.join(dataroot_path,category)
107
+ self.phase = phase
108
+ self.height = size[0]
109
+ self.width = size[1]
110
+ self.size = size
111
+ self.transform = transforms.Compose(
112
+ [
113
+ transforms.ToTensor(),
114
+ transforms.Normalize([0.5], [0.5]),
115
+ ]
116
+ )
117
+ self.toTensor = transforms.ToTensor()
118
+ self.order = order
119
+ self.radius = 5
120
+ self.category = category
121
+ im_names = []
122
+ c_names = []
123
+
124
+
125
+ if phase == "train":
126
+ filename = os.path.join(dataroot_path,category, f"{phase}_pairs.txt")
127
+ else:
128
+ filename = os.path.join(dataroot_path,category, f"{phase}_pairs_{order}.txt")
129
+
130
+ with open(filename, "r") as f:
131
+ for line in f.readlines():
132
+ im_name, c_name = line.strip().split()
133
+
134
+ im_names.append(im_name)
135
+ c_names.append(c_name)
136
+
137
+
138
+ file_path = os.path.join(dataroot_path,category,"dc_caption.txt")
139
+
140
+ self.annotation_pair = {}
141
+ with open(file_path, "r") as file:
142
+ for line in file:
143
+ parts = line.strip().split(" ")
144
+ self.annotation_pair[parts[0]] = ' '.join(parts[1:])
145
+
146
+
147
+ self.im_names = im_names
148
+ self.c_names = c_names
149
+ self.clip_processor = CLIPImageProcessor()
150
+ def __getitem__(self, index):
151
+ c_name = self.c_names[index]
152
+ im_name = self.im_names[index]
153
+ if c_name in self.annotation_pair:
154
+ cloth_annotation = self.annotation_pair[c_name]
155
+ else:
156
+ cloth_annotation = self.category
157
+ cloth = Image.open(os.path.join(self.dataroot, "images", c_name))
158
+
159
+ im_pil_big = Image.open(
160
+ os.path.join(self.dataroot, "images", im_name)
161
+ ).resize((self.width,self.height))
162
+ image = self.transform(im_pil_big)
163
+
164
+
165
+
166
+
167
+ skeleton = Image.open(os.path.join(self.dataroot, 'skeletons', im_name.replace("_0", "_5")))
168
+ skeleton = skeleton.resize((self.width, self.height))
169
+ skeleton = self.transform(skeleton)
170
+
171
+ # Label Map
172
+ parse_name = im_name.replace('_0.jpg', '_4.png')
173
+ im_parse = Image.open(os.path.join(self.dataroot, 'label_maps', parse_name))
174
+ im_parse = im_parse.resize((self.width, self.height), Image.NEAREST)
175
+ parse_array = np.array(im_parse)
176
+
177
+ # Load pose points
178
+ pose_name = im_name.replace('_0.jpg', '_2.json')
179
+ with open(os.path.join(self.dataroot, 'keypoints', pose_name), 'r') as f:
180
+ pose_label = json.load(f)
181
+ pose_data = pose_label['keypoints']
182
+ pose_data = np.array(pose_data)
183
+ pose_data = pose_data.reshape((-1, 4))
184
+
185
+ point_num = pose_data.shape[0]
186
+ pose_map = torch.zeros(point_num, self.height, self.width)
187
+ r = self.radius * (self.height / 512.0)
188
+ for i in range(point_num):
189
+ one_map = Image.new('L', (self.width, self.height))
190
+ draw = ImageDraw.Draw(one_map)
191
+ point_x = np.multiply(pose_data[i, 0], self.width / 384.0)
192
+ point_y = np.multiply(pose_data[i, 1], self.height / 512.0)
193
+ if point_x > 1 and point_y > 1:
194
+ draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white')
195
+ one_map = self.toTensor(one_map)
196
+ pose_map[i] = one_map[0]
197
+
198
+ agnostic_mask = self.get_agnostic(parse_array, pose_data, self.category, (self.width,self.height))
199
+ # agnostic_mask = transforms.functional.resize(agnostic_mask, (self.height, self.width),
200
+ # interpolation=transforms.InterpolationMode.NEAREST)
201
+
202
+ mask = 1 - agnostic_mask
203
+ im_mask = image * agnostic_mask
204
+
205
+ pose_img = Image.open(
206
+ os.path.join(self.dataroot, "image-densepose", im_name)
207
+ )
208
+ pose_img = self.transform(pose_img) # [-1,1]
209
+
210
+ result = {}
211
+ result["c_name"] = c_name
212
+ result["im_name"] = im_name
213
+ result["image"] = image
214
+ result["cloth_pure"] = self.transform(cloth)
215
+ result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values
216
+ result["inpaint_mask"] =mask
217
+ result["im_mask"] = im_mask
218
+ result["caption_cloth"] = "a photo of " + cloth_annotation
219
+ result["caption"] = "model is wearing a " + cloth_annotation
220
+ result["pose_img"] = pose_img
221
+
222
+ return result
223
+
224
+ def __len__(self):
225
+ # model images + cloth image
226
+ return len(self.im_names)
227
+
228
+
229
+
230
+
231
+ def get_agnostic(self,parse_array, pose_data, category, size):
232
+ parse_shape = (parse_array > 0).astype(np.float32)
233
+
234
+ parse_head = (parse_array == 1).astype(np.float32) + \
235
+ (parse_array == 2).astype(np.float32) + \
236
+ (parse_array == 3).astype(np.float32) + \
237
+ (parse_array == 11).astype(np.float32)
238
+
239
+ parser_mask_fixed = (parse_array == label_map["hair"]).astype(np.float32) + \
240
+ (parse_array == label_map["left_shoe"]).astype(np.float32) + \
241
+ (parse_array == label_map["right_shoe"]).astype(np.float32) + \
242
+ (parse_array == label_map["hat"]).astype(np.float32) + \
243
+ (parse_array == label_map["sunglasses"]).astype(np.float32) + \
244
+ (parse_array == label_map["scarf"]).astype(np.float32) + \
245
+ (parse_array == label_map["bag"]).astype(np.float32)
246
+
247
+ parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)
248
+
249
+ arms = (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32)
250
+
251
+ if category == 'dresses':
252
+ label_cat = 7
253
+ parse_mask = (parse_array == 7).astype(np.float32) + \
254
+ (parse_array == 12).astype(np.float32) + \
255
+ (parse_array == 13).astype(np.float32)
256
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
257
+
258
+ elif category == 'upper_body':
259
+ label_cat = 4
260
+ parse_mask = (parse_array == 4).astype(np.float32)
261
+
262
+ parser_mask_fixed += (parse_array == label_map["skirt"]).astype(np.float32) + \
263
+ (parse_array == label_map["pants"]).astype(np.float32)
264
+
265
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
266
+ elif category == 'lower_body':
267
+ label_cat = 6
268
+ parse_mask = (parse_array == 6).astype(np.float32) + \
269
+ (parse_array == 12).astype(np.float32) + \
270
+ (parse_array == 13).astype(np.float32)
271
+
272
+ parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
273
+ (parse_array == 14).astype(np.float32) + \
274
+ (parse_array == 15).astype(np.float32)
275
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
276
+
277
+ parse_head = torch.from_numpy(parse_head) # [0,1]
278
+ parse_mask = torch.from_numpy(parse_mask) # [0,1]
279
+ parser_mask_fixed = torch.from_numpy(parser_mask_fixed)
280
+ parser_mask_changeable = torch.from_numpy(parser_mask_changeable)
281
+
282
+ # dilation
283
+ parse_without_cloth = np.logical_and(parse_shape, np.logical_not(parse_mask))
284
+ parse_mask = parse_mask.cpu().numpy()
285
+
286
+ width = size[0]
287
+ height = size[1]
288
+
289
+ im_arms = Image.new('L', (width, height))
290
+ arms_draw = ImageDraw.Draw(im_arms)
291
+ if category == 'dresses' or category == 'upper_body':
292
+ shoulder_right = tuple(np.multiply(pose_data[2, :2], height / 512.0))
293
+ shoulder_left = tuple(np.multiply(pose_data[5, :2], height / 512.0))
294
+ elbow_right = tuple(np.multiply(pose_data[3, :2], height / 512.0))
295
+ elbow_left = tuple(np.multiply(pose_data[6, :2], height / 512.0))
296
+ wrist_right = tuple(np.multiply(pose_data[4, :2], height / 512.0))
297
+ wrist_left = tuple(np.multiply(pose_data[7, :2], height / 512.0))
298
+ if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
299
+ if elbow_right[0] <= 1. and elbow_right[1] <= 1.:
300
+ arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right], 'white', 30, 'curve')
301
+ else:
302
+ arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right], 'white', 30,
303
+ 'curve')
304
+ elif wrist_left[0] <= 1. and wrist_left[1] <= 1.:
305
+ if elbow_left[0] <= 1. and elbow_left[1] <= 1.:
306
+ arms_draw.line([shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30, 'curve')
307
+ else:
308
+ arms_draw.line([elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30,
309
+ 'curve')
310
+ else:
311
+ arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white',
312
+ 30, 'curve')
313
+
314
+ if height > 512:
315
+ im_arms = cv2.dilate(np.float32(im_arms), np.ones((10, 10), np.uint16), iterations=5)
316
+ elif height > 256:
317
+ im_arms = cv2.dilate(np.float32(im_arms), np.ones((5, 5), np.uint16), iterations=5)
318
+ hands = np.logical_and(np.logical_not(im_arms), arms)
319
+ parse_mask += im_arms
320
+ parser_mask_fixed += hands
321
+
322
+ # delete neck
323
+ parse_head_2 = torch.clone(parse_head)
324
+ if category == 'dresses' or category == 'upper_body':
325
+ points = []
326
+ points.append(np.multiply(pose_data[2, :2], height / 512.0))
327
+ points.append(np.multiply(pose_data[5, :2], height / 512.0))
328
+ x_coords, y_coords = zip(*points)
329
+ A = np.vstack([x_coords, np.ones(len(x_coords))]).T
330
+ m, c = lstsq(A, y_coords, rcond=None)[0]
331
+ for i in range(parse_array.shape[1]):
332
+ y = i * m + c
333
+ parse_head_2[int(y - 20 * (height / 512.0)):, i] = 0
334
+
335
+ parser_mask_fixed = np.logical_or(parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16))
336
+ parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16),
337
+ np.logical_not(np.array(parse_head_2, dtype=np.uint16))))
338
+
339
+ if height > 512:
340
+ parse_mask = cv2.dilate(parse_mask, np.ones((20, 20), np.uint16), iterations=5)
341
+ elif height > 256:
342
+ parse_mask = cv2.dilate(parse_mask, np.ones((10, 10), np.uint16), iterations=5)
343
+ else:
344
+ parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5)
345
+ parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
346
+ parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
347
+ agnostic_mask = parse_mask_total.unsqueeze(0)
348
+ return agnostic_mask
349
+
350
+
351
+
352
+
353
+ def main():
354
+ args = parse_args()
355
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir)
356
+ accelerator = Accelerator(
357
+ mixed_precision=args.mixed_precision,
358
+ project_config=accelerator_project_config,
359
+ )
360
+ if accelerator.is_local_main_process:
361
+ transformers.utils.logging.set_verbosity_warning()
362
+ diffusers.utils.logging.set_verbosity_info()
363
+ else:
364
+ transformers.utils.logging.set_verbosity_error()
365
+ diffusers.utils.logging.set_verbosity_error()
366
+ # If passed along, set the training seed now.
367
+ if args.seed is not None:
368
+ set_seed(args.seed)
369
+
370
+ # Handle the repository creation
371
+ if accelerator.is_main_process:
372
+ if args.output_dir is not None:
373
+ os.makedirs(args.output_dir, exist_ok=True)
374
+
375
+ weight_dtype = torch.float16
376
+ # if accelerator.mixed_precision == "fp16":
377
+ # weight_dtype = torch.float16
378
+ # args.mixed_precision = accelerator.mixed_precision
379
+ # elif accelerator.mixed_precision == "bf16":
380
+ # weight_dtype = torch.bfloat16
381
+ # args.mixed_precision = accelerator.mixed_precision
382
+
383
+ # Load scheduler, tokenizer and models.
384
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
385
+ vae = AutoencoderKL.from_pretrained(
386
+ args.pretrained_model_name_or_path,
387
+ subfolder="vae",
388
+ torch_dtype=torch.float16,
389
+ )
390
+ unet = UNet2DConditionModel.from_pretrained(
391
+ "yisol/IDM-VTON-DC",
392
+ subfolder="unet",
393
+ torch_dtype=torch.float16,
394
+ )
395
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
396
+ args.pretrained_model_name_or_path,
397
+ subfolder="image_encoder",
398
+ torch_dtype=torch.float16,
399
+ )
400
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
401
+ args.pretrained_model_name_or_path,
402
+ subfolder="unet_encoder",
403
+ torch_dtype=torch.float16,
404
+ )
405
+ text_encoder_one = CLIPTextModel.from_pretrained(
406
+ args.pretrained_model_name_or_path,
407
+ subfolder="text_encoder",
408
+ torch_dtype=torch.float16,
409
+ )
410
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
411
+ args.pretrained_model_name_or_path,
412
+ subfolder="text_encoder_2",
413
+ torch_dtype=torch.float16,
414
+ )
415
+ tokenizer_one = AutoTokenizer.from_pretrained(
416
+ args.pretrained_model_name_or_path,
417
+ subfolder="tokenizer",
418
+ revision=None,
419
+ use_fast=False,
420
+ )
421
+ tokenizer_two = AutoTokenizer.from_pretrained(
422
+ args.pretrained_model_name_or_path,
423
+ subfolder="tokenizer_2",
424
+ revision=None,
425
+ use_fast=False,
426
+ )
427
+
428
+
429
+ # Freeze vae and text_encoder and set unet to trainable
430
+ unet.requires_grad_(False)
431
+ vae.requires_grad_(False)
432
+ image_encoder.requires_grad_(False)
433
+ UNet_Encoder.requires_grad_(False)
434
+ text_encoder_one.requires_grad_(False)
435
+ text_encoder_two.requires_grad_(False)
436
+ UNet_Encoder.to(accelerator.device, weight_dtype)
437
+ unet.eval()
438
+ UNet_Encoder.eval()
439
+
440
+
441
+
442
+ if args.enable_xformers_memory_efficient_attention:
443
+ if is_xformers_available():
444
+ import xformers
445
+
446
+ xformers_version = version.parse(xformers.__version__)
447
+ if xformers_version == version.parse("0.0.16"):
448
+ logger.warn(
449
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
450
+ )
451
+ unet.enable_xformers_memory_efficient_attention()
452
+ else:
453
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
454
+
455
+ test_dataset = DresscodeTestDataset(
456
+ dataroot_path=args.data_dir,
457
+ phase="test",
458
+ order="unpaired" if args.unpaired else "paired",
459
+ category = args.category,
460
+ size=(args.height, args.width),
461
+ )
462
+ test_dataloader = torch.utils.data.DataLoader(
463
+ test_dataset,
464
+ shuffle=False,
465
+ batch_size=args.test_batch_size,
466
+ num_workers=4,
467
+ )
468
+
469
+ pipe = TryonPipeline.from_pretrained(
470
+ args.pretrained_model_name_or_path,
471
+ unet=unet,
472
+ vae=vae,
473
+ feature_extractor= CLIPImageProcessor(),
474
+ text_encoder = text_encoder_one,
475
+ text_encoder_2 = text_encoder_two,
476
+ tokenizer = tokenizer_one,
477
+ tokenizer_2 = tokenizer_two,
478
+ scheduler = noise_scheduler,
479
+ image_encoder=image_encoder,
480
+ torch_dtype=torch.float16,
481
+ ).to(accelerator.device)
482
+ pipe.unet_encoder = UNet_Encoder
483
+
484
+ # pipe.enable_sequential_cpu_offload()
485
+ # pipe.enable_model_cpu_offload()
486
+ # pipe.enable_vae_slicing()
487
+
488
+
489
+
490
+ with torch.no_grad():
491
+ # Extract the images
492
+ with torch.cuda.amp.autocast():
493
+ with torch.no_grad():
494
+ for sample in test_dataloader:
495
+ img_emb_list = []
496
+ for i in range(sample['cloth'].shape[0]):
497
+ img_emb_list.append(sample['cloth'][i])
498
+
499
+ prompt = sample["caption"]
500
+
501
+ num_prompts = sample['cloth'].shape[0]
502
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
503
+
504
+ if not isinstance(prompt, List):
505
+ prompt = [prompt] * num_prompts
506
+ if not isinstance(negative_prompt, List):
507
+ negative_prompt = [negative_prompt] * num_prompts
508
+
509
+ image_embeds = torch.cat(img_emb_list,dim=0)
510
+
511
+ with torch.inference_mode():
512
+ (
513
+ prompt_embeds,
514
+ negative_prompt_embeds,
515
+ pooled_prompt_embeds,
516
+ negative_pooled_prompt_embeds,
517
+ ) = pipe.encode_prompt(
518
+ prompt,
519
+ num_images_per_prompt=1,
520
+ do_classifier_free_guidance=True,
521
+ negative_prompt=negative_prompt,
522
+ )
523
+
524
+
525
+ prompt = sample["caption_cloth"]
526
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
527
+
528
+ if not isinstance(prompt, List):
529
+ prompt = [prompt] * num_prompts
530
+ if not isinstance(negative_prompt, List):
531
+ negative_prompt = [negative_prompt] * num_prompts
532
+
533
+
534
+ with torch.inference_mode():
535
+ (
536
+ prompt_embeds_c,
537
+ _,
538
+ _,
539
+ _,
540
+ ) = pipe.encode_prompt(
541
+ prompt,
542
+ num_images_per_prompt=1,
543
+ do_classifier_free_guidance=False,
544
+ negative_prompt=negative_prompt,
545
+ )
546
+
547
+
548
+
549
+ generator = torch.Generator(pipe.device).manual_seed(args.seed) if args.seed is not None else None
550
+ images = pipe(
551
+ prompt_embeds=prompt_embeds,
552
+ negative_prompt_embeds=negative_prompt_embeds,
553
+ pooled_prompt_embeds=pooled_prompt_embeds,
554
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
555
+ num_inference_steps=args.num_inference_steps,
556
+ generator=generator,
557
+ strength = 1.0,
558
+ pose_img = sample['pose_img'],
559
+ text_embeds_cloth=prompt_embeds_c,
560
+ cloth = sample["cloth_pure"].to(accelerator.device),
561
+ mask_image=sample['inpaint_mask'],
562
+ image=(sample['image']+1.0)/2.0,
563
+ height=args.height,
564
+ width=args.width,
565
+ guidance_scale=args.guidance_scale,
566
+ ip_adapter_image = image_embeds,
567
+ )[0]
568
+
569
+
570
+ for i in range(len(images)):
571
+ x_sample = pil_to_tensor(images[i])
572
+ torchvision.utils.save_image(x_sample,os.path.join(args.output_dir,sample['im_name'][i]))
573
+
574
+
575
+
576
+
577
+ if __name__ == "__main__":
578
+ main()
vitonhd_test_tagged.json ADDED
The diff for this file is too large to render. See raw diff