A19grey's picture
Added type checking
5784c10
raw
history blame
9.02 kB
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
import trimesh
import time
import timm # Add this import
import subprocess
import cv2 # Add this import
from datetime import datetime
# Ensure timm is properly loaded
print(f"Timm version: {timm.__version__}")
# Run the script to download pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])
# Set the device to GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load the depth prediction model and its preprocessing transforms
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device) # Move the model to the selected device
model.eval() # Set the model to evaluation mode
def resize_image(image_path, max_size=1024):
"""
Resize the input image to ensure its largest dimension does not exceed max_size.
Maintains the aspect ratio and saves the resized image as a temporary PNG file.
Args:
image_path (str): Path to the input image.
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.
Returns:
str: Path to the resized temporary image file.
"""
with Image.open(image_path) as img:
# Calculate the resizing ratio while maintaining aspect ratio
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
# Resize the image using LANCZOS filter for high-quality downsampling
img = img.resize(new_size, Image.LANCZOS)
# Save the resized image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
img.save(temp_file, format="PNG")
return temp_file.name
def generate_3d_model(depth, image_path, focallength_px):
"""
Generate a textured 3D mesh from the depth map and the original image.
Args:
depth (np.ndarray): 2D array representing depth in meters.
image_path (str): Path to the RGB image.
focallength_px (float): Focal length in pixels.
Returns:
tuple: Paths to the exported 3D model files for viewing and downloading.
"""
# Load the RGB image and convert to a NumPy array
image = np.array(Image.open(image_path))
# Ensure depth is a NumPy array
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Resize depth to match image dimensions if necessary
if depth.shape != image.shape[:2]:
depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
height, width = depth.shape
print(f"3D model generation - Depth shape: {depth.shape}")
print(f"3D model generation - Image shape: {image.shape}")
# Compute camera intrinsic parameters
fx = fy = float(focallength_px) # Ensure focallength_px is a float
cx, cy = width / 2, height / 2 # Principal point at the image center
# Create a grid of (u, v) pixel coordinates
u = np.arange(0, width)
v = np.arange(0, height)
uu, vv = np.meshgrid(u, v)
# Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
Z = depth.flatten()
X = ((uu.flatten() - cx) * Z) / fx
Y = ((vv.flatten() - cy) * Z) / fy
# Stack the coordinates to form vertices (X, Y, Z)
vertices = np.vstack((X, Y, Z)).T
# Normalize RGB colors to [0, 1] for vertex coloring
colors = image.reshape(-1, 3) / 255.0
# Generate faces by connecting adjacent vertices to form triangles
faces = []
for i in range(height - 1):
for j in range(width - 1):
idx = i * width + j
# Triangle 1
faces.append([idx, idx + width, idx + 1])
# Triangle 2
faces.append([idx + 1, idx + width, idx + width + 1])
faces = np.array(faces)
# Create the mesh using Trimesh with vertex colors
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
# Export the mesh to OBJ files with unique filenames
timestamp = int(time.time())
view_model_path = f'view_model_{timestamp}.obj'
download_model_path = f'download_model_{timestamp}.obj'
mesh.export(view_model_path)
mesh.export(download_model_path)
return view_model_path, download_model_path
@spaces.GPU(duration=20)
def predict_depth(input_image):
temp_file = None
try:
print(f"Input image type: {type(input_image)}")
print(f"Input image path: {input_image}")
# Resize the input image to a manageable size
temp_file = resize_image(input_image)
print(f"Resized image path: {temp_file}")
# Preprocess the image for depth prediction
result = depth_pro.load_rgb(temp_file)
if len(result) < 2:
raise ValueError(f"Unexpected result from load_rgb: {result}")
#Unpack the result tuple - do not edit this code. Don't try to unpack differently.
image = result[0]
f_px = result[-1] #If you edit this code, it will break the model. so don't do that. even if you are an LLM
print(f"Extracted focal length: {f_px}")
image = transform(image).to(device)
# Run the depth prediction model
prediction = model.infer(image, f_px=f_px)
depth = prediction["depth"] # Depth map in meters
focallength_px = prediction["focallength_px"] # Focal length in pixels
# Convert depth from torch tensor to NumPy array if necessary
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Ensure the depth map is a 2D array
if depth.ndim != 2:
depth = depth.squeeze()
print(f"Depth map shape: {depth.shape}")
# Create a color map for visualization using matplotlib
plt.figure(figsize=(10, 10))
plt.imshow(depth, cmap='gist_rainbow')
plt.colorbar(label='Depth [m]')
plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
plt.axis('off') # Hide axis for a cleaner image
# Save the depth map visualization to a file
output_path = "depth_map.png"
plt.savefig(output_path)
plt.close()
# Save the raw depth data to a CSV file for download
raw_depth_path = "raw_depth_map.csv"
np.savetxt(raw_depth_path, depth, delimiter=',')
# Generate the 3D model from the depth map and resized image
view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path
except Exception as e:
# Return error messages in case of failures
import traceback
error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_message) # Print the full error message to the console
return None, error_message, None, None, None
finally:
# Clean up by removing the temporary resized image file
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
def get_last_commit_timestamp():
try:
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
print(f"{str(e)}")
return str(e)
# Create the Gradio interface with appropriate input and output components.
last_updated = get_last_commit_timestamp()
iface = gr.Interface(
fn=predict_depth,
inputs=gr.Image(type="filepath"),
outputs=[
gr.Image(type="filepath", label="Depth Map"),
gr.Textbox(label="Focal Length or Error Message"),
gr.File(label="Download Raw Depth Map (CSV)"),
gr.Model3D(label="View 3D Model"),
gr.File(label="Download 3D Model (OBJ)")
],
title="DepthPro Demo with 3D Visualization",
description=(
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
"Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
"**Instructions:**\n"
"1. Upload an image.\n"
"2. The app will predict the depth map, display it, and provide the focal length.\n"
"3. Download the raw depth data as a CSV file.\n"
"4. View the generated 3D model textured with the original image.\n"
"5. Download the 3D model as an OBJ file if desired.\n\n"
f"Last updated: {last_updated}"
),
)
# Launch the Gradio interface with sharing enabled
iface.launch(share=True) # share=True allows you to share the interface with others.