acanivet commited on
Commit
e72f4c2
1 Parent(s): 2237ddd

Formatting

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. app.py +40 -17
  3. cvae/__init__.py +1 -7
  4. cvae/blocks.py +7 -7
  5. cvae/models.py +84 -53
  6. model.py +64 -17
README.md CHANGED
@@ -13,7 +13,7 @@ pinned: false
13
 
14
  ![Screenshot of the app](app.png)
15
 
16
- This is a demo for the sound generation models built in `pytorch`. It relies on a simple `streamlit` app calling the model with the parameters given by the user.
17
 
18
  ## Install :
19
 
 
13
 
14
  ![Screenshot of the app](app.png)
15
 
16
+ This is a demo for the sound generation models built in `pytorch`. It relies on a simple `streamlit` app calling the model with the parameters given by the user. Due to time and material difficulties, the model isn't properly trained and isn't able to produce interesting sounds now.
17
 
18
  ## Install :
19
 
app.py CHANGED
@@ -4,42 +4,66 @@ import io
4
  import numpy as np
5
  from scipy.io.wavfile import write
6
 
 
 
 
 
 
7
  @st.cache_data
8
- def np_to_wav(waveform, sample_rate) -> bytes:
9
  bytes_wav = bytes()
10
  byte_io = io.BytesIO(bytes_wav)
11
  write(byte_io, sample_rate, waveform.T)
12
  return byte_io.read()
13
 
 
 
 
 
 
14
  if "result" not in st.session_state:
15
  st.session_state["result"] = None
16
 
 
 
 
 
17
  st.title("Sound Exploration")
18
 
19
  col1, col2 = st.columns(2)
20
-
21
  with col1:
