Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
import random | |
import numbers | |
from torchvision.transforms import RandomCrop, RandomResizedCrop | |
def _is_tensor_video_clip(clip): | |
if not torch.is_tensor(clip): | |
raise TypeError("clip should be Tensor. Got %s" % type(clip)) | |
if not clip.ndimension() == 4: | |
raise ValueError("clip should be 4D. Got %dD" % clip.dim()) | |
return True | |
def to_tensor(clip): | |
""" | |
Convert tensor data type from uint8 to float, divide value by 255.0 and | |
permute the dimensions of clip tensor | |
Args: | |
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) | |
Return: | |
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) | |
""" | |
_is_tensor_video_clip(clip) | |
if not clip.dtype == torch.uint8: | |
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) | |
# return clip.float().permute(3, 0, 1, 2) / 255.0 | |
return clip.float() / 255.0 | |
def resize(clip, target_size, interpolation_mode): | |
if len(target_size) != 2: | |
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") | |
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) | |
class ToTensorVideo: | |
""" | |
Convert tensor data type from uint8 to float, divide value by 255.0 and | |
permute the dimensions of clip tensor | |
""" | |
def __init__(self): | |
pass | |
def __call__(self, clip): | |
""" | |
Args: | |
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) | |
Return: | |
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) | |
""" | |
return to_tensor(clip) | |
def __repr__(self) -> str: | |
return self.__class__.__name__ | |
class ResizeVideo: | |
''' | |
Resize to the specified size | |
''' | |
def __init__( | |
self, | |
size, | |
interpolation_mode="bilinear", | |
): | |
if isinstance(size, tuple): | |
if len(size) != 2: | |
raise ValueError(f"size should be tuple (height, width), instead got {size}") | |
self.size = size | |
else: | |
self.size = (size, size) | |
self.interpolation_mode = interpolation_mode | |
def __call__(self, clip): | |
""" | |
Args: | |
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
Returns: | |
torch.tensor: scale resized video clip. | |
size is (T, C, h, w) | |
""" | |
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) | |
return clip_resize | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" | |
class TemporalRandomCrop(object): | |
"""Temporally crop the given frame indices at a random location. | |
Args: | |
size (int): Desired length of frames will be seen in the model. | |
""" | |
def __init__(self, size): | |
self.size = size | |
def __call__(self, total_frames): | |
rand_end = max(0, total_frames - self.size - 1) | |
begin_index = random.randint(0, rand_end) | |
end_index = min(begin_index + self.size, total_frames) | |
return begin_index, end_index | |