pcunwa commited on
Commit
e6daba8
1 Parent(s): 1964f60

Delete mel_band_roformer.py

Browse files
Files changed (1) hide show
  1. mel_band_roformer.py +0 -637
mel_band_roformer.py DELETED
@@ -1,637 +0,0 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from models.bs_roformer.attend import Attend
9
-
10
- from beartype.typing import Tuple, Optional, List, Callable
11
- from beartype import beartype
12
-
13
- from rotary_embedding_torch import RotaryEmbedding
14
-
15
- from einops import rearrange, pack, unpack, reduce, repeat
16
- from einops.layers.torch import Rearrange
17
-
18
- from librosa import filters
19
-
20
-
21
- # helper functions
22
-
23
- def exists(val):
24
- return val is not None
25
-
26
-
27
- def default(v, d):
28
- return v if exists(v) else d
29
-
30
-
31
- def pack_one(t, pattern):
32
- return pack([t], pattern)
33
-
34
-
35
- def unpack_one(t, ps, pattern):
36
- return unpack(t, ps, pattern)[0]
37
-
38
-
39
- def pad_at_dim(t, pad, dim=-1, value=0.):
40
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
41
- zeros = ((0, 0) * dims_from_right)
42
- return F.pad(t, (*zeros, *pad), value=value)
43
-
44
-
45
- def l2norm(t):
46
- return F.normalize(t, dim=-1, p=2)
47
-
48
-
49
- # norm
50
-
51
- class RMSNorm(Module):
52
- def __init__(self, dim):
53
- super().__init__()
54
- self.scale = dim ** 0.5
55
- self.gamma = nn.Parameter(torch.ones(dim))
56
-
57
- def forward(self, x):
58
- return F.normalize(x, dim=-1) * self.scale * self.gamma
59
-
60
-
61
- # attention
62
-
63
- class FeedForward(Module):
64
- def __init__(
65
- self,
66
- dim,
67
- mult=4,
68
- dropout=0.
69
- ):
70
- super().__init__()
71
- dim_inner = int(dim * mult)
72
- self.net = nn.Sequential(
73
- RMSNorm(dim),
74
- nn.Linear(dim, dim_inner),
75
- nn.GELU(),
76
- nn.Dropout(dropout),
77
- nn.Linear(dim_inner, dim),
78
- nn.Dropout(dropout)
79
- )
80
-
81
- def forward(self, x):
82
- return self.net(x)
83
-
84
-
85
- class Attention(Module):
86
- def __init__(
87
- self,
88
- dim,
89
- heads=8,
90
- dim_head=64,
91
- dropout=0.,
92
- rotary_embed=None,
93
- flash=True
94
- ):
95
- super().__init__()
96
- self.heads = heads
97
- self.scale = dim_head ** -0.5
98
- dim_inner = heads * dim_head
99
-
100
- self.rotary_embed = rotary_embed
101
-
102
- self.attend = Attend(flash=flash, dropout=dropout)
103
-
104
- self.norm = RMSNorm(dim)
105
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
106
-
107
- self.to_gates = nn.Linear(dim, heads)
108
-
109
- self.to_out = nn.Sequential(
110
- nn.Linear(dim_inner, dim, bias=False),
111
- nn.Dropout(dropout)
112
- )
113
-
114
- def forward(self, x):
115
- x = self.norm(x)
116
-
117
- q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
118
-
119
- if exists(self.rotary_embed):
120
- q = self.rotary_embed.rotate_queries_or_keys(q)
121
- k = self.rotary_embed.rotate_queries_or_keys(k)
122
-
123
- out = self.attend(q, k, v)
124
-
125
- gates = self.to_gates(x)
126
- out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
127
-
128
- out = rearrange(out, 'b h n d -> b n (h d)')
129
- return self.to_out(out)
130
-
131
-
132
- class LinearAttention(Module):
133
- """
134
- this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
135
- """
136
-
137
- @beartype
138
- def __init__(
139
- self,
140
- *,
141
- dim,
142
- dim_head=32,
143
- heads=8,
144
- scale=8,
145
- flash=False,
146
- dropout=0.
147
- ):
148
- super().__init__()
149
- dim_inner = dim_head * heads
150
- self.norm = RMSNorm(dim)
151
-
152
- self.to_qkv = nn.Sequential(
153
- nn.Linear(dim, dim_inner * 3, bias=False),
154
- Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
- )
156
-
157
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
-
159
- self.attend = Attend(
160
- scale=scale,
161
- dropout=dropout,
162
- flash=flash
163
- )
164
-
165
- self.to_out = nn.Sequential(
166
- Rearrange('b h d n -> b n (h d)'),
167
- nn.Linear(dim_inner, dim, bias=False)
168
- )
169
-
170
- def forward(
171
- self,
172
- x
173
- ):
174
- x = self.norm(x)
175
-
176
- q, k, v = self.to_qkv(x)
177
-
178
- q, k = map(l2norm, (q, k))
179
- q = q * self.temperature.exp()
180
-
181
- out = self.attend(q, k, v)
182
-
183
- return self.to_out(out)
184
-
185
-
186
- class Transformer(Module):
187
- def __init__(
188
- self,
189
- *,
190
- dim,
191
- depth,
192
- dim_head=64,
193
- heads=8,
194
- attn_dropout=0.,
195
- ff_dropout=0.,
196
- ff_mult=4,
197
- norm_output=True,
198
- rotary_embed=None,
199
- flash_attn=True,
200
- linear_attn=False
201
- ):
202
- super().__init__()
203
- self.layers = ModuleList([])
204
-
205
- for _ in range(depth):
206
- if linear_attn:
207
- attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
208
- else:
209
- attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
210
- rotary_embed=rotary_embed, flash=flash_attn)
211
-
212
- self.layers.append(ModuleList([
213
- attn,
214
- FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
215
- ]))
216
-
217
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
218
-
219
- def forward(self, x):
220
-
221
- for attn, ff in self.layers:
222
- x = attn(x) + x
223
- x = ff(x) + x
224
-
225
- return self.norm(x)
226
-
227
-
228
- # bandsplit module
229
-
230
- class BandSplit(Module):
231
- @beartype
232
- def __init__(
233
- self,
234
- dim,
235
- dim_inputs: Tuple[int, ...]
236
- ):
237
- super().__init__()
238
- self.dim_inputs = dim_inputs
239
- self.to_features = ModuleList([])
240
-
241
- for dim_in in dim_inputs:
242
- net = nn.Sequential(
243
- RMSNorm(dim_in),
244
- nn.Linear(dim_in, dim)
245
- )
246
-
247
- self.to_features.append(net)
248
-
249
- def forward(self, x):
250
- x = x.split(self.dim_inputs, dim=-1)
251
-
252
- outs = []
253
- for split_input, to_feature in zip(x, self.to_features):
254
- split_output = to_feature(split_input)
255
- outs.append(split_output)
256
-
257
- return torch.stack(outs, dim=-2)
258
-
259
-
260
- def MLP(
261
- dim_in,
262
- dim_out,
263
- dim_hidden=None,
264
- depth=1,
265
- activation=nn.Tanh
266
- ):
267
- dim_hidden = default(dim_hidden, dim_in)
268
-
269
- net = []
270
- dims = (dim_in, *((dim_hidden,) * depth), dim_out)
271
-
272
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
273
- is_last = ind == (len(dims) - 2)
274
-
275
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
276
-
277
- if is_last:
278
- continue
279
-
280
- net.append(activation())
281
-
282
- return nn.Sequential(*net)
283
-
284
-
285
- class MaskEstimator(Module):
286
- @beartype
287
- def __init__(
288
- self,
289
- dim,
290
- dim_inputs: Tuple[int, ...],
291
- depth,
292
- mlp_expansion_factor=1
293
- ):
294
- super().__init__()
295
- self.dim_inputs = dim_inputs
296
- self.to_freqs = ModuleList([])
297
- dim_hidden = dim * mlp_expansion_factor
298
-
299
- for dim_in in dim_inputs:
300
- net = []
301
-
302
- mlp = nn.Sequential(
303
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
304
- nn.GLU(dim=-1)
305
- )
306
-
307
- self.to_freqs.append(mlp)
308
-
309
- def forward(self, x):
310
- x = x.unbind(dim=-2)
311
-
312
- outs = []
313
-
314
- for band_features, mlp in zip(x, self.to_freqs):
315
- freq_out = mlp(band_features)
316
- outs.append(freq_out)
317
-
318
- return torch.cat(outs, dim=-1)
319
-
320
-
321
- # main class
322
-
323
- class MelBandRoformer(Module):
324
-
325
- @beartype
326
- def __init__(
327
- self,
328
- dim,
329
- *,
330
- depth,
331
- stereo=False,
332
- num_stems=1,
333
- time_transformer_depth=2,
334
- freq_transformer_depth=2,
335
- linear_transformer_depth=0,
336
- num_bands=60,
337
- dim_head=64,
338
- heads=8,
339
- attn_dropout=0.1,
340
- ff_dropout=0.1,
341
- flash_attn=True,
342
- dim_freqs_in=1025,
343
- sample_rate=44100, # needed for mel filter bank from librosa
344
- stft_n_fft=2048,
345
- stft_hop_length=512,
346
- # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
347
- stft_win_length=2048,
348
- stft_normalized=False,
349
- stft_window_fn: Optional[Callable] = None,
350
- mask_estimator_depth=1,
351
- multi_stft_resolution_loss_weight=1.,
352
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
353
- multi_stft_hop_size=147,
354
- multi_stft_normalized=False,
355
- multi_stft_window_fn: Callable = torch.hann_window,
356
- match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
357
- ):
358
- super().__init__()
359
-
360
- self.stereo = stereo
361
- self.audio_channels = 2 if stereo else 1
362
- self.num_stems = num_stems
363
-
364
- self.layers = ModuleList([])
365
-
366
- transformer_kwargs = dict(
367
- dim=dim,
368
- heads=heads,
369
- dim_head=dim_head,
370
- attn_dropout=attn_dropout,
371
- ff_dropout=ff_dropout,
372
- flash_attn=flash_attn
373
- )
374
-
375
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
376
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
377
-
378
- for _ in range(depth):
379
- tran_modules = []
380
- if linear_transformer_depth > 0:
381
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
382
- tran_modules.append(
383
- Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
384
- )
385
- tran_modules.append(
386
- Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
387
- )
388
- self.layers.append(nn.ModuleList(tran_modules))
389
-
390
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
391
-
392
- self.stft_kwargs = dict(
393
- n_fft=stft_n_fft,
394
- hop_length=stft_hop_length,
395
- win_length=stft_win_length,
396
- normalized=stft_normalized
397
- )
398
-
399
- freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
400
-
401
- # create mel filter bank
402
- # with librosa.filters.mel as in section 2 of paper
403
-
404
- mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
405
-
406
- mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
407
-
408
- # for some reason, it doesn't include the first freq? just force a value for now
409
-
410
- mel_filter_bank[0][0] = 1.
411
-
412
- # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
413
- # so let's force a positive value
414
-
415
- mel_filter_bank[-1, -1] = 1.
416
-
417
- # binary as in paper (then estimated masks are averaged for overlapping regions)
418
-
419
- freqs_per_band = mel_filter_bank > 0
420
- assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
421
-
422
- repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
423
- freq_indices = repeated_freq_indices[freqs_per_band]
424
-
425
- if stereo:
426
- freq_indices = repeat(freq_indices, 'f -> f s', s=2)
427
- freq_indices = freq_indices * 2 + torch.arange(2)
428
- freq_indices = rearrange(freq_indices, 'f s -> (f s)')
429
-
430
- self.register_buffer('freq_indices', freq_indices, persistent=False)
431
- self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
432
-
433
- num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
434
- num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
435
-
436
- self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
437
- self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
438
-
439
- # band split and mask estimator
440
-
441
- freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
442
-
443
- self.band_split = BandSplit(
444
- dim=dim,
445
- dim_inputs=freqs_per_bands_with_complex
446
- )
447
-
448
- self.mask_estimators = nn.ModuleList([])
449
-
450
- for _ in range(num_stems):
451
- mask_estimator = MaskEstimator(
452
- dim=dim,
453
- dim_inputs=freqs_per_bands_with_complex,
454
- depth=mask_estimator_depth
455
- )
456
-
457
- self.mask_estimators.append(mask_estimator)
458
-
459
- # for the multi-resolution stft loss
460
-
461
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
462
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
463
- self.multi_stft_n_fft = stft_n_fft
464
- self.multi_stft_window_fn = multi_stft_window_fn
465
-
466
- self.multi_stft_kwargs = dict(
467
- hop_length=multi_stft_hop_size,
468
- normalized=multi_stft_normalized
469
- )
470
-
471
- self.match_input_audio_length = match_input_audio_length
472
-
473
- def forward(
474
- self,
475
- raw_audio,
476
- target=None,
477
- return_loss_breakdown=False
478
- ):
479
- """
480
- einops
481
-
482
- b - batch
483
- f - freq
484
- t - time
485
- s - audio channel (1 for mono, 2 for stereo)
486
- n - number of 'stems'
487
- c - complex (2)
488
- d - feature dimension
489
- """
490
-
491
- device = raw_audio.device
492
-
493
- if raw_audio.ndim == 2:
494
- raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
495
-
496
- batch, channels, raw_audio_length = raw_audio.shape
497
-
498
- istft_length = raw_audio_length if self.match_input_audio_length else None
499
-
500
- assert (not self.stereo and channels == 1) or (
501
- self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
502
-
503
- # to stft
504
-
505
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
506
-
507
- stft_window = self.stft_window_fn(device=device)
508
-
509
- stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
510
- stft_repr = torch.view_as_real(stft_repr)
511
-
512
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
513
- stft_repr = rearrange(stft_repr,
514
- 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
515
-
516
- # index out all frequencies for all frequency ranges across bands ascending in one go
517
-
518
- batch_arange = torch.arange(batch, device=device)[..., None]
519
-
520
- # account for stereo
521
-
522
- x = stft_repr[batch_arange, self.freq_indices]
523
-
524
- # fold the complex (real and imag) into the frequencies dimension
525
-
526
- x = rearrange(x, 'b f t c -> b t (f c)')
527
-
528
- x = self.band_split(x)
529
-
530
- # axial / hierarchical attention
531
-
532
- for transformer_block in self.layers:
533
-
534
- if len(transformer_block) == 3:
535
- linear_transformer, time_transformer, freq_transformer = transformer_block
536
-
537
- x, ft_ps = pack([x], 'b * d')
538
- x = linear_transformer(x)
539
- x, = unpack(x, ft_ps, 'b * d')
540
- else:
541
- time_transformer, freq_transformer = transformer_block
542
-
543
- x = rearrange(x, 'b t f d -> b f t d')
544
- x, ps = pack([x], '* t d')
545
-
546
- x = time_transformer(x)
547
-
548
- x, = unpack(x, ps, '* t d')
549
- x = rearrange(x, 'b f t d -> b t f d')
550
- x, ps = pack([x], '* f d')
551
-
552
- x = freq_transformer(x)
553
-
554
- x, = unpack(x, ps, '* f d')
555
-
556
- num_stems = len(self.mask_estimators)
557
-
558
- masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
559
- masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
560
-
561
- # modulate frequency representation
562
-
563
- stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
564
-
565
- # complex number multiplication
566
-
567
- stft_repr = torch.view_as_complex(stft_repr)
568
- masks = torch.view_as_complex(masks)
569
-
570
- masks = masks.type(stft_repr.dtype)
571
-
572
- # need to average the estimated mask for the overlapped frequencies
573
-
574
- scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
575
-
576
- stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
577
- masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
578
-
579
- denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
580
-
581
- masks_averaged = masks_summed / denom.clamp(min=1e-8)
582
-
583
- # modulate stft repr with estimated mask
584
-
585
- stft_repr = stft_repr * masks_averaged
586
-
587
- # istft
588
-
589
- stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
590
-
591
- recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
592
- length=istft_length)
593
-
594
- recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
595
-
596
- if num_stems == 1:
597
- recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
598
-
599
- # if a target is passed in, calculate loss for learning
600
-
601
- if not exists(target):
602
- return recon_audio
603
-
604
- if self.num_stems > 1:
605
- assert target.ndim == 4 and target.shape[1] == self.num_stems
606
-
607
- if target.ndim == 2:
608
- target = rearrange(target, '... t -> ... 1 t')
609
-
610
- target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
611
-
612
- loss = F.l1_loss(recon_audio, target)
613
-
614
- multi_stft_resolution_loss = 0.
615
-
616
- for window_size in self.multi_stft_resolutions_window_sizes:
617
- res_stft_kwargs = dict(
618
- n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
619
- win_length=window_size,
620
- return_complex=True,
621
- window=self.multi_stft_window_fn(window_size, device=device),
622
- **self.multi_stft_kwargs,
623
- )
624
-
625
- recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
626
- target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
627
-
628
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
629
-
630
- weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
631
-
632
- total_loss = loss + weighted_multi_resolution_loss
633
-
634
- if not return_loss_breakdown:
635
- return total_loss
636
-
637
- return total_loss, (loss, multi_stft_resolution_loss)