spycoder commited on
Commit
5914cfd
1 Parent(s): 2cbb9da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -16,14 +16,13 @@ from collections import Counter
16
  device = torch.device("cpu")
17
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
19
- # model_path = "dysarthria_classifier12.pth"
20
  # model_path = 'model_weights2.pth'
21
- model_path = '/home/user/app/dysarthria_classifier10.pth'
22
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
23
 
24
- # if os.path.exists(model_path):
25
- # print(f"Loading saved model {model_path}")
26
- # model.load_state_dict(torch.load(model_path))
27
 
28
 
29
  title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
@@ -65,13 +64,13 @@ def predict(file_path):
65
  predicted_class_id = torch.argmax(logits, dim=-1).item()
66
 
67
  return predicted_class_id
68
- # gr.Interface(
69
- # fn=predict,
70
- # inputs="file",
71
- # outputs="text",
72
- # title=title,
73
- # description=description,
74
- # ).launch()
75
 
76
- iface = gr.Interface(fn=predict, inputs="file", outputs="text")
77
- iface.launch()
 
16
  device = torch.device("cpu")
17
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
19
+ model_path = "dysarthria_classifier12.pth"
20
  # model_path = 'model_weights2.pth'
21
+ # model_path = '/home/user/app/dysarthria_classifier10.pth'
 
22
 
23
+ if os.path.exists(model_path):
24
+ print(f"Loading saved model {model_path}")
25
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
26
 
27
 
28
  title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
 
64
  predicted_class_id = torch.argmax(logits, dim=-1).item()
65
 
66
  return predicted_class_id
67
+ gr.Interface(
68
+ fn=predict,
69
+ inputs="file",
70
+ outputs="text",
71
+ title=title,
72
+ description=description,
73
+ ).launch()
74
 
75
+ # iface = gr.Interface(fn=predict, inputs="file", outputs="text")
76
+ # iface.launch()