Formatting
Browse files- README.md +1 -1
- app.py +40 -17
- cvae/__init__.py +1 -7
- cvae/blocks.py +7 -7
- cvae/models.py +84 -53
- model.py +64 -17
README.md
CHANGED
@@ -13,7 +13,7 @@ pinned: false
|
|
13 |
|
14 |
![Screenshot of the app](app.png)
|
15 |
|
16 |
-
This is a demo for the sound generation models built in `pytorch`. It relies on a simple `streamlit` app calling the model with the parameters given by the user.
|
17 |
|
18 |
## Install :
|
19 |
|
|
|
13 |
|
14 |
![Screenshot of the app](app.png)
|
15 |
|
16 |
+
This is a demo for the sound generation models built in `pytorch`. It relies on a simple `streamlit` app calling the model with the parameters given by the user. Due to time and material difficulties, the model isn't properly trained and isn't able to produce interesting sounds now.
|
17 |
|
18 |
## Install :
|
19 |
|
app.py
CHANGED
@@ -4,42 +4,66 @@ import io
|
|
4 |
import numpy as np
|
5 |
from scipy.io.wavfile import write
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
@st.cache_data
|
8 |
-
def np_to_wav(waveform, sample_rate) -> bytes:
|
9 |
bytes_wav = bytes()
|
10 |
byte_io = io.BytesIO(bytes_wav)
|
11 |
write(byte_io, sample_rate, waveform.T)
|
12 |
return byte_io.read()
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
if "result" not in st.session_state:
|
15 |
st.session_state["result"] = None
|
16 |
|
|
|
|
|
|
|
|
|
17 |
st.title("Sound Exploration")
|
18 |
|
19 |
col1, col2 = st.columns(2)
|
20 |
-
|
21 |
with col1:
|
22 |
instrument = st.selectbox(
|
23 |
-
|
24 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
)
|
26 |
-
|
27 |
with col2:
|
28 |
instrument_t = st.selectbox(
|
29 |
-
|
30 |
-
(
|
31 |
-
)
|
32 |
-
|
33 |
with st.expander("Magical parameters 🪄"):
|
34 |
col1, col2 = st.columns(2)
|
35 |
with col1:
|
36 |
-
p1 = st.slider(
|
37 |
-
p2 = st.slider(
|
38 |
-
p3 = st.slider(
|
39 |
with col2:
|
40 |
-
p4 = st.slider(
|
41 |
-
p5 = st.slider(
|
42 |
-
use_params = st.toggle(
|
43 |
params = (p1, p2, p3, p4, p5) if use_params else None
|
44 |
|
45 |
if st.button("Generate ✨", type="primary"):
|
@@ -53,6 +77,5 @@ if st.session_state["result"] is not None:
|
|
53 |
st.download_button(
|
54 |
label="Download ⬇️",
|
55 |
data=np_to_wav(st.session_state["result"], 16000),
|
56 |
-
file_name=
|
57 |
)
|
58 |
-
|
|
|
4 |
import numpy as np
|
5 |
from scipy.io.wavfile import write
|
6 |
|
7 |
+
# -----
|
8 |
+
# Utils
|
9 |
+
# -----
|
10 |
+
|
11 |
+
|
12 |
@st.cache_data
|
13 |
+
def np_to_wav(waveform: np.Array, sample_rate: int) -> bytes:
|
14 |
bytes_wav = bytes()
|
15 |
byte_io = io.BytesIO(bytes_wav)
|
16 |
write(byte_io, sample_rate, waveform.T)
|
17 |
return byte_io.read()
|
18 |
|
19 |
+
|
20 |
+
# ------------------
|
21 |
+
# App initialization
|
22 |
+
# ------------------
|
23 |
+
|
24 |
if "result" not in st.session_state:
|
25 |
st.session_state["result"] = None
|
26 |
|
27 |
+
# ---
|
28 |
+
# App
|
29 |
+
# ---
|
30 |
+
|
31 |
st.title("Sound Exploration")
|
32 |
|
33 |
col1, col2 = st.columns(2)
|
|
|
34 |
with col1:
|
35 |
instrument = st.selectbox(
|
36 |
+
"Which intrument do you want ?",
|
37 |
+
(
|
38 |
+
"🎸 Bass",
|
39 |
+
"🎺 Brass",
|
40 |
+
"🪈 Flute",
|
41 |
+
"🪕 Guitar",
|
42 |
+
"🎹 Keyboard",
|
43 |
+
"🔨 Mallet",
|
44 |
+
"🪗 Organ",
|
45 |
+
"🎷 Reed",
|
46 |
+
"🎻 String",
|
47 |
+
"⚡ Synth lead",
|
48 |
+
"🎤 Vocal",
|
49 |
+
),
|
50 |
)
|
|
|
51 |
with col2:
|
52 |
instrument_t = st.selectbox(
|
53 |
+
"Which type intrument do you want ?",
|
54 |
+
("📯 Acoustic", "🎙️ Electronic", "🎛️ Synthetic"),
|
55 |
+
)
|
56 |
+
|
57 |
with st.expander("Magical parameters 🪄"):
|
58 |
col1, col2 = st.columns(2)
|
59 |
with col1:
|
60 |
+
p1 = st.slider("p1", 0.0, 1.0, step=0.001, label_visibility="collapsed")
|
61 |
+
p2 = st.slider("p2", 0.0, 1.0, step=0.001, label_visibility="collapsed")
|
62 |
+
p3 = st.slider("p3", 0.0, 1.0, step=0.001, label_visibility="collapsed")
|
63 |
with col2:
|
64 |
+
p4 = st.slider("p4", 0.0, 1.0, step=0.001, label_visibility="collapsed")
|
65 |
+
p5 = st.slider("p5", 0.0, 1.0, step=0.001, label_visibility="collapsed")
|
66 |
+
use_params = st.toggle("Use magical parameters ?")
|
67 |
params = (p1, p2, p3, p4, p5) if use_params else None
|
68 |
|
69 |
if st.button("Generate ✨", type="primary"):
|
|
|
77 |
st.download_button(
|
78 |
label="Download ⬇️",
|
79 |
data=np_to_wav(st.session_state["result"], 16000),
|
80 |
+
file_name="result.wav",
|
81 |
)
|
|
cvae/__init__.py
CHANGED
@@ -1,7 +1 @@
|
|
1 |
-
from .models import
|
2 |
-
Encoder, Decoder, VAE, CVAE
|
3 |
-
)
|
4 |
-
|
5 |
-
from .blocks import (
|
6 |
-
UpResConvBlock, DownResConvBlock
|
7 |
-
)
|
|
|
1 |
+
from .models import VAE, CVAE
|
|
|
|
|
|
|
|
|
|
|
|
cvae/blocks.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from torch import nn
|
2 |
|
3 |
class UpResConvBlock(nn.Module):
|
4 |
-
def __init__(self, in_channels, out_channels, kernel_size):
|
5 |
super(UpResConvBlock, self).__init__()
|
6 |
|
7 |
self.residual = nn.Sequential(
|
@@ -19,11 +19,11 @@ class UpResConvBlock(nn.Module):
|
|
19 |
nn.GELU()
|
20 |
)
|
21 |
|
22 |
-
def forward(self, x):
|
23 |
return self.main(x) + self.residual(x)
|
24 |
|
25 |
class DownResConvBlock(nn.Module):
|
26 |
-
def __init__(self, in_channels, out_channels, kernel_size):
|
27 |
super(DownResConvBlock, self).__init__()
|
28 |
|
29 |
self.residual = nn.Conv1d(in_channels, out_channels, 1, 2, bias=False)
|
@@ -37,11 +37,11 @@ class DownResConvBlock(nn.Module):
|
|
37 |
nn.GELU()
|
38 |
)
|
39 |
|
40 |
-
def forward(self, x):
|
41 |
return self.main(x) + self.residual(x)
|
42 |
|
43 |
class ResConvBlock(nn.Module):
|
44 |
-
def __init__(self, in_channels, out_channels, kernel_size):
|
45 |
super(ResConvBlock, self).__init__()
|
46 |
|
47 |
self.residual = nn.Identity() if in_channels == out_channels else nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
@@ -55,5 +55,5 @@ class ResConvBlock(nn.Module):
|
|
55 |
nn.GELU()
|
56 |
)
|
57 |
|
58 |
-
def forward(self, x):
|
59 |
return self.main(x) + self.residual(x)
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
|
3 |
class UpResConvBlock(nn.Module):
|
4 |
+
def __init__(self, in_channels, out_channels, kernel_size) -> None:
|
5 |
super(UpResConvBlock, self).__init__()
|
6 |
|
7 |
self.residual = nn.Sequential(
|
|
|
19 |
nn.GELU()
|
20 |
)
|
21 |
|
22 |
+
def forward(self, x: Tensor) -> Tensor:
|
23 |
return self.main(x) + self.residual(x)
|
24 |
|
25 |
class DownResConvBlock(nn.Module):
|
26 |
+
def __init__(self, in_channels, out_channels, kernel_size) -> None:
|
27 |
super(DownResConvBlock, self).__init__()
|
28 |
|
29 |
self.residual = nn.Conv1d(in_channels, out_channels, 1, 2, bias=False)
|
|
|
37 |
nn.GELU()
|
38 |
)
|
39 |
|
40 |
+
def forward(self, x: Tensor) -> Tensor:
|
41 |
return self.main(x) + self.residual(x)
|
42 |
|
43 |
class ResConvBlock(nn.Module):
|
44 |
+
def __init__(self, in_channels, out_channels, kernel_size) -> None:
|
45 |
super(ResConvBlock, self).__init__()
|
46 |
|
47 |
self.residual = nn.Identity() if in_channels == out_channels else nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
|
|
55 |
nn.GELU()
|
56 |
)
|
57 |
|
58 |
+
def forward(self, x: Tensor) -> Tensor:
|
59 |
return self.main(x) + self.residual(x)
|
cvae/models.py
CHANGED
@@ -4,164 +4,195 @@ from torch.optim import Optimizer
|
|
4 |
from .blocks import UpResConvBlock, DownResConvBlock
|
5 |
import lightning as L
|
6 |
from auraloss.freq import MultiResolutionSTFTLoss
|
|
|
|
|
7 |
|
8 |
class Encoder(nn.Module):
|
9 |
-
def __init__(
|
|
|
10 |
in_channels: int,
|
11 |
in_features: int,
|
12 |
out_features: int,
|
13 |
-
channels:
|
14 |
-
|
15 |
super(Encoder, self).__init__()
|
16 |
|
17 |
-
assert
|
|
|
|
|
18 |
|
19 |
-
modules = [
|
20 |
-
nn.Conv1d(in_channels, channels[0], 1),
|
21 |
-
nn.GELU()
|
22 |
-
]
|
23 |
|
24 |
-
for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]):
|
25 |
modules += [
|
26 |
DownResConvBlock(in_channel, out_channel, 1),
|
27 |
]
|
28 |
|
29 |
-
n_features = int(in_features
|
30 |
|
31 |
modules += [
|
32 |
nn.Flatten(),
|
33 |
-
nn.Linear(n_features*channels[-1], 2*out_features)
|
34 |
]
|
35 |
|
36 |
self.net = nn.Sequential(*modules)
|
37 |
|
38 |
-
def forward(self, x):
|
39 |
mean, logvar = self.net(x).chunk(2, dim=1)
|
40 |
return mean, logvar
|
41 |
-
|
|
|
42 |
class Decoder(nn.Module):
|
43 |
-
def __init__(
|
|
|
44 |
out_channels: int,
|
45 |
in_features: int,
|
46 |
out_features: int,
|
47 |
-
channels:
|
48 |
-
|
49 |
super(Decoder, self).__init__()
|
50 |
|
51 |
-
n_features = int(out_features/2**len(channels))
|
52 |
|
53 |
modules = [
|
54 |
-
nn.Linear(in_features, n_features*channels[0]),
|
55 |
-
nn.Unflatten(-1, (channels[0], n_features))
|
56 |
]
|
57 |
|
58 |
-
for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]):
|
59 |
modules += [
|
60 |
UpResConvBlock(in_channel, out_channel, 1),
|
61 |
]
|
62 |
|
63 |
-
modules += [
|
64 |
-
nn.Conv1d(channels[-1], out_channels, 1),
|
65 |
-
nn.GELU()
|
66 |
-
]
|
67 |
|
68 |
self.net = nn.Sequential(*modules)
|
69 |
|
70 |
-
def forward(self, x):
|
71 |
x = torch.tanh(self.net(x))
|
72 |
return x
|
73 |
|
74 |
-
|
75 |
class VAE(L.LightningModule):
|
76 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
super().__init__()
|
|
|
78 |
self.encoder = Encoder(io_channels, io_features, latent_features, channels)
|
|
|
79 |
channels.reverse()
|
80 |
self.decoder = Decoder(io_channels, latent_features, io_features, channels)
|
|
|
81 |
self.latent_features = latent_features
|
82 |
self.audio_loss_func = MultiResolutionSTFTLoss()
|
83 |
self.learning_rate = learning_rate
|
84 |
|
85 |
@torch.no_grad()
|
86 |
-
def sample(self, eps=None):
|
87 |
if eps is None:
|
88 |
eps = torch.rand((1, self.latent_features))
|
89 |
return self.decoder(eps)
|
90 |
-
|
91 |
-
def loss_function(
|
|
|
|
|
92 |
audio_loss = self.audio_loss_func(x, x_hat)
|
93 |
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
|
94 |
return audio_loss + kld_loss
|
95 |
-
|
96 |
-
def reparameterize(self, mean, logvar):
|
97 |
-
std= torch.exp(0.5 * logvar)
|
98 |
eps = torch.randn_like(std)
|
99 |
return eps * std + mean
|
100 |
-
|
101 |
-
def forward(self, x):
|
102 |
mean, logvar = self.encoder(x)
|
103 |
z = self.reparameterize(mean, logvar)
|
104 |
return self.decoder(z), mean, logvar
|
105 |
-
|
106 |
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
|
107 |
x_hat, mean, logvar = self.forward(batch)
|
108 |
loss = self.loss_function(batch, x_hat, mean, logvar)
|
109 |
-
if log:
|
|
|
110 |
return loss
|
111 |
|
112 |
def configure_optimizers(self) -> Optimizer:
|
113 |
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
114 |
return optimizer
|
115 |
-
|
116 |
|
117 |
class CVAE(L.LightningModule):
|
118 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
super().__init__()
|
|
|
120 |
self.class_embedder = nn.Linear(num_classes, io_features)
|
121 |
self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1)
|
122 |
-
|
|
|
|
|
123 |
channels.reverse()
|
124 |
-
self.decoder = Decoder(
|
|
|
|
|
|
|
125 |
self.num_classes = num_classes
|
126 |
self.latent_features = latent_features
|
127 |
self.audio_loss_func = MultiResolutionSTFTLoss()
|
128 |
self.learning_rate = learning_rate
|
129 |
|
130 |
@torch.no_grad()
|
131 |
-
def sample(self, c, eps=None):
|
132 |
c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0)
|
133 |
if eps is None:
|
134 |
eps = torch.rand((1, self.latent_features))
|
135 |
z = torch.cat([eps, c], dim=1)
|
136 |
return self.decoder(z)
|
137 |
-
|
138 |
-
def loss_function(
|
|
|
|
|
139 |
audio_loss = self.audio_loss_func(x, x_hat)
|
140 |
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
|
141 |
return audio_loss + kld_loss
|
142 |
-
|
143 |
-
def reparameterize(self, mean, logvar):
|
144 |
-
std= torch.exp(0.5 * logvar)
|
145 |
eps = torch.randn_like(std)
|
146 |
return eps * std + mean
|
147 |
|
148 |
-
def forward(self, x, c):
|
149 |
c = nn.functional.one_hot(c, num_classes=self.num_classes).float()
|
150 |
c_embedding = self.class_embedder(c).unsqueeze(1)
|
151 |
x_embedding = self.data_embedder(x)
|
152 |
-
x = torch.cat([x_embedding, c_embedding], dim
|
153 |
mean, logvar = self.encoder(x)
|
154 |
z = self.reparameterize(mean, logvar)
|
155 |
-
z = torch.cat([z, c], dim
|
156 |
return self.decoder(z), mean, logvar
|
157 |
-
|
158 |
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
|
159 |
x, c = batch
|
160 |
x_hat, mean, logvar = self.forward(x, c)
|
161 |
loss = self.loss_function(x, x_hat, mean, logvar)
|
162 |
-
if log:
|
|
|
163 |
return loss
|
164 |
|
165 |
def configure_optimizers(self) -> Optimizer:
|
166 |
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
167 |
-
return optimizer
|
|
|
4 |
from .blocks import UpResConvBlock, DownResConvBlock
|
5 |
import lightning as L
|
6 |
from auraloss.freq import MultiResolutionSTFTLoss
|
7 |
+
from typing import Sequence
|
8 |
+
|
9 |
|
10 |
class Encoder(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
in_channels: int,
|
14 |
in_features: int,
|
15 |
out_features: int,
|
16 |
+
channels: Sequence[int],
|
17 |
+
) -> None:
|
18 |
super(Encoder, self).__init__()
|
19 |
|
20 |
+
assert (
|
21 |
+
in_features % 2 ** len(channels) == 0
|
22 |
+
), f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})"
|
23 |
|
24 |
+
modules = [nn.Conv1d(in_channels, channels[0], 1), nn.GELU()]
|
|
|
|
|
|
|
25 |
|
26 |
+
for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]):
|
27 |
modules += [
|
28 |
DownResConvBlock(in_channel, out_channel, 1),
|
29 |
]
|
30 |
|
31 |
+
n_features = int(in_features * 0.5 ** len(channels))
|
32 |
|
33 |
modules += [
|
34 |
nn.Flatten(),
|
35 |
+
nn.Linear(n_features * channels[-1], 2 * out_features),
|
36 |
]
|
37 |
|
38 |
self.net = nn.Sequential(*modules)
|
39 |
|
40 |
+
def forward(self, x: Tensor) -> Tensor:
|
41 |
mean, logvar = self.net(x).chunk(2, dim=1)
|
42 |
return mean, logvar
|
43 |
+
|
44 |
+
|
45 |
class Decoder(nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
out_channels: int,
|
49 |
in_features: int,
|
50 |
out_features: int,
|
51 |
+
channels: Sequence[int],
|
52 |
+
) -> None:
|
53 |
super(Decoder, self).__init__()
|
54 |
|
55 |
+
n_features = int(out_features / 2 ** len(channels))
|
56 |
|
57 |
modules = [
|
58 |
+
nn.Linear(in_features, n_features * channels[0]),
|
59 |
+
nn.Unflatten(-1, (channels[0], n_features)),
|
60 |
]
|
61 |
|
62 |
+
for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]):
|
63 |
modules += [
|
64 |
UpResConvBlock(in_channel, out_channel, 1),
|
65 |
]
|
66 |
|
67 |
+
modules += [nn.Conv1d(channels[-1], out_channels, 1), nn.GELU()]
|
|
|
|
|
|
|
68 |
|
69 |
self.net = nn.Sequential(*modules)
|
70 |
|
71 |
+
def forward(self, x: Tensor) -> Tensor:
|
72 |
x = torch.tanh(self.net(x))
|
73 |
return x
|
74 |
|
75 |
+
|
76 |
class VAE(L.LightningModule):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
io_channels: int,
|
80 |
+
io_features: int,
|
81 |
+
latent_features: int,
|
82 |
+
channels: Sequence[int],
|
83 |
+
learning_rate: float,
|
84 |
+
) -> None:
|
85 |
super().__init__()
|
86 |
+
|
87 |
self.encoder = Encoder(io_channels, io_features, latent_features, channels)
|
88 |
+
|
89 |
channels.reverse()
|
90 |
self.decoder = Decoder(io_channels, latent_features, io_features, channels)
|
91 |
+
|
92 |
self.latent_features = latent_features
|
93 |
self.audio_loss_func = MultiResolutionSTFTLoss()
|
94 |
self.learning_rate = learning_rate
|
95 |
|
96 |
@torch.no_grad()
|
97 |
+
def sample(self, eps: Tensor = None) -> Tensor:
|
98 |
if eps is None:
|
99 |
eps = torch.rand((1, self.latent_features))
|
100 |
return self.decoder(eps)
|
101 |
+
|
102 |
+
def loss_function(
|
103 |
+
self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor
|
104 |
+
) -> Tensor:
|
105 |
audio_loss = self.audio_loss_func(x, x_hat)
|
106 |
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
|
107 |
return audio_loss + kld_loss
|
108 |
+
|
109 |
+
def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor:
|
110 |
+
std = torch.exp(0.5 * logvar)
|
111 |
eps = torch.randn_like(std)
|
112 |
return eps * std + mean
|
113 |
+
|
114 |
+
def forward(self, x: Tensor) -> tuple[Tensor]:
|
115 |
mean, logvar = self.encoder(x)
|
116 |
z = self.reparameterize(mean, logvar)
|
117 |
return self.decoder(z), mean, logvar
|
118 |
+
|
119 |
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
|
120 |
x_hat, mean, logvar = self.forward(batch)
|
121 |
loss = self.loss_function(batch, x_hat, mean, logvar)
|
122 |
+
if log:
|
123 |
+
self.log("train_loss", loss, prog_bar=True)
|
124 |
return loss
|
125 |
|
126 |
def configure_optimizers(self) -> Optimizer:
|
127 |
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
128 |
return optimizer
|
129 |
+
|
130 |
|
131 |
class CVAE(L.LightningModule):
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
io_channels: int,
|
135 |
+
io_features: int,
|
136 |
+
latent_features: int,
|
137 |
+
channels: Sequence[int],
|
138 |
+
num_classes: int,
|
139 |
+
learning_rate: float,
|
140 |
+
):
|
141 |
super().__init__()
|
142 |
+
|
143 |
self.class_embedder = nn.Linear(num_classes, io_features)
|
144 |
self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1)
|
145 |
+
|
146 |
+
self.encoder = Encoder(io_channels + 1, io_features, latent_features, channels)
|
147 |
+
|
148 |
channels.reverse()
|
149 |
+
self.decoder = Decoder(
|
150 |
+
io_channels, latent_features + num_classes, io_features, channels
|
151 |
+
)
|
152 |
+
|
153 |
self.num_classes = num_classes
|
154 |
self.latent_features = latent_features
|
155 |
self.audio_loss_func = MultiResolutionSTFTLoss()
|
156 |
self.learning_rate = learning_rate
|
157 |
|
158 |
@torch.no_grad()
|
159 |
+
def sample(self, c, eps=None) -> Tensor:
|
160 |
c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0)
|
161 |
if eps is None:
|
162 |
eps = torch.rand((1, self.latent_features))
|
163 |
z = torch.cat([eps, c], dim=1)
|
164 |
return self.decoder(z)
|
165 |
+
|
166 |
+
def loss_function(
|
167 |
+
self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor
|
168 |
+
) -> Tensor:
|
169 |
audio_loss = self.audio_loss_func(x, x_hat)
|
170 |
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
|
171 |
return audio_loss + kld_loss
|
172 |
+
|
173 |
+
def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor:
|
174 |
+
std = torch.exp(0.5 * logvar)
|
175 |
eps = torch.randn_like(std)
|
176 |
return eps * std + mean
|
177 |
|
178 |
+
def forward(self, x: Tensor, c: Tensor) -> tuple[Tensor]:
|
179 |
c = nn.functional.one_hot(c, num_classes=self.num_classes).float()
|
180 |
c_embedding = self.class_embedder(c).unsqueeze(1)
|
181 |
x_embedding = self.data_embedder(x)
|
182 |
+
x = torch.cat([x_embedding, c_embedding], dim=1)
|
183 |
mean, logvar = self.encoder(x)
|
184 |
z = self.reparameterize(mean, logvar)
|
185 |
+
z = torch.cat([z, c], dim=1)
|
186 |
return self.decoder(z), mean, logvar
|
187 |
+
|
188 |
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
|
189 |
x, c = batch
|
190 |
x_hat, mean, logvar = self.forward(x, c)
|
191 |
loss = self.loss_function(x, x_hat, mean, logvar)
|
192 |
+
if log:
|
193 |
+
self.log("train_loss", loss, prog_bar=True)
|
194 |
return loss
|
195 |
|
196 |
def configure_optimizers(self) -> Optimizer:
|
197 |
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
198 |
+
return optimizer
|
model.py
CHANGED
@@ -2,34 +2,81 @@ from cvae import CVAE
|
|
2 |
import torch
|
3 |
from typing import Sequence
|
4 |
import streamlit as st
|
|
|
5 |
|
6 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
@st.cache_resource
|
11 |
-
def load_model(device):
|
12 |
return CVAE.load_from_checkpoint(
|
13 |
-
|
14 |
io_channels=1,
|
15 |
-
io_features=16000*4,
|
16 |
latent_features=5,
|
17 |
channels=[32, 64, 128, 256, 512],
|
18 |
num_classes=len(instruments),
|
19 |
-
learning_rate=1e-5
|
20 |
).to(device)
|
21 |
|
22 |
-
model = load_model(device)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
|
30 |
-
choice = '_'.join([format(i) for i in choice])
|
31 |
-
return torch.tensor(instruments.index(choice))
|
32 |
|
33 |
-
def generate(choice: Sequence[str], params: Sequence[int]=None):
|
34 |
-
noise =
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
from typing import Sequence
|
4 |
import streamlit as st
|
5 |
+
from lightning import LightningModule
|
6 |
|
|
|
7 |
|
8 |
+
def format_instruments(text: str) -> str:
|
9 |
+
stems = text.split(" ")[1:]
|
10 |
+
stems = [stem.replace(" ", "").lower() for stem in stems]
|
11 |
+
return "_".join(stems)
|
12 |
+
|
13 |
+
|
14 |
+
def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
|
15 |
+
choice = "_".join([format_instruments(i) for i in choice])
|
16 |
+
return torch.tensor(instruments.index(choice))
|
17 |
+
|
18 |
|
19 |
@st.cache_resource
|
20 |
+
def load_model(device: str) -> LightningModule:
|
21 |
return CVAE.load_from_checkpoint(
|
22 |
+
"epoch=77-step=2819778.ckpt",
|
23 |
io_channels=1,
|
24 |
+
io_features=16000 * 4,
|
25 |
latent_features=5,
|
26 |
channels=[32, 64, 128, 256, 512],
|
27 |
num_classes=len(instruments),
|
28 |
+
learning_rate=1e-5,
|
29 |
).to(device)
|
30 |
|
|
|
31 |
|
32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
|
34 |
+
instruments = [
|
35 |
+
"bass_acoustic",
|
36 |
+
"brass_acoustic",
|
37 |
+
"flute_acoustic",
|
38 |
+
"guitar_acoustic",
|
39 |
+
"keyboard_acoustic",
|
40 |
+
"mallet_acoustic",
|
41 |
+
"organ_acoustic",
|
42 |
+
"reed_acoustic",
|
43 |
+
"string_acoustic",
|
44 |
+
"synth_lead_acoustic",
|
45 |
+
"vocal_acoustic",
|
46 |
+
"bass_synthetic",
|
47 |
+
"brass_synthetic",
|
48 |
+
"flute_synthetic",
|
49 |
+
"guitar_synthetic",
|
50 |
+
"keyboard_synthetic",
|
51 |
+
"mallet_synthetic",
|
52 |
+
"organ_synthetic",
|
53 |
+
"reed_synthetic",
|
54 |
+
"string_synthetic",
|
55 |
+
"synth_lead_synthetic",
|
56 |
+
"vocal_synthetic",
|
57 |
+
"bass_electronic",
|
58 |
+
"brass_electronic",
|
59 |
+
"flute_electronic",
|
60 |
+
"guitar_electronic",
|
61 |
+
"keyboard_electronic",
|
62 |
+
"mallet_electronic",
|
63 |
+
"organ_electronic",
|
64 |
+
"reed_electronic",
|
65 |
+
"string_electronic",
|
66 |
+
"synth_lead_electronic",
|
67 |
+
"vocal_electronic",
|
68 |
+
]
|
69 |
+
|
70 |
+
|
71 |
+
model = load_model(device)
|
72 |
|
|
|
|
|
|
|
73 |
|
74 |
+
def generate(choice: Sequence[str], params: Sequence[int] = None):
|
75 |
+
noise = (
|
76 |
+
torch.tensor(params).unsqueeze(0).to(device)
|
77 |
+
if params
|
78 |
+
else torch.randn(1, 5).to(device)
|
79 |
+
)
|
80 |
+
return (
|
81 |
+
model.sample(eps=noise, c=choice_to_tensor(choice).to(device)).cpu().numpy()[0]
|
82 |
+
)
|