22
  instrument = st.selectbox(
23
- 'Which intrument do you want ?',
24
- ('🎸 Bass', '🎺 Brass', '🪈 Flute', '🪕 Guitar', '🎹 Keyboard', '🔨 Mallet', '🪗 Organ', '🎷 Reed', '🎻 String', '⚡ Synth lead', '🎤 Vocal')
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
-
27
  with col2:
28
  instrument_t = st.selectbox(
29
- 'Which type intrument do you want ?',
30
- ('📯 Acoustic', '🎙️ Electronic', '🎛️ Synthetic')
31
- )
32
-
33
  with st.expander("Magical parameters 🪄"):
34
  col1, col2 = st.columns(2)
35
  with col1:
36
- p1 = st.slider('p1', 0., 1., step=0.001, label_visibility='collapsed')
37
- p2 = st.slider('p2', 0., 1., step=0.001, label_visibility='collapsed')
38
- p3 = st.slider('p3', 0., 1., step=0.001, label_visibility='collapsed')
39
  with col2:
40
- p4 = st.slider('p4', 0., 1., step=0.001, label_visibility='collapsed')
41
- p5 = st.slider('p5', 0., 1., step=0.001, label_visibility='collapsed')
42
- use_params = st.toggle('Use magical parameters ?')
43
  params = (p1, p2, p3, p4, p5) if use_params else None
44
 
45
  if st.button("Generate ✨", type="primary"):
@@ -53,6 +77,5 @@ if st.session_state["result"] is not None:
53
  st.download_button(
54
  label="Download ⬇️",
55
  data=np_to_wav(st.session_state["result"], 16000),
56
- file_name='result.wav',
57
  )
58
-
 
4
  import numpy as np
5
  from scipy.io.wavfile import write
6
 
7
+ # -----
8
+ # Utils
9
+ # -----
10
+
11
+
12
  @st.cache_data
13
+ def np_to_wav(waveform: np.Array, sample_rate: int) -> bytes:
14
  bytes_wav = bytes()
15
  byte_io = io.BytesIO(bytes_wav)
16
  write(byte_io, sample_rate, waveform.T)
17
  return byte_io.read()
18
 
19
+
20
+ # ------------------
21
+ # App initialization
22
+ # ------------------
23
+
24
  if "result" not in st.session_state:
25
  st.session_state["result"] = None
26
 
27
+ # ---
28
+ # App
29
+ # ---
30
+
31
  st.title("Sound Exploration")
32
 
33
  col1, col2 = st.columns(2)
 
34
  with col1:
35
  instrument = st.selectbox(
36
+ "Which intrument do you want ?",
37
+ (
38
+ "🎸 Bass",
39
+ "🎺 Brass",
40
+ "🪈 Flute",
41
+ "🪕 Guitar",
42
+ "🎹 Keyboard",
43
+ "🔨 Mallet",
44
+ "🪗 Organ",
45
+ "🎷 Reed",
46
+ "🎻 String",
47
+ "⚡ Synth lead",
48
+ "🎤 Vocal",
49
+ ),
50
  )
 
51
  with col2:
52
  instrument_t = st.selectbox(
53
+ "Which type intrument do you want ?",
54
+ ("📯 Acoustic", "🎙️ Electronic", "🎛️ Synthetic"),
55
+ )
56
+
57
  with st.expander("Magical parameters 🪄"):
58
  col1, col2 = st.columns(2)
59
  with col1:
60
+ p1 = st.slider("p1", 0.0, 1.0, step=0.001, label_visibility="collapsed")
61
+ p2 = st.slider("p2", 0.0, 1.0, step=0.001, label_visibility="collapsed")
62
+ p3 = st.slider("p3", 0.0, 1.0, step=0.001, label_visibility="collapsed")
63
  with col2:
64
+ p4 = st.slider("p4", 0.0, 1.0, step=0.001, label_visibility="collapsed")
65
+ p5 = st.slider("p5", 0.0, 1.0, step=0.001, label_visibility="collapsed")
66
+ use_params = st.toggle("Use magical parameters ?")
67
  params = (p1, p2, p3, p4, p5) if use_params else None
68
 
69
  if st.button("Generate ✨", type="primary"):
 
77
  st.download_button(
78
  label="Download ⬇️",
79
  data=np_to_wav(st.session_state["result"], 16000),
80
+ file_name="result.wav",
81
  )
 
cvae/__init__.py CHANGED
@@ -1,7 +1 @@
1
- from .models import (
2
- Encoder, Decoder, VAE, CVAE
3
- )
4
-
5
- from .blocks import (
6
- UpResConvBlock, DownResConvBlock
7
- )
 
1
+ from .models import VAE, CVAE
 
 
 
 
 
 
cvae/blocks.py CHANGED
@@ -1,7 +1,7 @@
1
- from torch import nn
2
 
3
  class UpResConvBlock(nn.Module):
4
- def __init__(self, in_channels, out_channels, kernel_size):
5
  super(UpResConvBlock, self).__init__()
6
 
7
  self.residual = nn.Sequential(
@@ -19,11 +19,11 @@ class UpResConvBlock(nn.Module):
19
  nn.GELU()
20
  )
21
 
22
- def forward(self, x):
23
  return self.main(x) + self.residual(x)
24
 
25
  class DownResConvBlock(nn.Module):
26
- def __init__(self, in_channels, out_channels, kernel_size):
27
  super(DownResConvBlock, self).__init__()
28
 
29
  self.residual = nn.Conv1d(in_channels, out_channels, 1, 2, bias=False)
@@ -37,11 +37,11 @@ class DownResConvBlock(nn.Module):
37
  nn.GELU()
38
  )
39
 
40
- def forward(self, x):
41
  return self.main(x) + self.residual(x)
42
 
43
  class ResConvBlock(nn.Module):
44
- def __init__(self, in_channels, out_channels, kernel_size):
45
  super(ResConvBlock, self).__init__()
46
 
47
  self.residual = nn.Identity() if in_channels == out_channels else nn.Conv1d(in_channels, out_channels, 1, bias=False)
@@ -55,5 +55,5 @@ class ResConvBlock(nn.Module):
55
  nn.GELU()
56
  )
57
 
58
- def forward(self, x):
59
  return self.main(x) + self.residual(x)
 
1
+ from torch import nn, Tensor
2
 
3
  class UpResConvBlock(nn.Module):
4
+ def __init__(self, in_channels, out_channels, kernel_size) -> None:
5
  super(UpResConvBlock, self).__init__()
6
 
7
  self.residual = nn.Sequential(
 
19
  nn.GELU()
20
  )
21
 
22
+ def forward(self, x: Tensor) -> Tensor:
23
  return self.main(x) + self.residual(x)
24
 
25
  class DownResConvBlock(nn.Module):
26
+ def __init__(self, in_channels, out_channels, kernel_size) -> None:
27
  super(DownResConvBlock, self).__init__()
28
 
29
  self.residual = nn.Conv1d(in_channels, out_channels, 1, 2, bias=False)
 
37
  nn.GELU()
38
  )
39
 
40
+ def forward(self, x: Tensor) -> Tensor:
41
  return self.main(x) + self.residual(x)
