Spaces:
Build error
Build error
File size: 7,341 Bytes
94ada0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random, shlex, datetime
import os, sys, subprocess, shutil
from glob import iglob
def copy_all_python_files(
source, snapshot_main_dir, code_snapshot_hash, recurse_dirs="fairseq"
):
"""
Copies following files from source to destination:
a) all *.py files at direct source location.
b) all fairseq/*.py recursively (default); recurse through comma-separated recurse_dirs
"""
os.makedirs(snapshot_main_dir, exist_ok=True)
destination = os.path.join(snapshot_main_dir, code_snapshot_hash)
assert not os.path.exists(destination), "Code snapshot: {0} alredy exists".format(
code_snapshot_hash
)
os.makedirs(destination)
def all_pys(recurse_dirs):
yield from iglob(os.path.join(source, "*.py"))
for d in recurse_dirs.split(","):
yield from iglob(os.path.join(source, d, "**/*.py"), recursive=True)
yield from iglob(os.path.join(source, d, "**/*.so"), recursive=True)
yield from iglob(os.path.join(source, d, "**/*.yaml"), recursive=True)
for filepath in all_pys(recurse_dirs):
directory, filename = os.path.split(filepath)
if directory:
os.makedirs(os.path.join(destination, directory), exist_ok=True)
shutil.copy2(
os.path.join(source, filepath), os.path.join(destination, filepath)
)
return destination
def launch_cluster(slurm_args, model_args):
# prepare
jobname = slurm_args.get('job-name', 'test')
if slurm_args.get('workplace') is not None:
os.makedirs(slurm_args.get('workplace'), exist_ok=True)
if slurm_args.get('workplace') is not None:
train_log = os.path.join(slurm_args['workplace'], 'train.%A.out')
train_stderr = os.path.join(slurm_args['workplace'], 'train.%A.stderr.%j')
else:
train_log = train_stderr = None
nodes, gpus = slurm_args.get('nodes', 1), slurm_args.get('gpus', 8)
if not slurm_args.get('local', False):
assert (train_log is not None) and (train_stderr is not None)
# parse slurm
destination = ""
# if slurm_args.get('workplace', None) is not None:
# # Currently hash is just the current time in ISO format.
# # Remove colons since they cannot be escaped in POSIX PATH env vars.
# code_snapshot_hash = datetime.datetime.now().isoformat().replace(":", "_")
# destination = copy_all_python_files(
# ".",
# os.path.join(slurm_args['workplace'], "slurm_snapshot_code"),
# code_snapshot_hash,
# 'fairseq',
# )
# os.environ["PYTHONPATH"] = destination + ":" + os.environ.get("PYTHONPATH", "")
# print('creat snapshot at {}'.format(destination))
train_cmd = ['python', os.path.join(destination, 'run_train.py'), ]
train_cmd.extend([f'gpus={nodes * gpus}'])
train_cmd.extend([f'port={get_random_port()}'])
train_cmd += model_args
base_srun_cmd = [
'srun',
'--job-name', jobname,
'--output', train_log,
'--error', train_stderr,
'--open-mode', 'append',
'--unbuffered',
]
srun_cmd = base_srun_cmd + train_cmd
srun_cmd_str = ' '.join(map(shlex.quote, srun_cmd))
srun_cmd_str = srun_cmd_str + ' &'
sbatch_cmd = [
'sbatch',
'--job-name', jobname,
'--partition', slurm_args.get('partition', 'learnfair'),
'--gres', 'gpu:volta:{}'.format(gpus),
'--nodes', str(nodes),
'--ntasks-per-node', '1',
'--cpus-per-task', '20',
'--output', train_log,
'--error', train_stderr,
'--open-mode', 'append',
'--signal', 'B:USR1@180',
'--time', slurm_args.get('time', '4320'),
'--mem', slurm_args.get('mem', '500gb'),
'--exclusive',
'--exclude', 'learnfair5035,learnfair5289,learnfair5088,learnfair5028,learnfair5032,learnfair5033,learnfair5056,learnfair5098,learnfair5122,learnfair5124,learnfair5156,learnfair5036,learnfair5258,learnfair5205,learnfair5201,learnfair5240,learnfair5087,learnfair5119,learnfair5246,learnfair7474,learnfair7585,learnfair5150,learnfair5166,learnfair5215,learnfair5142,learnfair5070,learnfair5236,learnfair7523'
]
if 'constraint' in slurm_args:
sbatch_cmd += ['-C', slurm_args.get('constraint')]
if 'comment' in slurm_args:
sbatch_cmd += ['--comment', slurm_args.get('comment')]
wrapped_cmd = requeue_support() + '\n' + srun_cmd_str + ' \n wait $! \n sleep 610 & \n wait $!'
sbatch_cmd += ['--wrap', wrapped_cmd]
sbatch_cmd_str = ' '.join(map(shlex.quote, sbatch_cmd))
# start training
env = os.environ.copy()
env['OMP_NUM_THREADS'] = '2'
env['NCCL_SOCKET_IFNAME'] = ''
if env.get('SLURM_ARGS', None) is not None:
del env['SLURM_ARGS']
if nodes > 1:
env['NCCL_SOCKET_IFNAME'] = '^docker0,lo'
env['NCCL_DEBUG'] = 'INFO'
if slurm_args.get('dry-run', False):
print(sbatch_cmd_str)
elif slurm_args.get('local', False):
assert nodes == 1, 'distributed training cannot be combined with local'
if 'CUDA_VISIBLE_DEVICES' not in env:
env['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(gpus)))
env['NCCL_DEBUG'] = 'INFO'
if train_log is not None:
train_proc = subprocess.Popen(train_cmd, env=env, stdout=subprocess.PIPE)
tee_proc = subprocess.Popen(['tee', '-a', train_log], stdin=train_proc.stdout)
train_proc.stdout.close()
train_proc.wait()
tee_proc.wait()
else:
train_proc = subprocess.Popen(train_cmd, env=env)
train_proc.wait()
else:
with open(train_log, 'a') as train_log_h:
print(f'running command: {sbatch_cmd_str}\n')
with subprocess.Popen(sbatch_cmd, stdout=subprocess.PIPE, env=env) as train_proc:
stdout = train_proc.stdout.read().decode('utf-8')
print(stdout, file=train_log_h)
try:
job_id = int(stdout.rstrip().split()[-1])
return job_id
except IndexError:
return None
def launch(slurm_args, model_args):
job_id = launch_cluster(slurm_args, model_args)
if job_id is not None:
print('Launched {}'.format(job_id))
else:
print('Failed.')
def requeue_support():
return """
trap_handler () {
echo "Caught signal: " $1
# SIGTERM must be bypassed
if [ "$1" = "TERM" ]; then
echo "bypass sigterm"
else
# Submit a new job to the queue
echo "Requeuing " $SLURM_JOB_ID
scontrol requeue $SLURM_JOB_ID
fi
}
# Install signal handler
trap 'trap_handler USR1' USR1
trap 'trap_handler TERM' TERM
"""
def get_random_port():
old_state = random.getstate()
random.seed()
port = random.randint(10000, 20000)
random.setstate(old_state)
return port
|