File size: 8,697 Bytes
cf8c487
 
41a69fa
cf8c487
 
dc9d69c
 
404967f
ee04d83
 
f95c546
40c89eb
 
4bfe855
 
 
40c89eb
 
 
024a2b8
f95c546
dc9d69c
cf8c487
f95c546
764b436
 
f95c546
cf8c487
f95c546
 
cf8c487
501d06f
f95c546
 
 
 
 
 
 
 
 
 
 
501d06f
f95c546
501d06f
 
 
f95c546
501d06f
 
f95c546
ee04d83
 
 
501d06f
f95c546
 
 
 
 
 
b3b839e
f95c546
 
 
84a6150
f95c546
 
 
b3b839e
 
 
 
 
f95c546
 
b3b839e
 
 
f95c546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a6150
 
 
 
 
 
 
f95c546
764b436
cf8c487
ee04d83
404967f
29a026e
 
 
f95c546
ee04d83
29a026e
501d06f
f95c546
ee04d83
29a026e
 
 
 
b3b839e
29a026e
 
b3b839e
404967f
f95c546
404967f
f95c546
404967f
 
f95c546
404967f
 
 
f95c546
404967f
 
 
b3b839e
f95c546
 
404967f
b3b839e
182cf21
b3b839e
f95c546
 
 
404967f
 
 
 
f95c546
182cf21
 
 
f95c546
84a6150
f95c546
84a6150
404967f
f95c546
29a026e
 
 
 
ee04d83
f95c546
ee04d83
 
cf8c487
c6f3d95
 
 
 
b3b839e
 
 
c6f3d95
d893f72
4bfe855
 
cf8c487
 
 
182cf21
4bfe855
 
 
 
 
182cf21
f95c546
 
 
4bfe855
f95c546
 
 
 
84a6150
4bfe855
 
f95c546
cf8c487
 
f95c546
c6f3d95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
import trimesh
import time
import timm  # Add this import
import subprocess
import cv2  # Add this import
from datetime import datetime

# Ensure timm is properly loaded
print(f"Timm version: {timm.__version__}")

# Run the script to download pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

# Set the device to GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the depth prediction model and its preprocessing transforms
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)  # Move the model to the selected device
model.eval()  # Set the model to evaluation mode

def resize_image(image_path, max_size=1024):
    """
    Resize the input image to ensure its largest dimension does not exceed max_size.
    Maintains the aspect ratio and saves the resized image as a temporary PNG file.

    Args:
        image_path (str): Path to the input image.
        max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.

    Returns:
        str: Path to the resized temporary image file.
    """
    with Image.open(image_path) as img:
        # Calculate the resizing ratio while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image using LANCZOS filter for high-quality downsampling
        img = img.resize(new_size, Image.LANCZOS)
        
        # Save the resized image to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

def generate_3d_model(depth, image_path, focallength_px):
    """
    Generate a textured 3D mesh from the depth map and the original image.

    Args:
        depth (np.ndarray): 2D array representing depth in meters.
        image_path (str): Path to the RGB image.
        focallength_px (float): Focal length in pixels.

    Returns:
        tuple: Paths to the exported 3D model files for viewing and downloading.
    """
    # Load the RGB image and convert to a NumPy array
    image = np.array(Image.open(image_path))
    
    # Resize depth to match image dimensions if necessary
    if depth.shape != image.shape[:2]:
        depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
    
    height, width = depth.shape

    print(f"3D model generation - Depth shape: {depth.shape}")
    print(f"3D model generation - Image shape: {image.shape}")

    # Compute camera intrinsic parameters
    fx = fy = focallength_px  # Assuming square pixels and fx = fy
    cx, cy = width / 2, height / 2  # Principal point at the image center

    # Create a grid of (u, v) pixel coordinates
    u = np.arange(0, width)
    v = np.arange(0, height)
    uu, vv = np.meshgrid(u, v)

    # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
    Z = depth.flatten()
    X = ((uu.flatten() - cx) * Z) / fx
    Y = ((vv.flatten() - cy) * Z) / fy

    # Stack the coordinates to form vertices (X, Y, Z)
    vertices = np.vstack((X, Y, Z)).T

    # Normalize RGB colors to [0, 1] for vertex coloring
    colors = image.reshape(-1, 3) / 255.0

    # Generate faces by connecting adjacent vertices to form triangles
    faces = []
    for i in range(height - 1):
        for j in range(width - 1):
            idx = i * width + j
            # Triangle 1
            faces.append([idx, idx + width, idx + 1])
            # Triangle 2
            faces.append([idx + 1, idx + width, idx + width + 1])
    faces = np.array(faces)

    # Create the mesh using Trimesh with vertex colors
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)

    # Export the mesh to OBJ files with unique filenames
    timestamp = int(time.time())
    view_model_path = f'view_model_{timestamp}.obj'
    download_model_path = f'download_model_{timestamp}.obj'
    mesh.export(view_model_path)
    mesh.export(download_model_path)
    return view_model_path, download_model_path