42
 
43
  class ResConvBlock(nn.Module):
44
+ def __init__(self, in_channels, out_channels, kernel_size) -> None:
45
  super(ResConvBlock, self).__init__()
46
 
47
  self.residual = nn.Identity() if in_channels == out_channels else nn.Conv1d(in_channels, out_channels, 1, bias=False)
 
55
  nn.GELU()
56
  )
57
 
58
+ def forward(self, x: Tensor) -> Tensor:
59
  return self.main(x) + self.residual(x)
cvae/models.py CHANGED
@@ -4,164 +4,195 @@ from torch.optim import Optimizer
4
  from .blocks import UpResConvBlock, DownResConvBlock
5
  import lightning as L
6
  from auraloss.freq import MultiResolutionSTFTLoss
 
 
7
 
8
  class Encoder(nn.Module):
9
- def __init__(self,
 
10
  in_channels: int,
11
  in_features: int,
12
  out_features: int,
13
- channels: list = None,
14
- ) -> None:
15
  super(Encoder, self).__init__()
16
 
17
- assert in_features % 2**len(channels) == 0, f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})"
 
 
18
 
19
- modules = [
20
- nn.Conv1d(in_channels, channels[0], 1),
21
- nn.GELU()
22
- ]
23
 
24
- for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]):
25
  modules += [
26
  DownResConvBlock(in_channel, out_channel, 1),
27
  ]
28
 
29
- n_features = int(in_features*.5**len(channels))
30
 
31
  modules += [
32
  nn.Flatten(),
33
- nn.Linear(n_features*channels[-1], 2*out_features)
34
  ]
35
 
36
  self.net = nn.Sequential(*modules)
37
 
38
- def forward(self, x):
39
  mean, logvar = self.net(x).chunk(2, dim=1)
40
  return mean, logvar
41
-
 
42
  class Decoder(nn.Module):
43
- def __init__(self,
 
44
  out_channels: int,
45
  in_features: int,
46
  out_features: int,
47
- channels: list = None,
48
- ) -> None:
49
  super(Decoder, self).__init__()
50
 
51
- n_features = int(out_features/2**len(channels))
52
 
53
  modules = [
54
- nn.Linear(in_features, n_features*channels[0]),
55
- nn.Unflatten(-1, (channels[0], n_features))
56
  ]
57
 
58
- for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]):
59
  modules += [
60
  UpResConvBlock(in_channel, out_channel, 1),
61
  ]
62
 
63
- modules += [
64
- nn.Conv1d(channels[-1], out_channels, 1),
65
- nn.GELU()
66
- ]
67
 
68
  self.net = nn.Sequential(*modules)
69
 
70
- def forward(self, x):
71
  x = torch.tanh(self.net(x))
72
  return x
73
 
74
-
75
  class VAE(L.LightningModule):
76
- def __init__(self, io_channels: int, io_features: int, latent_features: int, channels: list, learning_rate: float):
 
 
 
 
 
 
 
77
  super().__init__()
 
78
  self.encoder = Encoder(io_channels, io_features, latent_features, channels)
 
79
  channels.reverse()
80
  self.decoder = Decoder(io_channels, latent_features, io_features, channels)
 
81
  self.latent_features = latent_features
82
  self.audio_loss_func = MultiResolutionSTFTLoss()
83
  self.learning_rate = learning_rate
84
 
85
  @torch.no_grad()
86
- def sample(self, eps=None):
87
  if eps is None:
88
  eps = torch.rand((1, self.latent_features))
89
  return self.decoder(eps)
90
-
91
- def loss_function(self, x, x_hat, mean, logvar):
 
 
92
  audio_loss = self.audio_loss_func(x, x_hat)
93
  kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
94
  return audio_loss + kld_loss
95
-
96
- def reparameterize(self, mean, logvar):
97
- std= torch.exp(0.5 * logvar)
98
  eps = torch.randn_like(std)
99
  return eps * std + mean
100
-
101
- def forward(self, x):
102
  mean, logvar = self.encoder(x)
103
  z = self.reparameterize(mean, logvar)
104
  return self.decoder(z), mean, logvar
105
-
106
  def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
107
  x_hat, mean, logvar = self.forward(batch)
108
  loss = self.loss_function(batch, x_hat, mean, logvar)
109
- if log: self.log("train_loss", loss, prog_bar=True)
 
110
  return loss
111
 
