Utsaha commited on
Commit
28432e9
1 Parent(s): 37fd1e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import pandas as pd
3
  import torch
4
  import torchaudio
5
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
6
  from sklearn.preprocessing import LabelEncoder
7
  import numpy as np
8
 
@@ -10,7 +10,7 @@ import numpy as np
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Load the fine-tuned model and processor
13
- model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model").to(device)
14
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
15
 
16
  # Load the label encoder
@@ -20,6 +20,15 @@ label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])
20
  # Fixed audio length (e.g., 10 seconds)
21
  fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
22
 
 
 
 
 
 
 
 
 
 
23
  # Prediction function
24
  def predict(file_path):
25
  waveform, sample_rate = torchaudio.load(file_path)
@@ -40,12 +49,7 @@ def predict(file_path):
40
  with torch.no_grad():
41
  logits = model(inputs.input_values).logits
42
 
43
- probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
44
- top5_idx = np.argsort(probabilities)[-5:][::-1]
45
- top5_probs = probabilities[top5_idx]
46
- top5_labels = label_encoder.inverse_transform(top5_idx)
47
-
48
- return list(zip(top5_labels, top5_probs))
49
 
50
  # Streamlit interface
51
  st.title("Bird Sound Classification")
@@ -66,4 +70,4 @@ if uploaded_file is not None:
66
  top5_predictions = predict(file_path)
67
  st.success("Top 5 Predicted Bird Species with Probabilities:")
68
  for label, prob in top5_predictions:
69
- st.write(f"{label}: {prob:.4f}")
 
2
  import pandas as pd
3
  import torch
4
  import torchaudio
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor, safetensors
6
  from sklearn.preprocessing import LabelEncoder
7
  import numpy as np
8
 
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Load the fine-tuned model and processor
13
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model", from_safetensors=True).to(device)
14
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
15
 
16
  # Load the label encoder
 
20
  # Fixed audio length (e.g., 10 seconds)
21
  fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
22
 
23
+ # Function to get top 5 predictions with probabilities
24
+ def get_top_5_predictions(logits, label_encoder):
25
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
26
+ top5_idx = np.argsort(probabilities, axis=-1)[:, -5:][:, ::-1] # Top 5 indices
27
+ top5_probs = np.take_along_axis(probabilities, top5_idx, axis=-1)
28
+ top5_labels = label_encoder.inverse_transform(top5_idx[0])
29
+
30
+ return list(zip(top5_labels, top5_probs[0]))
31
+
32
  # Prediction function
33
  def predict(file_path):
34
  waveform, sample_rate = torchaudio.load(file_path)
 
49
  with torch.no_grad():
50
  logits = model(inputs.input_values).logits
51
 
52
+ return get_top_5_predictions(logits, label_encoder)
 
 
 
 
 
53
 
54
  # Streamlit interface
55
  st.title("Bird Sound Classification")
 
70
  top5_predictions = predict(file_path)
71
  st.success("Top 5 Predicted Bird Species with Probabilities:")
72
  for label, prob in top5_predictions:
73
+ st.write(f"{label}: {prob:.4f}")