3v324v23's picture
init
2f85de4
raw
history blame
1.59 kB
# python3.8
"""Contains some utility operators."""
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
__all__ = [
'all_gather',
'upsample',
'downsample',
]
def all_gather(tensor):
"""Gathers tensor from all devices and executes averaging."""
if not dist.is_initialized():
return tensor
world_size = dist.get_world_size()
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, async_op=False)
return torch.stack(tensor_list, dim=0).mean(dim=0)
def upsample(img_nerf, size, filter=None):
up = size // img_nerf.size(-1)
if up <= 1:
return img_nerf
if filter is not None:
from third_party.stylegan2_official_ops import upfirdn2d
for _ in range(int(math.log2(up))):
img_nerf = upfirdn2d.downsample2d(img_nerf, filter, up=2)
else:
img_nerf = F.interpolate(img_nerf, (size, size),
mode='bilinear',
align_corners=False)
return img_nerf
def downsample(img0, size, filter=None):
down = img0.size(-1) // size
if down <= 1:
return img0
if filter is not None:
from third_party.stylegan2_official_ops import upfirdn2d
for _ in range(int(math.log2(down))):
img0 = upfirdn2d.downsample2d(img0, filter, down=2)
else:
img0 = F.interpolate(img0, (size, size),
mode='bilinear',
align_corners=False)
return img0