depth-any-video / dav /pipelines /dav_pipeline.py
depthanyvideo
update
e9f3e75
raw
history blame
8.95 kB
import torch
import tqdm
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.utils import BaseOutput
import matplotlib
def colorize_depth(depth, cmap="Spectral"):
# colorize
cm = matplotlib.colormaps[cmap]
# (B, N, H, W, 3)
depth_colored = cm(depth, bytes=False)[..., 0:3] # value from 0 to 1
return depth_colored
class DAVOutput(BaseOutput):
r"""
Output class for zero-shot text-to-video pipeline.
Args:
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
"""
disparity: np.ndarray
disparity_colored: np.ndarray
image: np.ndarray
class DAVPipeline(DiffusionPipeline):
def __init__(self, vae, unet, unet_interp, scheduler):
super().__init__()
self.register_modules(
vae=vae, unet=unet, unet_interp=unet_interp, scheduler=scheduler
)
def encode(self, input):
num_frames = input.shape[1]
input = input.flatten(0, 1)
latent = self.vae.encode(input.to(self.vae.dtype)).latent_dist.mode()
latent = latent * self.vae.config.scaling_factor
latent = latent.reshape(-1, num_frames, *latent.shape[1:])
return latent
def decode(self, latents, decode_chunk_size=16):
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
num_frames = latents.shape[1]
latents = latents.flatten(0, 1)
latents = latents / self.vae.config.scaling_factor
# decode decode_chunk_size frames at a time to avoid OOM
frames = []
for i in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
frame = self.vae.decode(
latents[i : i + decode_chunk_size].to(self.vae.dtype),
num_frames=num_frames_in,
).sample
frames.append(frame)
frames = torch.cat(frames, dim=0)
# [batch, frames, channels, height, width]
frames = frames.reshape(-1, num_frames, *frames.shape[1:])
return frames.to(torch.float32)
def single_infer(self, rgb, position_ids=None, num_inference_steps=None):
rgb_latent = self.encode(rgb)
noise_latent = torch.randn_like(rgb_latent)
self.scheduler.set_timesteps(num_inference_steps, device=rgb.device)
timesteps = self.scheduler.timesteps
image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to(
noise_latent
)
for i, t in enumerate(timesteps):
latent_model_input = noise_latent
latent_model_input = torch.cat([latent_model_input, rgb_latent], dim=2)
# [batch_size, num_frame, 4, h, w]
model_output = self.unet(
latent_model_input,
t,
encoder_hidden_states=image_embeddings,
position_ids=position_ids,
).sample
# compute the previous noisy sample x_t -> x_t-1
noise_latent = self.scheduler.step(
model_output, t, noise_latent
).prev_sample
return noise_latent
def single_interp_infer(
self, rgb, masked_depth_latent, mask, num_inference_steps=None
):
rgb_latent = self.encode(rgb)
noise_latent = torch.randn_like(rgb_latent)
self.scheduler.set_timesteps(num_inference_steps, device=rgb.device)
timesteps = self.scheduler.timesteps
image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to(
noise_latent
)
for i, t in enumerate(timesteps):
latent_model_input = noise_latent
latent_model_input = torch.cat(
[latent_model_input, rgb_latent, masked_depth_latent, mask], dim=2
)
# [batch_size, num_frame, 4, h, w]
model_output = self.unet_interp(
latent_model_input, t, encoder_hidden_states=image_embeddings
).sample
# compute the previous noisy sample x_t -> x_t-1
noise_latent = self.scheduler.step(
model_output, t, noise_latent
).prev_sample
return noise_latent
def __call__(
self,
image,
num_frames,
num_overlap_frames,
num_interp_frames,
decode_chunk_size,
num_inference_steps,
):
self.vae.to(dtype=torch.float16)
# (1, N, 3, H, W)
image = image.unsqueeze(0)
B, N = image.shape[:2]
rgb = image * 2 - 1 # [-1, 1]
if N <= num_frames or N <= num_interp_frames + 2 - num_overlap_frames:
depth_latent = self.single_infer(
rgb, num_inference_steps=num_inference_steps
)
else:
assert 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
assert num_frames % 2 == 0
key_frame_indices = []
for i in range(0, N, num_interp_frames + 2 - num_overlap_frames):
if (
i + num_interp_frames + 1 >= N
or len(key_frame_indices) >= num_frames
):
break
key_frame_indices.append(i)
key_frame_indices.append(i + num_interp_frames + 1)
key_frame_indices = torch.tensor(key_frame_indices, device=rgb.device)
sorted_key_frame_indices, origin_indices = torch.sort(key_frame_indices)
key_rgb = rgb[:, sorted_key_frame_indices]
key_depth_latent = self.single_infer(
key_rgb,
sorted_key_frame_indices.unsqueeze(0).repeat(B, 1),
num_inference_steps=num_inference_steps,
)
key_depth_latent = key_depth_latent[:, origin_indices]
torch.cuda.empty_cache()
depth_latent = []
pre_latent = None
for i in tqdm.tqdm(range(0, len(key_frame_indices), 2)):
frame1 = key_depth_latent[:, i]
frame2 = key_depth_latent[:, i + 1]
masked_depth_latent = torch.zeros(
(B, num_interp_frames + 2, *key_depth_latent.shape[2:])
).to(key_depth_latent)
masked_depth_latent[:, 0] = frame1
masked_depth_latent[:, -1] = frame2
mask = torch.zeros_like(masked_depth_latent)
mask[:, [0, -1]] = 1.0
latent = self.single_interp_infer(
rgb[:, key_frame_indices[i] : key_frame_indices[i + 1] + 1],
masked_depth_latent,
mask,
num_inference_steps=num_inference_steps,
)
latent = latent[:, 1:-1]
if pre_latent is not None:
overlap_a = pre_latent[
:, pre_latent.shape[1] - (num_overlap_frames - 2) :
]
overlap_b = latent[:, : (num_overlap_frames - 2)]
ratio = (
torch.linspace(0, 1, num_overlap_frames - 2)
.to(overlap_a)
.view(1, -1, 1, 1, 1)
)
overlap = overlap_a * (1 - ratio) + overlap_b * ratio
pre_latent[:, pre_latent.shape[1] - (num_overlap_frames - 2) :] = (
overlap
)
depth_latent.append(pre_latent)
pre_latent = latent[:, (num_overlap_frames - 2) if i > 0 else 0 :]
torch.cuda.empty_cache()
depth_latent.append(pre_latent)
depth_latent = torch.cat(depth_latent, dim=1)
# dicard the first and last key frames
image = image[:, key_frame_indices[0] + 1 : key_frame_indices[-1]]
assert depth_latent.shape[1] == image.shape[1]
disparity = self.decode(depth_latent, decode_chunk_size=decode_chunk_size)
disparity = disparity.mean(dim=2, keepdim=False)
disparity = torch.clamp(disparity * 0.5 + 0.5, 0.0, 1.0)
# (N, H, W)
disparity = disparity.squeeze(0)
# (N, H, W, 3)
mid_d, max_d = disparity.min(), disparity.max()
disparity_colored = torch.clamp((max_d - disparity) / (max_d - mid_d), 0.0, 1.0)
disparity_colored = colorize_depth(disparity_colored.cpu().numpy())
disparity_colored = (disparity_colored * 255).astype(np.uint8)
image = image.squeeze(0).permute(0, 2, 3, 1).cpu().numpy()
image = (image * 255).astype(np.uint8)
disparity = disparity.cpu().numpy()
return DAVOutput(
disparity=disparity,
disparity_colored=disparity_colored,
image=image,
)