Add details on fine-tuning

#11
by SkyyySi - opened

For this model to really be useful, it should be possible to locally fine-tune it, e.g. to produce different genres of music that were underrepresented in the dataset.

Seems like there are instructions already: https://github.com/Stability-AI/stable-audio-tools?tab=readme-ov-file#fine-tuning

That said, this only for full fine-tunes. Being able to train and share, like, LoRAs or something like that would be dope.

Hello @SkyyySi (and everyone else)

When I try to finetune the model, it throws a ValueError

ValueError: Conditioner key prompt not found in batch metadata

I use --pretrained-ckpt-path using an unwrapped model

EDIT: My fault, I had an invalid config file for my dataset. Sorry

@ImusingX what did you use as your model-config along with the checkpoint?

@ImusingX Please share what you did, as I am stuck on that as well.

@ImusingX Please share what you did, as I am stuck on that as well.

My dataset config had the custom_metadata_module in the root of the json file but it had to be within the dataset key

As in, I had:

    "dataset_type": "audio_dir",
    "datasets": [
        {
           ...
        }
    ],
    "custom_metadata_module": "/mnt/e/tts/ok/stable-audio-tools-main/mydata/metadata.py"  <--- here
    "random_crop": true
}

but it had to be

    "dataset_type": "audio_dir",
    "datasets": [
        {
           ...
            "custom_metadata_module": "/mnt/e/tts/ok/stable-audio-tools-main/mydata/metadata.py" <--- should be here
        }
    ],
    "random_crop": true
}

@ImusingX Thank you, was able to get it running by adding a custom metadata module (I tried to just place .json files next to each audio file, but that didn't seem to work).

However, as it turns out, 24 GB of RAM seems to not be enough for full fine-tuning. I tried it on Arch Linux with the latest Pytorch for ROCm nightly build on an RX 7900 XTX, but I always get an out-of-memory error - even when stopping any other graphics process.

Unless someone develops a LoRA / embedding trainer for this, training on consumer hardware seems to be out of the question for now.

Stability AI org

No official LoRA training in stable-audio-tools yet, but there is this repo that was made by a community member to train LoRAs for stable-audio-tools models. I haven't tried it myself but I've heard others getting it to work: https://github.com/NeuralNotW0rk/LoRAW

When finetuning I always just get noise in the demo audio files. I used these configs as a base with pre-trained laion_clap checkpoint I found. But no joy.

When finetuning I always just get noise in the demo audio files. I used these configs as a base with pre-trained laion_clap checkpoint I found. But no joy.

Use the LoRAW ( https://github.com/NeuralNotW0rk/LoRAW ), that works very well for me

This comment has been hidden
This comment has been hidden
This comment has been hidden

@benbowler did you have any success? same here, the training process happens but just silent demos

I have made a LoRA using LoRAW, and I merged it into the checkpoint so it can be used without LoRAW. So it technically is possible if you merge a LoRA (I don't think LoRAW has a built in merging feature, so I made that myself)

@benbowler @cvillela How did you load their checkpoint? Upon some investigation I found that their checkpoint is not being properly loaded by the given instruction since it's wrapped slightly differently. If you try to load the given model.safetensor into --pretrained_ckpt_path, it will throw no error but actually none of the given weight is being loaded! If you inspect the training/utils.py/copy_state_dict function as follows everything is skipped since the name is different.

def copy_state_dict(model, state_dict):
    """Load state_dict to model, but only for keys that match exactly.

    Args:
        model (nn.Module): model to load state_dict.
        state_dict (OrderedDict): state_dict to load.
    """
    model_state_dict = model.state_dict()
    for key in state_dict:
        if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
            if isinstance(state_dict[key], torch.nn.Parameter):
                # backwards compatibility for serialized parameters
                state_dict[key] = state_dict[key].data
            model_state_dict[key] = state_dict[key]
            print(f"Loaded {key} from checkpoint")
        else:
            print(f"Skipped {key} from checkpoint")
        
    model.load_state_dict(model_state_dict, strict=False)

Solution is simply to replace the copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path)) by model.load_state_dict(load_ckpt_state_dict(args.pretrained_ckpt_path)), which is the loading in inference code that properly generates. I don't know why they do the copying in the first place. For the config you'll also need to use the given config model_config.json instead of txt2audio/stable_audio_2_0.json. In this case I am able to get meaningful output for the first demo checking which means that it's properly loaded, and the finetuning is able to run. Hope this is helpful

Sign up or log in to comment