Spaces:
Running
on
Zero
Running
on
Zero
parokshsaxena
commited on
Commit
β’
b0054c9
1
Parent(s):
034254b
setting same dtype in enhanced garment net
Browse files- src/enhanced_garment_net.py +5 -0
- src/tryon_pipeline.py +4 -1
src/enhanced_garment_net.py
CHANGED
@@ -100,6 +100,11 @@ class EnhancedGarmentNetWithTimestep(nn.Module):
|
|
100 |
])
|
101 |
|
102 |
def forward(self, x, t, text_embeds):
|
|
|
|
|
|
|
|
|
|
|
103 |
# Get garment features
|
104 |
garment_out, garment_features = self.garment_net(x)
|
105 |
|
|
|
100 |
])
|
101 |
|
102 |
def forward(self, x, t, text_embeds):
|
103 |
+
# Ensure all inputs are of the same dtype
|
104 |
+
x = x.to(dtype=self.garment_net.initial[0].weight.dtype)
|
105 |
+
t = t.to(dtype=self.garment_net.initial[0].weight.dtype)
|
106 |
+
text_embeds = text_embeds.to(dtype=self.garment_net.initial[0].weight.dtype)
|
107 |
+
|
108 |
# Get garment features
|
109 |
garment_out, garment_features = self.garment_net(x)
|
110 |
|
src/tryon_pipeline.py
CHANGED
@@ -401,7 +401,10 @@ class StableDiffusionXLInpaintPipeline(
|
|
401 |
force_zeros_for_empty_prompt: bool = True,
|
402 |
):
|
403 |
super().__init__()
|
404 |
-
self.garment_net = EnhancedGarmentNetWithTimestep()
|
|
|
|
|
|
|
405 |
|
406 |
self.register_modules(
|
407 |
vae=vae,
|
|
|
401 |
force_zeros_for_empty_prompt: bool = True,
|
402 |
):
|
403 |
super().__init__()
|
404 |
+
#self.garment_net = EnhancedGarmentNetWithTimestep()
|
405 |
+
self.garment_net = EnhancedGarmentNetWithTimestep().to(device=self._execution_device, dtype=self.unet.dtype)
|
406 |
+
|
407 |
+
|
408 |
|
409 |
self.register_modules(
|
410 |
vae=vae,
|