Utsaha commited on
Commit
57933ea
1 Parent(s): e130848

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -9,10 +9,15 @@ import numpy as np
9
  # Set device
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
- # Load the fine-tuned model and processor
13
- model = Wav2Vec2ForSequenceClassification.from_pretrained("./fine_tuned_model", use_safetensors=True).to(device)
14
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
15
 
 
 
 
 
 
 
16
  # Load the label encoder
17
  label_encoder = LabelEncoder()
18
  label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])
 
9
  # Set device
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Load the processor
 
13
  processor = Wav2Vec2Processor.from_pretrained("./fine_tuned_model")
14
 
15
+ # Load the model
16
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=len(pd.read_csv("dataset/train_wav.csv")["Common Name"].unique()))
17
+ model.load_state_dict(torch.load("./fine_tuned_model/model.pt", map_location=device))
18
+ model.to(device)
19
+ model.eval()
20
+
21
  # Load the label encoder
22
  label_encoder = LabelEncoder()
23
  label_encoder.fit(pd.read_csv("dataset/train_wav.csv")["Common Name"])