waidhoferj commited on
Commit
e82ec2b
1 Parent(s): 3b31903

fixed pandas set on copy error

Browse files
Files changed (1) hide show
  1. preprocessing/preprocess.py +64 -33
preprocessing/preprocess.py CHANGED
@@ -8,22 +8,28 @@ import torchaudio
8
  import torch
9
  from tqdm import tqdm
10
 
11
- def url_to_filename(url:str) -> str:
 
12
  return f"{url.split('/')[-1]}.wav"
13
 
14
- def has_valid_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
 
15
  audio_urls = audio_urls.replace(".", np.nan)
16
  audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
17
- valid_audio_mask = audio_urls.apply(lambda url : url is not np.nan and url_to_filename(url) in audio_files)
 
 
18
  return valid_audio_mask
19
 
20
- def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
 
21
  """
22
- Tests audio urls to ensure that their file exists and the contents is valid.
23
  """
24
  audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
 
25
  def is_valid(url):
26
- valid_url = type(url) == str and "http" in url
27
  if not valid_url:
28
  return False
29
  filename = url_to_filename(url)
@@ -33,23 +39,29 @@ def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
33
  w, _ = torchaudio.load(os.path.join(audio_dir, filename))
34
  except:
35
  return False
36
- contents_invalid = torch.any(torch.isnan(w)) or torch.any(torch.isinf(w)) or len(torch.unique(w)) <= 2
 
 
 
 
37
  return not contents_invalid
38
-
39
  idxs = []
40
  validations = []
41
- for index, url in tqdm(audio_urls.items(), total=len(audio_urls), desc="Audio URLs Validated"):
 
 
42
  idxs.append(index)
43
  validations.append(is_valid(url))
44
 
45
  return pd.Series(validations, index=idxs)
46
-
47
-
48
 
49
- def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
 
50
  tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
51
- dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
52
- def fix_labels(labels:dict) -> dict | float:
 
53
  new_labels = {}
54
  for k, v in labels.items():
55
  match = tag_pattern.search(k)
@@ -57,21 +69,25 @@ def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
57
  new_labels[k] = new_labels.get(k, 0) + v
58
  else:
59
  k = match[1]
60
- sign = 1 if match[2] == '+' else -1
61
  scale = int(match[3])
62
  new_labels[k] = new_labels.get(k, 0) + v * scale * sign
63
  valid = any(v > 0 for v in new_labels.values())
64
  return new_labels if valid else np.nan
 
65
  return dance_ratings.apply(fix_labels)
66
 
67
 
68
- def get_unique_labels(dance_labels:pd.Series) -> list:
69
  labels = set()
70
  for dances in dance_labels:
71
  labels |= set(dances)
72
  return sorted(labels)
73
 
74
- def vectorize_label_probs(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
 
 
 
75
  """
76
  Turns label dict into probability distribution vector based on each label count.
77
  """
@@ -80,37 +96,53 @@ def vectorize_label_probs(labels: dict[str,int], unique_labels:np.ndarray) -> np
80
  item_vec = (unique_labels == k) * v
81
  label_vec += item_vec
82
  lv_cache = label_vec.copy()
83
- label_vec[label_vec<0] = 0
84
  label_vec /= label_vec.sum()
85
  assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
86
  return label_vec
87
 
88
- def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
 
 
 
89
  """
90
  Turns label dict into binary label vectors for multi-label classification.
91
  """
92
- probs = vectorize_label_probs(labels,unique_labels)
93
  probs[probs > 0.0] = 1.0
94
  return probs
95
 
96
- def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None, multi_label=True, min_votes=1) -> tuple[np.ndarray, np.ndarray]:
97
- sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)]
 
 
 
98
  sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
99
  if class_list is not None:
100
  class_list = set(class_list)
101
  sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
102
- lambda labels : {k: v for k,v in labels.items() if k in class_list}
103
- if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
104
- else np.nan)
 
 
105
  sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
106
- vote_mask = sampled_songs["DanceRating"].apply(lambda dances: any(votes >= min_votes for votes in dances.values()))
 
 
107
  sampled_songs = sampled_songs[vote_mask]
108
- labels = sampled_songs["DanceRating"].apply(lambda dances : {dance: votes for dance, votes in dances.items() if votes >= min_votes})
 
 
 
 
109
  unique_labels = np.array(get_unique_labels(labels))
110
  vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
111
- labels = labels.apply(lambda i : vectorizer(i, unique_labels))
112
 
113
- audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
 
 
114
 
115
  return np.array(audio_paths), np.stack(labels)
116
 
@@ -119,12 +151,11 @@ if __name__ == "__main__":
119
  links = pd.read_csv("data/backup_2.csv", index_col="index")
120
  df = pd.read_csv("data/songs.csv")
121
  l = links["link"].str.strip()
122
- l = l.apply(lambda url : url if "http" in url else np.nan)
123
  l = l.dropna()
124
  df["Sample"].update(l)
125
- addna = lambda url : url if type(url) == str and "http" in url else np.nan
126
  df["Sample"] = df["Sample"].apply(addna)
127
- is_valid = validate_audio(df["Sample"],"data/samples")
128
  df["valid"] = is_valid
129
  df.to_csv("data/songs_validated.csv")
130
-
 
8
  import torch
9
  from tqdm import tqdm
10
 
