import os import json import argparse import subprocess from shutil import copyfile import torch.distributed as dist import torch import torch.multiprocessing as mp import core import core.trainer import core.trainer_flow_w_edge # import warnings # warnings.filterwarnings("ignore") from core.dist import ( get_world_size, get_local_rank, get_global_rank, get_master_ip, ) parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', default='configs/train_propainter.json', type=str) parser.add_argument('-p', '--port', default='23490', type=str) args = parser.parse_args() def main_worker(rank, config): if 'local_rank' not in config: config['local_rank'] = config['global_rank'] = rank if config['distributed']: torch.cuda.set_device(int(config['local_rank'])) torch.distributed.init_process_group(backend='nccl', init_method=config['init_method'], world_size=config['world_size'], rank=config['global_rank'], group_name='mtorch') print('using GPU {}-{} for training'.format(int(config['global_rank']), int(config['local_rank']))) config['save_dir'] = os.path.join( config['save_dir'], '{}_{}'.format(config['model']['net'], os.path.basename(args.config).split('.')[0])) config['save_metric_dir'] = os.path.join( './scores', '{}_{}'.format(config['model']['net'], os.path.basename(args.config).split('.')[0])) if torch.cuda.is_available(): config['device'] = torch.device("cuda:{}".format(config['local_rank'])) else: config['device'] = 'cpu' if (not config['distributed']) or config['global_rank'] == 0: os.makedirs(config['save_dir'], exist_ok=True) config_path = os.path.join(config['save_dir'], args.config.split('/')[-1]) if not os.path.isfile(config_path): copyfile(args.config, config_path) print('[**] create folder {}'.format(config['save_dir'])) trainer_version = config['trainer']['version'] trainer = core.__dict__[trainer_version].__dict__['Trainer'](config) # Trainer(config) trainer.train() if __name__ == "__main__": torch.backends.cudnn.benchmark = True mp.set_sharing_strategy('file_system') # loading configs config = json.load(open(args.config)) # setting distributed configurations # config['world_size'] = get_world_size() config['world_size'] = torch.cuda.device_count() config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" config['distributed'] = True if config['world_size'] > 1 else False print('world_size:', config['world_size']) # setup distributed parallel training environments # if get_master_ip() == "127.0.0.X": # # manually launch distributed processes # mp.spawn(main_worker, nprocs=config['world_size'], args=(config, )) # else: # # multiple processes have been launched by openmpi # config['local_rank'] = get_local_rank() # config['global_rank'] = get_global_rank() # main_worker(-1, config) mp.spawn(main_worker, nprocs=torch.cuda.device_count(), args=(config, ))