CharadesEgo / lavila /models /distributed_utils.py
gina9726's picture
Upload demo files
c6f92cc verified
raw
history blame contribute delete
No virus
3 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Part of the code is from
# `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and
# `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py`
# Modified by Yue Zhao
# The original code is under MIT License
import torch
import torch.distributed as dist
from typing import Tuple
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
"""
For some backends, such as NCCL, communication only works if the
tensor is on the GPU. This helper function converts to the correct
device and returns the tensor + original device.
"""
orig_device = "cpu" if not tensor.is_cuda else "gpu"
if (
torch.distributed.is_available()
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
and not tensor.is_cuda
):
tensor = tensor.cuda()
return (tensor, orig_device)
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
"""
For some backends, such as NCCL, communication only works if the
tensor is on the GPU. This converts the tensor back to original device.
"""
if tensor.is_cuda and orig_device == "cpu":
tensor = tensor.cpu()
return tensor
def is_distributed_training_run() -> bool:
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and (torch.distributed.get_world_size() > 1)
)
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
"""
Similar to classy_vision.generic.distributed_util.gather_from_all
except that it does not cut the gradients
"""
if tensor.ndim == 0:
# 0 dim tensors cannot be gathered. so unsqueeze
tensor = tensor.unsqueeze(0)
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
gathered_tensors = GatherLayer.apply(tensor)
gathered_tensors = [
convert_to_normal_tensor(_tensor, orig_device)
for _tensor in gathered_tensors
]
else:
gathered_tensors = [tensor]
gathered_tensor = torch.cat(gathered_tensors, 0)
return gathered_tensor