L0SG commited on
Commit
36de23b
1 Parent(s): c8638e0
Files changed (1) hide show
  1. bigvgan.py +19 -9
bigvgan.py CHANGED
@@ -257,14 +257,18 @@ class BigVGAN(
257
  return x
258
 
259
  def remove_weight_norm(self):
260
- print('Removing weight norm...')
261
- for l in self.ups:
262
- for l_i in l:
263
- remove_weight_norm(l_i)
264
- for l in self.resblocks:
265
- l.remove_weight_norm()
266
- remove_weight_norm(self.conv_pre)
267
- remove_weight_norm(self.conv_post)
 
 
 
 
268
 
269
  ##################################################################
270
  # additional methods for huggingface_hub support
@@ -351,6 +355,12 @@ class BigVGAN(
351
  )
352
 
353
  checkpoint_dict = torch.load(model_file, map_location=map_location)
354
- model.load_state_dict(checkpoint_dict['generator'])
 
 
 
 
 
 
355
 
356
  return model
 
257
  return x
258
 
259
  def remove_weight_norm(self):
260
+ try:
261
+ print('Removing weight norm...')
262
+ for l in self.ups:
263
+ for l_i in l:
264
+ remove_weight_norm(l_i)
265
+ for l in self.resblocks:
266
+ l.remove_weight_norm()
267
+ remove_weight_norm(self.conv_pre)
268
+ remove_weight_norm(self.conv_post)
269
+ except ValueError:
270
+ print('[INFO] Model already removed weight norm. Skipping!')
271
+ pass
272
 
273
  ##################################################################
274
  # additional methods for huggingface_hub support
 
355
  )
356
 
357
  checkpoint_dict = torch.load(model_file, map_location=map_location)
358
+
359
+ try:
360
+ model.load_state_dict(checkpoint_dict['generator'])
361
+ except RuntimeError:
362
+ print(f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!")
363
+ model.remove_weight_norm()
364
+ model.load_state_dict(checkpoint_dict['generator'])
365
 
366
  return model