Spaces:
Runtime error
Runtime error
remove deprecated averaged_model
Browse files- audiodiffusion/__init__.py +1 -1
- scripts/train_unet.py +4 -2
audiodiffusion/__init__.py
CHANGED
@@ -9,7 +9,7 @@ from tqdm.auto import tqdm
|
|
9 |
# from diffusers import AudioDiffusionPipeline
|
10 |
from .pipeline_audio_diffusion import AudioDiffusionPipeline
|
11 |
|
12 |
-
VERSION = "1.5.
|
13 |
|
14 |
|
15 |
class AudioDiffusion:
|
|
|
9 |
# from diffusers import AudioDiffusionPipeline
|
10 |
from .pipeline_audio_diffusion import AudioDiffusionPipeline
|
11 |
|
12 |
+
VERSION = "1.5.1"
|
13 |
|
14 |
|
15 |
class AudioDiffusion:
|
scripts/train_unet.py
CHANGED
@@ -304,10 +304,12 @@ def main(args):
|
|
304 |
if ((epoch + 1) % args.save_model_epochs == 0
|
305 |
or (epoch + 1) % args.save_images_epochs == 0
|
306 |
or epoch == args.num_epochs - 1):
|
|
|
|
|
|
|
307 |
pipeline = AudioDiffusionPipeline(
|
308 |
vqvae=vqvae,
|
309 |
-
unet=
|
310 |
-
ema_model.averaged_model if args.use_ema else model),
|
311 |
mel=mel,
|
312 |
scheduler=noise_scheduler,
|
313 |
)
|
|
|
304 |
if ((epoch + 1) % args.save_model_epochs == 0
|
305 |
or (epoch + 1) % args.save_images_epochs == 0
|
306 |
or epoch == args.num_epochs - 1):
|
307 |
+
unet = accelerator.unwrap_model(model)
|
308 |
+
if args.use_ema:
|
309 |
+
ema_model.copy_to(unet.parameters())
|
310 |
pipeline = AudioDiffusionPipeline(
|
311 |
vqvae=vqvae,
|
312 |
+
unet=unet,
|
|
|
313 |
mel=mel,
|
314 |
scheduler=noise_scheduler,
|
315 |
)
|