# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Sequence import torch import torch.nn as nn from torchvision import transforms class Permute(nn.Module): """ Permutation as an op """ def __init__(self, ordering): super().__init__() self.ordering = ordering def forward(self, frames): """ Args: frames in some ordering, by default (C, T, H, W) Returns: frames in the ordering that was specified """ return frames.permute(self.ordering) class TemporalCrop(nn.Module): """ Convert the video into smaller clips temporally. """ def __init__( self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1 ): super().__init__() self.frames = frames_per_clip self.stride = stride self.frame_stride = frame_stride def forward(self, video): assert video.ndim == 4, "Must be (C, T, H, W)" res = [] for start in range( 0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride ): end = start + (self.frames) * self.frame_stride res.append(video[:, start: end: self.frame_stride, ...]) return res def crop_boxes(boxes, x_offset, y_offset): """ Peform crop on the bounding boxes given the offsets. Args: boxes (ndarray or None): bounding boxes to peform crop. The dimension is `num boxes` x 4. x_offset (int): cropping offset in the x axis. y_offset (int): cropping offset in the y axis. Returns: cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ cropped_boxes = boxes.copy() cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset return cropped_boxes def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): """ Perform uniform spatial sampling on the images and corresponding boxes. Args: images (tensor): images to perform uniform crop. The dimension is `num frames` x `channel` x `height` x `width`. size (int): size of height and weight to crop the images. spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width is larger than height. Or 0, 1, or 2 for top, center, and bottom crop if height is larger than width. boxes (ndarray or None): optional. Corresponding boxes to images. Dimension is `num boxes` x 4. scale_size (int): optinal. If not None, resize the images to scale_size before performing any crop. Returns: cropped (tensor): images with dimension of `num frames` x `channel` x `size` x `size`. cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ assert spatial_idx in [0, 1, 2] ndim = len(images.shape) if ndim == 3: images = images.unsqueeze(0) height = images.shape[2] width = images.shape[3] if scale_size is not None: if width <= height: width, height = scale_size, int(height / width * scale_size) else: width, height = int(width / height * scale_size), scale_size images = torch.nn.functional.interpolate( images, size=(height, width), mode="bilinear", align_corners=False, ) y_offset = int(math.ceil((height - size) / 2)) x_offset = int(math.ceil((width - size) / 2)) if height > width: if spatial_idx == 0: y_offset = 0 elif spatial_idx == 2: y_offset = height - size else: if spatial_idx == 0: x_offset = 0 elif spatial_idx == 2: x_offset = width - size cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size] cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None if ndim == 3: cropped = cropped.squeeze(0) return cropped, cropped_boxes class SpatialCrop(nn.Module): """ Convert the video into 3 smaller clips spatially. Must be used after the temporal crops to get spatial crops, and should be used with -2 in the spatial crop at the slowfast augmentation stage (so full frames are passed in here). Will return a larger list with the 3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT) or 3x10 testing in SlowFast etc. """ def __init__(self, crop_size: int = 224, num_crops: int = 3): super().__init__() self.crop_size = crop_size if num_crops == 6: self.crops_to_ext = [0, 1, 2] # I guess Swin uses 5 crops without flipping, but that doesn't # make sense given they first resize to 224 and take 224 crops. # (pg 6 of https://arxiv.org/pdf/2106.13230.pdf) # So I'm assuming we can use flipped crops and that will add sth.. self.flipped_crops_to_ext = [0, 1, 2] elif num_crops == 3: self.crops_to_ext = [0, 1, 2] self.flipped_crops_to_ext = [] elif num_crops == 1: self.crops_to_ext = [1] self.flipped_crops_to_ext = [] else: raise NotImplementedError( "Nothing else supported yet, " "slowfast only takes 0, 1, 2 as arguments" ) def forward(self, videos: Sequence[torch.Tensor]): """ Args: videos: A list of C, T, H, W videos. Returns: videos: A list with 3x the number of elements. Each video converted to C, T, H', W' by spatial cropping. """ assert isinstance(videos, list), "Must be a list of videos after temporal crops" assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" res = [] for video in videos: for spatial_idx in self.crops_to_ext: res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) if not self.flipped_crops_to_ext: continue flipped_video = transforms.functional.hflip(video) for spatial_idx in self.flipped_crops_to_ext: res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) return res