112
  def configure_optimizers(self) -> Optimizer:
113
  optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
114
  return optimizer
115
-
116
 
117
  class CVAE(L.LightningModule):
118
- def __init__(self, io_channels: int, io_features: int, latent_features: int, channels: list, num_classes: int, learning_rate: float):
 
 
 
 
 
 
 
 
119
  super().__init__()
 
120
  self.class_embedder = nn.Linear(num_classes, io_features)
121
  self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1)
122
- self.encoder = Encoder(io_channels+1, io_features, latent_features, channels)
 
 
123
  channels.reverse()
124
- self.decoder = Decoder(io_channels, latent_features+num_classes, io_features, channels)
 
 
 
125
  self.num_classes = num_classes
126
  self.latent_features = latent_features
127
  self.audio_loss_func = MultiResolutionSTFTLoss()
128
  self.learning_rate = learning_rate
129
 
130
  @torch.no_grad()
131
- def sample(self, c, eps=None):
132
  c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0)
133
  if eps is None:
134
  eps = torch.rand((1, self.latent_features))
135
  z = torch.cat([eps, c], dim=1)
136
  return self.decoder(z)
137
-
138
- def loss_function(self, x, x_hat, mean, logvar):
 
 
139
  audio_loss = self.audio_loss_func(x, x_hat)
140
  kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
141
  return audio_loss + kld_loss
142
-
143
- def reparameterize(self, mean, logvar):
144
- std= torch.exp(0.5 * logvar)
145
  eps = torch.randn_like(std)
146
  return eps * std + mean
147
 
148
- def forward(self, x, c):
149
  c = nn.functional.one_hot(c, num_classes=self.num_classes).float()
150
  c_embedding = self.class_embedder(c).unsqueeze(1)
151
  x_embedding = self.data_embedder(x)
152
- x = torch.cat([x_embedding, c_embedding], dim = 1)
153
  mean, logvar = self.encoder(x)
154
  z = self.reparameterize(mean, logvar)
155
- z = torch.cat([z, c], dim = 1)
156
  return self.decoder(z), mean, logvar
157
-
158
  def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
159
  x, c = batch
160
  x_hat, mean, logvar = self.forward(x, c)
161
  loss = self.loss_function(x, x_hat, mean, logvar)
162
- if log: self.log("train_loss", loss, prog_bar=True)
 
163
  return loss
164
 
165
  def configure_optimizers(self) -> Optimizer:
166
  optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
167
- return optimizer
 
4
  from .blocks import UpResConvBlock, DownResConvBlock
5
  import lightning as L
6
  from auraloss.freq import MultiResolutionSTFTLoss
7
+ from typing import Sequence
8
+
9
 
10
  class Encoder(nn.Module):
11
+ def __init__(
12
+ self,
13
  in_channels: int,
14
  in_features: int,
15
  out_features: int,
16
+ channels: Sequence[int],
17
+ ) -> None:
18
  super(Encoder, self).__init__()
19
 
20
+ assert (
21
+ in_features % 2 ** len(channels) == 0
22
+ ), f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})"
23
 
24
+ modules = [nn.Conv1d(in_channels, channels[0], 1), nn.GELU()]
 
 
 
25
 
26
+ for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]):
27
  modules += [
28
  DownResConvBlock(in_channel, out_channel, 1),
29
  ]
30
 
31
+ n_features = int(in_features * 0.5 ** len(channels))
32
 
33
  modules += [
34
  nn.Flatten(),
35
+ nn.Linear(n_features * channels[-1], 2 * out_features),
36
  ]
37
 
38
  self.net = nn.Sequential(*modules)
39
 
40
+ def forward(self, x: Tensor) -> Tensor:
41
  mean, logvar = self.net(x).chunk(2, dim=1)
42
  return mean, logvar
43
+
44
+
45
  class Decoder(nn.Module):
46
+ def __init__(
47
+ self,
48
  out_channels: int,
49
  in_features: int,
50
  out_features: int,
51
+ channels: Sequence[int],
52
+ ) -> None:
53
  super(Decoder, self).__init__()
54
 
55
+ n_features = int(out_features / 2 ** len(channels))
56
 
57
  modules = [
58
+ nn.Linear(in_features, n_features * channels[0]),
59
+ nn.Unflatten(-1, (channels[0], n_features)),
60
  ]
61
 
62
+ for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]):
63
  modules += [
64
  UpResConvBlock(in_channel, out_channel, 1),
65
  ]
