|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Sequence |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
|
|
|
|
class Permute(nn.Module): |
|
""" |
|
Permutation as an op |
|
""" |
|
|
|
def __init__(self, ordering): |
|
super().__init__() |
|
self.ordering = ordering |
|
|
|
def forward(self, frames): |
|
""" |
|
Args: |
|
frames in some ordering, by default (C, T, H, W) |
|
Returns: |
|
frames in the ordering that was specified |
|
""" |
|
return frames.permute(self.ordering) |
|
|
|
|
|
class TemporalCrop(nn.Module): |
|
""" |
|
Convert the video into smaller clips temporally. |
|
""" |
|
|
|
def __init__( |
|
self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1 |
|
): |
|
super().__init__() |
|
self.frames = frames_per_clip |
|
self.stride = stride |
|
self.frame_stride = frame_stride |
|
|
|
def forward(self, video): |
|
assert video.ndim == 4, "Must be (C, T, H, W)" |
|
res = [] |
|
for start in range( |
|
0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride |
|
): |
|
end = start + (self.frames) * self.frame_stride |
|
res.append(video[:, start: end: self.frame_stride, ...]) |
|
return res |
|
|
|
|
|
def crop_boxes(boxes, x_offset, y_offset): |
|
""" |
|
Peform crop on the bounding boxes given the offsets. |
|
Args: |
|
boxes (ndarray or None): bounding boxes to peform crop. The dimension |
|
is `num boxes` x 4. |
|
x_offset (int): cropping offset in the x axis. |
|
y_offset (int): cropping offset in the y axis. |
|
Returns: |
|
cropped_boxes (ndarray or None): the cropped boxes with dimension of |
|
`num boxes` x 4. |
|
""" |
|
cropped_boxes = boxes.copy() |
|
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset |
|
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset |
|
|
|
return cropped_boxes |
|
|
|
|
|
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): |
|
""" |
|
Perform uniform spatial sampling on the images and corresponding boxes. |
|
Args: |
|
images (tensor): images to perform uniform crop. The dimension is |
|
`num frames` x `channel` x `height` x `width`. |
|
size (int): size of height and weight to crop the images. |
|
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width |
|
is larger than height. Or 0, 1, or 2 for top, center, and bottom |
|
crop if height is larger than width. |
|
boxes (ndarray or None): optional. Corresponding boxes to images. |
|
Dimension is `num boxes` x 4. |
|
scale_size (int): optinal. If not None, resize the images to scale_size before |
|
performing any crop. |
|
Returns: |
|
cropped (tensor): images with dimension of |
|
`num frames` x `channel` x `size` x `size`. |
|
cropped_boxes (ndarray or None): the cropped boxes with dimension of |
|
`num boxes` x 4. |
|
""" |
|
assert spatial_idx in [0, 1, 2] |
|
ndim = len(images.shape) |
|
if ndim == 3: |
|
images = images.unsqueeze(0) |
|
height = images.shape[2] |
|
width = images.shape[3] |
|
|
|
if scale_size is not None: |
|
if width <= height: |
|
width, height = scale_size, int(height / width * scale_size) |
|
else: |
|
width, height = int(width / height * scale_size), scale_size |
|
images = torch.nn.functional.interpolate( |
|
images, |
|
size=(height, width), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
y_offset = int(math.ceil((height - size) / 2)) |
|
x_offset = int(math.ceil((width - size) / 2)) |
|
|
|
if height > width: |
|
if spatial_idx == 0: |
|
y_offset = 0 |
|
elif spatial_idx == 2: |
|
y_offset = height - size |
|
else: |
|
if spatial_idx == 0: |
|
x_offset = 0 |
|
elif spatial_idx == 2: |
|
x_offset = width - size |
|
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size] |
|
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None |
|
if ndim == 3: |
|
cropped = cropped.squeeze(0) |
|
return cropped, cropped_boxes |
|
|
|
|
|
class SpatialCrop(nn.Module): |
|
""" |
|
Convert the video into 3 smaller clips spatially. Must be used after the |
|
temporal crops to get spatial crops, and should be used with |
|
-2 in the spatial crop at the slowfast augmentation stage (so full |
|
frames are passed in here). Will return a larger list with the |
|
3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT) |
|
or 3x10 testing in SlowFast etc. |
|
""" |
|
|
|
def __init__(self, crop_size: int = 224, num_crops: int = 3): |
|
super().__init__() |
|
self.crop_size = crop_size |
|
if num_crops == 6: |
|
self.crops_to_ext = [0, 1, 2] |
|
|
|
|
|
|
|
|
|
self.flipped_crops_to_ext = [0, 1, 2] |
|
elif num_crops == 3: |
|
self.crops_to_ext = [0, 1, 2] |
|
self.flipped_crops_to_ext = [] |
|
elif num_crops == 1: |
|
self.crops_to_ext = [1] |
|
self.flipped_crops_to_ext = [] |
|
else: |
|
raise NotImplementedError( |
|
"Nothing else supported yet, " |
|
"slowfast only takes 0, 1, 2 as arguments" |
|
) |
|
|
|
def forward(self, videos: Sequence[torch.Tensor]): |
|
""" |
|
Args: |
|
videos: A list of C, T, H, W videos. |
|
Returns: |
|
videos: A list with 3x the number of elements. Each video converted |
|
to C, T, H', W' by spatial cropping. |
|
""" |
|
assert isinstance(videos, list), "Must be a list of videos after temporal crops" |
|
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" |
|
res = [] |
|
for video in videos: |
|
for spatial_idx in self.crops_to_ext: |
|
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) |
|
if not self.flipped_crops_to_ext: |
|
continue |
|
flipped_video = transforms.functional.hflip(video) |
|
for spatial_idx in self.flipped_crops_to_ext: |
|
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) |
|
return res |
|
|