A19grey commited on
Commit
f95c546
1 Parent(s): 0efa5f1

First try to add 3D model creation

Browse files
Files changed (1) hide show
  1. app.py +141 -36
app.py CHANGED
@@ -8,100 +8,205 @@ import spaces
8
  import torch
9
  import tempfile
10
  import os
 
11
 
12
- # Run the script to get pretrained models
13
  subprocess.run(["bash", "get_pretrained_models.sh"])
14
 
 
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
 
17
- # Load model and preprocessing transform
18
  model, transform = depth_pro.create_model_and_transforms()
19
- model = model.to(device)
20
- model.eval()
21
 
22
  def resize_image(image_path, max_size=1024):
 
 
 
 
 
 
 
 
 
 
 
23
  with Image.open(image_path) as img:
24
- # Calculate the new size while maintaining aspect ratio
25
  ratio = max_size / max(img.size)
26
  new_size = tuple([int(x * ratio) for x in img.size])
27
 
28
- # Resize the image
29
  img = img.resize(new_size, Image.LANCZOS)
30
 
31
- # Create a temporary file
32
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
33
  img.save(temp_file, format="PNG")
34
  return temp_file.name
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @spaces.GPU(duration=20)
37
  def predict_depth(input_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  temp_file = None
39
  try:
40
- # Resize the input image
41
  temp_file = resize_image(input_image)
42
 
43
- # Preprocess the image
44
  result = depth_pro.load_rgb(temp_file)
45
  image = result[0]
46
- f_px = result[-1] # Assuming f_px is the last item in the returned tuple
47
- image = transform(image)
48
- image = image.to(device)
49
 
50
- # Run inference
51
  prediction = model.infer(image, f_px=f_px)
52
- depth = prediction["depth"] # Depth in [m]
53
  focallength_px = prediction["focallength_px"] # Focal length in pixels
54
 
55
- # Convert depth to numpy array if it's a torch tensor
56
  if isinstance(depth, torch.Tensor):
57
  depth = depth.cpu().numpy()
58
 
59
- # Ensure depth is a 2D numpy array
60
  if depth.ndim != 2:
61
  depth = depth.squeeze()
62
 
63
- # Normalize depth for visualization
64
- # agk - No never normalize depth. It is already in meters. EMBRACE REALITY. TOUCH GRASS.
 
 
 
 
 
 
 
 
 
 
 
 
65
  depth_min = np.min(depth)
66
  depth_max = np.max(depth)
67
- depth_normalized = depth #it is normal to have depth in meters. Normalize reality.
68
-
69
- # Create a color map
70
  plt.figure(figsize=(10, 10))
71
  plt.imshow(depth_normalized, cmap='gist_rainbow')
72
  plt.colorbar(label='Depth [m]')
73
  plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
74
- plt.axis('off')
75
-
76
- # Save the plot to a file
77
  output_path = "depth_map.png"
78
  plt.savefig(output_path)
79
  plt.close()
80
 
81
- # Save raw depth data as CSV
82
  raw_depth_path = "raw_depth_map.csv"
83
  np.savetxt(raw_depth_path, depth, delimiter=',')
84
 
85
- return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path
 
 
 
86
  except Exception as e:
87
- return None, f"An error occurred: {str(e)}", None
 
88
  finally:
89
- # Clean up the temporary file
90
  if temp_file and os.path.exists(temp_file):
91
  os.remove(temp_file)
92
 
93
- # Create Gradio interface
94
  iface = gr.Interface(
95
  fn=predict_depth,
96
  inputs=gr.Image(type="filepath"),
97
  outputs=[
98
- gr.Image(type="filepath", label="Depth Map"),
99
- gr.Textbox(label="Focal Length or Error Message"),
100
- gr.File(label="Download Raw Depth Map (CSV)")
 
101
  ],
102
- title="DepthPro Demo in Meters fork from akhaliq (original description below)",
103
- description="forked from https://huggingface.co/spaces/akhaliq/depth-pro to add raw meters output. [DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its depth map and focal length. Large images will be automatically resized. You can also download the raw depth map data as a CSV file."
 
 
 
 
 
 
 
104
  )
105
 
106
- # Launch the interface
107
- iface.launch(share=True) #share=True allows you to share the interface with others.
 
8
  import torch
9
  import tempfile
10
  import os
11
+ import trimesh
12
 
13
+ # Run the script to download pretrained models
14
  subprocess.run(["bash", "get_pretrained_models.sh"])
15
 
16
+ # Set the device to GPU if available, else CPU
17
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
 
19
+ # Load the depth prediction model and its preprocessing transforms
20
  model, transform = depth_pro.create_model_and_transforms()
21
+ model = model.to(device) # Move the model to the selected device
22
+ model.eval() # Set the model to evaluation mode
23
 
24
  def resize_image(image_path, max_size=1024):
25
+ """
26
+ Resize the input image to ensure its largest dimension does not exceed max_size.
27
+ Maintains the aspect ratio and saves the resized image as a temporary PNG file.
28
+
29
+ Args:
30
+ image_path (str): Path to the input image.
31
+ max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.
32
+
33
+ Returns:
34
+ str: Path to the resized temporary image file.
35
+ """
36
  with Image.open(image_path) as img:
37
+ # Calculate the resizing ratio while maintaining aspect ratio
38
  ratio = max_size / max(img.size)
39
  new_size = tuple([int(x * ratio) for x in img.size])
40
 
41
+ # Resize the image using LANCZOS filter for high-quality downsampling
42
  img = img.resize(new_size, Image.LANCZOS)
43
 
44
+ # Save the resized image to a temporary file
45
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
46
  img.save(temp_file, format="PNG")
47
  return temp_file.name
48
 
49
+ def generate_3d_model(depth, image_path, focallength_px):
50
+ """
51
+ Generate a textured 3D mesh from the depth map and the original image.
52
+
53
+ Args:
54
+ depth (np.ndarray): 2D array representing depth in meters.
55
+ image_path (str): Path to the resized RGB image.
56
+ focallength_px (float): Focal length in pixels.
57
+
58
+ Returns:
59
+ str: Path to the exported 3D model file in OBJ format.
60
+ """
61
+ # Load the RGB image and convert to a NumPy array
62
+ image = np.array(Image.open(image_path))
63
+ height, width = depth.shape
64
+
65
+ # Compute camera intrinsic parameters
66
+ fx = fy = focallength_px # Assuming square pixels and fx = fy
67
+ cx, cy = width / 2, height / 2 # Principal point at the image center
68
+
69
+ # Create a grid of (u, v) pixel coordinates
70
+ u = np.arange(0, width)
71
+ v = np.arange(0, height)
72
+ uu, vv = np.meshgrid(u, v)
73
+
74
+ # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
75
+ Z = depth.flatten()
76
+ X = ((uu.flatten() - cx) * Z) / fx
77
+ Y = ((vv.flatten() - cy) * Z) / fy
78
+
79
+ # Stack the coordinates to form vertices (X, Y, Z)
80
+ vertices = np.vstack((X, Y, Z)).T
81
+
82
+ # Normalize RGB colors to [0, 1] for vertex coloring
83
+ colors = image.reshape(-1, 3) / 255.0
84
+
85
+ # Generate faces by connecting adjacent vertices to form triangles
86
+ faces = []
87
+ for i in range(height - 1):
88
+ for j in range(width - 1):
89
+ idx = i * width + j
90
+ # Triangle 1
91
+ faces.append([idx, idx + width, idx + 1])
92
+ # Triangle 2
93
+ faces.append([idx + 1, idx + width, idx + width + 1])
94
+ faces = np.array(faces)
95
+
96
+ # Create the mesh using Trimesh with vertex colors
97
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
98
+
99
+ # Export the mesh to an OBJ file
100
+ model_path = 'output_model.obj'
101
+ mesh.export(model_path)
102
+ return model_path
103
+
104
  @spaces.GPU(duration=20)
105
  def predict_depth(input_image):
106
+ """
107
+ Predict the depth map from the input image, generate visualizations and a 3D model.
108
+
109
+ Args:
110
+ input_image (str): Path to the input image file.
111
+
112
+ Returns:
113
+ tuple:
114
+ - str: Path to the depth map image.
115
+ - str: Focal length in pixels or an error message.
116
+ - str: Path to the raw depth data CSV file.
117
+ - str: Path to the generated 3D model file.
118
+ """
119
  temp_file = None
120
  try:
121
+ # Resize the input image to a manageable size
122
  temp_file = resize_image(input_image)
123
 
124
+ # Preprocess the image for depth prediction
125
  result = depth_pro.load_rgb(temp_file)
126
  image = result[0]
127
+ f_px = result[-1] # Focal length in pixels
128
+ image = transform(image) # Apply preprocessing transforms
129
+ image = image.to(device) # Move the image tensor to the selected device
130
 
131
+ # Run the depth prediction model
132
  prediction = model.infer(image, f_px=f_px)
133
+ depth = prediction["depth"] # Depth map in meters
134
  focallength_px = prediction["focallength_px"] # Focal length in pixels
135
 
136
+ # Convert depth from torch tensor to NumPy array if necessary
137
  if isinstance(depth, torch.Tensor):
138
  depth = depth.cpu().numpy()
139
 
140
+ # Ensure the depth map is a 2D array
141
  if depth.ndim != 2:
142
  depth = depth.squeeze()
143
 
144
+ # **Downsample depth map and image to improve processing speed**
145
+ downscale_factor = 2 # Factor by which to downscale (e.g., 2 reduces dimensions by half)
146
+ depth = depth[::downscale_factor, ::downscale_factor]
147
+ # Convert image tensor to CPU and NumPy for slicing
148
+ image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0)
149
+ image_ds = image_np[::downscale_factor, ::downscale_factor, :]
150
+ # Update focal length based on downscaling
151
+ focallength_px = focallength_px / downscale_factor
152
+
153
+ # **Note:** The downscaled image is saved back to the temporary file for consistency
154
+ downscaled_image = Image.fromarray((image_ds * 255).astype(np.uint8))
155
+ downscaled_image.save(temp_file)
156
+
157
+ # No normalization of depth map as it is already in meters
158
  depth_min = np.min(depth)
