Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import subprocess | |
import numpy as np | |
import multiprocessing as mp | |
import math | |
import torch | |
import torch.distributed as dist | |
from torch.utils.data.sampler import Sampler | |
from torch.nn import Module | |
class DistModule(Module): | |
def __init__(self, module): | |
super(DistModule, self).__init__() | |
self.module = module | |
broadcast_params(self.module) | |
def forward(self, *inputs, **kwargs): | |
return self.module(*inputs, **kwargs) | |
def train(self, mode=True): | |
super(DistModule, self).train(mode) | |
self.module.train(mode) | |
def average_gradients(model): | |
""" average gradients """ | |
for param in model.parameters(): | |
if param.requires_grad: | |
dist.all_reduce(param.grad.data) | |
def broadcast_params(model): | |
""" broadcast model parameters """ | |
for p in model.state_dict().values(): | |
dist.broadcast(p, 0) | |
def dist_init(launcher, backend='nccl', **kwargs): | |
if mp.get_start_method(allow_none=True) is None: | |
mp.set_start_method('spawn') | |
if launcher == 'pytorch': | |
_init_dist_pytorch(backend, **kwargs) | |
elif launcher == 'mpi': | |
_init_dist_mpi(backend, **kwargs) | |
elif launcher == 'slurm': | |
_init_dist_slurm(backend, **kwargs) | |
else: | |
raise ValueError('Invalid launcher type: {}'.format(launcher)) | |
def _init_dist_pytorch(backend, **kwargs): | |
rank = int(os.environ['RANK']) | |
num_gpus = torch.cuda.device_count() | |
torch.cuda.set_device(rank % num_gpus) | |
dist.init_process_group(backend=backend, **kwargs) | |
def _init_dist_mpi(backend, **kwargs): | |
raise NotImplementedError | |
def _init_dist_slurm(backend, port=10086, **kwargs): | |
proc_id = int(os.environ['SLURM_PROCID']) | |
ntasks = int(os.environ['SLURM_NTASKS']) | |
node_list = os.environ['SLURM_NODELIST'] | |
num_gpus = torch.cuda.device_count() | |
torch.cuda.set_device(proc_id % num_gpus) | |
addr = subprocess.getoutput( | |
'scontrol show hostname {} | head -n1'.format(node_list)) | |
os.environ['MASTER_PORT'] = str(port) | |
os.environ['MASTER_ADDR'] = addr | |
os.environ['WORLD_SIZE'] = str(ntasks) | |
os.environ['RANK'] = str(proc_id) | |
dist.init_process_group(backend=backend) | |
def gather_tensors(input_array): | |
world_size = dist.get_world_size() | |
## gather shapes first | |
myshape = input_array.shape | |
mycount = input_array.size | |
shape_tensor = torch.Tensor(np.array(myshape)).cuda() | |
all_shape = [torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)] | |
dist.all_gather(all_shape, shape_tensor) | |
## compute largest shapes | |
all_shape = [x.cpu().numpy() for x in all_shape] | |
all_count = [int(x.prod()) for x in all_shape] | |
all_shape = [list(map(int, x)) for x in all_shape] | |
max_count = max(all_count) | |
## padding tensors and gather them | |
output_tensors = [torch.Tensor(max_count).cuda() for i in range(world_size)] | |
padded_input_array = np.zeros(max_count) | |
padded_input_array[:mycount] = input_array.reshape(-1) | |
input_tensor = torch.Tensor(padded_input_array).cuda() | |
dist.all_gather(output_tensors, input_tensor) | |
## unpadding gathered tensors | |
padded_output = [x.cpu().numpy() for x in output_tensors] | |
output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)] | |
return output | |
def gather_tensors_batch(input_array, part_size=10): | |
# gather | |
rank = dist.get_rank() | |
all_features = [] | |
part_num = input_array.shape[0] // part_size + 1 if input_array.shape[0] % part_size != 0 else input_array.shape[0] // part_size | |
for i in range(part_num): | |
part_feat = input_array[i * part_size:min((i+1)*part_size, input_array.shape[0]),...] | |
assert part_feat.shape[0] > 0, "rank: {}, length of part features should > 0".format(rank) | |
print("rank: {}, gather part: {}/{}, length: {}".format(rank, i, part_num, len(part_feat))) | |
gather_part_feat = gather_tensors(part_feat) | |
all_features.append(gather_part_feat) | |
print("rank: {}, gather done.".format(rank)) | |
all_features = np.concatenate([np.concatenate([all_features[i][j] for i in range(part_num)], axis=0) for j in range(len(all_features[0]))], axis=0) | |
return all_features | |
def reduce_tensors(tensor): | |
reduced_tensor = tensor.clone() | |
dist.all_reduce(reduced_tensor) | |
return reduced_tensor | |
class DistributedSequentialSampler(Sampler): | |
def __init__(self, dataset, world_size=None, rank=None): | |
if world_size == None: | |
world_size = dist.get_world_size() | |
if rank == None: | |
rank = dist.get_rank() | |
self.dataset = dataset | |
self.world_size = world_size | |
self.rank = rank | |
assert len(self.dataset) >= self.world_size, '{} vs {}'.format(len(self.dataset), self.world_size) | |
sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) | |
self.beg = sub_num * self.rank | |
#self.end = min(self.beg+sub_num, len(self.dataset)) | |
self.end = self.beg + sub_num | |
self.padded_ind = list(range(len(self.dataset))) + list(range(sub_num * self.world_size - len(self.dataset))) | |
def __iter__(self): | |
indices = [self.padded_ind[i] for i in range(self.beg, self.end)] | |
return iter(indices) | |
def __len__(self): | |
return self.end - self.beg | |
class GivenIterationSampler(Sampler): | |
def __init__(self, dataset, total_iter, batch_size, last_iter=-1): | |
self.dataset = dataset | |
self.total_iter = total_iter | |
self.batch_size = batch_size | |
self.last_iter = last_iter | |
self.total_size = self.total_iter * self.batch_size | |
self.indices = self.gen_new_list() | |
self.call = 0 | |
def __iter__(self): | |
if self.call == 0: | |
self.call = 1 | |
return iter(self.indices[(self.last_iter + 1) * self.batch_size:]) | |
else: | |
raise RuntimeError("this sampler is not designed to be called more than once!!") | |
def gen_new_list(self): | |
# each process shuffle all list with same seed, and pick one piece according to rank | |
np.random.seed(0) | |
all_size = self.total_size | |
indices = np.arange(len(self.dataset)) | |
indices = indices[:all_size] | |
num_repeat = (all_size-1) // indices.shape[0] + 1 | |
indices = np.tile(indices, num_repeat) | |
indices = indices[:all_size] | |
np.random.shuffle(indices) | |
assert len(indices) == self.total_size | |
return indices | |
def __len__(self): | |
return self.total_size | |
class DistributedGivenIterationSampler(Sampler): | |
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1): | |
if world_size is None: | |
world_size = dist.get_world_size() | |
if rank is None: | |
rank = dist.get_rank() | |
assert rank < world_size | |
self.dataset = dataset | |
self.total_iter = total_iter | |
self.batch_size = batch_size | |
self.world_size = world_size | |
self.rank = rank | |
self.last_iter = last_iter | |
self.total_size = self.total_iter*self.batch_size | |
self.indices = self.gen_new_list() | |
self.call = 0 | |
def __iter__(self): | |
if self.call == 0: | |
self.call = 1 | |
return iter(self.indices[(self.last_iter+1)*self.batch_size:]) | |
else: | |
raise RuntimeError("this sampler is not designed to be called more than once!!") | |
def gen_new_list(self): | |
# each process shuffle all list with same seed, and pick one piece according to rank | |
np.random.seed(0) | |
all_size = self.total_size * self.world_size | |
indices = np.arange(len(self.dataset)) | |
indices = indices[:all_size] | |
num_repeat = (all_size-1) // indices.shape[0] + 1 | |
indices = np.tile(indices, num_repeat) | |
indices = indices[:all_size] | |
np.random.shuffle(indices) | |
beg = self.total_size * self.rank | |
indices = indices[beg:beg+self.total_size] | |
assert len(indices) == self.total_size | |
return indices | |
def __len__(self): | |
# note here we do not take last iter into consideration, since __len__ | |
# should only be used for displaying, the correct remaining size is | |
# handled by dataloader | |
#return self.total_size - (self.last_iter+1)*self.batch_size | |
return self.total_size | |