|
from cvae import CVAE |
|
import torch |
|
from typing import Sequence |
|
import streamlit as st |
|
from lightning import LightningModule |
|
|
|
|
|
def format_instruments(text: str) -> str: |
|
stems = text.split(" ")[1:] |
|
stems = [stem.replace(" ", "").lower() for stem in stems] |
|
return "_".join(stems) |
|
|
|
|
|
def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor: |
|
choice = "_".join([format_instruments(i) for i in choice]) |
|
return torch.tensor(instruments.index(choice)) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(device: str) -> LightningModule: |
|
return CVAE.load_from_checkpoint( |
|
"epoch=77-step=2819778.ckpt", |
|
io_channels=1, |
|
io_features=16000 * 4, |
|
latent_features=5, |
|
channels=[32, 64, 128, 256, 512], |
|
num_classes=len(instruments), |
|
learning_rate=1e-5, |
|
).to(device) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
instruments = [ |
|
"bass_acoustic", |
|
"brass_acoustic", |
|
"flute_acoustic", |
|
"guitar_acoustic", |
|
"keyboard_acoustic", |
|
"mallet_acoustic", |
|
"organ_acoustic", |
|
"reed_acoustic", |
|
"string_acoustic", |
|
"synth_lead_acoustic", |
|
"vocal_acoustic", |
|
"bass_synthetic", |
|
"brass_synthetic", |
|
"flute_synthetic", |
|
"guitar_synthetic", |
|
"keyboard_synthetic", |
|
"mallet_synthetic", |
|
"organ_synthetic", |
|
"reed_synthetic", |
|
"string_synthetic", |
|
"synth_lead_synthetic", |
|
"vocal_synthetic", |
|
"bass_electronic", |
|
"brass_electronic", |
|
"flute_electronic", |
|
"guitar_electronic", |
|
"keyboard_electronic", |
|
"mallet_electronic", |
|
"organ_electronic", |
|
"reed_electronic", |
|
"string_electronic", |
|
"synth_lead_electronic", |
|
"vocal_electronic", |
|
] |
|
|
|
|
|
model = load_model(device) |
|
|
|
|
|
def generate(choice: Sequence[str], params: Sequence[int] = None): |
|
noise = ( |
|
torch.tensor(params).unsqueeze(0).to(device) |
|
if params |
|
else torch.randn(1, 5).to(device) |
|
) |
|
return ( |
|
model.sample(eps=noise, c=choice_to_tensor(choice).to(device)).cpu().numpy()[0] |
|
) |
|
|