File size: 8,886 Bytes
cf8c487
 
41a69fa
cf8c487
 
dc9d69c
 
404967f
ee04d83
 
f95c546
84a6150
024a2b8
f95c546
dc9d69c
cf8c487
f95c546
764b436
 
f95c546
cf8c487
f95c546
 
cf8c487
501d06f
f95c546
 
 
 
 
 
 
 
 
 
 
501d06f
f95c546
501d06f
 
 
f95c546
501d06f
 
f95c546
ee04d83
 
 
501d06f
f95c546
 
 
 
 
 
 
 
 
 
84a6150
f95c546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a6150
 
 
 
 
 
 
f95c546
764b436
cf8c487
f95c546
 
 
 
 
 
 
 
 
 
 
84a6150
 
f95c546
ee04d83
404967f
f95c546
ee04d83
501d06f
f95c546
ee04d83
404967f
f95c546
 
 
404967f
f95c546
404967f
f95c546
404967f
 
f95c546
404967f
 
 
f95c546
404967f
 
 
f95c546
 
 
 
 
 
 
 
 
 
 
 
 
 
404967f
 
f95c546
 
 
404967f
b0d730c
182cf21
b0d730c
f95c546
 
 
404967f
 
 
 
f95c546
182cf21
 
 
f95c546
84a6150
f95c546
84a6150
404967f
f95c546
84a6150
ee04d83
f95c546
ee04d83
 
cf8c487
f95c546
cf8c487
 
 
182cf21
f95c546
 
 
84a6150
 
182cf21
f95c546
 
 
 
 
 
 
84a6150
 
f95c546
cf8c487
 
f95c546
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
import trimesh
import time  # Add this import at the top of the file

# Run the script to download pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

# Set the device to GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the depth prediction model and its preprocessing transforms
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)  # Move the model to the selected device
model.eval()  # Set the model to evaluation mode

def resize_image(image_path, max_size=1024):
    """
    Resize the input image to ensure its largest dimension does not exceed max_size.
    Maintains the aspect ratio and saves the resized image as a temporary PNG file.

    Args:
        image_path (str): Path to the input image.
        max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.

    Returns:
        str: Path to the resized temporary image file.
    """
    with Image.open(image_path) as img:
        # Calculate the resizing ratio while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image using LANCZOS filter for high-quality downsampling
        img = img.resize(new_size, Image.LANCZOS)
        
        # Save the resized image to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

def generate_3d_model(depth, image_path, focallength_px):
    """
    Generate a textured 3D mesh from the depth map and the original image.

    Args:
        depth (np.ndarray): 2D array representing depth in meters.
        image_path (str): Path to the resized RGB image.
        focallength_px (float): Focal length in pixels.

    Returns:
        tuple: Paths to the exported 3D model files for viewing and downloading.
    """
    # Load the RGB image and convert to a NumPy array
    image = np.array(Image.open(image_path))
    height, width = depth.shape

    # Compute camera intrinsic parameters
    fx = fy = focallength_px  # Assuming square pixels and fx = fy
    cx, cy = width / 2, height / 2  # Principal point at the image center

    # Create a grid of (u, v) pixel coordinates
    u = np.arange(0, width)
    v = np.arange(0, height)
    uu, vv = np.meshgrid(u, v)

    # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
    Z = depth.flatten()
    X = ((uu.flatten() - cx) * Z) / fx
    Y = ((vv.flatten() - cy) * Z) / fy

    # Stack the coordinates to form vertices (X, Y, Z)
    vertices = np.vstack((X, Y, Z)).T

    # Normalize RGB colors to [0, 1] for vertex coloring
    colors = image.reshape(-1, 3) / 255.0

    # Generate faces by connecting adjacent vertices to form triangles
    faces = []
    for i in range(height - 1):
        for j in range(width - 1):
            idx = i * width + j
            # Triangle 1
            faces.append([idx, idx + width, idx + 1])
            # Triangle 2
            faces.append([idx + 1, idx + width, idx + width + 1])
    faces = np.array(faces)

    # Create the mesh using Trimesh with vertex colors
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)

    # Export the mesh to OBJ files with unique filenames
    timestamp = int(time.time())
    view_model_path = f'view_model_{timestamp}.obj'
    download_model_path = f'download_model_{timestamp}.obj'
    mesh.export(view_model_path)
    mesh.export(download_model_path)
    return view_model_path, download_model_path

