depthanyvideo commited on
Commit
47ac829
1 Parent(s): 4be2365
Files changed (1) hide show
  1. app.py +89 -88
app.py CHANGED
@@ -70,98 +70,99 @@ def depth_any_video(
70
  """
71
  Perform depth estimation on the uploaded video/image.
72
  """
73
- with tempfile.TemporaryDirectory() as tmp_dir:
74
- # Save the uploaded file
75
- input_path = os.path.join(tmp_dir, file.name)
76
- with open(input_path, "wb") as f:
77
- f.write(file.read())
78
-
79
- # Set up output directory
80
- output_dir = os.path.join(tmp_dir, "output")
81
- os.makedirs(output_dir, exist_ok=True)
82
-
83
- # Prepare configuration
84
- cfg = EasyDict(
85
- {
86
- "model_base": MODEL_BASE,
87
- "data_path": input_path,
88
- "output_dir": output_dir,
89
- "denoise_steps": denoise_steps,
90
- "num_frames": num_frames,
91
- "decode_chunk_size": decode_chunk_size,
92
- "num_interp_frames": num_interp_frames,
93
- "num_overlap_frames": num_overlap_frames,
94
- "max_resolution": max_resolution,
95
- "seed": 666,
96
- }
97
- )
98
-
99
- seed_all(cfg.seed)
100
-
101
- file_name = os.path.splitext(os.path.basename(cfg.data_path))[0]
102
- is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv"))
103
-
104
- if is_video:
105
- num_interp_frames = cfg.num_interp_frames
106
- num_overlap_frames = cfg.num_overlap_frames
107
- num_frames = cfg.num_frames
108
- assert num_frames % 2 == 0, "num_frames should be even."
109
- assert (
110
- 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
111
- ), "Invalid frame overlap."
112
- max_frames = (num_interp_frames + 2 - num_overlap_frames) * (
113
- num_frames // 2
114
- )
115
- image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames)
116
- else:
117
- image = img_utils.read_image(cfg.data_path)
118
-
119
- image = img_utils.imresize_max(image, cfg.max_resolution)
120
- image = img_utils.imcrop_multi(image)
121
- image_tensor = np.ascontiguousarray(
122
- [_img.transpose(2, 0, 1) / 255.0 for _img in image]
123
- )
124
- image_tensor = torch.from_numpy(image_tensor).to(DEVICE)
125
-
126
- with torch.no_grad(), torch.autocast(
127
- device_type=DEVICE_TYPE, dtype=torch.float16
128
- ):
129
- pipe_out = pipe(
130
- image_tensor,
131
- num_frames=cfg.num_frames,
132
- num_overlap_frames=cfg.num_overlap_frames,
133
- num_interp_frames=cfg.num_interp_frames,
134
- decode_chunk_size=cfg.decode_chunk_size,
135
- num_inference_steps=cfg.denoise_steps,
136
  )
137
 
138
- disparity = pipe_out.disparity
139
- disparity_colored = pipe_out.disparity_colored
140
- image = pipe_out.image
141
- # (N, H, 2 * W, 3)
142
- merged = np.concatenate(
143
- [
144
- image,
145
- disparity_colored,
146
- ],
147
- axis=2,
148
- )
149
-
150
- if is_video:
151
- output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4")
152
- img_utils.write_video(
153
- output_path,
154
- merged,
155
- fps,
 
 
 
 
 
 
156
  )
157
- return output_path
158
- else:
159
- output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png")
160
- img_utils.write_image(
161
- output_path,
162
- merged[0],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  )
164
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  # Define Gradio interface
 
70
  """
71
  Perform depth estimation on the uploaded video/image.
72
  """
73
+ with open(file, "rb") as _file:
74
+ with tempfile.TemporaryDirectory() as tmp_dir:
75
+ # Save the uploaded file
76
+ input_path = os.path.join(tmp_dir, file.name)
77
+ with open(input_path, "wb") as f:
78
+ f.write(_file.read())
79
+
80
+ # Set up output directory
81
+ output_dir = os.path.join(tmp_dir, "output")
82
+ os.makedirs(output_dir, exist_ok=True)
83
+
84
+ # Prepare configuration
85
+ cfg = EasyDict(
86
+ {
87
+ "model_base": MODEL_BASE,
88
+ "data_path": input_path,
89
+ "output_dir": output_dir,
90
+ "denoise_steps": denoise_steps,
91
+ "num_frames": num_frames,
92
+ "decode_chunk_size": decode_chunk_size,
93
+ "num_interp_frames": num_interp_frames,
94
+ "num_overlap_frames": num_overlap_frames,
95
+ "max_resolution": max_resolution,
96
+ "seed": 666,
97
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
99
 
100
+ seed_all(cfg.seed)
101
+
102
+ file_name = os.path.splitext(os.path.basename(cfg.data_path))[0]
103
+ is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv"))
104
+
105
+ if is_video:
106
+ num_interp_frames = cfg.num_interp_frames
107
+ num_overlap_frames = cfg.num_overlap_frames
108
+ num_frames = cfg.num_frames
109
+ assert num_frames % 2 == 0, "num_frames should be even."
110
+ assert (
111
+ 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
112
+ ), "Invalid frame overlap."
113
+ max_frames = (num_interp_frames + 2 - num_overlap_frames) * (
114
+ num_frames // 2
115
+ )
116
+ image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames)
117
+ else:
118
+ image = img_utils.read_image(cfg.data_path)
119
+
120
+ image = img_utils.imresize_max(image, cfg.max_resolution)
121
+ image = img_utils.imcrop_multi(image)
122
+ image_tensor = np.ascontiguousarray(
123
+ [_img.transpose(2, 0, 1) / 255.0 for _img in image]
124
  )
125
+ image_tensor = torch.from_numpy(image_tensor).to(DEVICE)
126
+
127
+ with torch.no_grad(), torch.autocast(
128
+ device_type=DEVICE_TYPE, dtype=torch.float16
129
+ ):
130
+ pipe_out = pipe(
131
+ image_tensor,
132
+ num_frames=cfg.num_frames,
133
+ num_overlap_frames=cfg.num_overlap_frames,
134
+ num_interp_frames=cfg.num_interp_frames,
135
+ decode_chunk_size=cfg.decode_chunk_size,
136
+ num_inference_steps=cfg.denoise_steps,
137
+ )
138
+
139
+ disparity = pipe_out.disparity
140
+ disparity_colored = pipe_out.disparity_colored
141
+ image = pipe_out.image
142
+ # (N, H, 2 * W, 3)
143
+ merged = np.concatenate(
144
+ [
145
+ image,
146
+ disparity_colored,
147
+ ],
148
+ axis=2,
149
  )
150
+
151
+ if is_video:
152
+ output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4")
153
+ img_utils.write_video(
154
+ output_path,
155
+ merged,
156
+ fps,
157
+ )
158
+ return output_path
159
+ else:
160
+ output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png")
161
+ img_utils.write_image(
162
+ output_path,
163
+ merged[0],
164
+ )
165
+ return output_path
166
 
167
 
168
  # Define Gradio interface