guarin commited on
Commit
e6eaebf
1 Parent(s): ded9d1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import numpy as np
3
  import cv2
 
4
 
5
  # import spaces
6
 
@@ -23,7 +24,8 @@ choice_mapping = {
23
  # @spaces.GPU
24
  def predict(model_choice: str, annotations, image):
25
  config_file, ckpt_path = choice_mapping[str(model_choice)]
26
- sam2_model = build_sam2(config_file, ckpt_path, device="cuda")
 
27
  predictor = SAM2ImagePredictor(sam2_model)
28
  predictor.set_image(image)
29
  coordinates = np.array(
 
1
  import gradio as gr
2
  import numpy as np
3
  import cv2
4
+ import torch
5
 
6
  # import spaces
7
 
 
24
  # @spaces.GPU
25
  def predict(model_choice: str, annotations, image):
26
  config_file, ckpt_path = choice_mapping[str(model_choice)]
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ sam2_model = build_sam2(config_file, ckpt_path, device=device)
29
  predictor = SAM2ImagePredictor(sam2_model)
30
  predictor.set_image(image)
31
  coordinates = np.array(