SViTT-Ego_Action_Recognition / svitt /video_transforms.py
hvaldez's picture
first commit
c18a21e verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Sequence
import torch
import torch.nn as nn
from torchvision import transforms
class Permute(nn.Module):
"""
Permutation as an op
"""
def __init__(self, ordering):
super().__init__()
self.ordering = ordering
def forward(self, frames):
"""
Args:
frames in some ordering, by default (C, T, H, W)
Returns:
frames in the ordering that was specified
"""
return frames.permute(self.ordering)
class TemporalCrop(nn.Module):
"""
Convert the video into smaller clips temporally.
"""
def __init__(
self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
):
super().__init__()
self.frames = frames_per_clip
self.stride = stride
self.frame_stride = frame_stride
def forward(self, video):
assert video.ndim == 4, "Must be (C, T, H, W)"
res = []
for start in range(
0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
):
end = start + (self.frames) * self.frame_stride
res.append(video[:, start: end: self.frame_stride, ...])
return res
def crop_boxes(boxes, x_offset, y_offset):
"""
Peform crop on the bounding boxes given the offsets.
Args:
boxes (ndarray or None): bounding boxes to peform crop. The dimension
is `num boxes` x 4.
x_offset (int): cropping offset in the x axis.
y_offset (int): cropping offset in the y axis.
Returns:
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
cropped_boxes = boxes.copy()
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
return cropped_boxes
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
"""
Perform uniform spatial sampling on the images and corresponding boxes.
Args:
images (tensor): images to perform uniform crop. The dimension is
`num frames` x `channel` x `height` x `width`.
size (int): size of height and weight to crop the images.
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
is larger than height. Or 0, 1, or 2 for top, center, and bottom
crop if height is larger than width.
boxes (ndarray or None): optional. Corresponding boxes to images.
Dimension is `num boxes` x 4.
scale_size (int): optinal. If not None, resize the images to scale_size before
performing any crop.
Returns:
cropped (tensor): images with dimension of
`num frames` x `channel` x `size` x `size`.
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
assert spatial_idx in [0, 1, 2]
ndim = len(images.shape)
if ndim == 3:
images = images.unsqueeze(0)
height = images.shape[2]
width = images.shape[3]
if scale_size is not None:
if width <= height:
width, height = scale_size, int(height / width * scale_size)
else:
width, height = int(width / height * scale_size), scale_size
images = torch.nn.functional.interpolate(
images,
size=(height, width),
mode="bilinear",
align_corners=False,
)
y_offset = int(math.ceil((height - size) / 2))
x_offset = int(math.ceil((width - size) / 2))
if height > width:
if spatial_idx == 0:
y_offset = 0
elif spatial_idx == 2:
y_offset = height - size
else:
if spatial_idx == 0:
x_offset = 0
elif spatial_idx == 2:
x_offset = width - size
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
if ndim == 3:
cropped = cropped.squeeze(0)
return cropped, cropped_boxes
class SpatialCrop(nn.Module):
"""
Convert the video into 3 smaller clips spatially. Must be used after the
temporal crops to get spatial crops, and should be used with
-2 in the spatial crop at the slowfast augmentation stage (so full
frames are passed in here). Will return a larger list with the
3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
or 3x10 testing in SlowFast etc.
"""
def __init__(self, crop_size: int = 224, num_crops: int = 3):
super().__init__()
self.crop_size = crop_size
if num_crops == 6:
self.crops_to_ext = [0, 1, 2]
# I guess Swin uses 5 crops without flipping, but that doesn't
# make sense given they first resize to 224 and take 224 crops.
# (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
# So I'm assuming we can use flipped crops and that will add sth..
self.flipped_crops_to_ext = [0, 1, 2]
elif num_crops == 3:
self.crops_to_ext = [0, 1, 2]
self.flipped_crops_to_ext = []
elif num_crops == 1:
self.crops_to_ext = [1]
self.flipped_crops_to_ext = []
else:
raise NotImplementedError(
"Nothing else supported yet, "
"slowfast only takes 0, 1, 2 as arguments"
)
def forward(self, videos: Sequence[torch.Tensor]):
"""
Args:
videos: A list of C, T, H, W videos.
Returns:
videos: A list with 3x the number of elements. Each video converted
to C, T, H', W' by spatial cropping.
"""
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
res = []
for video in videos:
for spatial_idx in self.crops_to_ext:
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
if not self.flipped_crops_to_ext:
continue
flipped_video = transforms.functional.hflip(video)
for spatial_idx in self.flipped_crops_to_ext:
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
return res