yansong1616 commited on
Commit
c015010
1 Parent(s): 7da8c29

Update SAM2/sam2/sam2_video_predictor.py

Browse files
Files changed (1) hide show
  1. SAM2/sam2/sam2_video_predictor.py +2 -2
SAM2/sam2/sam2_video_predictor.py CHANGED
@@ -65,11 +65,11 @@ class SAM2VideoPredictor(SAM2Base):
65
  # the original video height and width, used for resizing final output scores
66
  inference_state["video_height"] = video_height
67
  inference_state["video_width"] = video_width
68
- inference_state["device"] = torch.device("cuda")
69
  if offload_state_to_cpu:
70
  inference_state["storage_device"] = torch.device("cpu")
71
  else:
72
- inference_state["storage_device"] = torch.device("cuda")
73
  # inputs on each frame
74
  inference_state["point_inputs_per_obj"] = {}
75
  inference_state["mask_inputs_per_obj"] = {}
 
65
  # the original video height and width, used for resizing final output scores
66
  inference_state["video_height"] = video_height
67
  inference_state["video_width"] = video_width
68
+ inference_state["device"] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
  if offload_state_to_cpu:
70
  inference_state["storage_device"] = torch.device("cpu")
71
  else:
72
+ inference_state["storage_device"] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73
  # inputs on each frame
74
  inference_state["point_inputs_per_obj"] = {}
75
  inference_state["mask_inputs_per_obj"] = {}