Utsaha commited on
Commit
0b620d9
1 Parent(s): 28432e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import pandas as pd
3
  import torch
4
  import torchaudio
5
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor, safetensors
6
  from sklearn.preprocessing import LabelEncoder
7
  import numpy as np
8
 
@@ -10,7 +10,7 @@ import numpy as np
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Load the fine-tuned model and processor
13
- model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model", from_safetensors=True).to(device)
14
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
15
 
16
  # Load the label encoder
@@ -22,6 +22,7 @@ fixed_length = 10 * 16000 # 10 seconds * 16000 Hz
22
 
23
  # Function to get top 5 predictions with probabilities
24
  def get_top_5_predictions(logits, label_encoder):
 
25
  probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
26
  top5_idx = np.argsort(probabilities, axis=-1)[:, -5:][:, ::-1] # Top 5 indices
27
  top5_probs = np.take_along_axis(probabilities, top5_idx, axis=-1)
@@ -49,7 +50,7 @@ def predict(file_path):
49
  with torch.no_grad():
50
  logits = model(inputs.input_values).logits
51
 
52
- return get_top_5_predictions(logits, label_encoder)
53
 
54
  # Streamlit interface
55
  st.title("Bird Sound Classification")
 
2
  import pandas as pd
3
  import torch
4
  import torchaudio
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
6
  from sklearn.preprocessing import LabelEncoder
7
  import numpy as np
8
 
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Load the fine-tuned model and processor
13
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model", use_safetensors=True).to(device)
14
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
15
 
16
  # Load the label encoder
 
22
 
23
  # Function to get top 5 predictions with probabilities
24
  def get_top_5_predictions(logits, label_encoder):
25
+ logits = torch.tensor(logits) # Convert numpy array to PyTorch tensor
26
  probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
27
  top5_idx = np.argsort(probabilities, axis=-1)[:, -5:][:, ::-1] # Top 5 indices
28
  top5_probs = np.take_along_axis(probabilities, top5_idx, axis=-1)
 
50
  with torch.no_grad():
51
  logits = model(inputs.input_values).logits
52
 
53
+ return get_top_5_predictions(logits.cpu().numpy(), label_encoder)
54
 
55
  # Streamlit interface
56
  st.title("Bird Sound Classification")