TerraTunes / app.py
Utsaha's picture
Create app.py
bbbd935 verified
raw
history blame
2.44 kB
import streamlit as st
import pandas as pd
import torch
import torchaudio
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
from sklearn.preprocessing import LabelEncoder
import numpy as np
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the fine-tuned model and processor
model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model").to(device)
processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
# Load the label encoder
label_encoder = LabelEncoder()
label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])
# Fixed audio length (e.g., 10 seconds)
fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
# Prediction function
def predict(file_path):
waveform, sample_rate = torchaudio.load(file_path)
# Ensure the audio is mono
if waveform.size(0) > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Ensure the audio is exactly 10 seconds
if waveform.size(1) > fixed_length:
waveform = waveform[:, :fixed_length]
else:
padding = fixed_length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, padding))
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
logits = model(inputs.input_values).logits
probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
top5_idx = np.argsort(probabilities)[-5:][::-1]
top5_probs = probabilities[top5_idx]
top5_labels = label_encoder.inverse_transform(top5_idx)
return list(zip(top5_labels, top5_probs))
# Streamlit interface
st.title("Bird Sound Classification")
uploaded_file = st.file_uploader("Upload an audio file", type=["wav"])
if uploaded_file is not None:
# Save the uploaded file temporarily
file_path = f"temp/{uploaded_file.name}"
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.audio(file_path, format='audio/wav')
if st.button("Predict"):
with st.spinner("Classifying..."):
top5_predictions = predict(file_path)
st.success("Top 5 Predicted Bird Species with Probabilities:")
for label, prob in top5_predictions:
st.write(f"{label}: {prob:.4f}")