Utsaha commited on
Commit
bbbd935
1 Parent(s): 55d6996

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ import torchaudio
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
6
+ from sklearn.preprocessing import LabelEncoder
7
+ import numpy as np
8
+
9
+ # Set device
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Load the fine-tuned model and processor
13
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model").to(device)
14
+ processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
15
+
16
+ # Load the label encoder
17
+ label_encoder = LabelEncoder()
18
+ label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])
19
+
20
+ # Fixed audio length (e.g., 10 seconds)
21
+ fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
22
+
23
+ # Prediction function
24
+ def predict(file_path):
25
+ waveform, sample_rate = torchaudio.load(file_path)
26
+
27
+ # Ensure the audio is mono
28
+ if waveform.size(0) > 1:
29
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
30
+
31
+ # Ensure the audio is exactly 10 seconds
32
+ if waveform.size(1) > fixed_length:
33
+ waveform = waveform[:, :fixed_length]
34
+ else:
35
+ padding = fixed_length - waveform.size(1)
36
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
37
+
38
+ inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt", padding=True).to(device)
39
+
40
+ with torch.no_grad():
41
+ logits = model(inputs.input_values).logits
42
+
43
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
44
+ top5_idx = np.argsort(probabilities)[-5:][::-1]
45
+ top5_probs = probabilities[top5_idx]
46
+ top5_labels = label_encoder.inverse_transform(top5_idx)
47
+
48
+ return list(zip(top5_labels, top5_probs))
49
+
50
+ # Streamlit interface
51
+ st.title("Bird Sound Classification")
52
+
53
+ uploaded_file = st.file_uploader("Upload an audio file", type=["wav"])
54
+
55
+ if uploaded_file is not None:
56
+ # Save the uploaded file temporarily
57
+ file_path = f"temp/{uploaded_file.name}"
58
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
59
+ with open(file_path, "wb") as f:
60
+ f.write(uploaded_file.getbuffer())
61
+
62
+ st.audio(file_path, format='audio/wav')
63
+
64
+ if st.button("Predict"):
65
+ with st.spinner("Classifying..."):
66
+ top5_predictions = predict(file_path)
67
+ st.success("Top 5 Predicted Bird Species with Probabilities:")
68
+ for label, prob in top5_predictions:
69
+ st.write(f"{label}: {prob:.4f}")