Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
c914273
0
Parent(s):
first commit
Browse files- .gitattributes +1 -0
- .gitignore +5 -0
- app.py +107 -0
- assets/song-samples/alejandro.wav +3 -0
- assets/song-samples/exs_and_ohs.wav +3 -0
- assets/song-samples/take_it_to_the_limit.wav +3 -0
- dancer_net/dancer_net.py +85 -0
- environment.yml +20 -0
- main.py +46 -0
- preprocessing/dataset.py +49 -0
- preprocessing/preprocess.py +104 -0
- scrapers/music4dance.py +113 -0
- train.py +215 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.DS_Store
|
3 |
+
data
|
4 |
+
logs
|
5 |
+
gradio_cached_examples
|
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from preprocessing.preprocess import AudioPipeline
|
6 |
+
from preprocessing.preprocess import AudioPipeline
|
7 |
+
from dancer_net.dancer_net import ShortChunkCNN
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
from functools import cache
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
@cache
|
14 |
+
def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]:
|
15 |
+
model_path = "logs/20221226-230930"
|
16 |
+
weights = os.path.join(model_path, "dancer_net.pt")
|
17 |
+
config_path = os.path.join(model_path, "config.json")
|
18 |
+
|
19 |
+
with open(config_path) as f:
|
20 |
+
config = json.load(f)
|
21 |
+
labels = np.array(sorted(config["classes"]))
|
22 |
+
|
23 |
+
model = ShortChunkCNN(n_class=len(labels))
|
24 |
+
model.load_state_dict(torch.load(weights))
|
25 |
+
model = model.to(device).eval()
|
26 |
+
return model, labels
|
27 |
+
|
28 |
+
@cache
|
29 |
+
def get_pipeline(sample_rate:int) -> AudioPipeline:
|
30 |
+
return AudioPipeline(input_freq=sample_rate)
|
31 |
+
|
32 |
+
@cache
|
33 |
+
def get_dance_map() -> dict:
|
34 |
+
df = pd.read_csv("data/dance_mapping.csv")
|
35 |
+
return df.set_index("id").to_dict()["name"]
|
36 |
+
|
37 |
+
|
38 |
+
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
39 |
+
sample_rate, waveform = audio
|
40 |
+
|
41 |
+
expected_duration = 6
|
42 |
+
threshold = 0.5
|
43 |
+
sample_len = sample_rate * expected_duration
|
44 |
+
device = "mps"
|
45 |
+
|
46 |
+
audio_pipeline = get_pipeline(sample_rate)
|
47 |
+
model, labels = get_model(device)
|
48 |
+
|
49 |
+
if sample_len > len(waveform):
|
50 |
+
raise gr.Error("You must record for at least 6 seconds")
|
51 |
+
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
|
52 |
+
waveform = waveform.transpose(1,0)
|
53 |
+
waveform = waveform.mean(axis=0, keepdims=True)
|
54 |
+
else:
|
55 |
+
waveform = np.expand_dims(waveform, 0)
|
56 |
+
waveform = waveform[: ,:sample_len]
|
57 |
+
waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
|
58 |
+
waveform = waveform.astype("float32")
|
59 |
+
waveform = torch.from_numpy(waveform)
|
60 |
+
spectrogram = audio_pipeline(waveform)
|
61 |
+
spectrogram = spectrogram.unsqueeze(0).to(device)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
results = model(spectrogram)
|
65 |
+
dance_mapping = get_dance_map()
|
66 |
+
results = results.squeeze(0).detach().cpu().numpy()
|
67 |
+
result_mask = results > threshold
|
68 |
+
probs = results[result_mask]
|
69 |
+
dances = labels[result_mask]
|
70 |
+
|
71 |
+
return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."
|
72 |
+
|
73 |
+
|
74 |
+
def demo():
|
75 |
+
title = "Dance Classifier"
|
76 |
+
description = "Record 6 seconds of a song and find out what dance fits the music."
|
77 |
+
with gr.Blocks() as app:
|
78 |
+
gr.Markdown(f"# {title}")
|
79 |
+
gr.Markdown(description)
|
80 |
+
with gr.Tab("Record Song"):
|
81 |
+
mic_audio = gr.Audio(source="microphone", label="Song Recording")
|
82 |
+
mic_submit = gr.Button("Predict")
|
83 |
+
|
84 |
+
with gr.Tab("Upload Song") as t:
|
85 |
+
audio_file = gr.Audio(label="Song Audio File")
|
86 |
+
audio_file_submit = gr.Button("Predict")
|
87 |
+
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
88 |
+
example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
|
89 |
+
|
90 |
+
labels = gr.Label(label="Dances")
|
91 |
+
|
92 |
+
gr.Markdown("## Examples")
|
93 |
+
gr.Examples(
|
94 |
+
examples=example_audio,
|
95 |
+
inputs=audio_file,
|
96 |
+
outputs=labels,
|
97 |
+
fn=predict,
|
98 |
+
)
|
99 |
+
|
100 |
+
audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
|
101 |
+
mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)
|
102 |
+
|
103 |
+
return app
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
demo().launch()
|
assets/song-samples/alejandro.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85f9a65fc4adb1fc0cbdbfafb7f7268a0934d97a120110d3f3a43375e59cba54
|
3 |
+
size 5292078
|
assets/song-samples/exs_and_ohs.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4e53fe157ff687b5464e98c7d0c03d0712527c3a7ed24b6b063a328fcf7bf608
|
3 |
+
size 5292082
|
assets/song-samples/take_it_to_the_limit.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c69e0eeb4321c44daaaaf95dd596b1d813b9f7e9b5ef4ac5ae9fe11878d4b13b
|
3 |
+
size 5292082
|
dancer_net/dancer_net.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchaudio import transforms as taT, functional as taF
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
DEVICE = "mps"
|
9 |
+
class ShortChunkCNN(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
n_channels=128,
|
12 |
+
sample_rate=16000,
|
13 |
+
n_class=50):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# Spectrogram
|
17 |
+
self.spec_bn = nn.BatchNorm2d(1)
|
18 |
+
|
19 |
+
# CNN
|
20 |
+
self.res_layers = nn.Sequential(
|
21 |
+
Res_2d(1, n_channels, stride=2),
|
22 |
+
Res_2d(n_channels, n_channels, stride=2),
|
23 |
+
Res_2d(n_channels, n_channels*2, stride=2),
|
24 |
+
Res_2d(n_channels*2, n_channels*2, stride=2),
|
25 |
+
Res_2d(n_channels*2, n_channels*2, stride=2),
|
26 |
+
Res_2d(n_channels*2, n_channels*2, stride=2),
|
27 |
+
Res_2d(n_channels*2, n_channels*4, stride=2)
|
28 |
+
)
|
29 |
+
|
30 |
+
# Dense
|
31 |
+
self.dense1 = nn.Linear(n_channels*4, n_channels*4)
|
32 |
+
self.bn = nn.BatchNorm1d(n_channels*4)
|
33 |
+
self.dense2 = nn.Linear(n_channels*4, n_class)
|
34 |
+
self.dropout = nn.Dropout(0.3)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.spec_bn(x)
|
38 |
+
|
39 |
+
# CNN
|
40 |
+
x = self.res_layers(x)
|
41 |
+
x = x.squeeze(2)
|
42 |
+
|
43 |
+
# Global Max Pooling
|
44 |
+
if x.size(-1) != 1:
|
45 |
+
x = nn.MaxPool1d(x.size(-1))(x)
|
46 |
+
x = x.squeeze(2)
|
47 |
+
|
48 |
+
# Dense
|
49 |
+
x = self.dense1(x)
|
50 |
+
x = self.bn(x)
|
51 |
+
x = F.relu(x)
|
52 |
+
x = self.dropout(x)
|
53 |
+
x = self.dense2(x)
|
54 |
+
x = nn.Sigmoid()(x)
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class Res_2d(nn.Module):
|
60 |
+
def __init__(self, input_channels, output_channels, shape=3, stride=2):
|
61 |
+
super().__init__()
|
62 |
+
# convolution
|
63 |
+
self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
|
64 |
+
self.bn_1 = nn.BatchNorm2d(output_channels)
|
65 |
+
self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
|
66 |
+
self.bn_2 = nn.BatchNorm2d(output_channels)
|
67 |
+
|
68 |
+
# residual
|
69 |
+
self.diff = False
|
70 |
+
if (stride != 1) or (input_channels != output_channels):
|
71 |
+
self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
|
72 |
+
self.bn_3 = nn.BatchNorm2d(output_channels)
|
73 |
+
self.diff = True
|
74 |
+
self.relu = nn.ReLU()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
# convolution
|
78 |
+
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
|
79 |
+
|
80 |
+
# residual
|
81 |
+
if self.diff:
|
82 |
+
x = self.bn_3(self.conv_3(x))
|
83 |
+
out = x + out
|
84 |
+
out = self.relu(out)
|
85 |
+
return out
|
environment.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: dancer-net
|
2 |
+
channels:
|
3 |
+
- anaconda
|
4 |
+
- conda-forge
|
5 |
+
dependencies:
|
6 |
+
- torchvision
|
7 |
+
- pytorch
|
8 |
+
- numpy
|
9 |
+
- pandas
|
10 |
+
- seaborn
|
11 |
+
- python=3.10
|
12 |
+
- matplotlib
|
13 |
+
- torchaudio
|
14 |
+
- bs4
|
15 |
+
- requests
|
16 |
+
- bidict
|
17 |
+
- tqdm
|
18 |
+
- pip
|
19 |
+
- gradio
|
20 |
+
prefix: /opt/homebrew/Caskroom/miniforge/base/envs/dancer-net
|
main.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
from preprocessing.preprocess import AudioPipeline
|
3 |
+
from dancer_net.dancer_net import ShortChunkCNN
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
|
11 |
+
audio_file = "data/samples/mzm.iqskzxzx.aac.p.m4a.wav"
|
12 |
+
seconds = 6
|
13 |
+
model_path = "logs/20221226-230930"
|
14 |
+
weights = os.path.join(model_path, "dancer_net.pt")
|
15 |
+
config_path = os.path.join(model_path, "config.json")
|
16 |
+
device = "mps"
|
17 |
+
threshold = 0.5
|
18 |
+
|
19 |
+
with open(config_path) as f:
|
20 |
+
config = json.load(f)
|
21 |
+
labels = np.array(sorted(config["classes"]))
|
22 |
+
|
23 |
+
audio_pipeline = AudioPipeline()
|
24 |
+
waveform, sample_rate = torchaudio.load(audio_file)
|
25 |
+
waveform = waveform[:, :seconds * sample_rate]
|
26 |
+
spectrogram = audio_pipeline(waveform)
|
27 |
+
spectrogram = spectrogram.unsqueeze(0).to(device)
|
28 |
+
|
29 |
+
model = ShortChunkCNN(n_class=len(labels))
|
30 |
+
model.load_state_dict(torch.load(weights))
|
31 |
+
model = model.to(device).eval()
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
results = model(spectrogram)
|
35 |
+
results = results.squeeze(0).detach().cpu().numpy()
|
36 |
+
results = results > threshold
|
37 |
+
results = labels[results]
|
38 |
+
print(results)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
preprocessing/dataset.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
import torchaudio as ta
|
5 |
+
from .preprocess import AudioPipeline
|
6 |
+
|
7 |
+
|
8 |
+
class SongDataset(Dataset):
|
9 |
+
def __init__(self,
|
10 |
+
audio_paths: list[str],
|
11 |
+
dance_labels: list[np.ndarray],
|
12 |
+
audio_duration=30, # seconds
|
13 |
+
audio_window_duration=6, # seconds
|
14 |
+
):
|
15 |
+
assert audio_duration % audio_window_duration == 0, "Audio window should divide duration evenly."
|
16 |
+
|
17 |
+
self.audio_paths = audio_paths
|
18 |
+
self.dance_labels = dance_labels
|
19 |
+
audio_info = ta.info(audio_paths[0])
|
20 |
+
self.sample_rate = audio_info.sample_rate
|
21 |
+
self.audio_window_duration = int(audio_window_duration)
|
22 |
+
self.audio_duration = int(audio_duration)
|
23 |
+
|
24 |
+
self.audio_pipeline = AudioPipeline(input_freq=self.sample_rate)
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
|
28 |
+
|
29 |
+
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
|
30 |
+
waveform = self._waveform_from_index(idx)
|
31 |
+
spectrogram = self.audio_pipeline(waveform)
|
32 |
+
|
33 |
+
dance_labels = self._label_from_index(idx)
|
34 |
+
|
35 |
+
return spectrogram, dance_labels
|
36 |
+
|
37 |
+
|
38 |
+
def _waveform_from_index(self, idx:int) -> torch.Tensor:
|
39 |
+
audio_file_idx = idx * self.audio_window_duration // self.audio_duration
|
40 |
+
frame_offset = idx % self.audio_duration // self.audio_window_duration
|
41 |
+
num_frames = self.sample_rate * self.audio_window_duration
|
42 |
+
waveform, sample_rate = ta.load(self.audio_paths[audio_file_idx], frame_offset=frame_offset, num_frames=num_frames)
|
43 |
+
assert sample_rate == self.sample_rate, f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
|
44 |
+
return waveform
|
45 |
+
|
46 |
+
|
47 |
+
def _label_from_index(self, idx:int) -> torch.Tensor:
|
48 |
+
label_idx = idx * self.audio_window_duration // self.audio_duration
|
49 |
+
return torch.from_numpy(self.dance_labels[label_idx])
|
preprocessing/preprocess.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import re
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import torchaudio.transforms as taT
|
9 |
+
|
10 |
+
def url_to_filename(url:str) -> str:
|
11 |
+
return f"{url.split('/')[-1]}.wav"
|
12 |
+
|
13 |
+
def get_songs_with_audio(df:pd.DataFrame, audio_dir:str) -> pd.DataFrame:
|
14 |
+
audio_urls = df["Sample"].replace(".", np.nan)
|
15 |
+
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
16 |
+
valid_audio = audio_urls.apply(lambda url : url is not np.nan and url_to_filename(url) in audio_files)
|
17 |
+
df = df[valid_audio]
|
18 |
+
return df
|
19 |
+
|
20 |
+
def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
|
21 |
+
tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
|
22 |
+
dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
|
23 |
+
def fix_labels(labels:dict) -> dict | float:
|
24 |
+
new_labels = {}
|
25 |
+
for k, v in labels.items():
|
26 |
+
match = tag_pattern.search(k)
|
27 |
+
if match is None:
|
28 |
+
new_labels[k] = new_labels.get(k, 0) + v
|
29 |
+
else:
|
30 |
+
k = match[1]
|
31 |
+
sign = 1 if match[2] == '+' else -1
|
32 |
+
scale = int(match[3])
|
33 |
+
new_labels[k] = new_labels.get(k, 0) + v * scale * sign
|
34 |
+
valid = any(v > 0 for v in new_labels.values())
|
35 |
+
return new_labels if valid else np.nan
|
36 |
+
return dance_ratings.apply(fix_labels)
|
37 |
+
|
38 |
+
|
39 |
+
def get_unique_labels(dance_labels:pd.Series) -> list:
|
40 |
+
labels = set()
|
41 |
+
for dances in dance_labels:
|
42 |
+
labels |= set(dances)
|
43 |
+
return sorted(labels)
|
44 |
+
|
45 |
+
def vectorize_label_probs(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
|
46 |
+
"""
|
47 |
+
Turns label dict into probability distribution vector based on each label count.
|
48 |
+
"""
|
49 |
+
label_vec = np.zeros((len(unique_labels),), dtype="float32")
|
50 |
+
for k, v in labels.items():
|
51 |
+
item_vec = (unique_labels == k) * v
|
52 |
+
label_vec += item_vec
|
53 |
+
lv_cache = label_vec.copy()
|
54 |
+
label_vec[label_vec<0] = 0
|
55 |
+
label_vec /= label_vec.sum()
|
56 |
+
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
|
57 |
+
return label_vec
|
58 |
+
|
59 |
+
def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
|
60 |
+
"""
|
61 |
+
Turns label dict into binary label vectors for multi-label classification.
|
62 |
+
"""
|
63 |
+
probs = vectorize_label_probs(labels,unique_labels)
|
64 |
+
probs[probs > 0.0] = 1.0
|
65 |
+
return probs
|
66 |
+
|
67 |
+
def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[str], list[np.ndarray]]:
|
68 |
+
sampled_songs = get_songs_with_audio(df, audio_dir)
|
69 |
+
sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
70 |
+
if class_list is not None:
|
71 |
+
class_list = set(class_list)
|
72 |
+
sampled_songs.loc[:,"DanceRating"] = sampled_songs["DanceRating"].apply(
|
73 |
+
lambda labels : {k: v for k,v in labels.items() if k in class_list}
|
74 |
+
if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
|
75 |
+
else np.nan)
|
76 |
+
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
|
77 |
+
labels = sampled_songs["DanceRating"]
|
78 |
+
unique_labels = np.array(get_unique_labels(labels))
|
79 |
+
labels = labels.apply(lambda i : vectorize_multi_label(i, unique_labels))
|
80 |
+
|
81 |
+
audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
|
82 |
+
|
83 |
+
return audio_paths, list(labels)
|
84 |
+
|
85 |
+
class AudioPipeline(torch.nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
input_freq=16000,
|
89 |
+
resample_freq=16000,
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
|
93 |
+
self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
|
94 |
+
self.to_db = taT.AmplitudeToDB()
|
95 |
+
|
96 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
97 |
+
if waveform.shape[0] > 1:
|
98 |
+
waveform = waveform.mean(0, keepdim=True)
|
99 |
+
waveform = self.resample(waveform)
|
100 |
+
spectrogram = self.spec(waveform)
|
101 |
+
spectrogram = self.to_db(spectrogram)
|
102 |
+
|
103 |
+
return spectrogram
|
104 |
+
|
scrapers/music4dance.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup as bs
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
from pathlib import Path
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
import re
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def scrape_song_library(page_count=2054) -> pd.DataFrame:
|
15 |
+
columns = [
|
16 |
+
"Title",
|
17 |
+
"Artist",
|
18 |
+
"Length",
|
19 |
+
"Tempo",
|
20 |
+
"Beat",
|
21 |
+
"Energy",
|
22 |
+
"Danceability",
|
23 |
+
"Valence",
|
24 |
+
"Sample",
|
25 |
+
"Tags",
|
26 |
+
"DanceRating",
|
27 |
+
]
|
28 |
+
song_df = pd.DataFrame(columns=columns)
|
29 |
+
for i in tqdm(range(1, page_count + 1), desc="Pages processed"):
|
30 |
+
link = "https://www.music4dance.net/song/Index?filter=v2-Index&page=" + str(i)
|
31 |
+
page = requests.get(link)
|
32 |
+
soup = bs(page.content, "html.parser")
|
33 |
+
songs = pd.DataFrame(get_songs(soup))
|
34 |
+
song_df = pd.concat([song_df, songs], axis=0, ignore_index=True)
|
35 |
+
return song_df
|
36 |
+
|
37 |
+
|
38 |
+
def get_songs(soup: bs) -> dict:
|
39 |
+
js_obj = re.compile(r"{(.|\n)*}")
|
40 |
+
reset_keys = [
|
41 |
+
"Title",
|
42 |
+
"Artist",
|
43 |
+
"Length",
|
44 |
+
"Tempo",
|
45 |
+
"Beat",
|
46 |
+
"Energy",
|
47 |
+
"Danceability",
|
48 |
+
"Valence",
|
49 |
+
"Sample",
|
50 |
+
]
|
51 |
+
song_text = [str(v) for v in soup.find_all("script") if "histories" in str(v)][0]
|
52 |
+
songs_data = json.loads(js_obj.search(song_text).group(0))
|
53 |
+
songs = []
|
54 |
+
for song_data in songs_data["histories"]:
|
55 |
+
song = {"Tags": set(), "DanceRating": {}}
|
56 |
+
for feature in song_data["properties"]:
|
57 |
+
if "name" not in feature or "value" not in feature:
|
58 |
+
continue
|
59 |
+
key = feature["name"]
|
60 |
+
value = feature["value"]
|
61 |
+
if key in reset_keys:
|
62 |
+
song[key] = value
|
63 |
+
elif key == "Tag+":
|
64 |
+
song["Tags"].add(value)
|
65 |
+
elif key == "DeleteTag":
|
66 |
+
try:
|
67 |
+
song["Tags"].remove(value)
|
68 |
+
except:
|
69 |
+
continue
|
70 |
+
elif key == "DanceRating":
|
71 |
+
dance = value.replace("+1", "")
|
72 |
+
prev = song["DanceRating"].get(dance, 0)
|
73 |
+
song["DanceRating"][dance] = prev + 1
|
74 |
+
songs.append(song)
|
75 |
+
return songs
|
76 |
+
|
77 |
+
|
78 |
+
def download_song(url: str, out_dir: str):
|
79 |
+
response = requests.get(url)
|
80 |
+
filename = url.split("/")[-1]
|
81 |
+
out_file = Path(out_dir, f"{filename}.mp3")
|
82 |
+
with open(out_file, "wb") as f:
|
83 |
+
f.write(response.content)
|
84 |
+
|
85 |
+
def scrape_dance_info() -> pd.DataFrame:
|
86 |
+
js_obj = re.compile(r"{(.|\n)*}")
|
87 |
+
link = "https://www.music4dance.net/song/Index?filter=v2-Index"
|
88 |
+
page = requests.get(link)
|
89 |
+
soup = bs(page.content, "html.parser")
|
90 |
+
|
91 |
+
dance_info_text = [str(v) for v in soup.find_all("script") if "environment" in str(v)][0]
|
92 |
+
dance_info = json.loads(js_obj.search(dance_info_text).group(0))
|
93 |
+
dance_info = dance_info["dances"]
|
94 |
+
wanted_keys = ["name", "id", "synonyms", "tempoRange", "songCount"]
|
95 |
+
dance_df = pd.DataFrame([{k:v for k, v in dance.items() if k in wanted_keys}
|
96 |
+
for dance
|
97 |
+
in dance_info])
|
98 |
+
return dance_df
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
parser = argparse.ArgumentParser()
|
104 |
+
parser.add_argument("--page-count", default=2, type=int)
|
105 |
+
parser.add_argument("--out", default="data/song.csv")
|
106 |
+
|
107 |
+
args = parser.parse_args()
|
108 |
+
out_path = Path(args.out)
|
109 |
+
out_dir = os.path.dirname(out_path)
|
110 |
+
if not os.path.exists(out_dir):
|
111 |
+
print(f"Output location does not exist: {out_dir}")
|
112 |
+
df = scrape_song_library(args.page_count)
|
113 |
+
df.to_csv(out_path)
|
train.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import torch.nn as nn
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
from torch.utils.data import random_split, SubsetRandomSampler
|
10 |
+
import json
|
11 |
+
from sklearn.model_selection import KFold
|
12 |
+
|
13 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
14 |
+
from preprocessing.dataset import SongDataset
|
15 |
+
from preprocessing.preprocess import get_examples
|
16 |
+
from dancer_net.dancer_net import ShortChunkCNN
|
17 |
+
|
18 |
+
DEVICE = "mps"
|
19 |
+
SEED = 42
|
20 |
+
|
21 |
+
def get_timestamp() -> str:
|
22 |
+
return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
23 |
+
|
24 |
+
class EarlyStopping:
|
25 |
+
def __init__(self, patience=0):
|
26 |
+
self.patience = patience
|
27 |
+
self.last_measure = np.inf
|
28 |
+
self.consecutive_increase = 0
|
29 |
+
|
30 |
+
def step(self, val) -> bool:
|
31 |
+
if self.last_measure <= val:
|
32 |
+
self.consecutive_increase +=1
|
33 |
+
else:
|
34 |
+
self.consecutive_increase = 0
|
35 |
+
self.last_measure = val
|
36 |
+
|
37 |
+
return self.patience < self.consecutive_increase
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def calculate_metrics(pred, target, threshold=0.5, prefix=""):
|
42 |
+
target = target.detach().cpu().numpy()
|
43 |
+
pred = pred.detach().cpu().numpy()
|
44 |
+
pred = np.array(pred > threshold, dtype=float)
|
45 |
+
metrics= {
|
46 |
+
'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
47 |
+
'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
48 |
+
'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
49 |
+
'accuracy': accuracy_score(y_true=target, y_pred=pred),
|
50 |
+
}
|
51 |
+
if prefix != "":
|
52 |
+
metrics = {prefix + k : v for k, v in metrics.items()}
|
53 |
+
|
54 |
+
return metrics
|
55 |
+
|
56 |
+
|
57 |
+
def evaluate(model:nn.Module, data_loader:DataLoader, criterion, device="mps") -> pd.Series:
|
58 |
+
val_metrics = []
|
59 |
+
for features, labels in (prog_bar := tqdm(data_loader)):
|
60 |
+
features = features.to(device)
|
61 |
+
labels = labels.to(device)
|
62 |
+
with torch.no_grad():
|
63 |
+
outputs = model(features)
|
64 |
+
loss = criterion(outputs, labels)
|
65 |
+
batch_metrics = calculate_metrics(outputs, labels, prefix="val_")
|
66 |
+
batch_metrics["val_loss"] = loss.item()
|
67 |
+
prog_bar.set_description(f'Validation - Loss: {batch_metrics["val_loss"]:.2f}, Accuracy: {batch_metrics["val_accuracy"]:.2f}')
|
68 |
+
val_metrics.append(batch_metrics)
|
69 |
+
return pd.DataFrame(val_metrics).mean()
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def train(
|
74 |
+
model: nn.Module,
|
75 |
+
data_loader: DataLoader,
|
76 |
+
val_loader=None,
|
77 |
+
epochs=3,
|
78 |
+
lr=1e-3,
|
79 |
+
device="mps"):
|
80 |
+
criterion = nn.BCELoss()
|
81 |
+
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
|
82 |
+
early_stop = EarlyStopping(1)
|
83 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr,
|
84 |
+
steps_per_epoch=int(len(data_loader)),
|
85 |
+
epochs=epochs,
|
86 |
+
anneal_strategy='linear')
|
87 |
+
metrics = []
|
88 |
+
for epoch in range(1,epochs+1):
|
89 |
+
train_metrics = []
|
90 |
+
prog_bar = tqdm(data_loader)
|
91 |
+
for features, labels in prog_bar:
|
92 |
+
features = features.to(device)
|
93 |
+
labels = labels.to(device)
|
94 |
+
optimizer.zero_grad()
|
95 |
+
outputs = model(features)
|
96 |
+
loss = criterion(outputs, labels)
|
97 |
+
loss.backward()
|
98 |
+
optimizer.step()
|
99 |
+
scheduler.step()
|
100 |
+
batch_metrics = calculate_metrics(outputs, labels)
|
101 |
+
batch_metrics["loss"] = loss.item()
|
102 |
+
train_metrics.append(batch_metrics)
|
103 |
+
prog_bar.set_description(f'Training - Epoch: {epoch}/{epochs}, Loss: {batch_metrics["loss"]:.2f}, Accuracy: {batch_metrics["accuracy"]:.2f}')
|
104 |
+
train_metrics = pd.DataFrame(train_metrics).mean()
|
105 |
+
if val_loader is not None:
|
106 |
+
val_metrics = evaluate(model, val_loader, criterion)
|
107 |
+
if early_stop.step(val_metrics["val_f1"]):
|
108 |
+
break
|
109 |
+
epoch_metrics = pd.concat([train_metrics, val_metrics], axis=0)
|
110 |
+
else:
|
111 |
+
epoch_metrics = train_metrics
|
112 |
+
metrics.append(dict(epoch_metrics))
|
113 |
+
|
114 |
+
return model, metrics
|
115 |
+
|
116 |
+
|
117 |
+
def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
|
118 |
+
target_classes = ['ATN',
|
119 |
+
'BBA',
|
120 |
+
'BCH',
|
121 |
+
'BLU',
|
122 |
+
'CHA',
|
123 |
+
'CMB',
|
124 |
+
'CSG',
|
125 |
+
'ECS',
|
126 |
+
'HST',
|
127 |
+
'JIV',
|
128 |
+
'LHP',
|
129 |
+
'QST',
|
130 |
+
'RMB',
|
131 |
+
'SFT',
|
132 |
+
'SLS',
|
133 |
+
'SMB',
|
134 |
+
'SWZ',
|
135 |
+
'TGO',
|
136 |
+
'VWZ',
|
137 |
+
'WCS']
|
138 |
+
df = pd.read_csv("data/songs.csv")
|
139 |
+
x,y = get_examples(df, "data/samples",class_list=target_classes)
|
140 |
+
|
141 |
+
dataset = SongDataset(x,y)
|
142 |
+
splits=KFold(n_splits=k,shuffle=True,random_state=seed)
|
143 |
+
metrics = []
|
144 |
+
for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
|
145 |
+
print(f"Fold {fold+1}")
|
146 |
+
|
147 |
+
train_sampler = SubsetRandomSampler(train_idx)
|
148 |
+
test_sampler = SubsetRandomSampler(val_idx)
|
149 |
+
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
|
150 |
+
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
|
151 |
+
n_classes = len(y[0])
|
152 |
+
model = ShortChunkCNN(n_class=n_classes).to(device)
|
153 |
+
model, _ = train(model,train_loader, epochs=2, device=device)
|
154 |
+
val_metrics = evaluate(model, test_loader, nn.BCELoss())
|
155 |
+
metrics.append(val_metrics)
|
156 |
+
metrics = pd.DataFrame(metrics)
|
157 |
+
log_dir = os.path.join(
|
158 |
+
"logs", get_timestamp()
|
159 |
+
)
|
160 |
+
os.makedirs(log_dir, exist_ok=True)
|
161 |
+
|
162 |
+
metrics.to_csv(model.state_dict(), os.path.join(log_dir, "cross_val.csv"))
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
def train_model():
|
167 |
+
target_classes = ['ATN',
|
168 |
+
'BBA',
|
169 |
+
'BCH',
|
170 |
+
'BLU',
|
171 |
+
'CHA',
|
172 |
+
'CMB',
|
173 |
+
'CSG',
|
174 |
+
'ECS',
|
175 |
+
'HST',
|
176 |
+
'JIV',
|
177 |
+
'LHP',
|
178 |
+
'QST',
|
179 |
+
'RMB',
|
180 |
+
'SFT',
|
181 |
+
'SLS',
|
182 |
+
'SMB',
|
183 |
+
'SWZ',
|
184 |
+
'TGO',
|
185 |
+
'VWZ',
|
186 |
+
'WCS']
|
187 |
+
df = pd.read_csv("data/songs.csv")
|
188 |
+
x,y = get_examples(df, "data/samples",class_list=target_classes)
|
189 |
+
dataset = SongDataset(x,y)
|
190 |
+
train_count = int(len(dataset) * 0.9)
|
191 |
+
datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED))
|
192 |
+
data_loaders = [DataLoader(data, batch_size=64, shuffle=True) for data in datasets]
|
193 |
+
train_data, val_data = data_loaders
|
194 |
+
example_spec, example_label = dataset[0]
|
195 |
+
n_classes = len(example_label)
|
196 |
+
model = ShortChunkCNN(n_class=n_classes).to(DEVICE)
|
197 |
+
model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE)
|
198 |
+
|
199 |
+
log_dir = os.path.join(
|
200 |
+
"logs", get_timestamp()
|
201 |
+
)
|
202 |
+
os.makedirs(log_dir, exist_ok=True)
|
203 |
+
|
204 |
+
torch.save(model.state_dict(), os.path.join(log_dir, "dancer_net.pt"))
|
205 |
+
metrics = pd.DataFrame(metrics)
|
206 |
+
metrics.to_csv(os.path.join(log_dir, "metrics.csv"))
|
207 |
+
config = {
|
208 |
+
"classes": target_classes
|
209 |
+
}
|
210 |
+
with open(os.path.join(log_dir, "config.json")) as f:
|
211 |
+
json.dump(config, f)
|
212 |
+
print("Training information saved!")
|
213 |
+
|
214 |
+
if __name__ == "__main__":
|
215 |
+
cross_validation()
|