66
 
67
+ modules += [nn.Conv1d(channels[-1], out_channels, 1), nn.GELU()]
 
 
 
68
 
69
  self.net = nn.Sequential(*modules)
70
 
71
+ def forward(self, x: Tensor) -> Tensor:
72
  x = torch.tanh(self.net(x))
73
  return x
74
 
75
+
76
  class VAE(L.LightningModule):
77
+ def __init__(
78
+ self,
79
+ io_channels: int,
80
+ io_features: int,
81
+ latent_features: int,
82
+ channels: Sequence[int],
83
+ learning_rate: float,
84
+ ) -> None:
85
  super().__init__()
86
+
87
  self.encoder = Encoder(io_channels, io_features, latent_features, channels)
88
+
89
  channels.reverse()
90
  self.decoder = Decoder(io_channels, latent_features, io_features, channels)
91
+
92
  self.latent_features = latent_features
93
  self.audio_loss_func = MultiResolutionSTFTLoss()
94
  self.learning_rate = learning_rate
95
 
96
  @torch.no_grad()
97
+ def sample(self, eps: Tensor = None) -> Tensor:
98
  if eps is None:
99
  eps = torch.rand((1, self.latent_features))
100
  return self.decoder(eps)
101
+
102
+ def loss_function(
103
+ self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor
104
+ ) -> Tensor:
105
  audio_loss = self.audio_loss_func(x, x_hat)
106
  kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
107
  return audio_loss + kld_loss
108
+
109
+ def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor:
110
+ std = torch.exp(0.5 * logvar)
111
  eps = torch.randn_like(std)
112
  return eps * std + mean
113
+
114
+ def forward(self, x: Tensor) -> tuple[Tensor]:
115
  mean, logvar = self.encoder(x)
116
  z = self.reparameterize(mean, logvar)
117
  return self.decoder(z), mean, logvar
118
+
119
  def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
120
  x_hat, mean, logvar = self.forward(batch)
121
  loss = self.loss_function(batch, x_hat, mean, logvar)
122
+ if log:
123
+ self.log("train_loss", loss, prog_bar=True)
124
  return loss
125
 
126
  def configure_optimizers(self) -> Optimizer:
127
  optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
128
  return optimizer
129
+
130
 
131
  class CVAE(L.LightningModule):
132
+ def __init__(
133
+ self,
134
+ io_channels: int,
135
+ io_features: int,
136
+ latent_features: int,
137
+ channels: Sequence[int],
138
+ num_classes: int,
139
+ learning_rate: float,
140
+ ):
141
  super().__init__()
142
+
143
  self.class_embedder = nn.Linear(num_classes, io_features)
144
  self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1)
145
+
146
+ self.encoder = Encoder(io_channels + 1, io_features, latent_features, channels)
147
+
148
  channels.reverse()
149
+ self.decoder = Decoder(
150
+ io_channels, latent_features + num_classes, io_features, channels
151
+ )
152
+
153
  self.num_classes = num_classes
154
  self.latent_features = latent_features
155
  self.audio_loss_func = MultiResolutionSTFTLoss()
156
  self.learning_rate = learning_rate
157
 
158
  @torch.no_grad()
159
+ def sample(self, c, eps=None) -> Tensor:
160
  c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0)
161
  if eps is None:
162
  eps = torch.rand((1, self.latent_features))
163
  z = torch.cat([eps, c], dim=1)
164
  return self.decoder(z)
165
+
166
+ def loss_function(
167
+ self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor
168
+ ) -> Tensor:
169
  audio_loss = self.audio_loss_func(x, x_hat)
170
  kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
171
  return audio_loss + kld_loss
172
+
173
+ def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor:
174
+ std = torch.exp(0.5 * logvar)
175
  eps = torch.randn_like(std)
176
  return eps * std + mean
177
 
178
+ def forward(self, x: Tensor, c: Tensor) -> tuple[Tensor]:
179
  c = nn.functional.one_hot(c, num_classes=self.num_classes).float()
180
  c_embedding = self.class_embedder(c).unsqueeze(1)
181
  x_embedding = self.data_embedder(x)
182
+ x = torch.cat([x_embedding, c_embedding], dim=1)
183
  mean, logvar = self.encoder(x)
184
  z = self.reparameterize(mean, logvar)
185
+ z = torch.cat([z, c], dim=1)
186
  return self.decoder(z), mean, logvar
187
+
188
  def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
189
  x, c = batch
