Bai-YT commited on
Commit
ddaa2a5
1 Parent(s): 83cac0f

Update consistencytta.py

Browse files
Files changed (1) hide show
  1. consistencytta.py +10 -2
consistencytta.py CHANGED
@@ -75,9 +75,17 @@ class ConsistencyTTA(nn.Module):
75
  [self.text_encoder, self.vae, self.fn_STFT, self.unet],
76
  ['text_encoder', 'vae', 'fn_STFT', 'unet']
77
  ):
78
- assert model.training == False, f"The {name} is not in eval mode."
 
 
 
 
79
  for param in model.parameters():
80
- assert param.requires_grad == False, f"The {name} is not frozen."
 
 
 
 
81
 
82
 
83
  @torch.no_grad()
 
75
  [self.text_encoder, self.vae, self.fn_STFT, self.unet],
76
  ['text_encoder', 'vae', 'fn_STFT', 'unet']
77
  ):
78
+ try:
79
+ assert model.training == False, f"The {name} is not in eval mode."
80
+ except:
81
+ model.eval()
82
+ assert model.training == False, f"The {name} is not in eval mode."
83
  for param in model.parameters():
84
+ try:
85
+ assert param.requires_grad == False, f"The {name} is not frozen."
86
+ except:
87
+ param.requires_grad_(False)
88
+ assert param.requires_grad == False, f"The {name} is not frozen."
89
 
90
 
91
  @torch.no_grad()