spycoder commited on
Commit
1c411ce
1 Parent(s): c4ab564

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -12,13 +12,14 @@ import torch.nn.functional as F
12
  from torch.utils.data import Dataset, DataLoader
13
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
14
  from collections import Counter
 
 
15
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
16
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
17
  model_path = "dysarthria_classifier12.pth"
18
  if os.path.exists(model_path):
19
  print(f"Loading saved model {model_path}")
20
  model.load_state_dict(torch.load(model_path))
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  def predict(file_path):
23
  max_length = 100000
24
 
 
12
  from torch.utils.data import Dataset, DataLoader
13
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
14
  from collections import Counter
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
19
  model_path = "dysarthria_classifier12.pth"
20
  if os.path.exists(model_path):
21
  print(f"Loading saved model {model_path}")
22
  model.load_state_dict(torch.load(model_path))
 
23
  def predict(file_path):
24
  max_length = 100000
25