@spaces.GPU(duration=20)
def predict_depth(input_image):
    """
    Predict the depth map from the input image, generate visualizations and a 3D model.

    Args:
        input_image (str): Path to the input image file.

    Returns:
        tuple:
            - str: Path to the depth map image.
            - str: Focal length in pixels or an error message.
            - str: Path to the raw depth data CSV file.
            - str: Path to the generated 3D model file for viewing.
            - str: Path to the downloadable 3D model file.
    """
    temp_file = None
    try:
        # Resize the input image to a manageable size
        temp_file = resize_image(input_image)
        
        # Preprocess the image for depth prediction
        result = depth_pro.load_rgb(temp_file)
        image = result[0]
        f_px = result[-1]  # Focal length in pixels
        image = transform(image)  # Apply preprocessing transforms
        image = image.to(device)  # Move the image tensor to the selected device

        # Run the depth prediction model
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth map in meters
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth from torch tensor to NumPy array if necessary
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure the depth map is a 2D array
        if depth.ndim != 2:
            depth = depth.squeeze()

        # **Downsample depth map and image to improve processing speed**
        downscale_factor = 2  # Factor by which to downscale (e.g., 2 reduces dimensions by half)
        depth = depth[::downscale_factor, ::downscale_factor]
        # Convert image tensor to CPU and NumPy for slicing
        image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0)
        image_ds = image_np[::downscale_factor, ::downscale_factor, :]
        # Update focal length based on downscaling
        focallength_px = focallength_px / downscale_factor

        # **Note:** The downscaled image is saved back to the temporary file for consistency
        downscaled_image = Image.fromarray((image_ds * 255).astype(np.uint8))
        downscaled_image.save(temp_file)

        # No normalization of depth map as it is already in meters
        depth_min = np.min(depth)
        depth_max = np.max(depth)
        depth_normalized = depth  # Depth remains in meters

        # Create a color map for visualization using matplotlib
        plt.figure(figsize=(10, 10))
        plt.imshow(depth_normalized, cmap='gist_rainbow')
        plt.colorbar(label='Depth [m]')
        plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
        plt.axis('off')  # Hide axis for a cleaner image

        # Save the depth map visualization to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        # Save the raw depth data to a CSV file for download
        raw_depth_path = "raw_depth_map.csv"
        np.savetxt(raw_depth_path, depth, delimiter=',')

        # Generate the 3D model from the depth map and resized image
        view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)

        return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path
    except Exception as e:
        # Return error messages in case of failures
        return None, f"An error occurred: {str(e)}", None, None, None
    finally:
        # Clean up by removing the temporary resized image file
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

# Create the Gradio interface with appropriate input and output components
iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Image(type="filepath", label="Depth Map"),  # Displays the depth map image
        gr.Textbox(label="Focal Length or Error Message"),  # Shows focal length or error messages
        gr.File(label="Download Raw Depth Map (CSV)"),  # Allows downloading the raw depth data
        gr.Model3D(label="View 3D Model"),  # For viewing the 3D model
        gr.File(label="Download 3D Model (OBJ)")  # For downloading the 3D model
    ],
    title="DepthPro Demo with 3D Visualization",
    description=(
        "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
        "**Instructions:**\n"
        "1. Upload an image.\n"
        "2. The app will predict the depth map, display it, and provide the focal length.\n"
        "3. Download the raw depth data as a CSV file.\n"
        "4. View the generated 3D model textured with the original image.\n"
        "5. Download the 3D model as an OBJ file if desired."
    ),
)

# Launch the Gradio interface with sharing enabled
iface.launch(share=True)  # share=True allows you to share the interface with others.