waidhoferj commited on
Commit
557fb53
β€’
1 Parent(s): e82ec2b

Refactor config style and reorganize files

Browse files
.gitignore CHANGED
@@ -9,3 +9,4 @@ lightning_logs
9
  .lr_find_*
10
  .cache
11
  .vscode
 
 
9
  .lr_find_*
10
  .cache
11
  .vscode
12
+ models/weights/ast
TODO.md CHANGED
@@ -6,10 +6,13 @@
6
  - Create an attention-based network
7
  - βœ… Increase parameter count in network
8
  - Verify that labels really match what is on the music4dance site
9
- - Read the Medium series about audio DL
10
  - double check \_rectify_duration
11
  - βœ… Filter out songs that have only one vote
 
 
 
12
 
13
  ## Notes
14
 
15
- 2xM60 insufficient memory.
 
6
  - Create an attention-based network
7
  - βœ… Increase parameter count in network
8
  - Verify that labels really match what is on the music4dance site
9
+ - βœ… Read the Medium series about audio DL
10
  - double check \_rectify_duration
11
  - βœ… Filter out songs that have only one vote
12
+ - βœ… Download songs from [Best Ballroom](https://www.youtube.com/channel/UC0bYSnzAFMwPiEjmVsrvmRg)
13
+
14
+ - βœ… fix nan values
15
 
16
  ## Notes
17
 
18
+ 2xM60 insufficient memory for the AST.
environment.yml CHANGED
@@ -23,6 +23,11 @@ dependencies:
23
  - scikit-learn
24
  - tensorboard
25
  - transformers
 
 
 
26
  - pip:
27
  - evaluate
28
  - wakepy
 
 
 
23
  - scikit-learn
24
  - tensorboard
25
  - transformers
26
+ - accelerate
27
+ - pytest
28
+
29
  - pip:
30
  - evaluate
31
  - wakepy
32
+ - soundfile
33
+ - youtube_dl
models/audio_spectrogram_transformer.py CHANGED
@@ -1,93 +1,138 @@
1
- from transformers import ASTModel, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from torch import nn
4
- from sklearn.utils.class_weight import compute_class_weight
5
- import evaluate
6
- import numpy as np
7
 
8
- accuracy = evaluate.load("accuracy")
 
 
 
 
 
 
9
 
 
 
10
 
11
- class MultiModalAST(nn.Module):
12
 
13
 
14
- def __init__(self, labels, sample_rate, *args, **kwargs) -> None:
 
15
  super().__init__(*args, **kwargs)
16
  id2label, label2id = get_id_label_mapping(labels)
17
- model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
18
- self.ast_feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
19
-
20
- self.ast_model = ASTModel.from_pretrained(
21
- model_checkpoint,
22
- num_labels=len(label2id),
23
- label2id=label2id,
24
- id2label=id2label,
25
- ignore_mismatched_sizes=True
26
- )
27
- self.sample_rate = sample_rate
28
-
29
- self.bpm_model = nn.Sequential(
30
- nn.Linear(len(labels), 100),
31
- nn.Linear(100, 50)
32
- )
33
-
34
- out_dim = 50 # TODO: Calculate output dimension
35
- self.classifier = nn.Sequential(
36
- nn.Linear(out_dim, 100),
37
- nn.Linear(100, len(labels))
38
  )
39
-
40
- def vectorize_bpm(self, waveform):
41
- pass
42
-
43
-
44
- def forward(self, audio):
45
-
46
- bpm_vector = self.vectorize_bpm(audio)
47
- bpm_out = self.bpm_model(bpm_vector)
48
-
49
- spectrogram = self.ast_feature_extractor(audio)
50
- ast_out = self.ast_model(spectrogram)
51
-
52
- # Late fusion
53
- z = torch.cat([ast_out, bpm_out]) # Which dimension?
54
- return self.classifier(z)
55
 
 
 
56
 
57
- def compute_metrics(eval_pred):
58
- predictions = np.argmax(eval_pred.predictions, axis=1)
59
- return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
60
 
61
- def get_id_label_mapping(labels:list[str]) -> tuple[dict, dict]:
62
- id2label = {str(i) : label for i, label in enumerate(labels)}
63
- label2id = {label : str(i) for i, label in enumerate(labels)}
 
 
 
64
 
65
- return id2label, label2id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- def train(
68
- labels,
69
- train_ds,
70
- test_ds,
71
- output_dir="models/weights/ast",
72
- device="cpu",
73
- batch_size=128,
74
- epochs=10):
75
- id2label, label2id = get_id_label_mapping(labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
77
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
78
- preprocess_waveform = lambda wf : feature_extractor(wf, sampling_rate=train_ds.resample_frequency, padding="max_length", return_tensors="pt")
79
- train_ds.map(preprocess_waveform)
80
- test_ds.map(preprocess_waveform)
 
 
 
 
 
 
 
 
 
81
 
82
  model = AutoModelForAudioClassification.from_pretrained(
83
- model_checkpoint,
84
- num_labels=len(labels),
85
- label2id=label2id,
86
- id2label=id2label,
87
- ignore_mismatched_sizes=True
88
- ).to(device)
89
  training_args = TrainingArguments(
90
- output_dir=output_dir,
91
  evaluation_strategy="epoch",
92
  save_strategy="epoch",
93
  learning_rate=5e-5,
@@ -100,7 +145,7 @@ def train(
100
  load_best_model_at_end=True,
101
  metric_for_best_model="accuracy",
102
  push_to_hub=False,
103
- use_mps_device=device == "mps"
104
  )
105
 
106
  trainer = Trainer(
@@ -109,11 +154,7 @@ def train(
109
  train_dataset=train_ds,
110
  eval_dataset=test_ds,
111
  tokenizer=feature_extractor,
112
- compute_metrics=compute_metrics,
113
  )
114
  trainer.train()
115
  return model
116
-
117
-
118
-
119
-
 
1
+ from typing import Any
2
+ import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+ from transformers import (
5
+ AutoFeatureExtractor,
6
+ AutoModelForAudioClassification,
7
+ TrainingArguments,
8
+ Trainer,
9
+ ASTConfig,
10
+ ASTFeatureExtractor,
11
+ ASTForAudioClassification,
12
+ )
13
  import torch
14
  from torch import nn
15
+ from models.training_environment import TrainingEnvironment
16
+ from preprocessing.pipelines import WaveformTrainingPipeline
 
17
 
18
+ from preprocessing.dataset import (
19
+ DanceDataModule,
20
+ HuggingFaceDatasetWrapper,
21
+ get_datasets,
22
+ )
23
+ from preprocessing.dataset import get_music4dance_examples
24
+ from .utils import get_id_label_mapping, compute_hf_metrics
25
 
26
+ import pytorch_lightning as pl
27
+ from pytorch_lightning import callbacks as cb
28
 
29
+ MODEL_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
30
 
31
 
32
+ class AST(nn.Module):
33
+ def __init__(self, labels, *args, **kwargs) -> None:
34
  super().__init__(*args, **kwargs)
35
  id2label, label2id = get_id_label_mapping(labels)
36
+ config = ASTConfig(
37
+ hidden_size=300,
38
+ num_attention_heads=5,
39
+ num_hidden_layers=3,
40
+ id2label=id2label,
41
+ label2id=label2id,
42
+ num_labels=len(label2id),
43
+ ignore_mismatched_sizes=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
+ self.model = ASTForAudioClassification(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ def forward(self, x):
48
+ return self.model(x).logits
49
 
 
 
 
50
 
51
+ class ASTExtractorWrapper:
52
+ def __init__(self, sampling_rate=16000, return_tensors="pt") -> None:
53
+ self.extractor = ASTFeatureExtractor()
54
+ self.sampling_rate = sampling_rate
55
+ self.return_tensors = return_tensors
56
+ self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml
57
 
58
+ def __call__(self, x) -> Any:
59
+ x = self.waveform_pipeline(x)
60
+ device = x.device
61
+ x = x.squeeze(0).numpy()
62
+ x = self.extractor(
63
+ x, return_tensors=self.return_tensors, sampling_rate=self.sampling_rate
64
+ )
65
+ return x["input_values"].squeeze(0).to(device)
66
+
67
+
68
+ def train_lightning_ast(config: dict):
69
+ """
70
+ work on integration between waveform dataset and environment. Should work for both HF and PTL.
71
+ """
72
+ TARGET_CLASSES = config["dance_ids"]
73
+ DEVICE = config["device"]
74
+ SEED = config["seed"]
75
+ pl.seed_everything(SEED, workers=True)
76
+ feature_extractor = ASTExtractorWrapper()
77
+ dataset = get_datasets(config["datasets"], feature_extractor)
78
+ data = DanceDataModule(
79
+ dataset,
80
+ target_classes=TARGET_CLASSES,
81
+ **config["data_module"],
82
+ )
83
 
84
+ model = AST(TARGET_CLASSES).to(DEVICE)
85
+ label_weights = data.get_label_weights().to(DEVICE)
86
+ criterion = nn.CrossEntropyLoss(
87
+ label_weights
88
+ ) # LabelWeightedBCELoss(label_weights)
89
+ train_env = TrainingEnvironment(model, criterion, config)
90
+ callbacks = [
91
+ # cb.LearningRateFinder(update_attr=True),
92
+ cb.EarlyStopping("val/loss", patience=5),
93
+ cb.RichProgressBar(),
94
+ ]
95
+ trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
96
+ trainer.fit(train_env, datamodule=data)
97
+ trainer.test(train_env, datamodule=data)
98
+
99
+
100
+ def train_huggingface_ast(config: dict):
101
+ TARGET_CLASSES = config["dance_ids"]
102
+ DEVICE = config["device"]
103
+ SEED = config["seed"]
104
+ OUTPUT_DIR = "models/weights/ast"
105
+ batch_size = config["data_module"]["batch_size"]
106
+ epochs = config["data_module"]["min_epochs"]
107
+ test_proportion = config["data_module"].get("test_proportion", 0.2)
108
+ pl.seed_everything(SEED, workers=True)
109
+ dataset = get_datasets(config["datasets"])
110
+ hf_dataset = HuggingFaceDatasetWrapper(dataset)
111
+ id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
112
  model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
113
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
114
+ preprocess_waveform = lambda wf: feature_extractor(
115
+ wf,
116
+ sampling_rate=train_ds.resample_frequency,
117
+ # padding="max_length",
118
+ # return_tensors="pt",
119
+ )
120
+ hf_dataset.append_to_pipeline(preprocess_waveform)
121
+ test_proportion = config["data_module"]["test_proportion"]
122
+ train_proporition = 1 - test_proportion
123
+ train_ds, test_ds = torch.utils.data.random_split(
124
+ hf_dataset, [train_proporition, test_proportion]
125
+ )
126
 
127
  model = AutoModelForAudioClassification.from_pretrained(
128
+ model_checkpoint,
129
+ num_labels=len(TARGET_CLASSES),
130
+ label2id=label2id,
131
+ id2label=id2label,
132
+ ignore_mismatched_sizes=True,
133
+ ).to(DEVICE)
134
  training_args = TrainingArguments(
135
+ output_dir=OUTPUT_DIR,
136
  evaluation_strategy="epoch",
137
  save_strategy="epoch",
138
  learning_rate=5e-5,
 
145
  load_best_model_at_end=True,
146
  metric_for_best_model="accuracy",
147
  push_to_hub=False,
148
+ use_mps_device=DEVICE == "mps",
149
  )
150
 
151
  trainer = Trainer(
 
154
  train_dataset=train_ds,
155
  eval_dataset=test_ds,
156
  tokenizer=feature_extractor,
157
+ compute_metrics=compute_hf_metrics,
158
  )
159
  trainer.train()
160
  return model
 
 
 
 
models/config/decision_tree.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ id: decision_tree
3
+ device: mps
4
+ seed: 42
5
+ dance_ids:
6
+ - ATN
7
+ - BCH
8
+ - CHA
9
+ - ECS
10
+ - HST
11
+ - JIV
12
+ - QST
13
+ - RMB
14
+ - SFT
15
+ - SLS
16
+ - SMB
17
+ - SWZ
18
+ - TGO
19
+ - VWZ
20
+ - WCS
21
+ data_module:
22
+ song_data_path: data/songs_cleaned.csv
23
+ song_audio_path: data/samples
24
+ batch_size: 32
25
+ num_workers: 7
26
+ min_votes: 1
27
+ dataset_kwargs:
28
+ audio_window_duration: 6
29
+ audio_window_jitter: 1.5
30
+ audio_pipeline_kwargs:
31
+ mask_count: 0 # Don't mask the data
32
+ snr_mean: 15.0 # Pretty much eliminate the noise
33
+ freq_mask_size: 10
34
+ time_mask_size: 80
35
+
36
+ trainer:
37
+ log_every_n_steps: 15
38
+ accelerator: gpu
39
+ max_epochs: 50
40
+ min_epochs: 5
41
+ fast_dev_run: False
42
+ # gradient_clip_val: 0.5
43
+ # overfit_batches: 1
44
+ training_environment:
45
+ learning_rate: 0.00053
46
+ model:
47
+ n_channels: 128
models/config/train.yaml CHANGED
@@ -27,11 +27,11 @@ data_module:
27
  dataset_kwargs:
28
  audio_window_duration: 6
29
  audio_window_jitter: 1.5
30
- audio_pipeline_kwargs:
31
- mask_count: 0 # Don't mask the data
32
- snr_mean: 15.0 # Pretty much eliminate the noise
33
- freq_mask_size: 10
34
- time_mask_size: 80
35
 
36
  trainer:
37
  log_every_n_steps: 15
 
27
  dataset_kwargs:
28
  audio_window_duration: 6
29
  audio_window_jitter: 1.5
30
+ # audio_pipeline_kwargs:
31
+ # mask_count: 0 # Don't mask the data
32
+ # snr_mean: 15.0 # Pretty much eliminate the noise
33
+ # freq_mask_size: 10
34
+ # time_mask_size: 80
35
 
36
  trainer:
37
  log_every_n_steps: 15
models/config/train_local.yaml CHANGED
@@ -1,47 +1,58 @@
1
- global:
2
- id: ast_ptl # decision_tree
3
- device: mps
4
- seed: 42
5
- dance_ids:
6
- - ATN
7
- - BCH
8
- - CHA
9
- - ECS
10
- - HST
11
- - JIV
12
- - QST
13
- - RMB
14
- - SFT
15
- - SLS
16
- - SMB
17
- - SWZ
18
- - TGO
19
- - VWZ
20
- - WCS
21
  data_module:
22
- song_data_path: data/songs_cleaned.csv
23
- song_audio_path: data/samples
24
- batch_size: 32
25
- num_workers: 7
26
- min_votes: 1
27
- dataset_kwargs:
28
- audio_window_duration: 6
29
- audio_window_jitter: 1.5
30
- audio_pipeline_kwargs:
31
- mask_count: 0 # Don't mask the data
32
- snr_mean: 15.0 # Pretty much eliminate the noise
33
- freq_mask_size: 10
34
- time_mask_size: 80
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  trainer:
37
  log_every_n_steps: 15
38
  accelerator: gpu
39
  max_epochs: 50
40
- min_epochs: 5
41
  fast_dev_run: False
42
  # gradient_clip_val: 0.5
43
  # overfit_batches: 1
 
44
  training_environment:
45
  learning_rate: 0.00053
46
- model:
47
- n_channels: 128
 
1
+ training_fn: audio_spectrogram_transformer.train_lightning_ast
2
+ device: mps
3
+ seed: 42
4
+ dance_ids: &dance_ids
5
+ - BCH
6
+ - CHA
7
+ - JIV
8
+ - ECS
9
+ - QST
10
+ - RMB
11
+ - SFT
12
+ - SLS
13
+ - SMB
14
+ - SWZ
15
+ - TGO
16
+ - VWZ
17
+ - WCS
18
+
 
 
19
  data_module:
20
+ batch_size: 64
21
+ num_workers: 10
22
+ test_proportion: 0.2
23
+
24
+ datasets:
25
+ preprocessing.dataset.BestBallroomDataset:
26
+ audio_dir: data/ballroom-songs
27
+ class_list: *dance_ids
28
+ audio_window_jitter: 0.7
29
+
30
+ preprocessing.dataset.Music4DanceDataset:
31
+ song_data_path: data/songs_cleaned.csv
32
+ song_audio_path: data/samples # data/samples
33
+ class_list: *dance_ids
34
+ multi_label: False
35
+ min_votes: 1
36
+ audio_window_jitter: 0.7
37
+
38
+ model:
39
+ n_channels: 128
40
+
41
+ feature_extractor:
42
+ mask_count: 0 # Don't mask the data
43
+ snr_mean: 15.0 # Pretty much eliminate the noise
44
+ freq_mask_size: 10
45
+ time_mask_size: 80
46
 
47
  trainer:
48
  log_every_n_steps: 15
49
  accelerator: gpu
50
  max_epochs: 50
51
+ min_epochs: 7
52
  fast_dev_run: False
53
  # gradient_clip_val: 0.5
54
  # overfit_batches: 1
55
+
56
  training_environment:
57
  learning_rate: 0.00053
58
+ log_spectrograms: False
 
models/decision_tree.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from sklearn.base import ClassifierMixin, BaseEstimator
2
  import pandas as pd
3
  from torch import nn
@@ -5,8 +6,14 @@ import torch
5
  from typing import Iterator
6
  import numpy as np
7
  import json
 
8
  from tqdm import tqdm
9
  import librosa
 
 
 
 
 
10
 
11
  DANCE_INFO_FILE = "data/dance_info.csv"
12
  dance_info_df = pd.read_csv(
@@ -24,9 +31,8 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
24
  - BPM
25
  """
26
 
27
- def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None:
28
  self.device = device
29
- self.epochs = epochs
30
  self.verbose = verbose
31
  self.lr = lr
32
  self.classifiers = {}
@@ -44,41 +50,40 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
44
  x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
45
  y: (batch_size, n_classes)
46
  """
47
- progress_bar = tqdm(range(self.epochs))
48
- for _ in progress_bar:
49
- # TODO: Introduce batches
50
- epoch_loss = 0
51
- pred_count = 0
52
- step = 0
53
- for (spec, bpm), label in zip(x, y):
54
- step += 1
55
- # find all models that are in the bpm range
56
- matching_dances = self.get_valid_dances_from_bpm(bpm)
57
- spec = torch.from_numpy(spec).to(self.device)
58
- for dance in matching_dances:
59
- if dance not in self.classifiers or dance not in self.optimizers:
60
- classifier = DanceCNN().to(self.device)
61
- self.classifiers[dance] = classifier
62
- self.optimizers[dance] = torch.optim.Adam(
63
- classifier.parameters(), lr=self.lr
64
- )
65
- models = [
66
- (dance, model, self.optimizers[dance])
67
- for dance, model in self.classifiers.items()
68
- if dance in matching_dances
69
- ]
70
- for model_i, (dance, model, opt) in enumerate(models):
71
- opt.zero_grad()
72
- output = model(spec)
73
- target = torch.tensor([float(dance == label)], device=self.device)
74
- loss = self.criterion(output, target)
75
- epoch_loss += loss.item()
76
- pred_count += 1
77
- loss.backward()
78
- opt.step()
79
- progress_bar.set_description(
80
- f"Loss: {epoch_loss / pred_count}, Step: {step}, Model: {model_i+1}/{len(models)}"
81
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def predict(self, x) -> list[str]:
84
  results = []
@@ -90,6 +95,52 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
90
  results.append(matching_dances[dance_i])
91
  return results
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  class DanceCNN(nn.Module):
95
  def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
@@ -136,7 +187,6 @@ def features_from_path(
136
  num_frames = audio_window_duration * sr
137
  tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
138
  spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
139
- mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
140
  spec_normalized = (spec - spec.mean()) / spec.std()
141
  spec_padded = librosa.util.fix_length(
142
  spec_normalized, size=sr * audio_duration, axis=1
@@ -145,3 +195,40 @@ def features_from_path(
145
  for i in range(audio_duration // audio_window_duration):
146
  spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
147
  yield (spec_window, tempo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
  from sklearn.base import ClassifierMixin, BaseEstimator
3
  import pandas as pd
4
  from torch import nn
 
6
  from typing import Iterator
7
  import numpy as np
8
  import json
9
+ from torch.utils.data import random_split
10
  from tqdm import tqdm
11
  import librosa
12
+ from joblib import dump, load
13
+ from os import path
14
+ import os
15
+
16
+ from preprocessing.dataset import get_music4dance_examples
17
 
18
  DANCE_INFO_FILE = "data/dance_info.csv"
19
  dance_info_df = pd.read_csv(
 
31
  - BPM
32
  """
33
 
34
+ def __init__(self, device="cpu", lr=1e-4, verbose=True) -> None:
35
  self.device = device
 
36
  self.verbose = verbose
37
  self.lr = lr
38
  self.classifiers = {}
 
50
  x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
51
  y: (batch_size, n_classes)
52
  """
53
+ epoch_loss = 0
54
+ pred_count = 0
55
+ data_loader = zip(x, y)
56
+ if self.verbose:
57
+ data_loader = tqdm(data_loader, total=len(y))
58
+ for (spec, bpm), label in data_loader:
59
+ # find all models that are in the bpm range
60
+ matching_dances = self.get_valid_dances_from_bpm(bpm)
61
+ spec = torch.from_numpy(spec).to(self.device)
62
+ for dance in matching_dances:
63
+ if dance not in self.classifiers or dance not in self.optimizers:
64
+ classifier = DanceCNN().to(self.device)
65
+ self.classifiers[dance] = classifier
66
+ self.optimizers[dance] = torch.optim.Adam(
67
+ classifier.parameters(), lr=self.lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
+ models = [
70
+ (dance, model, self.optimizers[dance])
71
+ for dance, model in self.classifiers.items()
72
+ if dance in matching_dances
73
+ ]
74
+ for model_i, (dance, model, opt) in enumerate(models, start=1):
75
+ opt.zero_grad()
76
+ output = model(spec)
77
+ target = torch.tensor([float(dance == label)], device=self.device)
78
+ loss = self.criterion(output, target)
79
+ epoch_loss += loss.item()
80
+ pred_count += 1
81
+ loss.backward()
82
+ if self.verbose:
83
+ data_loader.set_description(
84
+ f"model: {model_i}/{len(models)}, loss: {loss.item()}"
85
+ )
86
+ opt.step()
87
 
88
  def predict(self, x) -> list[str]:
89
  results = []
 
95
  results.append(matching_dances[dance_i])
96
  return results
97
 
98
+ def save(self, folder: str):
99
+ # Create a folder
100
+ classifier_path = path.join(folder, "classifier")
101
+ os.makedirs(classifier_path, exist_ok=True)
102
+
103
+ # Swap out model reference
104
+ classifiers = self.classifiers
105
+ optimizers = self.optimizers
106
+ criterion = self.criterion
107
+
108
+ self.classifiers = None
109
+ self.optimizers = None
110
+ self.criterion = None
111
+
112
+ # Save the Pth models
113
+ for dance, classifier in classifiers.items():
114
+ torch.save(
115
+ classifier.state_dict(), path.join(classifier_path, dance + ".pth")
116
+ )
117
+
118
+ # Save the Sklearn model
119
+ dump(path.join(folder, "sklearn.joblib"))
120
+
121
+ # Reload values
122
+ self.classifiers = classifiers
123
+ self.optimizers = optimizers
124
+ self.criterion = criterion
125
+
126
+ @staticmethod
127
+ def from_config(folder: str, device="cpu") -> "DanceTreeClassifier":
128
+ # load in weights
129
+ model_paths = (
130
+ p for p in os.listdir(path.join(folder, "classifier")) if p.endswith("pth")
131
+ )
132
+ classifiers = {}
133
+ for model_path in model_paths:
134
+ dance = model_path.split(".")[0]
135
+ model = DanceCNN().to(device)
136
+ model.load_state_dict(
137
+ torch.load(path.join(folder, "classifier", model_path))
138
+ )
139
+ classifiers[dance] = model
140
+ wrapper = load(path.join(folder, "sklearn.joblib"))
141
+ wrapper.classifiers = classifiers
142
+ return wrapper
143
+
144
 
145
  class DanceCNN(nn.Module):
146
  def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
 
187
  num_frames = audio_window_duration * sr
188
  tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
189
  spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
 
190
  spec_normalized = (spec - spec.mean()) / spec.std()
191
  spec_padded = librosa.util.fix_length(
192
  spec_normalized, size=sr * audio_duration, axis=1
 
195
  for i in range(audio_duration // audio_window_duration):
196
  spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
197
  yield (spec_window, tempo)
198
+
199
+
200
+ def train_decision_tree(config: dict):
201
+ TARGET_CLASSES = config["global"]["dance_ids"]
202
+ DEVICE = config["global"]["device"]
203
+ SEED = config["global"]["seed"]
204
+ SEED = config["global"]["seed"]
205
+ EPOCHS = config["trainer"]["min_epochs"]
206
+ song_data_path = config["data_module"]["song_data_path"]
207
+ song_audio_path = config["data_module"]["song_audio_path"]
208
+ pl.seed_everything(SEED, workers=True)
209
+
210
+ df = pd.read_csv(song_data_path)
211
+ x, y = get_music4dance_examples(
212
+ df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
213
+ )
214
+ # Convert y back to string classes
215
+ y = np.array(TARGET_CLASSES)[y.argmax(-1)]
216
+ train_i, test_i = random_split(
217
+ np.arange(len(x)), [0.1, 0.9]
218
+ ) # Temporary to test efficacy
219
+ train_paths, train_y = x[train_i], y[train_i]
220
+ model = DanceTreeClassifier(device=DEVICE)
221
+ for epoch in tqdm(range(1, EPOCHS + 1)):
222
+ # Shuffle the data
223
+ i = np.arange(len(train_paths))
224
+ np.random.shuffle(i)
225
+ train_paths = train_paths[i]
226
+ train_y = train_y[i]
227
+ train_x = features_from_path(train_paths)
228
+ model.fit(train_x, train_y)
229
+
230
+ # evaluate the model
231
+ preds = model.predict(x[test_i])
232
+ accuracy = (preds == y[test_i]).mean()
233
+ print(f"{accuracy=}")
234
+ model.save("models/weights/decision_tree")
models/residual.py CHANGED
@@ -1,18 +1,25 @@
 
 
1
  import torch
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- import pytorch_lightning as pl
6
  import numpy as np
7
  import torchaudio
8
  import yaml
9
- from .utils import calculate_metrics
10
- from preprocessing.pipelines import WaveformPreprocessing, AudioToSpectrogram
 
 
 
 
11
 
12
  # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
13
 
 
14
  class ResidualDancer(nn.Module):
15
- def __init__(self,n_channels=128, n_classes=50):
16
  super().__init__()
17
 
18
  self.n_channels = n_channels
@@ -25,17 +32,17 @@ class ResidualDancer(nn.Module):
25
  self.res_layers = nn.Sequential(
26
  ResBlock(1, n_channels, stride=2),
27
  ResBlock(n_channels, n_channels, stride=2),
28
- ResBlock(n_channels, n_channels*2, stride=2),
29
- ResBlock(n_channels*2, n_channels*2, stride=2),
30
- ResBlock(n_channels*2, n_channels*2, stride=2),
31
- ResBlock(n_channels*2, n_channels*2, stride=2),
32
- ResBlock(n_channels*2, n_channels*4, stride=2)
33
  )
34
 
35
  # Dense
36
- self.dense1 = nn.Linear(n_channels*4, n_channels*4)
37
- self.bn = nn.BatchNorm1d(n_channels*4)
38
- self.dense2 = nn.Linear(n_channels*4, n_classes)
39
  self.dropout = nn.Dropout(0.2)
40
 
41
  def forward(self, x):
@@ -56,24 +63,34 @@ class ResidualDancer(nn.Module):
56
  x = F.relu(x)
57
  x = self.dropout(x)
58
  x = self.dense2(x)
59
- x = nn.Sigmoid()(x)
60
 
61
  return x
62
-
63
 
64
  class ResBlock(nn.Module):
65
  def __init__(self, input_channels, output_channels, shape=3, stride=2):
66
  super().__init__()
67
  # convolution
68
- self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
 
 
69
  self.bn_1 = nn.BatchNorm2d(output_channels)
70
- self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
 
 
71
  self.bn_2 = nn.BatchNorm2d(output_channels)
72
 
73
  # residual
74
  self.diff = False
75
  if (stride != 1) or (input_channels != output_channels):
76
- self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
 
 
 
 
 
 
77
  self.bn_3 = nn.BatchNorm2d(output_channels)
78
  self.diff = True
79
  self.relu = nn.ReLU()
@@ -89,79 +106,31 @@ class ResBlock(nn.Module):
89
  out = self.relu(out)
90
  return out
91
 
92
- class TrainingEnvironment(pl.LightningModule):
93
-
94
- def __init__(self, model: nn.Module, criterion: nn.Module, config:dict, learning_rate=1e-4, *args, **kwargs):
95
- super().__init__(*args, **kwargs)
96
- self.model = model
97
- self.criterion = criterion
98
- self.learning_rate = learning_rate
99
- self.config=config
100
- self.save_hyperparameters({
101
- "model": type(model).__name__,
102
- "loss": type(criterion).__name__,
103
- "config": config,
104
- **kwargs
105
- })
106
-
107
- def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
108
- features, labels = batch
109
- outputs = self.model(features)
110
- loss = self.criterion(outputs, labels)
111
- metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
112
- self.log_dict(metrics, prog_bar=True)
113
- # Log spectrograms
114
- if batch_index % 100 == 0:
115
- tensorboard = self.logger.experiment
116
- img_index = torch.randint(0, len(features), (1,)).item()
117
- img = features[img_index][0]
118
- img = (img - img.min()) / (img.max() - img.min())
119
- tensorboard.add_image(f"batch: {batch_index}, element: {img_index}", img, 0, dataformats='HW')
120
- return loss
121
-
122
-
123
- def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
124
- x, y = batch
125
- preds = self.model(x)
126
- metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
127
- metrics["val/loss"] = self.criterion(preds, y)
128
- self.log_dict(metrics,prog_bar=True)
129
-
130
- def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
131
- x, y = batch
132
- preds = self.model(x)
133
- self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
134
-
135
- def configure_optimizers(self):
136
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
137
- # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
138
- return [optimizer]
139
-
140
-
141
 
142
  class DancePredictor:
143
  def __init__(
144
- self,
145
- weight_path:str,
146
- labels:list[str],
147
- expected_duration=6,
148
  threshold=0.5,
149
  resample_frequency=16000,
150
- device="cpu"):
151
-
152
  super().__init__()
153
-
154
  self.expected_duration = expected_duration
155
  self.threshold = threshold
156
  self.resample_frequency = resample_frequency
157
- self.preprocess_waveform = WaveformPreprocessing(resample_frequency * expected_duration)
158
- self.audio_to_spectrogram = AudioToSpectrogram(resample_frequency)
 
 
159
  self.labels = np.array(labels)
160
  self.device = device
161
  self.model = self.get_model(weight_path)
162
 
163
-
164
- def get_model(self, weight_path:str) -> nn.Module:
165
  weights = torch.load(weight_path, map_location=self.device)["state_dict"]
166
  model = ResidualDancer(n_classes=len(self.labels))
167
  for key in list(weights):
@@ -170,21 +139,25 @@ class DancePredictor:
170
  return model.to(self.device).eval()
171
 
172
  @classmethod
173
- def from_config(cls, config_path:str) -> "DancePredictor":
174
  with open(config_path, "r") as f:
175
  config = yaml.safe_load(f)
176
  return DancePredictor(**config)
177
 
178
  @torch.no_grad()
179
- def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
180
  if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
181
- waveform = waveform.transpose(1,0)
182
  elif len(waveform.shape) == 1:
183
  waveform = np.expand_dims(waveform, 0)
184
  waveform = torch.from_numpy(waveform.astype("int16"))
185
- waveform = torchaudio.functional.apply_codec(waveform,sample_rate, "wav", channels_first=True)
 
 
186
 
187
- waveform = torchaudio.functional.resample(waveform, sample_rate,self.resample_frequency)
 
 
188
  waveform = self.preprocess_waveform(waveform)
189
  spectrogram = self.audio_to_spectrogram(waveform)
190
  spectrogram = spectrogram.unsqueeze(0).to(self.device)
@@ -194,8 +167,31 @@ class DancePredictor:
194
  result_mask = results > self.threshold
195
  probs = results[result_mask]
196
  dances = self.labels[result_mask]
197
-
198
- return {dance:float(prob) for dance, prob in zip(dances, probs)}
199
-
200
-
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning import callbacks as cb
3
  import torch
4
+ from torch import nn
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
 
 
8
  import numpy as np
9
  import torchaudio
10
  import yaml
11
+ from models.training_environment import TrainingEnvironment
12
+ from preprocessing.dataset import DanceDataModule, get_datasets
13
+ from preprocessing.pipelines import (
14
+ SpectrogramTrainingPipeline,
15
+ WaveformPreprocessing,
16
+ )
17
 
18
  # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
19
 
20
+
21
  class ResidualDancer(nn.Module):
22
+ def __init__(self, n_channels=128, n_classes=50):
23
  super().__init__()
24
 
25
  self.n_channels = n_channels
 
32
  self.res_layers = nn.Sequential(
33
  ResBlock(1, n_channels, stride=2),
34
  ResBlock(n_channels, n_channels, stride=2),
35
+ ResBlock(n_channels, n_channels * 2, stride=2),
36
+ ResBlock(n_channels * 2, n_channels * 2, stride=2),
37
+ ResBlock(n_channels * 2, n_channels * 2, stride=2),
38
+ ResBlock(n_channels * 2, n_channels * 2, stride=2),
39
+ ResBlock(n_channels * 2, n_channels * 4, stride=2),
40
  )
41
 
42
  # Dense
43
+ self.dense1 = nn.Linear(n_channels * 4, n_channels * 4)
44
+ self.bn = nn.BatchNorm1d(n_channels * 4)
45
+ self.dense2 = nn.Linear(n_channels * 4, n_classes)
46
  self.dropout = nn.Dropout(0.2)
47
 
48
  def forward(self, x):
 
63
  x = F.relu(x)
64
  x = self.dropout(x)
65
  x = self.dense2(x)
66
+ # x = nn.Sigmoid()(x)
67
 
68
  return x
69
+
70
 
71
  class ResBlock(nn.Module):
72
  def __init__(self, input_channels, output_channels, shape=3, stride=2):
73
  super().__init__()
74
  # convolution
75
+ self.conv_1 = nn.Conv2d(
76
+ input_channels, output_channels, shape, stride=stride, padding=shape // 2
77
+ )
78
  self.bn_1 = nn.BatchNorm2d(output_channels)
79
+ self.conv_2 = nn.Conv2d(
80
+ output_channels, output_channels, shape, padding=shape // 2
81
+ )
82
  self.bn_2 = nn.BatchNorm2d(output_channels)
83
 
84
  # residual
85
  self.diff = False
86
  if (stride != 1) or (input_channels != output_channels):
87
+ self.conv_3 = nn.Conv2d(
88
+ input_channels,
89
+ output_channels,
90
+ shape,
91
+ stride=stride,
92
+ padding=shape // 2,
93
+ )
94
  self.bn_3 = nn.BatchNorm2d(output_channels)
95
  self.diff = True
96
  self.relu = nn.ReLU()
 
106
  out = self.relu(out)
107
  return out
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  class DancePredictor:
111
  def __init__(
112
+ self,
113
+ weight_path: str,
114
+ labels: list[str],
115
+ expected_duration=6,
116
  threshold=0.5,
117
  resample_frequency=16000,
118
+ device="cpu",
119
+ ):
120
  super().__init__()
121
+
122
  self.expected_duration = expected_duration
123
  self.threshold = threshold
124
  self.resample_frequency = resample_frequency
125
+ self.preprocess_waveform = WaveformPreprocessing(
126
+ resample_frequency * expected_duration
127
+ )
128
+ self.audio_to_spectrogram = lambda x: x # TODO: Fix
129
  self.labels = np.array(labels)
130
  self.device = device
131
  self.model = self.get_model(weight_path)
132
 
133
+ def get_model(self, weight_path: str) -> nn.Module:
 
134
  weights = torch.load(weight_path, map_location=self.device)["state_dict"]
135
  model = ResidualDancer(n_classes=len(self.labels))
136
  for key in list(weights):
 
139
  return model.to(self.device).eval()
140
 
141
  @classmethod
142
+ def from_config(cls, config_path: str) -> "DancePredictor":
143
  with open(config_path, "r") as f:
144
  config = yaml.safe_load(f)
145
  return DancePredictor(**config)
146
 
147
  @torch.no_grad()
148
+ def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]:
149
  if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
150
+ waveform = waveform.transpose(1, 0)
151
  elif len(waveform.shape) == 1:
152
  waveform = np.expand_dims(waveform, 0)
153
  waveform = torch.from_numpy(waveform.astype("int16"))
154
+ waveform = torchaudio.functional.apply_codec(
155
+ waveform, sample_rate, "wav", channels_first=True
156
+ )
157
 
158
+ waveform = torchaudio.functional.resample(
159
+ waveform, sample_rate, self.resample_frequency
160
+ )
161
  waveform = self.preprocess_waveform(waveform)
162
  spectrogram = self.audio_to_spectrogram(waveform)
163
  spectrogram = spectrogram.unsqueeze(0).to(self.device)
 
167
  result_mask = results > self.threshold
168
  probs = results[result_mask]
169
  dances = self.labels[result_mask]
 
 
 
 
170
 
171
+ return {dance: float(prob) for dance, prob in zip(dances, probs)}
172
+
173
+
174
+ def train_residual_dancer(config: dict):
175
+ TARGET_CLASSES = config["dance_ids"]
176
+ DEVICE = config["device"]
177
+ SEED = config["seed"]
178
+ pl.seed_everything(SEED, workers=True)
179
+ feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
180
+ dataset = get_datasets(config["datasets"], feature_extractor)
181
+
182
+ data = DanceDataModule(dataset, **config["data_module"])
183
+ model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
184
+ label_weights = data.get_label_weights().to(DEVICE)
185
+ criterion = nn.CrossEntropyLoss(label_weights)
186
+
187
+ train_env = TrainingEnvironment(model, criterion, config)
188
+ callbacks = [
189
+ # cb.LearningRateFinder(update_attr=True),
190
+ cb.EarlyStopping("val/loss", patience=5),
191
+ cb.StochasticWeightAveraging(1e-2),
192
+ cb.RichProgressBar(),
193
+ cb.DeviceStatsMonitor(),
194
+ ]
195
+ trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
196
+ trainer.fit(train_env, datamodule=data)
197
+ trainer.test(train_env, datamodule=data)
models/training_environment.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.utils import calculate_metrics
2
+
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class TrainingEnvironment(pl.LightningModule):
10
+ def __init__(
11
+ self,
12
+ model: nn.Module,
13
+ criterion: nn.Module,
14
+ config: dict,
15
+ learning_rate=1e-4,
16
+ log_spectrograms=False,
17
+ *args,
18
+ **kwargs,
19
+ ):
20
+ super().__init__(*args, **kwargs)
21
+ self.model = model
22
+ self.criterion = criterion
23
+ self.learning_rate = learning_rate
24
+ self.log_spectrograms = log_spectrograms
25
+ self.config = config
26
+ self.has_multi_label_predictions = (
27
+ not type(criterion).__name__ == "CrossEntropyLoss"
28
+ )
29
+ self.save_hyperparameters(
30
+ {
31
+ "model": type(model).__name__,
32
+ "loss": type(criterion).__name__,
33
+ "config": config,
34
+ **kwargs,
35
+ }
36
+ )
37
+
38
+ def training_step(
39
+ self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
40
+ ) -> torch.Tensor:
41
+ features, labels = batch
42
+ outputs = self.model(features)
43
+ loss = self.criterion(outputs, labels)
44
+ metrics = calculate_metrics(
45
+ outputs,
46
+ labels,
47
+ prefix="train/",
48
+ multi_label=self.has_multi_label_predictions,
49
+ )
50
+ self.log_dict(metrics, prog_bar=True)
51
+ # Log spectrograms
52
+ if self.log_spectrograms and batch_index % 100 == 0:
53
+ tensorboard = self.logger.experiment
54
+ img_index = torch.randint(0, len(features), (1,)).item()
55
+ img = features[img_index][0]
56
+ img = (img - img.min()) / (img.max() - img.min())
57
+ tensorboard.add_image(
58
+ f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
59
+ )
60
+ return loss
61
+
62
+ def validation_step(
63
+ self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
64
+ ):
65
+ x, y = batch
66
+ preds = self.model(x)
67
+ metrics = calculate_metrics(
68
+ preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
69
+ )
70
+ metrics["val/loss"] = self.criterion(preds, y)
71
+ self.log_dict(metrics, prog_bar=True)
72
+
73
+ def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
74
+ x, y = batch
75
+ preds = self.model(x)
76
+ self.log_dict(
77
+ calculate_metrics(
78
+ preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
79
+ ),
80
+ prog_bar=True,
81
+ )
82
+
83
+ def configure_optimizers(self):
84
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
85
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
86
+ return {
87
+ "optimizer": optimizer,
88
+ "lr_scheduler": scheduler,
89
+ "monitor": "val/loss",
90
+ }
models/utils.py CHANGED
@@ -1,14 +1,20 @@
1
  import torch.nn as nn
2
  import torch
3
  import numpy as np
 
4
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
5
 
 
 
 
 
6
  class LabelWeightedBCELoss(nn.Module):
7
  """
8
  Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
9
  Allows for the weighing of each probability distribution wrt loss.
10
  """
11
- def __init__(self, label_weights:torch.Tensor, reduction="mean"):
 
12
  super().__init__()
13
  self.label_weights = label_weights
14
 
@@ -17,46 +23,67 @@ class LabelWeightedBCELoss(nn.Module):
17
  self.reduction = torch.mean
18
  case "sum":
19
  self.reduction = torch.sum
20
-
21
- def _log(self,x:torch.Tensor) -> torch.Tensor:
22
  return torch.clamp_min(torch.log(x), -100)
23
 
24
  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
25
- losses = -self.label_weights * (target * self._log(input) + (1-target) * self._log(1-input))
 
 
26
  return self.reduction(losses)
27
 
28
 
29
  # TODO: Code a onehot
30
 
31
 
32
- def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
 
 
33
  target = target.detach().cpu().numpy()
34
  pred = pred.detach().cpu().numpy()
35
  params = {
36
- "y_true": target if multi_label else target.argmax(1) ,
37
- "y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1),
38
- "zero_division": 0,
39
- "average":"macro"
40
- }
41
- metrics= {
42
- 'precision': precision_score(**params),
43
- 'recall': recall_score(**params),
44
- 'f1': f1_score(**params),
45
- 'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
46
- }
47
- return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
 
 
 
 
 
48
 
49
  class EarlyStopping:
50
  def __init__(self, patience=0):
51
  self.patience = patience
52
  self.last_measure = np.inf
53
  self.consecutive_increase = 0
54
-
55
  def step(self, val) -> bool:
56
  if self.last_measure <= val:
57
- self.consecutive_increase +=1
58
  else:
59
  self.consecutive_increase = 0
60
  self.last_measure = val
61
 
62
- return self.patience < self.consecutive_increase
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch.nn as nn
2
  import torch
3
  import numpy as np
4
+ import evaluate
5
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
6
 
7
+
8
+ accuracy = evaluate.load("accuracy")
9
+
10
+
11
  class LabelWeightedBCELoss(nn.Module):
12
  """
13
  Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
14
  Allows for the weighing of each probability distribution wrt loss.
15
  """
16
+
17
+ def __init__(self, label_weights: torch.Tensor, reduction="mean"):
18
  super().__init__()
19
  self.label_weights = label_weights
20
 
 
23
  self.reduction = torch.mean
24
  case "sum":
25
  self.reduction = torch.sum
26
+
27
+ def _log(self, x: torch.Tensor) -> torch.Tensor:
28
  return torch.clamp_min(torch.log(x), -100)
29
 
30
  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
31
+ losses = -self.label_weights * (
32
+ target * self._log(input) + (1 - target) * self._log(1 - input)
33
+ )
34
  return self.reduction(losses)
35
 
36
 
37
  # TODO: Code a onehot
38
 
39
 
40
+ def calculate_metrics(
41
+ pred, target, threshold=0.5, prefix="", multi_label=True
42
+ ) -> dict[str, torch.Tensor]:
43
  target = target.detach().cpu().numpy()
44
  pred = pred.detach().cpu().numpy()
45
  params = {
46
+ "y_true": target if multi_label else target.argmax(1),
47
+ "y_pred": np.array(pred > threshold, dtype=float)
48
+ if multi_label
49
+ else pred.argmax(1),
50
+ "zero_division": 0,
51
+ "average": "macro",
52
+ }
53
+ metrics = {
54
+ "precision": precision_score(**params),
55
+ "recall": recall_score(**params),
56
+ "f1": f1_score(**params),
57
+ "accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
58
+ }
59
+ return {
60
+ prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
61
+ }
62
+
63
 
64
  class EarlyStopping:
65
  def __init__(self, patience=0):
66
  self.patience = patience
67
  self.last_measure = np.inf
68
  self.consecutive_increase = 0
69
+
70
  def step(self, val) -> bool:
71
  if self.last_measure <= val:
72
+ self.consecutive_increase += 1
73
  else:
74
  self.consecutive_increase = 0
75
  self.last_measure = val
76
 
77
+ return self.patience < self.consecutive_increase
78
+
79
+
80
+ def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
81
+ id2label = {str(i): label for i, label in enumerate(labels)}
82
+ label2id = {label: str(i) for i, label in enumerate(labels)}
83
+
84
+ return id2label, label2id
85
+
86
+
87
+ def compute_hf_metrics(eval_pred):
88
+ predictions = np.argmax(eval_pred.predictions, axis=1)
89
+ return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
models/wav2vec2.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import random_split
5
+ from transformers import AutoFeatureExtractor
6
+ from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
7
+
8
+ from preprocessing.dataset import (
9
+ HuggingFaceDatasetWrapper,
10
+ BestBallroomDataset,
11
+ get_datasets,
12
+ )
13
+ from preprocessing.pipelines import WaveformTrainingPipeline
14
+
15
+ from .utils import get_id_label_mapping, compute_hf_metrics
16
+
17
+ MODEL_CHECKPOINT = "facebook/wav2vec2-base"
18
+
19
+
20
+ class Wav2VecFeatureExtractor:
21
+ def __init__(self) -> None:
22
+ self.waveform_pipeline = WaveformTrainingPipeline()
23
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
24
+ MODEL_CHECKPOINT,
25
+ )
26
+
27
+ def __call__(self, waveform) -> Any:
28
+ waveform = self.waveform_pipeline(waveform)
29
+ return self.feature_extractor(
30
+ waveform, sampling_rate=self.feature_extractor.sampling_rate
31
+ )
32
+
33
+ def __getattr__(self, attr):
34
+ return getattr(self.feature_extractor, attr)
35
+
36
+
37
+ def train_wav_model(config: dict):
38
+ TARGET_CLASSES = config["dance_ids"]
39
+ DEVICE = config["device"]
40
+ SEED = config["seed"]
41
+ OUTPUT_DIR = "models/weights/wav2vec2"
42
+ batch_size = config["data_module"]["batch_size"]
43
+ epochs = config["trainer"]["min_epochs"]
44
+ test_proportion = config["data_module"].get("test_proportion", 0.2)
45
+ pl.seed_everything(SEED, workers=True)
46
+ dataset = get_datasets(config["datasets"])
47
+ id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
48
+ test_proportion = config["data_module"]["test_proportion"]
49
+ train_proporition = 1 - test_proportion
50
+ train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
51
+ feature_extractor = Wav2VecFeatureExtractor()
52
+ model = AutoModelForAudioClassification.from_pretrained(
53
+ MODEL_CHECKPOINT,
54
+ num_labels=len(TARGET_CLASSES),
55
+ label2id=label2id,
56
+ id2label=id2label,
57
+ ignore_mismatched_sizes=True,
58
+ ).to(DEVICE)
59
+ training_args = TrainingArguments(
60
+ output_dir=OUTPUT_DIR,
61
+ evaluation_strategy="epoch",
62
+ save_strategy="epoch",
63
+ learning_rate=3e-5,
64
+ per_device_train_batch_size=batch_size,
65
+ gradient_accumulation_steps=5,
66
+ per_device_eval_batch_size=batch_size,
67
+ num_train_epochs=epochs,
68
+ warmup_ratio=0.1,
69
+ logging_steps=10,
70
+ load_best_model_at_end=True,
71
+ metric_for_best_model="accuracy",
72
+ push_to_hub=False,
73
+ use_mps_device=DEVICE == "mps",
74
+ )
75
+ trainer = Trainer(
76
+ model=model,
77
+ args=training_args,
78
+ train_dataset=train_ds,
79
+ eval_dataset=test_ds,
80
+ tokenizer=feature_extractor,
81
+ compute_metrics=compute_hf_metrics,
82
+ )
83
+ trainer.train()
84
+ return model
preprocessing/dataset.py CHANGED
@@ -1,15 +1,21 @@
 
 
 
1
  import torch
2
- from torch.utils.data import Dataset, DataLoader, random_split
3
  import numpy as np
4
  import pandas as pd
5
  import torchaudio as ta
6
- from .pipelines import AudioTrainingPipeline
7
  import pytorch_lightning as pl
8
- from .preprocess import get_examples
9
- from sklearn.model_selection import train_test_split
10
- from torchaudio import transforms as taT
11
- from torch import nn
12
- from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
 
 
 
 
13
 
14
 
15
  class SongDataset(Dataset):
@@ -17,60 +23,67 @@ class SongDataset(Dataset):
17
  self,
18
  audio_paths: list[str],
19
  dance_labels: list[np.ndarray],
20
- audio_duration=30, # seconds
21
  audio_window_duration=6, # seconds
22
- audio_window_jitter=0.0, # seconds
23
- audio_pipeline_kwargs={},
24
- resample_frequency=16000,
25
  ):
26
- assert (
27
- audio_duration % audio_window_duration == 0
28
- ), "Audio window should divide duration evenly."
29
  assert (
30
  audio_window_duration > audio_window_jitter
31
  ), "Jitter should be a small fraction of the audio window duration."
32
 
33
  self.audio_paths = audio_paths
34
  self.dance_labels = dance_labels
35
- audio_info = ta.info(audio_paths[0])
36
- self.sample_rate = audio_info.sample_rate
 
 
 
37
  self.audio_window_duration = int(audio_window_duration)
 
38
  self.audio_window_jitter = audio_window_jitter
39
- self.audio_duration = int(audio_duration)
40
-
41
- self.audio_pipeline = AudioTrainingPipeline(
42
- self.sample_rate,
43
- resample_frequency,
44
- audio_window_duration,
45
- **audio_pipeline_kwargs,
46
- )
47
 
48
  def __len__(self):
49
- return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
 
 
 
 
 
50
 
51
  def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
52
- waveform = self._waveform_from_index(idx)
53
- assert (
54
- waveform.shape[1] > 10
55
- ), f"No data found: {self._backtrace_audio_path(idx)}"
56
- spectrogram = self.audio_pipeline(waveform)
57
 
 
58
  dance_labels = self._label_from_index(idx)
 
59
 
60
- example_is_valid = self._validate_output(spectrogram, dance_labels)
61
- if example_is_valid:
62
- return spectrogram, dance_labels
63
- else:
64
- # Try the previous one
65
- # This happens when some of the audio recordings are really quiet
66
- # This WILL NOT leak into other data partitions because songs belong entirely to a partition
67
- return self[idx - 1]
 
 
 
 
 
 
 
 
 
68
 
69
- def _convert_idx(self, idx: int) -> int:
70
- return idx * self.audio_window_duration // self.audio_duration
 
71
 
72
  def _backtrace_audio_path(self, index: int) -> str:
73
- return self.audio_paths[self._convert_idx(index)]
74
 
75
  def _validate_output(self, x, y):
76
  is_finite = not torch.any(torch.isinf(x))
@@ -80,16 +93,18 @@ class SongDataset(Dataset):
80
  return all((is_finite, is_numerical, has_data, is_binary))
81
 
82
  def _waveform_from_index(self, idx: int) -> torch.Tensor:
83
- audio_filepath = self.audio_paths[self._convert_idx(idx)]
84
- num_windows = self.audio_duration // self.audio_window_duration
85
- frame_index = idx % num_windows
86
  jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
87
  jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
88
  jitter = int(
89
  torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
90
  )
91
- frame_offset = (
92
- frame_index * self.audio_window_duration * self.sample_rate + jitter
 
 
93
  )
94
  num_frames = self.sample_rate * self.audio_window_duration
95
  waveform, sample_rate = ta.load(
@@ -101,41 +116,21 @@ class SongDataset(Dataset):
101
  return waveform
102
 
103
  def _label_from_index(self, idx: int) -> torch.Tensor:
104
- return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
105
 
106
 
107
- class WaveformSongDataset(SongDataset):
108
  """
109
- Outputs raw waveforms of the data instead of a spectrogram.
110
  """
111
 
112
- def __init__(self, *args, resample_frequency=16000, **kwargs):
113
  super().__init__(*args, **kwargs)
114
- self.resample_frequency = resample_frequency
115
- self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
116
  self.pipeline = []
117
 
118
  def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
119
- waveform = self._waveform_from_index(idx)
120
- assert (
121
- waveform.shape[1] > 10
122
- ), f"No data found: {self._backtrace_audio_path(idx)}"
123
- # resample the waveform
124
- waveform = self.resampler(waveform)
125
-
126
- waveform = waveform.mean(0)
127
-
128
- dance_labels = self._label_from_index(idx)
129
- return waveform, dance_labels
130
-
131
-
132
- class HuggingFaceWaveformSongDataset(WaveformSongDataset):
133
- def __init__(self, *args, **kwargs):
134
- super().__init__(*args, **kwargs)
135
- self.pipeline = []
136
-
137
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
138
- x, y = super().__getitem__(idx)
139
  if len(self.pipeline) > 0:
140
  for fn in self.pipeline:
141
  x = fn(x)
@@ -146,59 +141,158 @@ class HuggingFaceWaveformSongDataset(WaveformSongDataset):
146
  "label": dance_labels,
147
  }
148
 
149
- def map(self, fn):
 
 
 
150
  """
151
- NOTE this mutates the original, doesn't return a copy like normal maps.
152
  """
153
  self.pipeline.append(fn)
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  class DanceDataModule(pl.LightningDataModule):
157
  def __init__(
158
  self,
159
- song_data_path="data/songs_cleaned.csv",
160
- song_audio_path="data/samples",
161
  test_proportion=0.15,
162
  val_proportion=0.1,
163
  target_classes: list[str] = None,
164
- min_votes=1,
165
  batch_size: int = 64,
166
  num_workers=10,
167
- dataset_cls=None,
168
- dataset_kwargs={},
169
  ):
170
  super().__init__()
171
- self.song_data_path = song_data_path
172
- self.song_audio_path = song_audio_path
173
  self.val_proportion = val_proportion
174
  self.test_proportion = test_proportion
175
  self.train_proportion = 1.0 - test_proportion - val_proportion
176
  self.target_classes = target_classes
177
  self.batch_size = batch_size
178
  self.num_workers = num_workers
179
- self.dataset_kwargs = dataset_kwargs
180
- self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
181
-
182
- df = pd.read_csv(song_data_path)
183
- self.x, self.y = get_examples(
184
- df,
185
- self.song_audio_path,
186
- class_list=self.target_classes,
187
- multi_label=True,
188
- min_votes=min_votes,
189
- )
190
 
191
  def setup(self, stage: str):
192
- train_i, val_i, test_i = random_split(
193
- np.arange(len(self.x)),
194
  [self.train_proportion, self.val_proportion, self.test_proportion],
195
  )
196
- self.train_ds = self._dataset_from_indices(train_i)
197
- self.val_ds = self._dataset_from_indices(val_i)
198
- self.test_ds = self._dataset_from_indices(test_i)
199
-
200
- def _dataset_from_indices(self, idx: list[int]) -> SongDataset:
201
- return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
202
 
203
  def train_dataloader(self):
204
  return DataLoader(
@@ -210,110 +304,48 @@ class DanceDataModule(pl.LightningDataModule):
210
 
211
  def val_dataloader(self):
212
  return DataLoader(
213
- self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers
 
 
214
  )
215
 
216
  def test_dataloader(self):
217
  return DataLoader(
218
- self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers
 
 
219
  )
220
 
221
  def get_label_weights(self):
222
- n_examples, n_classes = self.y.shape
223
- return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
 
 
224
 
225
 
226
- class WaveformTrainingEnvironment(pl.LightningModule):
227
- def __init__(
228
- self,
229
- model: nn.Module,
230
- criterion: nn.Module,
231
- feature_extractor,
232
- config: dict,
233
- learning_rate=1e-4,
234
- *args,
235
- **kwargs,
236
- ):
237
- super().__init__(*args, **kwargs)
238
- self.model = model
239
- self.criterion = criterion
240
- self.learning_rate = learning_rate
241
- self.config = config
242
- self.feature_extractor = feature_extractor
243
- self.save_hyperparameters(
244
- {
245
- "model": type(model).__name__,
246
- "loss": type(criterion).__name__,
247
- "config": config,
248
- **kwargs,
249
- }
250
- )
251
-
252
- def preprocess_inputs(self, x):
253
- device = x.device
254
- x = list(x.squeeze(1).cpu().numpy())
255
- x = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000)
256
- return x["input_values"].to(device)
257
-
258
- def training_step(
259
- self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
260
- ) -> torch.Tensor:
261
- features, labels = batch
262
- features = self.preprocess_inputs(features)
263
- outputs = self.model(features).logits
264
- outputs = nn.Sigmoid()(
265
- outputs
266
- ) # good for multi label classification, should be softmax otherwise
267
- loss = self.criterion(outputs, labels)
268
- metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
269
- self.log_dict(metrics, prog_bar=True)
270
- return loss
271
-
272
- def validation_step(
273
- self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
274
- ):
275
- x, y = batch
276
- x = self.preprocess_inputs(x)
277
- preds = self.model(x).logits
278
- preds = nn.Sigmoid()(preds)
279
- metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
280
- metrics["val/loss"] = self.criterion(preds, y)
281
- self.log_dict(metrics, prog_bar=True)
282
-
283
- def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
284
- x, y = batch
285
- x = self.preprocess_inputs(x)
286
- preds = self.model(x).logits
287
- preds = nn.Sigmoid()(preds)
288
- self.log_dict(
289
- calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True
290
- )
291
-
292
- def configure_optimizers(self):
293
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
294
- # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
295
- return [optimizer]
296
-
297
-
298
- def calculate_metrics(
299
- pred, target, threshold=0.5, prefix="", multi_label=True
300
- ) -> dict[str, torch.Tensor]:
301
- target = target.detach().cpu().numpy()
302
- pred = pred.detach().cpu().numpy()
303
- params = {
304
- "y_true": target if multi_label else target.argmax(1),
305
- "y_pred": np.array(pred > threshold, dtype=float)
306
- if multi_label
307
- else pred.argmax(1),
308
- "zero_division": 0,
309
- "average": "macro",
310
- }
311
- metrics = {
312
- "precision": precision_score(**params),
313
- "recall": recall_score(**params),
314
- "f1": f1_score(**params),
315
- "accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
316
- }
317
- return {
318
- prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
319
- }
 
1
+ import importlib
2
+ import os
3
+ from typing import Any
4
  import torch
5
+ from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
6
  import numpy as np
7
  import pandas as pd
8
  import torchaudio as ta
 
9
  import pytorch_lightning as pl
10
+
11
+ from preprocessing.preprocess import (
12
+ fix_dance_rating_counts,
13
+ get_unique_labels,
14
+ has_valid_audio,
15
+ url_to_filename,
16
+ vectorize_label_probs,
17
+ vectorize_multi_label,
18
+ )
19
 
20
 
21
  class SongDataset(Dataset):
 
23
  self,
24
  audio_paths: list[str],
25
  dance_labels: list[np.ndarray],
26
+ audio_start_offset=6, # seconds
27
  audio_window_duration=6, # seconds
28
+ audio_window_jitter=1.0, # seconds
 
 
29
  ):
 
 
 
30
  assert (
31
  audio_window_duration > audio_window_jitter
32
  ), "Jitter should be a small fraction of the audio window duration."
33
 
34
  self.audio_paths = audio_paths
35
  self.dance_labels = dance_labels
36
+ audio_metadata = [ta.info(audio) for audio in audio_paths]
37
+ self.audio_durations = [
38
+ meta.num_frames / meta.sample_rate for meta in audio_metadata
39
+ ]
40
+ self.sample_rate = audio_metadata[0].sample_rate # assuming same sample rate
41
  self.audio_window_duration = int(audio_window_duration)
42
+ self.audio_start_offset = audio_start_offset
43
  self.audio_window_jitter = audio_window_jitter
 
 
 
 
 
 
 
 
44
 
45
  def __len__(self):
46
+ return int(
47
+ sum(
48
+ max(duration - self.audio_start_offset, 0) // self.audio_window_duration
49
+ for duration in self.audio_durations
50
+ )
51
+ )
52
 
53
  def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
54
+ if isinstance(idx, list):
55
+ return [
56
+ (self._waveform_from_index(i), self._label_from_index(i)) for i in idx
57
+ ]
 
58
 
59
+ waveform = self._waveform_from_index(idx)
60
  dance_labels = self._label_from_index(idx)
61
+ return waveform, dance_labels
62
 
63
+ def _idx2audio_idx(self, idx: int) -> int:
64
+ return self._get_audio_loc_from_idx(idx)[0]
65
+
66
+ def _get_audio_loc_from_idx(self, idx: int) -> tuple[int, int]:
67
+ """
68
+ Converts dataset index to the indices that reference the target audio path
69
+ and window offset.
70
+ """
71
+ total_slices = 0
72
+ for audio_index, duration in enumerate(self.audio_durations):
73
+ audio_slices = max(
74
+ (duration - self.audio_start_offset) // self.audio_window_duration, 1
75
+ )
76
+ if total_slices + audio_slices > idx:
77
+ frame_index = idx - total_slices
78
+ return audio_index, frame_index
79
+ total_slices += audio_slices
80
 
81
+ def get_label_weights(self):
82
+ n_examples, n_classes = self.dance_labels.shape
83
+ return torch.from_numpy(n_examples / (n_classes * sum(self.dance_labels)))
84
 
85
  def _backtrace_audio_path(self, index: int) -> str:
86
+ return self.audio_paths[self._idx2audio_idx(index)]
87
 
88
  def _validate_output(self, x, y):
89
  is_finite = not torch.any(torch.isinf(x))
 
93
  return all((is_finite, is_numerical, has_data, is_binary))
94
 
95
  def _waveform_from_index(self, idx: int) -> torch.Tensor:
96
+ audio_index, frame_index = self._get_audio_loc_from_idx(idx)
97
+ audio_filepath = self.audio_paths[audio_index]
98
+ num_windows = self.audio_durations[audio_index] // self.audio_window_duration
99
  jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
100
  jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
101
  jitter = int(
102
  torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
103
  )
104
+ frame_offset = int(
105
+ frame_index * self.audio_window_duration * self.sample_rate
106
+ + jitter
107
+ + self.audio_start_offset * self.sample_rate
108
  )
109
  num_frames = self.sample_rate * self.audio_window_duration
110
  waveform, sample_rate = ta.load(
 
116
  return waveform
117
 
118
  def _label_from_index(self, idx: int) -> torch.Tensor:
119
+ return torch.from_numpy(self.dance_labels[self._idx2audio_idx(idx)])
120
 
121
 
122
+ class HuggingFaceDatasetWrapper(Dataset):
123
  """
124
+ Makes a standard PyTorch Dataset compatible with a HuggingFace Trainer.
125
  """
126
 
127
+ def __init__(self, dataset, *args, **kwargs):
128
  super().__init__(*args, **kwargs)
129
+ self.dataset = dataset
 
130
  self.pipeline = []
131
 
132
  def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
133
+ x, y = self.dataset[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  if len(self.pipeline) > 0:
135
  for fn in self.pipeline:
136
  x = fn(x)
 
141
  "label": dance_labels,
142
  }
143
 
144
+ def __len__(self):
145
+ return len(self.dataset)
146
+
147
+ def append_to_pipeline(self, fn):
148
  """
149
+ Adds a preprocessing step to the dataset.
150
  """
151
  self.pipeline.append(fn)
152
 
153
 
154
+ class BestBallroomDataset(Dataset):
155
+ def __init__(
156
+ self, audio_dir="data/ballroom-songs", class_list=None, **kwargs
157
+ ) -> None:
158
+ super().__init__()
159
+ song_paths, labels = self.get_examples(audio_dir, class_list)
160
+ self.song_dataset = SongDataset(song_paths, labels, **kwargs)
161
+
162
+ def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
163
+ return self.song_dataset[index]
164
+
165
+ def __len__(self):
166
+ return len(self.song_dataset)
167
+
168
+ def get_examples(self, audio_dir, class_list=None):
169
+ dances = set(
170
+ f
171
+ for f in os.listdir(audio_dir)
172
+ if os.path.isdir(os.path.join(audio_dir, f))
173
+ )
174
+ common_dances = dances
175
+ if class_list is not None:
176
+ common_dances = dances & set(class_list)
177
+ dances = class_list
178
+ dances = np.array(sorted(dances))
179
+ song_paths = []
180
+ labels = []
181
+ for dance in common_dances:
182
+ dance_label = (dances == dance).astype("float32")
183
+ folder_path = os.path.join(audio_dir, dance)
184
+ folder_contents = [f for f in os.listdir(folder_path) if f.endswith(".wav")]
185
+ song_paths.extend(os.path.join(folder_path, f) for f in folder_contents)
186
+ labels.extend([dance_label] * len(folder_contents))
187
+
188
+ return np.array(song_paths), np.stack(labels)
189
+
190
+
191
+ class Music4DanceDataset(Dataset):
192
+ def __init__(
193
+ self,
194
+ song_data_path,
195
+ song_audio_path,
196
+ class_list=None,
197
+ multi_label=True,
198
+ min_votes=1,
199
+ **kwargs,
200
+ ) -> None:
201
+ super().__init__()
202
+ df = pd.read_csv(song_data_path)
203
+ song_paths, labels = get_music4dance_examples(
204
+ df,
205
+ song_audio_path,
206
+ class_list=class_list,
207
+ multi_label=multi_label,
208
+ min_votes=min_votes,
209
+ )
210
+ self.song_dataset = SongDataset(song_paths, labels, **kwargs)
211
+
212
+ def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
213
+ return self.song_dataset[index]
214
+
215
+ def __len__(self):
216
+ return len(self.song_dataset)
217
+
218
+
219
+ def get_music4dance_examples(
220
+ df: pd.DataFrame, audio_dir: str, class_list=None, multi_label=True, min_votes=1
221
+ ) -> tuple[np.ndarray, np.ndarray]:
222
+ sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
223
+ sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
224
+ if class_list is not None:
225
+ class_list = set(class_list)
226
+ sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
227
+ lambda labels: {k: v for k, v in labels.items() if k in class_list}
228
+ if not pd.isna(labels)
229
+ and any(label in class_list and amt > 0 for label, amt in labels.items())
230
+ else np.nan
231
+ )
232
+ sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
233
+ vote_mask = sampled_songs["DanceRating"].apply(
234
+ lambda dances: any(votes >= min_votes for votes in dances.values())
235
+ )
236
+ sampled_songs = sampled_songs[vote_mask]
237
+ labels = sampled_songs["DanceRating"].apply(
238
+ lambda dances: {
239
+ dance: votes for dance, votes in dances.items() if votes >= min_votes
240
+ }
241
+ )
242
+ unique_labels = np.array(get_unique_labels(labels))
243
+ vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
244
+ labels = labels.apply(lambda i: vectorizer(i, unique_labels))
245
+
246
+ audio_paths = [
247
+ os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]
248
+ ]
249
+
250
+ return np.array(audio_paths), np.stack(labels)
251
+
252
+
253
+ class PipelinedDataset(Dataset):
254
+ """
255
+ Adds a feature extractor preprocessing step to a dataset.
256
+ """
257
+
258
+ def __init__(self, dataset, feature_extractor):
259
+ self._data = dataset
260
+ self.feature_extractor = feature_extractor
261
+
262
+ def __len__(self):
263
+ return len(self._data)
264
+
265
+ def __getitem__(self, index):
266
+ sample, label = self._data[index]
267
+
268
+ features = self.feature_extractor(sample)
269
+ return features, label
270
+
271
+
272
  class DanceDataModule(pl.LightningDataModule):
273
  def __init__(
274
  self,
275
+ dataset: Dataset,
 
276
  test_proportion=0.15,
277
  val_proportion=0.1,
278
  target_classes: list[str] = None,
 
279
  batch_size: int = 64,
280
  num_workers=10,
 
 
281
  ):
282
  super().__init__()
 
 
283
  self.val_proportion = val_proportion
284
  self.test_proportion = test_proportion
285
  self.train_proportion = 1.0 - test_proportion - val_proportion
286
  self.target_classes = target_classes
287
  self.batch_size = batch_size
288
  self.num_workers = num_workers
289
+ self.dataset = dataset
 
 
 
 
 
 
 
 
 
 
290
 
291
  def setup(self, stage: str):
292
+ self.train_ds, self.val_ds, self.test_ds = random_split(
293
+ self.dataset,
294
  [self.train_proportion, self.val_proportion, self.test_proportion],
295
  )
 
 
 
 
 
 
296
 
297
  def train_dataloader(self):
298
  return DataLoader(
 
304
 
305
  def val_dataloader(self):
306
  return DataLoader(
307
+ self.val_ds,
308
+ batch_size=self.batch_size,
309
+ num_workers=self.num_workers,
310
  )
311
 
312
  def test_dataloader(self):
313
  return DataLoader(
314
+ self.test_ds,
315
+ batch_size=self.batch_size,
316
+ num_workers=self.num_workers,
317
  )
318
 
319
  def get_label_weights(self):
320
+ weights = [
321
+ ds.song_dataset.get_label_weights() for ds in self.dataset._data.datasets
322
+ ]
323
+ return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
324
 
325
 
326
+ def find_mean_std(dataset: Dataset, zscore=1.96, moe=0.02, p=0.5):
327
+ """
328
+ Estimates the mean and standard deviations of the a dataset.
329
+ """
330
+ sample_size = int(np.ceil((zscore**2 * p * (1 - p)) / (moe**2)))
331
+ sample_indices = np.random.choice(
332
+ np.arange(len(dataset)), size=sample_size, replace=False
333
+ )
334
+ mean = 0
335
+ std = 0
336
+ for i in sample_indices:
337
+ features = dataset[i][0]
338
+ mean += features.mean().item()
339
+ std += features.std().item()
340
+ print("std", std / sample_size)
341
+ print("mean", mean / sample_size)
342
+
343
+
344
+ def get_datasets(dataset_config: dict, feature_extractor) -> Dataset:
345
+ datasets = []
346
+ for dataset_path, kwargs in dataset_config.items():
347
+ module_name, class_name = dataset_path.rsplit(".", 1)
348
+ module = importlib.import_module(module_name)
349
+ ProvidedDataset = getattr(module, class_name)
350
+ datasets.append(ProvidedDataset(**kwargs))
351
+ return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
preprocessing/pipelines.py CHANGED
@@ -3,29 +3,26 @@ import torchaudio
3
  from torchaudio import transforms as taT, functional as taF
4
  import torch.nn as nn
5
 
6
- class AudioTrainingPipeline(torch.nn.Module):
7
- def __init__(self,
8
- input_freq=16000,
9
- resample_freq=16000,
10
- expected_duration=6,
11
- freq_mask_size=10,
12
- time_mask_size=80,
13
- mask_count = 2,
14
- snr_mean=6.0,
15
- noise_path=None):
16
  super().__init__()
17
  self.input_freq = input_freq
18
  self.snr_mean = snr_mean
19
- self.mask_count = mask_count
20
  self.noise = self.get_noise(noise_path)
21
- self.resample = taT.Resample(input_freq,resample_freq)
22
- self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
23
- self.audio_to_spectrogram = AudioToSpectrogram(
24
- sample_rate=resample_freq,
 
25
  )
26
- self.freq_mask = taT.FrequencyMasking(freq_mask_size)
27
- self.time_mask = taT.TimeMasking(time_mask_size)
28
-
29
 
30
  def get_noise(self, path) -> torch.Tensor:
31
  if path is None:
@@ -34,13 +31,15 @@ class AudioTrainingPipeline(torch.nn.Module):
34
  if noise.shape[0] > 1:
35
  noise = noise.mean(0, keepdim=True)
36
  if sr != self.input_freq:
37
- noise = taF.resample(noise,sr, self.input_freq)
38
  return noise
39
 
40
- def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
41
- assert self.noise is not None, "Cannot add noise because a noise file was not provided."
 
 
42
  num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
43
- noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
44
  noise_power = noise.norm(p=2)
45
  signal_power = waveform.norm(p=2)
46
  snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
@@ -49,14 +48,28 @@ class AudioTrainingPipeline(torch.nn.Module):
49
  noisy_waveform = (scale * waveform + noise) / 2
50
  return noisy_waveform
51
 
52
- def forward(self, waveform:torch.Tensor) -> torch.Tensor:
53
- try:
54
- waveform = self.resample(waveform)
55
- except:
56
- print("oops")
57
  waveform = self.preprocess_waveform(waveform)
58
  if self.noise is not None:
59
  waveform = self.add_noise(waveform)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  spec = self.audio_to_spectrogram(waveform)
61
 
62
  # Spectrogram augmentation
@@ -67,14 +80,11 @@ class AudioTrainingPipeline(torch.nn.Module):
67
 
68
 
69
  class WaveformPreprocessing(torch.nn.Module):
70
-
71
- def __init__(self, expected_sample_length:int):
72
  super().__init__()
73
  self.expected_sample_length = expected_sample_length
74
-
75
 
76
-
77
- def forward(self, waveform:torch.Tensor) -> torch.Tensor:
78
  # Take out extra channels
79
  if waveform.shape[0] > 1:
80
  waveform = waveform.mean(0, keepdim=True)
@@ -83,30 +93,34 @@ class WaveformPreprocessing(torch.nn.Module):
83
  waveform = self._rectify_duration(waveform)
84
  return waveform
85
 
86
-
87
- def _rectify_duration(self,waveform:torch.Tensor):
88
  expected_samples = self.expected_sample_length
89
  sample_count = waveform.shape[1]
90
  if expected_samples == sample_count:
91
  return waveform
92
  elif expected_samples > sample_count:
93
  pad_amount = expected_samples - sample_count
94
- return torch.nn.functional.pad(waveform, (0, pad_amount),mode="constant", value=0.0)
 
 
95
  else:
96
- return waveform[:,:expected_samples]
97
 
98
 
99
- class AudioToSpectrogram(torch.nn.Module):
100
  def __init__(
101
  self,
102
  sample_rate=16000,
103
  ):
104
- super().__init__()
105
-
106
- self.spec = taT.MelSpectrogram(sample_rate=sample_rate, n_mels=128, n_fft=1024) # TODO: Change mels to 64
107
  self.to_db = taT.AmplitudeToDB()
108
 
109
- def forward(self, waveform: torch.Tensor) -> torch.Tensor:
110
  spectrogram = self.spec(waveform)
111
  spectrogram = self.to_db(spectrogram)
112
- return spectrogram
 
 
 
 
3
  from torchaudio import transforms as taT, functional as taF
4
  import torch.nn as nn
5
 
6
+
7
+ class WaveformTrainingPipeline(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ input_freq=16000,
11
+ resample_freq=16000,
12
+ expected_duration=6,
13
+ snr_mean=6.0,
14
+ noise_path=None,
15
+ ):
16
  super().__init__()
17
  self.input_freq = input_freq
18
  self.snr_mean = snr_mean
 
19
  self.noise = self.get_noise(noise_path)
20
+ self.resample_frequency = resample_freq
21
+ self.resample = taT.Resample(input_freq, resample_freq)
22
+
23
+ self.preprocess_waveform = WaveformPreprocessing(
24
+ resample_freq * expected_duration
25
  )
 
 
 
26
 
27
  def get_noise(self, path) -> torch.Tensor:
28
  if path is None:
 
31
  if noise.shape[0] > 1:
32
  noise = noise.mean(0, keepdim=True)
33
  if sr != self.input_freq:
34
+ noise = taF.resample(noise, sr, self.input_freq)
35
  return noise
36
 
37
+ def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
38
+ assert (
39
+ self.noise is not None
40
+ ), "Cannot add noise because a noise file was not provided."
41
  num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
42
+ noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]]
43
  noise_power = noise.norm(p=2)
44
  signal_power = waveform.norm(p=2)
45
  snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
 
48
  noisy_waveform = (scale * waveform + noise) / 2
49
  return noisy_waveform
50
 
51
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
52
+ waveform = self.resample(waveform)
 
 
 
53
  waveform = self.preprocess_waveform(waveform)
54
  if self.noise is not None:
55
  waveform = self.add_noise(waveform)
56
+ return waveform
57
+
58
+
59
+ class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
60
+ def __init__(
61
+ self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs
62
+ ):
63
+ super().__init__(*args, **kwargs)
64
+ self.mask_count = mask_count
65
+ self.audio_to_spectrogram = AudioToSpectrogram(
66
+ sample_rate=self.resample_frequency,
67
+ )
68
+ self.freq_mask = taT.FrequencyMasking(freq_mask_size)
69
+ self.time_mask = taT.TimeMasking(time_mask_size)
70
+
71
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
72
+ waveform = super().forward(waveform)
73
  spec = self.audio_to_spectrogram(waveform)
74
 
75
  # Spectrogram augmentation
 
80
 
81
 
82
  class WaveformPreprocessing(torch.nn.Module):
83
+ def __init__(self, expected_sample_length: int):
 
84
  super().__init__()
85
  self.expected_sample_length = expected_sample_length
 
86
 
87
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
 
88
  # Take out extra channels
89
  if waveform.shape[0] > 1:
90
  waveform = waveform.mean(0, keepdim=True)
 
93
  waveform = self._rectify_duration(waveform)
94
  return waveform
95
 
96
+ def _rectify_duration(self, waveform: torch.Tensor):
 
97
  expected_samples = self.expected_sample_length
98
  sample_count = waveform.shape[1]
99
  if expected_samples == sample_count:
100
  return waveform
101
  elif expected_samples > sample_count:
102
  pad_amount = expected_samples - sample_count
103
+ return torch.nn.functional.pad(
104
+ waveform, (0, pad_amount), mode="constant", value=0.0
105
+ )
106
  else:
107
+ return waveform[:, :expected_samples]
108
 
109
 
110
+ class AudioToSpectrogram:
111
  def __init__(
112
  self,
113
  sample_rate=16000,
114
  ):
115
+ self.spec = taT.MelSpectrogram(
116
+ sample_rate=sample_rate, n_mels=128, n_fft=1024
117
+ ) # Note: this doesn't work on mps right now.
118
  self.to_db = taT.AmplitudeToDB()
119
 
120
+ def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
121
  spectrogram = self.spec(waveform)
122
  spectrogram = self.to_db(spectrogram)
123
+
124
+ # Normalize
125
+ spectrogram = (spectrogram - spectrogram.mean()) / (2 * spectrogram.std())
126
+ return spectrogram
preprocessing/preprocess.py CHANGED
@@ -3,7 +3,9 @@ import numpy as np
3
  import re
4
  import json
5
  from pathlib import Path
 
6
  import os
 
7
  import torchaudio
8
  import torch
9
  from tqdm import tqdm
@@ -95,7 +97,6 @@ def vectorize_label_probs(
95
  for k, v in labels.items():
96
  item_vec = (unique_labels == k) * v
97
  label_vec += item_vec
98
- lv_cache = label_vec.copy()
99
  label_vec[label_vec < 0] = 0
100
  label_vec /= label_vec.sum()
101
  assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
@@ -113,49 +114,70 @@ def vectorize_multi_label(
113
  return probs
114
 
115
 
116
- def get_examples(
117
- df: pd.DataFrame, audio_dir: str, class_list=None, multi_label=True, min_votes=1
118
- ) -> tuple[np.ndarray, np.ndarray]:
119
- sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
120
- sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
121
- if class_list is not None:
122
- class_list = set(class_list)
123
- sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
124
- lambda labels: {k: v for k, v in labels.items() if k in class_list}
125
- if not pd.isna(labels)
126
- and any(label in class_list and amt > 0 for label, amt in labels.items())
127
- else np.nan
128
- )
129
- sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
130
- vote_mask = sampled_songs["DanceRating"].apply(
131
- lambda dances: any(votes >= min_votes for votes in dances.values())
132
- )
133
- sampled_songs = sampled_songs[vote_mask]
134
- labels = sampled_songs["DanceRating"].apply(
135
- lambda dances: {
136
- dance: votes for dance, votes in dances.items() if votes >= min_votes
137
- }
138
- )
139
- unique_labels = np.array(get_unique_labels(labels))
140
- vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
141
- labels = labels.apply(lambda i: vectorizer(i, unique_labels))
142
-
143
- audio_paths = [
144
- os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]
145
- ]
146
-
147
- return np.array(audio_paths), np.stack(labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
 
150
  if __name__ == "__main__":
151
- links = pd.read_csv("data/backup_2.csv", index_col="index")
152
- df = pd.read_csv("data/songs.csv")
153
- l = links["link"].str.strip()
154
- l = l.apply(lambda url: url if "http" in url else np.nan)
155
- l = l.dropna()
156
- df["Sample"].update(l)
157
- addna = lambda url: url if type(url) == str and "http" in url else np.nan
158
- df["Sample"] = df["Sample"].apply(addna)
159
- is_valid = validate_audio(df["Sample"], "data/samples")
160
- df["valid"] = is_valid
161
- df.to_csv("data/songs_validated.csv")
 
3
  import re
4
  import json
5
  from pathlib import Path
6
+ import glob
7
  import os
8
+ import shutil
9
  import torchaudio
10
  import torch
11
  from tqdm import tqdm
 
97
  for k, v in labels.items():
98
  item_vec = (unique_labels == k) * v
99
  label_vec += item_vec
 
100
  label_vec[label_vec < 0] = 0
101
  label_vec /= label_vec.sum()
102
  assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
 
114
  return probs
115
 
116
 
117
+ def sort_yt_files(
118
+ aliases_path="data/dance_aliases.json",
119
+ all_dances_folder="data/best-ballroom-music",
120
+ original_location="data/yt-ballroom-music/",
121
+ ):
122
+ def normalize_string(s):
123
+ # Lowercase string and remove special characters
124
+ return re.sub(r"\W+", "", s.lower())
125
+
126
+ with open(aliases_path, "r") as f:
127
+ dances = json.load(f)
128
+
129
+ # Normalize the dance inputs and aliases
130
+ normalized_dances = {
131
+ normalize_string(dance_id): [normalize_string(alias) for alias in aliases]
132
+ for dance_id, aliases in dances.items()
133
+ }
134
+
135
+ # For every wav file in the target folder
136
+ bad_files = []
137
+ progress_bar = tqdm(os.listdir(all_dances_folder), unit="files moved")
138
+ for file_name in progress_bar:
139
+ if file_name.endswith(".wav"):
140
+ # check if the normalized wav file name contains the normalized dance alias
141
+ normalized_file_name = normalize_string(file_name)
142
+
143
+ matching_dance_ids = [
144
+ dance_id
145
+ for dance_id, aliases in normalized_dances.items()
146
+ if any(alias in normalized_file_name for alias in aliases)
147
+ ]
148
+
149
+ if len(matching_dance_ids) == 0:
150
+ # See if the dance is in the path
151
+ original_filename = file_name.replace(".wav", "")
152
+ matches = glob.glob(
153
+ os.path.join(original_location, "**", original_filename),
154
+ recursive=True,
155
+ )
156
+ if len(matches) == 1:
157
+ normalized_file_name = normalize_string(matches[0])
158
+ matching_dance_ids = [
159
+ dance_id
160
+ for dance_id, aliases in normalized_dances.items()
161
+ if any(alias in normalized_file_name for alias in aliases)
162
+ ]
163
+
164
+ if "swz" in matching_dance_ids and "vwz" in matching_dance_ids:
165
+ matching_dance_ids.remove("swz")
166
+ if len(matching_dance_ids) > 1 and "lhp" in matching_dance_ids:
167
+ matching_dance_ids.remove("lhp")
168
+
169
+ if len(matching_dance_ids) != 1:
170
+ bad_files.append(file_name)
171
+ progress_bar.set_description(f"bad files: {len(bad_files)}")
172
+ continue
173
+ dst = os.path.join("data", "ballroom-songs", matching_dance_ids[0].upper())
174
+ os.makedirs(dst, exist_ok=True)
175
+ filepath = os.path.join(all_dances_folder, file_name)
176
+ shutil.copy(filepath, os.path.join(dst, file_name))
177
+
178
+ with open("data/bad_files.json", "w") as f:
179
+ json.dump(bad_files, f)
180
 
181
 
182
  if __name__ == "__main__":
183
+ sort_yt_files()
 
 
 
 
 
 
 
 
 
 
tests.py DELETED
@@ -1,22 +0,0 @@
1
- import torchaudio
2
- import numpy as np
3
- from audio_utils import play_audio
4
- from preprocessing.dataset import SongDataset
5
-
6
- def test_audio_splitting():
7
-
8
-
9
-
10
- audio_paths = ["data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav"]
11
- labels = [np.array([1,0,1,0])]
12
- whole_song, sr = torchaudio.load("data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav")
13
-
14
- ds = SongDataset(audio_paths, labels)
15
- song_parts = (ds._waveform_from_index(i) for i in range(len(ds)))
16
- print("Sample Parts")
17
- for part in song_parts:
18
- play_audio(part,sr)
19
-
20
-
21
- print("Whole Sample")
22
- play_audio(whole_song,sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_datasets.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import set_path
2
+ import pytest
3
+
4
+ set_path()
5
+ from preprocessing.dataset import PipelinedDataset, BestBallroomDataset, SongDataset
6
+ import numpy as np
7
+
8
+
9
+ def test_preprocess_dataset():
10
+ dataset = BestBallroomDataset()
11
+ dataset = PipelinedDataset(dataset, lambda x: x * 0.0)
12
+ assert isinstance(dataset._data.song_dataset, SongDataset)
13
+ assert hasattr(dataset, "feature_extractor")
14
+ features, _ = dataset[0]
15
+ assert np.unique(features.numpy())[0] == 0.0
16
+ with pytest.raises(AttributeError):
17
+ dataset.foo
tests/test_pipelines.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import set_path
2
+
3
+ set_path()
4
+ from preprocessing.dataset import BestBallroomDataset
5
+ from preprocessing.pipelines import SpectrogramTrainingPipeline
6
+
7
+
8
+ def test_spectrogram_training_pipeline():
9
+ ds = BestBallroomDataset()
10
+ pipeline = SpectrogramTrainingPipeline()
11
+ waveform, _ = ds[0]
12
+ out = pipeline(waveform)
13
+ assert len(out.shape) == 3
tests/utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+
5
+ # Add parent directory to Python path
6
+ def set_path():
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
train.py CHANGED
@@ -1,49 +1,16 @@
1
- from torch.utils.data import DataLoader
2
- import pandas as pd
3
  from typing import Callable
4
- from torch import nn
5
- from torch.utils.data import SubsetRandomSampler
6
- from sklearn.model_selection import KFold
7
- import pytorch_lightning as pl
8
- from pytorch_lightning import callbacks as cb
9
- from models.utils import LabelWeightedBCELoss
10
- from models.audio_spectrogram_transformer import (
11
- train as train_audio_spectrogram_transformer,
12
- get_id_label_mapping,
13
- )
14
- from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
15
- from preprocessing.preprocess import get_examples
16
- from models.residual import ResidualDancer, TrainingEnvironment
17
- from models.decision_tree import DanceTreeClassifier, features_from_path
18
  import yaml
19
- from preprocessing.dataset import (
20
- DanceDataModule,
21
- WaveformSongDataset,
22
- HuggingFaceWaveformSongDataset,
23
- )
24
- from torch.utils.data import random_split
25
- import numpy as np
26
- from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
27
  from argparse import ArgumentParser
 
28
 
29
-
30
- import torch
31
- from torch import nn
32
- from sklearn.utils.class_weight import compute_class_weight
33
 
34
 
35
  def get_training_fn(id: str) -> Callable:
36
- match id:
37
- case "ast_ptl":
38
- return train_ast_lightning
39
- case "ast_hf":
40
- return train_ast
41
- case "residual_dancer":
42
- return train_model
43
- case "decision_tree":
44
- return train_decision_tree
45
- case _:
46
- raise Exception(f"Couldn't find a training function for '{id}'.")
47
 
48
 
49
  def get_config(filepath: str) -> dict:
@@ -52,141 +19,6 @@ def get_config(filepath: str) -> dict:
52
  return config
53
 
54
 
55
- def cross_validation(config, k=5):
56
- df = pd.read_csv("data/songs.csv")
57
- g_config = config["global"]
58
- batch_size = config["data_module"]["batch_size"]
59
- x, y = get_examples(df, "data/samples", class_list=g_config["dance_ids"])
60
- dataset = SongDataset(x, y)
61
- splits = KFold(n_splits=k, shuffle=True, random_state=g_config["seed"])
62
- trainer = pl.Trainer(accelerator=g_config["device"])
63
- for fold, (train_idx, val_idx) in enumerate(splits.split(x, y)):
64
- print(f"Fold {fold+1}")
65
- model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
66
- train_env = TrainingEnvironment(model, nn.BCELoss())
67
- train_sampler = SubsetRandomSampler(train_idx)
68
- test_sampler = SubsetRandomSampler(val_idx)
69
- train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
70
- test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
71
- trainer.fit(train_env, train_loader)
72
- trainer.test(train_env, test_loader)
73
-
74
-
75
- def train_model(config: dict):
76
- TARGET_CLASSES = config["global"]["dance_ids"]
77
- DEVICE = config["global"]["device"]
78
- SEED = config["global"]["seed"]
79
- pl.seed_everything(SEED, workers=True)
80
- data = DanceDataModule(target_classes=TARGET_CLASSES, **config["data_module"])
81
- model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
82
- label_weights = data.get_label_weights().to(DEVICE)
83
- criterion = LabelWeightedBCELoss(
84
- label_weights
85
- ) # nn.CrossEntropyLoss(label_weights)
86
- train_env = TrainingEnvironment(model, criterion, config)
87
- callbacks = [
88
- # cb.LearningRateFinder(update_attr=True),
89
- cb.EarlyStopping("val/loss", patience=5),
90
- cb.StochasticWeightAveraging(1e-2),
91
- cb.RichProgressBar(),
92
- cb.DeviceStatsMonitor(),
93
- ]
94
- trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
95
- trainer.fit(train_env, datamodule=data)
96
- trainer.test(train_env, datamodule=data)
97
-
98
-
99
- def train_ast(config: dict):
100
- TARGET_CLASSES = config["global"]["dance_ids"]
101
- DEVICE = config["global"]["device"]
102
- SEED = config["global"]["seed"]
103
- dataset_kwargs = config["data_module"]["dataset_kwargs"]
104
- test_proportion = config["data_module"].get("test_proportion", 0.2)
105
- train_proportion = 1.0 - test_proportion
106
- song_data_path = "data/songs_cleaned.csv"
107
- song_audio_path = "data/samples"
108
- pl.seed_everything(SEED, workers=True)
109
-
110
- df = pd.read_csv(song_data_path)
111
- x, y = get_examples(
112
- df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
113
- )
114
- train_i, test_i = random_split(
115
- np.arange(len(x)), [train_proportion, test_proportion]
116
- )
117
- train_ds = HuggingFaceWaveformSongDataset(
118
- x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000
119
- )
120
- test_ds = HuggingFaceWaveformSongDataset(
121
- x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000
122
- )
123
- train_audio_spectrogram_transformer(
124
- TARGET_CLASSES, train_ds, test_ds, device=DEVICE
125
- )
126
-
127
-
128
- def train_ast_lightning(config: dict):
129
- """
130
- work on integration between waveform dataset and environment. Should work for both HF and PTL.
131
- """
132
- TARGET_CLASSES = config["global"]["dance_ids"]
133
- DEVICE = config["global"]["device"]
134
- SEED = config["global"]["seed"]
135
- pl.seed_everything(SEED, workers=True)
136
- data = DanceDataModule(
137
- target_classes=TARGET_CLASSES,
138
- dataset_cls=WaveformSongDataset,
139
- **config["data_module"],
140
- )
141
- id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
142
- model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
143
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
144
-
145
- model = AutoModelForAudioClassification.from_pretrained(
146
- model_checkpoint,
147
- num_labels=len(label2id),
148
- label2id=label2id,
149
- id2label=id2label,
150
- ignore_mismatched_sizes=True,
151
- ).to(DEVICE)
152
- label_weights = data.get_label_weights().to(DEVICE)
153
- criterion = LabelWeightedBCELoss(
154
- label_weights
155
- ) # nn.CrossEntropyLoss(label_weights)
156
- train_env = WaveformTrainingEnvironment(model, criterion, feature_extractor, config)
157
- callbacks = [
158
- # cb.LearningRateFinder(update_attr=True),
159
- cb.EarlyStopping("val/loss", patience=5),
160
- cb.StochasticWeightAveraging(1e-2),
161
- cb.RichProgressBar(),
162
- ]
163
- trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
164
- trainer.fit(train_env, datamodule=data)
165
- trainer.test(train_env, datamodule=data)
166
-
167
-
168
- def train_decision_tree(config: dict):
169
- TARGET_CLASSES = config["global"]["dance_ids"]
170
- DEVICE = config["global"]["device"]
171
- SEED = config["global"]["seed"]
172
- song_data_path = config["data_module"]["song_data_path"]
173
- song_audio_path = config["data_module"]["song_audio_path"]
174
- pl.seed_everything(SEED, workers=True)
175
-
176
- df = pd.read_csv(song_data_path)
177
- x, y = get_examples(
178
- df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
179
- )
180
- # Convert y back to string classes
181
- y = np.array(TARGET_CLASSES)[y.argmax(-1)]
182
- train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
183
- train_paths, train_y = x[train_i], y[train_i]
184
- train_x = features_from_path(train_paths)
185
- model = DanceTreeClassifier(device=DEVICE)
186
- model.fit(train_x, train_y)
187
- model.save()
188
-
189
-
190
  if __name__ == "__main__":
191
  parser = ArgumentParser(
192
  description="Trains models on the dance dataset and saves weights."
@@ -198,6 +30,7 @@ if __name__ == "__main__":
198
  )
199
  args = parser.parse_args()
200
  config = get_config(args.config)
201
- training_id = config["global"]["id"]
202
- train = get_training_fn(training_id)
 
203
  train(config)
 
 
 
1
  from typing import Callable
2
+ import importlib
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import yaml
 
 
 
 
 
 
 
 
4
  from argparse import ArgumentParser
5
+ import os
6
 
7
+ ROOT_DIR = os.path.basename(os.path.dirname(__file__))
 
 
 
8
 
9
 
10
  def get_training_fn(id: str) -> Callable:
11
+ module_name, fn_name = id.rsplit(".", 1)
12
+ module = importlib.import_module("models." + module_name, ROOT_DIR)
13
+ return getattr(module, fn_name)
 
 
 
 
 
 
 
 
14
 
15
 
16
  def get_config(filepath: str) -> dict:
 
19
  return config
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if __name__ == "__main__":
23
  parser = ArgumentParser(
24
  description="Trains models on the dance dataset and saves weights."
 
30
  )
31
  args = parser.parse_args()
32
  config = get_config(args.config)
33
+ training_fn_path = config["training_fn"]
34
+ print(f"Config: {args.config}\nTrainer Id: {training_fn_path}")
35
+ train = get_training_fn(training_fn_path)
36
  train(config)