waidhoferj commited on
Commit
dad3c09
1 Parent(s): 17f9fb1

map_location on torch.load

Browse files
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -3,7 +3,6 @@ import gradio as gr
3
  import numpy as np
4
  import torch
5
  from preprocessing.preprocess import AudioPipeline
6
- from preprocessing.preprocess import AudioPipeline
7
  from models.residual import ResidualDancer
8
  import os
9
  import json
@@ -23,7 +22,7 @@ def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
23
  labels = np.array(sorted(config["classes"]))
24
 
25
  model = ResidualDancer(n_classes=len(labels))
26
- model.load_state_dict(torch.load(weights))
27
  model = model.to(device).eval()
28
  return model, labels
29
 
 
3
  import numpy as np
4
  import torch
5
  from preprocessing.preprocess import AudioPipeline
 
6
  from models.residual import ResidualDancer
7
  import os
8
  import json
 
22
  labels = np.array(sorted(config["classes"]))
23
 
24
  model = ResidualDancer(n_classes=len(labels))
25
+ model.load_state_dict(torch.load(weights, map_location=DEVICE))
26
  model = model.to(device).eval()
27
  return model, labels
28