Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
-
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
|
6 |
from sklearn.preprocessing import LabelEncoder
|
7 |
import numpy as np
|
8 |
|
@@ -10,7 +10,7 @@ import numpy as np
|
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
12 |
# Load the fine-tuned model and processor
|
13 |
-
model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model").to(device)
|
14 |
processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
|
15 |
|
16 |
# Load the label encoder
|
@@ -20,6 +20,15 @@ label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])
|
|
20 |
# Fixed audio length (e.g., 10 seconds)
|
21 |
fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Prediction function
|
24 |
def predict(file_path):
|
25 |
waveform, sample_rate = torchaudio.load(file_path)
|
@@ -40,12 +49,7 @@ def predict(file_path):
|
|
40 |
with torch.no_grad():
|
41 |
logits = model(inputs.input_values).logits
|
42 |
|
43 |
-
|
44 |
-
top5_idx = np.argsort(probabilities)[-5:][::-1]
|
45 |
-
top5_probs = probabilities[top5_idx]
|
46 |
-
top5_labels = label_encoder.inverse_transform(top5_idx)
|
47 |
-
|
48 |
-
return list(zip(top5_labels, top5_probs))
|
49 |
|
50 |
# Streamlit interface
|
51 |
st.title("Bird Sound Classification")
|
@@ -66,4 +70,4 @@ if uploaded_file is not None:
|
|
66 |
top5_predictions = predict(file_path)
|
67 |
st.success("Top 5 Predicted Bird Species with Probabilities:")
|
68 |
for label, prob in top5_predictions:
|
69 |
-
st.write(f"{label}: {prob:.4f}")
|
|
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
+
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor, safetensors
|
6 |
from sklearn.preprocessing import LabelEncoder
|
7 |
import numpy as np
|
8 |
|
|
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
12 |
# Load the fine-tuned model and processor
|
13 |
+
model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model", from_safetensors=True).to(device)
|
14 |
processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
|
15 |
|
16 |
# Load the label encoder
|
|
|
20 |
# Fixed audio length (e.g., 10 seconds)
|
21 |
fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
|
22 |
|
23 |
+
# Function to get top 5 predictions with probabilities
|
24 |
+
def get_top_5_predictions(logits, label_encoder):
|
25 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
|
26 |
+
top5_idx = np.argsort(probabilities, axis=-1)[:, -5:][:, ::-1] # Top 5 indices
|
27 |
+
top5_probs = np.take_along_axis(probabilities, top5_idx, axis=-1)
|
28 |
+
top5_labels = label_encoder.inverse_transform(top5_idx[0])
|
29 |
+
|
30 |
+
return list(zip(top5_labels, top5_probs[0]))
|
31 |
+
|
32 |
# Prediction function
|
33 |
def predict(file_path):
|
34 |
waveform, sample_rate = torchaudio.load(file_path)
|
|
|
49 |
with torch.no_grad():
|
50 |
logits = model(inputs.input_values).logits
|
51 |
|
52 |
+
return get_top_5_predictions(logits, label_encoder)
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# Streamlit interface
|
55 |
st.title("Bird Sound Classification")
|
|
|
70 |
top5_predictions = predict(file_path)
|
71 |
st.success("Top 5 Predicted Bird Species with Probabilities:")
|
72 |
for label, prob in top5_predictions:
|
73 |
+
st.write(f"{label}: {prob:.4f}")
|