acanivet's picture
Formatting
e72f4c2
raw
history blame
2.02 kB
from cvae import CVAE
import torch
from typing import Sequence
import streamlit as st
from lightning import LightningModule
def format_instruments(text: str) -> str:
stems = text.split(" ")[1:]
stems = [stem.replace(" ", "").lower() for stem in stems]
return "_".join(stems)
def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
choice = "_".join([format_instruments(i) for i in choice])
return torch.tensor(instruments.index(choice))
@st.cache_resource
def load_model(device: str) -> LightningModule:
return CVAE.load_from_checkpoint(
"epoch=77-step=2819778.ckpt",
io_channels=1,
io_features=16000 * 4,
latent_features=5,
channels=[32, 64, 128, 256, 512],
num_classes=len(instruments),
learning_rate=1e-5,
).to(device)
device = "cuda" if torch.cuda.is_available() else "cpu"
instruments = [
"bass_acoustic",
"brass_acoustic",
"flute_acoustic",
"guitar_acoustic",
"keyboard_acoustic",
"mallet_acoustic",
"organ_acoustic",
"reed_acoustic",
"string_acoustic",
"synth_lead_acoustic",
"vocal_acoustic",
"bass_synthetic",
"brass_synthetic",
"flute_synthetic",
"guitar_synthetic",
"keyboard_synthetic",
"mallet_synthetic",
"organ_synthetic",
"reed_synthetic",
"string_synthetic",
"synth_lead_synthetic",
"vocal_synthetic",
"bass_electronic",
"brass_electronic",
"flute_electronic",
"guitar_electronic",
"keyboard_electronic",
"mallet_electronic",
"organ_electronic",
"reed_electronic",
"string_electronic",
"synth_lead_electronic",
"vocal_electronic",
]
model = load_model(device)
def generate(choice: Sequence[str], params: Sequence[int] = None):
noise = (
torch.tensor(params).unsqueeze(0).to(device)
if params
else torch.randn(1, 5).to(device)
)
return (
model.sample(eps=noise, c=choice_to_tensor(choice).to(device)).cpu().numpy()[0]
)