TerraTunes / app.py
Utsaha's picture
Update app.py
57933ea verified
raw
history blame
2.94 kB
import streamlit as st
import pandas as pd
import torch
import torchaudio
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
from sklearn.preprocessing import LabelEncoder
import numpy as np
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the processor
processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
# Load the model
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=len(pd.read_csv("dataset/train_wav.csv")["Common Name"].unique()))
model.load_state_dict(torch.load("./fine_tuned_model/model.pt", map_location=device))
model.to(device)
model.eval()
# Load the label encoder
label_encoder = LabelEncoder()
label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])
# Fixed audio length (e.g., 10 seconds)
fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
# Function to get top 5 predictions with probabilities
def get_top_5_predictions(logits, label_encoder):
logits = torch.tensor(logits) # Convert numpy array to PyTorch tensor
probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
top5_idx = np.argsort(probabilities, axis=-1)[:, -5:][:, ::-1] # Top 5 indices
top5_probs = np.take_along_axis(probabilities, top5_idx, axis=-1)
top5_labels = label_encoder.inverse_transform(top5_idx[0])
return list(zip(top5_labels, top5_probs[0]))
# Prediction function
def predict(file_path):
waveform, sample_rate = torchaudio.load(file_path)
# Ensure the audio is mono
if waveform.size(0) > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Ensure the audio is exactly 10 seconds
if waveform.size(1) > fixed_length:
waveform = waveform[:, :fixed_length]
else:
padding = fixed_length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, padding))
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
logits = model(inputs.input_values).logits
return get_top_5_predictions(logits.cpu().numpy(), label_encoder)
# Streamlit interface
st.title("Bird Sound Classification")
uploaded_file = st.file_uploader("Upload an audio file", type=["wav"])
if uploaded_file is not None:
# Save the uploaded file temporarily
file_path = f"temp/{uploaded_file.name}"
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.audio(file_path, format='audio/wav')
if st.button("Predict"):
with st.spinner("Classifying..."):
top5_predictions = predict(file_path)
st.success("Top 5 Predicted Bird Species with Probabilities:")
for label, prob in top5_predictions:
st.write(f"{label}: {prob:.4f}")