Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Utility to export a training checkpoint to a lightweight release checkpoint. | |
""" | |
from pathlib import Path | |
import typing as tp | |
from omegaconf import OmegaConf | |
import torch | |
from audiocraft import __version__ | |
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): | |
"""Export only the best state from the given EnCodec checkpoint. This | |
should be used if you trained your own EnCodec model. | |
""" | |
pkg = torch.load(checkpoint_path, 'cpu') | |
new_pkg = { | |
'best_state': pkg['best_state']['model'], | |
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), | |
'version': __version__, | |
'exported': True, | |
} | |
Path(out_file).parent.mkdir(exist_ok=True, parents=True) | |
torch.save(new_pkg, out_file) | |
return out_file | |
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): | |
"""Export a compression model (potentially EnCodec) from a pretrained model. | |
This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. | |
Do not include the //pretrained/ prefix. For instance if you trained a model | |
with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. | |
In that case, this will not actually include a copy of the model, simply the reference | |
to the model used. | |
""" | |
if Path(pretrained_encodec).exists(): | |
pkg = torch.load(pretrained_encodec) | |
assert 'best_state' in pkg | |
assert 'xp.cfg' in pkg | |
assert 'version' in pkg | |
assert 'exported' in pkg | |
else: | |
pkg = { | |
'pretrained': pretrained_encodec, | |
'exported': True, | |
'version': __version__, | |
} | |
Path(out_file).parent.mkdir(exist_ok=True, parents=True) | |
torch.save(pkg, out_file) | |
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): | |
"""Export only the best state from the given MusicGen or AudioGen checkpoint. | |
""" | |
pkg = torch.load(checkpoint_path, 'cpu') | |
if pkg['fsdp_best_state']: | |
best_state = pkg['fsdp_best_state']['model'] | |
else: | |
assert pkg['best_state'] | |
best_state = pkg['best_state']['model'] | |
new_pkg = { | |
'best_state': best_state, | |
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), | |
'version': __version__, | |
'exported': True, | |
} | |
Path(out_file).parent.mkdir(exist_ok=True, parents=True) | |
torch.save(new_pkg, out_file) | |
return out_file | |