190
  x_hat, mean, logvar = self.forward(x, c)
191
  loss = self.loss_function(x, x_hat, mean, logvar)
192
+ if log:
193
+ self.log("train_loss", loss, prog_bar=True)
194
  return loss
195
 
196
  def configure_optimizers(self) -> Optimizer:
197
  optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
198
+ return optimizer
model.py CHANGED
@@ -2,34 +2,81 @@ from cvae import CVAE
2
  import torch
3
  from typing import Sequence
4
  import streamlit as st
 
5
 
6
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
- instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
 
 
 
 
 
 
 
 
 
9
 
10
  @st.cache_resource
11
- def load_model(device):
12
  return CVAE.load_from_checkpoint(
13
- 'epoch=77-step=2819778.ckpt',
14
  io_channels=1,
15
- io_features=16000*4,
16
  latent_features=5,
17
  channels=[32, 64, 128, 256, 512],
18
  num_classes=len(instruments),
19
- learning_rate=1e-5
20
  ).to(device)
21
 
22
- model = load_model(device)
23
 
24
- def format(text):
25
- stems = text.split(' ')[1:]
26
- stems = [stem.replace(" ", "").lower() for stem in stems]
27
- return '_'.join(stems)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
30
- choice = '_'.join([format(i) for i in choice])
31
- return torch.tensor(instruments.index(choice))
32
 
33
- def generate(choice: Sequence[str], params: Sequence[int]=None):
34
- noise = torch.tensor(params).unsqueeze(0).to(device) if params else torch.randn(1, 5).to(device)
35
- return model.sample(eps=noise, c = choice_to_tensor(choice).to(device)).cpu().numpy()[0]
 
 
 
 
 
 
 
2
  import torch
3
  from typing import Sequence
4
  import streamlit as st
5
+ from lightning import LightningModule
6
 
 
7
 
8
+ def format_instruments(text: str) -> str:
9
+ stems = text.split(" ")[1:]
10
+ stems = [stem.replace(" ", "").lower() for stem in stems]
11
+ return "_".join(stems)
12
+
13
+
14
+ def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
15
+ choice = "_".join([format_instruments(i) for i in choice])
16
+ return torch.tensor(instruments.index(choice))
17
+
18
 
19
  @st.cache_resource
20
+ def load_model(device: str) -> LightningModule:
21
  return CVAE.load_from_checkpoint(
22
+ "epoch=77-step=2819778.ckpt",
23
  io_channels=1,
24
+ io_features=16000 * 4,
25
  latent_features=5,
26
  channels=[32, 64, 128, 256, 512],
27
  num_classes=len(instruments),
28
+ learning_rate=1e-5,
29
  ).to(device)
30
 
 
31
 
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
+ instruments = [
35
+ "bass_acoustic",
36
+ "brass_acoustic",
37
+ "flute_acoustic",
38
+ "guitar_acoustic",
39
+ "keyboard_acoustic",
40
+ "mallet_acoustic",
41
+ "organ_acoustic",
42
+ "reed_acoustic",
43
+ "string_acoustic",
44
+ "synth_lead_acoustic",
45
+ "vocal_acoustic",
46
+ "bass_synthetic",
47
+ "brass_synthetic",
48
+ "flute_synthetic",
49
+ "guitar_synthetic",
50
+ "keyboard_synthetic",
51
+ "mallet_synthetic",
52
+ "organ_synthetic",
53
+ "reed_synthetic",
54
+ "string_synthetic",
55
+ "synth_lead_synthetic",
56
+ "vocal_synthetic",
57
+ "bass_electronic",
58
+ "brass_electronic",
59
+ "flute_electronic",
60
+ "guitar_electronic",
61
+ "keyboard_electronic",
62
+ "mallet_electronic",
63
+ "organ_electronic",
64
+ "reed_electronic",
65
+ "string_electronic",
66
+ "synth_lead_electronic",
67
+ "vocal_electronic",
68
+ ]
69
+
70
+
71
+ model = load_model(device)
72
 
 
 
 
73
 
74
+ def generate(choice: Sequence[str], params: Sequence[int] = None):
75
+ noise = (
76
+ torch.tensor(params).unsqueeze(0).to(device)
77
+ if params
78
+ else torch.randn(1, 5).to(device)
79
+ )
80
+ return (
81
+ model.sample(eps=noise, c=choice_to_tensor(choice).to(device)).cpu().numpy()[0]
82
+ )