Utsaha commited on
Commit
f84475d
1 Parent(s): 913c27a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -12,8 +12,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  # Load the processor
13
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
14
 
15
- # Load the entire model
16
- model = torch.load("./fine_tuned_model/model.pt", map_location=device)
 
 
 
 
 
 
17
  model.to(device)
18
  model.eval()
19
 
@@ -75,4 +81,4 @@ if uploaded_file is not None:
75
  top5_predictions = predict(file_path)
76
  st.success("Top 5 Predicted Bird Species with Probabilities:")
77
  for label, prob in top5_predictions:
78
- st.write(f"{label}: {prob:.4f}")
 
12
  # Load the processor
13
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
14
 
15
+ # Initialize the model
16
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(
17
+ "facebook/wav2vec2-base-960h",
18
+ num_labels=len(pd.read_csv("dataset/train_wav.csv")["Common Name"].unique())
19
+ )
20
+
21
+ # Load the model's state dictionary
22
+ model.load_state_dict(torch.load("./fine_tuned_model/model_state_dict.pt", map_location=device))
23
  model.to(device)
24
  model.eval()
25
 
 
81
  top5_predictions = predict(file_path)
82
  st.success("Top 5 Predicted Bird Species with Probabilities:")
83
  for label, prob in top5_predictions:
84
+ st.write(f"{label}: {prob:.4f}")