TerraTunes / app.py
Utsaha's picture
Update app.py
28432e9 verified
raw
history blame
2.7 kB
import streamlit as st
import pandas as pd
import torch
import torchaudio
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor, safetensors
from sklearn.preprocessing import LabelEncoder
import numpy as np
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the fine-tuned model and processor
model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model", from_safetensors=True).to(device)
processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
# Load the label encoder
label_encoder = LabelEncoder()
label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])
# Fixed audio length (e.g., 10 seconds)
fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
# Function to get top 5 predictions with probabilities
def get_top_5_predictions(logits, label_encoder):
probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
top5_idx = np.argsort(probabilities, axis=-1)[:, -5:][:, ::-1] # Top 5 indices
top5_probs = np.take_along_axis(probabilities, top5_idx, axis=-1)
top5_labels = label_encoder.inverse_transform(top5_idx[0])
return list(zip(top5_labels, top5_probs[0]))
# Prediction function
def predict(file_path):
waveform, sample_rate = torchaudio.load(file_path)
# Ensure the audio is mono
if waveform.size(0) > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Ensure the audio is exactly 10 seconds
if waveform.size(1) > fixed_length:
waveform = waveform[:, :fixed_length]
else:
padding = fixed_length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, padding))
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
logits = model(inputs.input_values).logits
return get_top_5_predictions(logits, label_encoder)
# Streamlit interface
st.title("Bird Sound Classification")
uploaded_file = st.file_uploader("Upload an audio file", type=["wav"])
if uploaded_file is not None:
# Save the uploaded file temporarily
file_path = f"temp/{uploaded_file.name}"
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.audio(file_path, format='audio/wav')
if st.button("Predict"):
with st.spinner("Classifying..."):
top5_predictions = predict(file_path)
st.success("Top 5 Predicted Bird Species with Probabilities:")
for label, prob in top5_predictions:
st.write(f"{label}: {prob:.4f}")