Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,8 +12,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
12 |
# Load the processor
|
13 |
processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
|
14 |
|
15 |
-
#
|
16 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
model.to(device)
|
18 |
model.eval()
|
19 |
|
@@ -75,4 +81,4 @@ if uploaded_file is not None:
|
|
75 |
top5_predictions = predict(file_path)
|
76 |
st.success("Top 5 Predicted Bird Species with Probabilities:")
|
77 |
for label, prob in top5_predictions:
|
78 |
-
st.write(f"{label}: {prob:.4f}")
|
|
|
12 |
# Load the processor
|
13 |
processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
|
14 |
|
15 |
+
# Initialize the model
|
16 |
+
model = Wav2Vec2ForSequenceClassification.from_pretrained(
|
17 |
+
"facebook/wav2vec2-base-960h",
|
18 |
+
num_labels=len(pd.read_csv("dataset/train_wav.csv")["Common Name"].unique())
|
19 |
+
)
|
20 |
+
|
21 |
+
# Load the model's state dictionary
|
22 |
+
model.load_state_dict(torch.load("./fine_tuned_model/model_state_dict.pt", map_location=device))
|
23 |
model.to(device)
|
24 |
model.eval()
|
25 |
|
|
|
81 |
top5_predictions = predict(file_path)
|
82 |
st.success("Top 5 Predicted Bird Species with Probabilities:")
|
83 |
for label, prob in top5_predictions:
|
84 |
+
st.write(f"{label}: {prob:.4f}")
|