parokshsaxena commited on
Commit
2e6f7e8
β€’
1 Parent(s): 4aff89b

reverting enhanced garment net

Browse files
Files changed (1) hide show
  1. src/tryon_pipeline.py +6 -5
src/tryon_pipeline.py CHANGED
@@ -400,7 +400,7 @@ class StableDiffusionXLInpaintPipeline(
400
  force_zeros_for_empty_prompt: bool = True,
401
  ):
402
  super().__init__()
403
- self.garment_net = EnhancedGarmentNetWithTimestep()
404
 
405
  self.register_modules(
406
  vae=vae,
@@ -1783,11 +1783,12 @@ class StableDiffusionXLInpaintPipeline(
1783
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1784
  if ip_adapter_image is not None:
1785
  added_cond_kwargs["image_embeds"] = image_embeds
 
1786
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1787
- # down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1788
- garment_out, reference_features = self.garment_net(cloth, t, text_embeds_cloth)
1789
- # print(type(reference_features))
1790
- # print(reference_features)
1791
  reference_features = list(reference_features)
1792
  # print(len(reference_features))
1793
  # for elem in reference_features:
 
400
  force_zeros_for_empty_prompt: bool = True,
401
  ):
402
  super().__init__()
403
+ # self.garment_net = EnhancedGarmentNetWithTimestep()
404
 
405
  self.register_modules(
406
  vae=vae,
 
1783
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1784
  if ip_adapter_image is not None:
1785
  added_cond_kwargs["image_embeds"] = image_embeds
1786
+ print("Calling unet encoder for garment feature extraction")
1787
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1788
+ down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1789
+ #garment_out, reference_features = self.garment_net(cloth, t, text_embeds_cloth)
1790
+ print(type(reference_features))
1791
+ print(reference_features)
1792
  reference_features = list(reference_features)
1793
  # print(len(reference_features))
1794
  # for elem in reference_features: