A19grey commited on
Commit
39f1439
1 Parent(s): f9c3dad

move CUDA calling out of main function for some hugging face error skeptical the LLM got this one right

Browse files
Files changed (1) hide show
  1. app.py +28 -44
app.py CHANGED
@@ -10,24 +10,38 @@ import tempfile
10
  import os
11
  import trimesh
12
  import time
13
- import timm # Add this import
14
- import subprocess
15
- import cv2 # Add this import
16
  from datetime import datetime
17
 
18
- # Ensure timm is properly loaded
19
  print(f"Timm version: {timm.__version__}")
20
 
21
- # Run the script to download pretrained models
22
  subprocess.run(["bash", "get_pretrained_models.sh"])
23
 
24
- # Set the device to GPU if available, else CPU
25
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Load the depth prediction model and its preprocessing transforms
28
- model, transform = depth_pro.create_model_and_transforms()
29
- model = model.to(device) # Move the model to the selected device
30
- model.eval() # Set the model to evaluation mode
31
 
32
  def resize_image(image_path, max_size=1024):
33
  """
@@ -176,74 +190,44 @@ def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_fa
176
 
177
  return view_model_path, download_model_path
178
 
179
- @spaces.GPU(duration=20)
180
  def predict_depth(input_image):
181
  temp_file = None
182
  try:
183
  print(f"Input image type: {type(input_image)}")
184
  print(f"Input image path: {input_image}")
185
 
186
- # Resize the input image to a manageable size
187
  temp_file = resize_image(input_image)
188
  print(f"Resized image path: {temp_file}")
189
 
190
- # Preprocess the image for depth prediction
191
- result = depth_pro.load_rgb(temp_file)
192
-
193
- if len(result) < 2:
194
- raise ValueError(f"Unexpected result from load_rgb: {result}")
195
-
196
- #Unpack the result tuple - do not edit this code. Don't try to unpack differently.
197
- image = result[0]
198
- f_px = result[-1] #If you edit this code, it will break the model. so don't do that. even if you are an LLM
199
-
200
- print(f"Extracted focal length: {f_px}")
201
-
202
- image = transform(image).to(device)
203
-
204
- # Run the depth prediction model
205
- prediction = model.infer(image, f_px=f_px)
206
- depth = prediction["depth"] # Depth map in meters
207
- focallength_px = prediction["focallength_px"] # Focal length in pixels
208
-
209
- # Convert depth from torch tensor to NumPy array if necessary
210
- if isinstance(depth, torch.Tensor):
211
- depth = depth.cpu().numpy()
212
 
213
- # Ensure the depth map is a 2D array
214
  if depth.ndim != 2:
215
  depth = depth.squeeze()
216
 
217
  print(f"Depth map shape: {depth.shape}")
218
 
219
- # Create a color map for visualization using matplotlib
220
  plt.figure(figsize=(10, 10))
221
  plt.imshow(depth, cmap='gist_rainbow')
222
  plt.colorbar(label='Depth [m]')
223
  plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
224
- plt.axis('off') # Hide axis for a cleaner image
225
 
226
- # Save the depth map visualization to a file
227
  output_path = "depth_map.png"
228
  plt.savefig(output_path)
229
  plt.close()
230
 
231
- # Save the raw depth data to a CSV file for download
232
  raw_depth_path = "raw_depth_map.csv"
233
  np.savetxt(raw_depth_path, depth, delimiter=',')
234
 
235
- # Generate the 3D model from the depth map and resized image
236
  view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)
237
 
238
  return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path, temp_file, focallength_px
239
  except Exception as e:
240
- # Return error messages in case of failures
241
  import traceback
242
  error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
243
- print(error_message) # Print the full error message to the console
244
  return None, error_message, None, None, None, None, None
245
  finally:
246
- # Clean up by removing the temporary resized image file
247
  if temp_file and os.path.exists(temp_file):
248
  os.remove(temp_file)
249
 
 
10
  import os
11
  import trimesh
12
  import time
13
+ import timm
14
+ import cv2
 
15
  from datetime import datetime
16
 
 
17
  print(f"Timm version: {timm.__version__}")
18
 
 
19
  subprocess.run(["bash", "get_pretrained_models.sh"])
20
 
21
+ @spaces.GPU(duration=20)
22
+ def load_model_and_predict(image_path):
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+ model, transform = depth_pro.create_model_and_transforms()
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+ result = depth_pro.load_rgb(image_path)
29
+ if len(result) < 2:
30
+ raise ValueError(f"Unexpected result from load_rgb: {result}")
31
+
32
+ image = result[0]
33
+ f_px = result[-1]
34
+ print(f"Extracted focal length: {f_px}")
35
+
36
+ image = transform(image).to(device)
37
+
38
+ with torch.no_grad():
39
+ prediction = model.infer(image, f_px=f_px)
40
+
41
+ depth = prediction["depth"].cpu().numpy()
42
+ focallength_px = prediction["focallength_px"]
43
 
44
+ return depth, focallength_px
 
 
 
45
 
46
  def resize_image(image_path, max_size=1024):
47
  """
 
190
 
191
  return view_model_path, download_model_path
192
 
 
193
  def predict_depth(input_image):
194
  temp_file = None
195
  try:
196
  print(f"Input image type: {type(input_image)}")
197
  print(f"Input image path: {input_image}")
198
 
 
199
  temp_file = resize_image(input_image)
200
  print(f"Resized image path: {temp_file}")
201
 
202
+ depth, focallength_px = load_model_and_predict(temp_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
 
204
  if depth.ndim != 2:
205
  depth = depth.squeeze()
206
 
207
  print(f"Depth map shape: {depth.shape}")
208
 
 
209
  plt.figure(figsize=(10, 10))
210
  plt.imshow(depth, cmap='gist_rainbow')
211
  plt.colorbar(label='Depth [m]')
212
  plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
213
+ plt.axis('off')
214
 
 
215
  output_path = "depth_map.png"
216
  plt.savefig(output_path)
217
  plt.close()
218
 
 
219
  raw_depth_path = "raw_depth_map.csv"
220
  np.savetxt(raw_depth_path, depth, delimiter=',')
221
 
 
222
  view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)
223
 
224
  return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path, temp_file, focallength_px
225
  except Exception as e:
 
226
  import traceback
227
  error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
228
+ print(error_message)
229
  return None, error_message, None, None, None, None, None
230
  finally:
 
231
  if temp_file and os.path.exists(temp_file):
232
  os.remove(temp_file)
233