Add details on fine-tuning
For this model to really be useful, it should be possible to locally fine-tune it, e.g. to produce different genres of music that were underrepresented in the dataset.
Hii
Seems like there are instructions already: https://github.com/Stability-AI/stable-audio-tools?tab=readme-ov-file#fine-tuning
That said, this only for full fine-tunes. Being able to train and share, like, LoRAs or something like that would be dope.
Hello @SkyyySi (and everyone else)
When I try to finetune the model, it throws a ValueError
ValueError: Conditioner key prompt not found in batch metadata
I use --pretrained-ckpt-path using an unwrapped model
EDIT: My fault, I had an invalid config file for my dataset. Sorry
@ImusingX Please share what you did, as I am stuck on that as well.
My dataset config had the custom_metadata_module
in the root of the json file but it had to be within the dataset
key
As in, I had:
"dataset_type": "audio_dir",
"datasets": [
{
...
}
],
"custom_metadata_module": "/mnt/e/tts/ok/stable-audio-tools-main/mydata/metadata.py" <--- here
"random_crop": true
}
but it had to be
"dataset_type": "audio_dir",
"datasets": [
{
...
"custom_metadata_module": "/mnt/e/tts/ok/stable-audio-tools-main/mydata/metadata.py" <--- should be here
}
],
"random_crop": true
}
@ImusingX Thank you, was able to get it running by adding a custom metadata module (I tried to just place .json files next to each audio file, but that didn't seem to work).
However, as it turns out, 24 GB of RAM seems to not be enough for full fine-tuning. I tried it on Arch Linux with the latest Pytorch for ROCm nightly build on an RX 7900 XTX, but I always get an out-of-memory error - even when stopping any other graphics process.
Unless someone develops a LoRA / embedding trainer for this, training on consumer hardware seems to be out of the question for now.
No official LoRA training in stable-audio-tools yet, but there is this repo that was made by a community member to train LoRAs for stable-audio-tools models. I haven't tried it myself but I've heard others getting it to work: https://github.com/NeuralNotW0rk/LoRAW
When finetuning I always just get noise in the demo audio files. I used these configs as a base with pre-trained laion_clap checkpoint I found. But no joy.
When finetuning I always just get noise in the demo audio files. I used these configs as a base with pre-trained laion_clap checkpoint I found. But no joy.
Use the LoRAW ( https://github.com/NeuralNotW0rk/LoRAW ), that works very well for me
@benbowler did you have any success? same here, the training process happens but just silent demos
I have made a LoRA using LoRAW, and I merged it into the checkpoint so it can be used without LoRAW. So it technically is possible if you merge a LoRA (I don't think LoRAW has a built in merging feature, so I made that myself)
@benbowler
@cvillela
How did you load their checkpoint? Upon some investigation I found that their checkpoint is not being properly loaded by the given instruction since it's wrapped slightly differently. If you try to load the given model.safetensor
into --pretrained_ckpt_path
, it will throw no error but actually none of the given weight is being loaded! If you inspect the training/utils.py/copy_state_dict
function as follows everything is skipped since the name is different.
def copy_state_dict(model, state_dict):
"""Load state_dict to model, but only for keys that match exactly.
Args:
model (nn.Module): model to load state_dict.
state_dict (OrderedDict): state_dict to load.
"""
model_state_dict = model.state_dict()
for key in state_dict:
if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
if isinstance(state_dict[key], torch.nn.Parameter):
# backwards compatibility for serialized parameters
state_dict[key] = state_dict[key].data
model_state_dict[key] = state_dict[key]
print(f"Loaded {key} from checkpoint")
else:
print(f"Skipped {key} from checkpoint")
model.load_state_dict(model_state_dict, strict=False)
Solution is simply to replace the copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path))
by model.load_state_dict(load_ckpt_state_dict(args.pretrained_ckpt_path))
, which is the loading in inference code that properly generates. I don't know why they do the copying in the first place. For the config you'll also need to use the given config model_config.json
instead of txt2audio/stable_audio_2_0.json
. In this case I am able to get meaningful output for the first demo checking which means that it's properly loaded, and the finetuning is able to run. Hope this is helpful