@spaces.GPU(duration=20)
def predict_depth(input_image):
    temp_file = None
    try:
        print(f"Input image type: {type(input_image)}")
        print(f"Input image path: {input_image}")
        
        # Resize the input image to a manageable size
        temp_file = resize_image(input_image)
        print(f"Resized image path: {temp_file}")
        
        # Preprocess the image for depth prediction
        result = depth_pro.load_rgb(temp_file)
        
        if len(result) < 2:
            raise ValueError(f"Unexpected result from load_rgb: {result}")
        
        image, _, _, _, f_px = result
        print(f"Extracted focal length: {f_px}")
        
        image = transform(image).to(device)

        # Run the depth prediction model
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth map in meters
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth from torch tensor to NumPy array if necessary
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure the depth map is a 2D array
        if depth.ndim != 2:
            depth = depth.squeeze()

        print(f"Depth map shape: {depth.shape}")

        # Create a color map for visualization using matplotlib
        plt.figure(figsize=(10, 10))
        plt.imshow(depth, cmap='gist_rainbow')
        plt.colorbar(label='Depth [m]')
        plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
        plt.axis('off')  # Hide axis for a cleaner image

        # Save the depth map visualization to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        # Save the raw depth data to a CSV file for download
        raw_depth_path = "raw_depth_map.csv"
        np.savetxt(raw_depth_path, depth, delimiter=',')

        # Generate the 3D model from the depth map and resized image
        view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)

        return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path
    except Exception as e:
        # Return error messages in case of failures
        import traceback
        error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_message)  # Print the full error message to the console
        return None, error_message, None, None, None
    finally:
        # Clean up by removing the temporary resized image file
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

def get_last_commit_timestamp():
    try:
        timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
        return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
    except Exception as e:
        print(f"{str(e)}")
        return str(e)
    
# Create the Gradio interface with appropriate input and output components. 
last_updated = get_last_commit_timestamp()

iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Image(type="filepath", label="Depth Map"),
        gr.Textbox(label="Focal Length or Error Message"),
        gr.File(label="Download Raw Depth Map (CSV)"),
        gr.Model3D(label="View 3D Model"),
        gr.File(label="Download 3D Model (OBJ)")
    ],
    title="DepthPro Demo with 3D Visualization",
    description=(
        "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
        "Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
        "**Instructions:**\n"
        "1. Upload an image.\n"
        "2. The app will predict the depth map, display it, and provide the focal length.\n"
        "3. Download the raw depth data as a CSV file.\n"
        "4. View the generated 3D model textured with the original image.\n"
        "5. Download the 3D model as an OBJ file if desired.\n\n"
        f"Last updated: {last_updated}"
    ),
)

# Launch the Gradio interface with sharing enabled
iface.launch(share=True)  # share=True allows you to share the interface with others.