Dan Bochman commited on
Commit
2233e64
1 Parent(s): a76f764

move spaces.GPU to main segment scope

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -12,6 +12,9 @@ from torchvision import transforms
12
 
13
  # ----------------- ENV ----------------- #
14
 
 
 
 
15
 
16
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
17
 
@@ -107,12 +110,12 @@ model.eval()
107
  model.to("cuda")
108
 
109
 
110
- @spaces.GPU
111
  @torch.inference_mode()
112
  def run_model(input_tensor, height, width):
113
- output = model(input_tensor)
114
- output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
115
- _, preds = torch.max(output, 1)
 
116
  return preds
117
 
118
 
@@ -126,8 +129,9 @@ transform_fn = transforms.Compose(
126
  # ----------------- CORE FUNCTION ----------------- #
127
 
128
 
 
129
  def segment(image: Image.Image) -> Image.Image:
130
- input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
131
  preds = run_model(input_tensor, height=image.height, width=image.width)
132
  mask = preds.squeeze(0).cpu().numpy()
133
  mask_image = Image.fromarray(mask.astype("uint8"))
 
12
 
13
  # ----------------- ENV ----------------- #
14
 
15
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
 
19
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
20
 
 
110
  model.to("cuda")
111
 
112
 
 
113
  @torch.inference_mode()
114
  def run_model(input_tensor, height, width):
115
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
116
+ output = model(input_tensor)
117
+ output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
118
+ _, preds = torch.max(output, 1)
119
  return preds
120
 
121
 
 
129
  # ----------------- CORE FUNCTION ----------------- #
130
 
131
 
132
+ @spaces.GPU
133
  def segment(image: Image.Image) -> Image.Image:
134
+ input_tensor = transform_fn(image).unsqueeze(0)
135
  preds = run_model(input_tensor, height=image.height, width=image.width)
136
  mask = preds.squeeze(0).cpu().numpy()
137
  mask_image = Image.fromarray(mask.astype("uint8"))