159
  depth_max = np.max(depth)
160
+ depth_normalized = depth # Depth remains in meters
161
+
162
+ # Create a color map for visualization using matplotlib
163
  plt.figure(figsize=(10, 10))
164
  plt.imshow(depth_normalized, cmap='gist_rainbow')
165
  plt.colorbar(label='Depth [m]')
166
  plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
167
+ plt.axis('off') # Hide axis for a cleaner image
168
+
169
+ # Save the depth map visualization to a file
170
  output_path = "depth_map.png"
171
  plt.savefig(output_path)
172
  plt.close()
173
 
174
+ # Save the raw depth data to a CSV file for download
175
  raw_depth_path = "raw_depth_map.csv"
176
  np.savetxt(raw_depth_path, depth, delimiter=',')
177
 
178
+ # Generate the 3D model from the depth map and resized image
179
+ model_path = generate_3d_model(depth, temp_file, focallength_px)
180
+
181
+ return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, model_path
182
  except Exception as e:
183
+ # Return error messages in case of failures
184
+ return None, f"An error occurred: {str(e)}", None, None
185
  finally:
186
+ # Clean up by removing the temporary resized image file
187
  if temp_file and os.path.exists(temp_file):
188
  os.remove(temp_file)
189
 
190
+ # Create the Gradio interface with appropriate input and output components
191
  iface = gr.Interface(
192
  fn=predict_depth,
193
  inputs=gr.Image(type="filepath"),
194
  outputs=[
195
+ gr.Image(type="filepath", label="Depth Map"), # Displays the depth map image
196
+ gr.Textbox(label="Focal Length or Error Message"), # Shows focal length or error messages
197
+ gr.File(label="Download Raw Depth Map (CSV)"), # Allows downloading the raw depth data
198
+ gr.Model3D(label="3D Model") # Displays the generated 3D model
199
  ],
200
+ title="DepthPro Demo with 3D Visualization",
201
+ description=(
202
+ "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
203
+ "**Instructions:**\n"
204
+ "1. Upload an image.\n"
205
+ "2. The app will predict the depth map, display it, and provide the focal length.\n"
206
+ "3. Download the raw depth data as a CSV file.\n"
207
+ "4. View the generated 3D model textured with the original image."
208
+ ),
209
  )
210
 
211
+ # Launch the Gradio interface with sharing enabled
212
+ iface.launch(share=True) # share=True allows you to share the interface with others.