parokshsaxena commited on
Commit
b0054c9
β€’
1 Parent(s): 034254b

setting same dtype in enhanced garment net

Browse files
src/enhanced_garment_net.py CHANGED
@@ -100,6 +100,11 @@ class EnhancedGarmentNetWithTimestep(nn.Module):
100
  ])
101
 
102
  def forward(self, x, t, text_embeds):
 
 
 
 
 
103
  # Get garment features
104
  garment_out, garment_features = self.garment_net(x)
105
 
 
100
  ])
101
 
102
  def forward(self, x, t, text_embeds):
103
+ # Ensure all inputs are of the same dtype
104
+ x = x.to(dtype=self.garment_net.initial[0].weight.dtype)
105
+ t = t.to(dtype=self.garment_net.initial[0].weight.dtype)
106
+ text_embeds = text_embeds.to(dtype=self.garment_net.initial[0].weight.dtype)
107
+
108
  # Get garment features
109
  garment_out, garment_features = self.garment_net(x)
110
 
src/tryon_pipeline.py CHANGED
@@ -401,7 +401,10 @@ class StableDiffusionXLInpaintPipeline(
401
  force_zeros_for_empty_prompt: bool = True,
402
  ):
403
  super().__init__()
404
- self.garment_net = EnhancedGarmentNetWithTimestep()
 
 
 
405
 
406
  self.register_modules(
407
  vae=vae,
 
401
  force_zeros_for_empty_prompt: bool = True,
402
  ):
403
  super().__init__()
404
+ #self.garment_net = EnhancedGarmentNetWithTimestep()
405
+ self.garment_net = EnhancedGarmentNetWithTimestep().to(device=self._execution_device, dtype=self.unet.dtype)
406
+
407
+
408
 
409
  self.register_modules(
410
  vae=vae,