spycoder commited on
Commit
9beef86
1 Parent(s): 547b1d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -39
app.py CHANGED
@@ -16,18 +16,18 @@ from collections import Counter
16
  device = torch.device("cpu")
17
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
19
- # model_path = "dysarthria_classifier12.pth"
20
- # model_path = 'model_weights2.pth'
21
- model_path = '/home/user/app/dysarthria_classifier12.pth'
22
 
23
- if os.path.exists(model_path):
24
- print(f"Loading saved model {model_path}")
25
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
26
 
27
 
28
- title = "Upload an mp3 file for supranuclear palsy (SP) detection! (Thai Language)"
29
  description = """
30
- The model was trained on Thai audio recordings with the following sentences, so submit audio recordings for one of these sentences:\n
31
  ชาวไร่ตัดต้นสนทำท่อนซุง\n
32
  ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
33
  อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
@@ -39,7 +39,13 @@ The model was trained on Thai audio recordings with the following sentences, so
39
  <img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
40
  """
41
 
42
- def actualpredict(file_path):
 
 
 
 
 
 
43
  model.eval()
44
  with torch.no_grad():
45
  wav_data, _ = sf.read(file_path.name)
@@ -56,44 +62,15 @@ def actualpredict(file_path):
56
  logits = model(**inputs).logits
57
  logits = logits.squeeze()
58
  predicted_class_id = torch.argmax(logits, dim=-1).item()
59
- return predicted_class_id
60
-
61
-
62
- def predict(file_upload):
63
-
64
- max_length = 100000
65
- warn_output = " "
66
- ans = " "
67
- # file_path = file_upload
68
- # if (microphone is not None) and (file_upload is not None):
69
- # warn_output = (
70
- # "WARNING: You've uploaded an audio file and used the microphone. "
71
- # "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
72
- # )
73
-
74
- # elif (microphone is None) and (file_upload is None):
75
- # return "ERROR: You have to either use the microphone or upload an audio file"
76
- # if(microphone is not None):
77
- # file_path = microphone
78
- # if(file_upload is not None):
79
- # file_path = file_upload
80
 
81
- predicted_class_id = actualpredict(file_upload)
82
- if(predicted_class_id==0):
83
- ans = "no_parkinson"
84
- else:
85
- ans = "parkinson"
86
  return predicted_class_id
87
  gr.Interface(
88
  fn=predict,
89
- inputs=[
90
- gr.inputs.Audio(source="upload", type="filepath", optional=True),
91
- ],
92
  outputs="text",
93
  title=title,
94
  description=description,
95
  ).launch()
96
 
97
- # gr.inputs.Audio(source="microphone", type="filepath", optional=True),
98
  # iface = gr.Interface(fn=predict, inputs="file", outputs="text")
99
  # iface.launch()
 
16
  device = torch.device("cpu")
17
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
19
+ model_path = "dysarthria_classifier12.pth"
20
+ # model_path = '/home/user/app/dysarthria_classifier12.pth'
21
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
22
 
23
+ # if os.path.exists(model_path):
24
+ # print(f"Loading saved model {model_path}")
25
+ # model.load_state_dict(torch.load(model_path))
26
 
27
 
28
+ title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
29
  description = """
30
+ The model was trained on Thai audio recordings with the following sentences: \n
31
  ชาวไร่ตัดต้นสนทำท่อนซุง\n
32
  ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
33
  อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
 
39
  <img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
40
  """
41
 
42
+
43
+
44
+
45
+
46
+ def predict(file_path):
47
+ max_length = 100000
48
+
49
  model.eval()
50
  with torch.no_grad():
51
  wav_data, _ = sf.read(file_path.name)
 
62
  logits = model(**inputs).logits
63
  logits = logits.squeeze()
64
  predicted_class_id = torch.argmax(logits, dim=-1).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
 
 
 
 
66
  return predicted_class_id
67
  gr.Interface(
68
  fn=predict,
69
+ inputs="file",
 
 
70
  outputs="text",
71
  title=title,
72
  description=description,
73
  ).launch()
74
 
 
75
  # iface = gr.Interface(fn=predict, inputs="file", outputs="text")
76
  # iface.launch()