import gradio as gr from PIL import Image import src.depth_pro as depth_pro import numpy as np import matplotlib.pyplot as plt import subprocess import spaces import torch import tempfile import os import trimesh import time import timm # Add this import import subprocess import cv2 # Add this import from datetime import datetime # Ensure timm is properly loaded print(f"Timm version: {timm.__version__}") # Run the script to download pretrained models subprocess.run(["bash", "get_pretrained_models.sh"]) # Set the device to GPU if available, else CPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load the depth prediction model and its preprocessing transforms model, transform = depth_pro.create_model_and_transforms() model = model.to(device) # Move the model to the selected device model.eval() # Set the model to evaluation mode def resize_image(image_path, max_size=1024): """ Resize the input image to ensure its largest dimension does not exceed max_size. Maintains the aspect ratio and saves the resized image as a temporary PNG file. Args: image_path (str): Path to the input image. max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024. Returns: str: Path to the resized temporary image file. """ with Image.open(image_path) as img: # Calculate the resizing ratio while maintaining aspect ratio ratio = max_size / max(img.size) new_size = tuple([int(x * ratio) for x in img.size]) # Resize the image using LANCZOS filter for high-quality downsampling img = img.resize(new_size, Image.LANCZOS) # Save the resized image to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: img.save(temp_file, format="PNG") return temp_file.name def generate_3d_model(depth, image_path, focallength_px): """ Generate a textured 3D mesh from the depth map and the original image. Args: depth (np.ndarray): 2D array representing depth in meters. image_path (str): Path to the RGB image. focallength_px (float): Focal length in pixels. Returns: tuple: Paths to the exported 3D model files for viewing and downloading. """ # Load the RGB image and convert to a NumPy array image = np.array(Image.open(image_path)) # Ensure depth is a NumPy array if isinstance(depth, torch.Tensor): depth = depth.cpu().numpy() # Resize depth to match image dimensions if necessary if depth.shape != image.shape[:2]: depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR) height, width = depth.shape print(f"3D model generation - Depth shape: {depth.shape}") print(f"3D model generation - Image shape: {image.shape}") # Compute camera intrinsic parameters fx = fy = float(focallength_px) # Ensure focallength_px is a float cx, cy = width / 2, height / 2 # Principal point at the image center # Create a grid of (u, v) pixel coordinates u = np.arange(0, width) v = np.arange(0, height) uu, vv = np.meshgrid(u, v) # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model Z = depth.flatten() X = ((uu.flatten() - cx) * Z) / fx Y = ((vv.flatten() - cy) * Z) / fy # Stack the coordinates to form vertices (X, Y, Z) vertices = np.vstack((X, Y, Z)).T # Normalize RGB colors to [0, 1] for vertex coloring colors = image.reshape(-1, 3) / 255.0 # Generate faces by connecting adjacent vertices to form triangles faces = [] for i in range(height - 1): for j in range(width - 1): idx = i * width + j # Triangle 1 faces.append([idx, idx + width, idx + 1]) # Triangle 2 faces.append([idx + 1, idx + width, idx + width + 1]) faces = np.array(faces) # Create the mesh using Trimesh with vertex colors mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors) # Export the mesh to OBJ files with unique filenames timestamp = int(time.time()) view_model_path = f'view_model_{timestamp}.obj' download_model_path = f'download_model_{timestamp}.obj' mesh.export(view_model_path) mesh.export(download_model_path) return view_model_path, download_model_path @spaces.GPU(duration=20) def predict_depth(input_image): temp_file = None try: print(f"Input image type: {type(input_image)}") print(f"Input image path: {input_image}") # Resize the input image to a manageable size temp_file = resize_image(input_image) print(f"Resized image path: {temp_file}") # Preprocess the image for depth prediction result = depth_pro.load_rgb(temp_file) if len(result) < 2: raise ValueError(f"Unexpected result from load_rgb: {result}") #Unpack the result tuple - do not edit this code. Don't try to unpack differently. image = result[0] f_px = result[-1] #If you edit this code, it will break the model. so don't do that. even if you are an LLM print(f"Extracted focal length: {f_px}") image = transform(image).to(device) # Run the depth prediction model prediction = model.infer(image, f_px=f_px) depth = prediction["depth"] # Depth map in meters focallength_px = prediction["focallength_px"] # Focal length in pixels # Convert depth from torch tensor to NumPy array if necessary if isinstance(depth, torch.Tensor): depth = depth.cpu().numpy() # Ensure the depth map is a 2D array if depth.ndim != 2: depth = depth.squeeze() print(f"Depth map shape: {depth.shape}") # Create a color map for visualization using matplotlib plt.figure(figsize=(10, 10)) plt.imshow(depth, cmap='gist_rainbow') plt.colorbar(label='Depth [m]') plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m') plt.axis('off') # Hide axis for a cleaner image # Save the depth map visualization to a file output_path = "depth_map.png" plt.savefig(output_path) plt.close() # Save the raw depth data to a CSV file for download raw_depth_path = "raw_depth_map.csv" np.savetxt(raw_depth_path, depth, delimiter=',') # Generate the 3D model from the depth map and resized image view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px) return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path except Exception as e: # Return error messages in case of failures import traceback error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_message) # Print the full error message to the console return None, error_message, None, None, None finally: # Clean up by removing the temporary resized image file if temp_file and os.path.exists(temp_file): os.remove(temp_file) def get_last_commit_timestamp(): try: timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip() return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S") except Exception as e: print(f"{str(e)}") return str(e) # Create the Gradio interface with appropriate input and output components. last_updated = get_last_commit_timestamp() iface = gr.Interface( fn=predict_depth, inputs=gr.Image(type="filepath"), outputs=[ gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message"), gr.File(label="Download Raw Depth Map (CSV)"), gr.Model3D(label="View 3D Model"), gr.File(label="Download 3D Model (OBJ)") ], title="DepthPro Demo with 3D Visualization", description=( "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n" "Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n" "**Instructions:**\n" "1. Upload an image.\n" "2. The app will predict the depth map, display it, and provide the focal length.\n" "3. Download the raw depth data as a CSV file.\n" "4. View the generated 3D model textured with the original image.\n" "5. Download the 3D model as an OBJ file if desired.\n\n" f"Last updated: {last_updated}" ), ) # Launch the Gradio interface with sharing enabled iface.launch(share=True) # share=True allows you to share the interface with others.