Create utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
|
9 |
+
def _max_by_axis(the_list):
|
10 |
+
# type: (List[List[int]]) -> List[int]
|
11 |
+
maxes = the_list[0]
|
12 |
+
for sublist in the_list[1:]:
|
13 |
+
for index, item in enumerate(sublist):
|
14 |
+
maxes[index] = max(maxes[index], item)
|
15 |
+
return maxes
|
16 |
+
|
17 |
+
|
18 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
19 |
+
# TODO make this more general
|
20 |
+
if tensor_list[0].ndim == 3:
|
21 |
+
# TODO make it support different-sized images
|
22 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
23 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
24 |
+
batch_shape = [len(tensor_list)] + max_size
|
25 |
+
b, c, h, w = batch_shape
|
26 |
+
dtype = tensor_list[0].dtype
|
27 |
+
device = tensor_list[0].device
|
28 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
29 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
30 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
31 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
32 |
+
m[: img.shape[1], :img.shape[2]] = False
|
33 |
+
else:
|
34 |
+
raise ValueError('not supported')
|
35 |
+
return NestedTensor(tensor, mask)
|
36 |
+
|
37 |
+
|
38 |
+
class NestedTensor(object):
|
39 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
40 |
+
self.tensors = tensors
|
41 |
+
self.mask = mask
|
42 |
+
|
43 |
+
def to(self, device):
|
44 |
+
# type: (Device) -> NestedTensor # noqa
|
45 |
+
cast_tensor = self.tensors.to(device)
|
46 |
+
mask = self.mask
|
47 |
+
if mask is not None:
|
48 |
+
assert mask is not None
|
49 |
+
cast_mask = mask.to(device)
|
50 |
+
else:
|
51 |
+
cast_mask = None
|
52 |
+
return NestedTensor(cast_tensor, cast_mask)
|
53 |
+
|
54 |
+
def decompose(self):
|
55 |
+
return self.tensors, self.mask
|
56 |
+
|
57 |
+
def __repr__(self):
|
58 |
+
return str(self.tensors)
|
59 |
+
|
60 |
+
|
61 |
+
def is_dist_avail_and_initialized():
|
62 |
+
if not dist.is_available():
|
63 |
+
return False
|
64 |
+
if not dist.is_initialized():
|
65 |
+
return False
|
66 |
+
return True
|
67 |
+
|
68 |
+
|
69 |
+
def get_rank():
|
70 |
+
if not is_dist_avail_and_initialized():
|
71 |
+
return 0
|
72 |
+
return dist.get_rank()
|
73 |
+
|
74 |
+
|
75 |
+
def is_main_process():
|
76 |
+
return get_rank() == 0
|