Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Part of the code is from | |
# `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and | |
# `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py` | |
# Modified by Yue Zhao | |
# The original code is under MIT License | |
import torch | |
import torch.distributed as dist | |
from typing import Tuple | |
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: | |
""" | |
For some backends, such as NCCL, communication only works if the | |
tensor is on the GPU. This helper function converts to the correct | |
device and returns the tensor + original device. | |
""" | |
orig_device = "cpu" if not tensor.is_cuda else "gpu" | |
if ( | |
torch.distributed.is_available() | |
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL | |
and not tensor.is_cuda | |
): | |
tensor = tensor.cuda() | |
return (tensor, orig_device) | |
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: | |
""" | |
For some backends, such as NCCL, communication only works if the | |
tensor is on the GPU. This converts the tensor back to original device. | |
""" | |
if tensor.is_cuda and orig_device == "cpu": | |
tensor = tensor.cpu() | |
return tensor | |
def is_distributed_training_run() -> bool: | |
return ( | |
torch.distributed.is_available() | |
and torch.distributed.is_initialized() | |
and (torch.distributed.get_world_size() > 1) | |
) | |
class GatherLayer(torch.autograd.Function): | |
""" | |
Gather tensors from all workers with support for backward propagation: | |
This implementation does not cut the gradients as torch.distributed.all_gather does. | |
""" | |
def forward(ctx, x): | |
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] | |
dist.all_gather(output, x) | |
return tuple(output) | |
def backward(ctx, *grads): | |
all_gradients = torch.stack(grads) | |
dist.all_reduce(all_gradients) | |
return all_gradients[dist.get_rank()] | |
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: | |
""" | |
Similar to classy_vision.generic.distributed_util.gather_from_all | |
except that it does not cut the gradients | |
""" | |
if tensor.ndim == 0: | |
# 0 dim tensors cannot be gathered. so unsqueeze | |
tensor = tensor.unsqueeze(0) | |
if is_distributed_training_run(): | |
tensor, orig_device = convert_to_distributed_tensor(tensor) | |
gathered_tensors = GatherLayer.apply(tensor) | |
gathered_tensors = [ | |
convert_to_normal_tensor(_tensor, orig_device) | |
for _tensor in gathered_tensors | |
] | |
else: | |
gathered_tensors = [tensor] | |
gathered_tensor = torch.cat(gathered_tensors, 0) | |
return gathered_tensor | |