spycoder commited on
Commit
223eb95
1 Parent(s): 1c411ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -17,6 +17,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
19
  model_path = "dysarthria_classifier12.pth"
 
 
 
20
  if os.path.exists(model_path):
21
  print(f"Loading saved model {model_path}")
22
  model.load_state_dict(torch.load(model_path))
 
17
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
19
  model_path = "dysarthria_classifier12.pth"
20
+ model_path = '/home/user/app/dysarthria_classifier12.pth'
21
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
22
+
23
  if os.path.exists(model_path):
24
  print(f"Loading saved model {model_path}")
25
  model.load_state_dict(torch.load(model_path))