11
+
12
+ def url_to_filename(url: str) -> str:
13
  return f"{url.split('/')[-1]}.wav"
14
 
15
+
16
+ def has_valid_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
17
  audio_urls = audio_urls.replace(".", np.nan)
18
  audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
19
+ valid_audio_mask = audio_urls.apply(
20
+ lambda url: url is not np.nan and url_to_filename(url) in audio_files
21
+ )
22
  return valid_audio_mask
23
 
24
+
25
+ def validate_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
26
  """
27
+ Tests audio urls to ensure that their file exists and the contents is valid.
28
  """
29
  audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
30
+
31
  def is_valid(url):
32
+ valid_url = type(url) == str and "http" in url
33
  if not valid_url:
34
  return False
35
  filename = url_to_filename(url)
 
39
  w, _ = torchaudio.load(os.path.join(audio_dir, filename))
40
  except:
41
  return False
42
+ contents_invalid = (
43
+ torch.any(torch.isnan(w))
44
+ or torch.any(torch.isinf(w))
45
+ or len(torch.unique(w)) <= 2
46
+ )
47
  return not contents_invalid
48
+
49
  idxs = []
50
  validations = []
51
+ for index, url in tqdm(
52
+ audio_urls.items(), total=len(audio_urls), desc="Audio URLs Validated"
53
+ ):
54
  idxs.append(index)
55
  validations.append(is_valid(url))
56
 
57
  return pd.Series(validations, index=idxs)
 
 
58
 
59
+
60
+ def fix_dance_rating_counts(dance_ratings: pd.Series) -> pd.Series:
61
  tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
62
+ dance_ratings = dance_ratings.apply(lambda v: json.loads(v.replace("'", '"')))
63
+
64
+ def fix_labels(labels: dict) -> dict | float:
65
  new_labels = {}
66
  for k, v in labels.items():
67
  match = tag_pattern.search(k)
 
69
  new_labels[k] = new_labels.get(k, 0) + v
70
  else:
71
  k = match[1]
72
+ sign = 1 if match[2] == "+" else -1
73
  scale = int(match[3])
74
  new_labels[k] = new_labels.get(k, 0) + v * scale * sign
75
  valid = any(v > 0 for v in new_labels.values())
76
  return new_labels if valid else np.nan
77
+
78
  return dance_ratings.apply(fix_labels)
79
 
80
 
81
+ def get_unique_labels(dance_labels: pd.Series) -> list:
82
  labels = set()
83
  for dances in dance_labels:
84
  labels |= set(dances)
85
  return sorted(labels)
86
 
87
+
88
+ def vectorize_label_probs(
89
+ labels: dict[str, int], unique_labels: np.ndarray
90
+ ) -> np.ndarray:
91
  """
92
  Turns label dict into probability distribution vector based on each label count.
93
  """
 
96
  item_vec = (unique_labels == k) * v
97
  label_vec += item_vec
98
  lv_cache = label_vec.copy()
99
+ label_vec[label_vec < 0] = 0
100
  label_vec /= label_vec.sum()
101
  assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
102
  return label_vec
103
 
104
+
105
+ def vectorize_multi_label(
106
+ labels: dict[str, int], unique_labels: np.ndarray
107
+ ) -> np.ndarray:
108
  """
109
  Turns label dict into binary label vectors for multi-label classification.
110
  """
111
+ probs = vectorize_label_probs(labels, unique_labels)
112
  probs[probs > 0.0] = 1.0
113
  return probs
114
 
115
+
116
+ def get_examples(
117
+ df: pd.DataFrame, audio_dir: str, class_list=None, multi_label=True, min_votes=1
118
+ ) -> tuple[np.ndarray, np.ndarray]:
119
+ sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
120
  sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
121
  if class_list is not None:
122
  class_list = set(class_list)
123
  sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
124
+ lambda labels: {k: v for k, v in labels.items() if k in class_list}
125
+ if not pd.isna(labels)
126
+ and any(label in class_list and amt > 0 for label, amt in labels.items())
127
+ else np.nan
128
+ )
129
  sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
130
+ vote_mask = sampled_songs["DanceRating"].apply(
131
+ lambda dances: any(votes >= min_votes for votes in dances.values())
132
+ )
133
  sampled_songs = sampled_songs[vote_mask]
134
+ labels = sampled_songs["DanceRating"].apply(
135
+ lambda dances: {
136
+ dance: votes for dance, votes in dances.items() if votes >= min_votes
137
+ }
138
+ )
139
  unique_labels = np.array(get_unique_labels(labels))
140
  vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
141
+ labels = labels.apply(lambda i: vectorizer(i, unique_labels))
142
 
143
+ audio_paths = [
144
+ os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]
145
+ ]
146
 
147
  return np.array(audio_paths), np.stack(labels)
148
 
 
151
  links = pd.read_csv("data/backup_2.csv", index_col="index")
152
  df = pd.read_csv("data/songs.csv")
153
  l = links["link"].str.strip()
154
+ l = l.apply(lambda url: url if "http" in url else np.nan)
155
  l = l.dropna()
156
  df["Sample"].update(l)
157
+ addna = lambda url: url if type(url) == str and "http" in url else np.nan
158
  df["Sample"] = df["Sample"].apply(addna)
159
+ is_valid = validate_audio(df["Sample"], "data/samples")
160
  df["valid"] = is_valid
161
  df.to_csv("data/songs_validated.csv")