Spaces:
Build error
Build error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import random, shlex, datetime | |
import os, sys, subprocess, shutil | |
from glob import iglob | |
def copy_all_python_files( | |
source, snapshot_main_dir, code_snapshot_hash, recurse_dirs="fairseq" | |
): | |
""" | |
Copies following files from source to destination: | |
a) all *.py files at direct source location. | |
b) all fairseq/*.py recursively (default); recurse through comma-separated recurse_dirs | |
""" | |
os.makedirs(snapshot_main_dir, exist_ok=True) | |
destination = os.path.join(snapshot_main_dir, code_snapshot_hash) | |
assert not os.path.exists(destination), "Code snapshot: {0} alredy exists".format( | |
code_snapshot_hash | |
) | |
os.makedirs(destination) | |
def all_pys(recurse_dirs): | |
yield from iglob(os.path.join(source, "*.py")) | |
for d in recurse_dirs.split(","): | |
yield from iglob(os.path.join(source, d, "**/*.py"), recursive=True) | |
yield from iglob(os.path.join(source, d, "**/*.so"), recursive=True) | |
yield from iglob(os.path.join(source, d, "**/*.yaml"), recursive=True) | |
for filepath in all_pys(recurse_dirs): | |
directory, filename = os.path.split(filepath) | |
if directory: | |
os.makedirs(os.path.join(destination, directory), exist_ok=True) | |
shutil.copy2( | |
os.path.join(source, filepath), os.path.join(destination, filepath) | |
) | |
return destination | |
def launch_cluster(slurm_args, model_args): | |
# prepare | |
jobname = slurm_args.get('job-name', 'test') | |
if slurm_args.get('workplace') is not None: | |
os.makedirs(slurm_args.get('workplace'), exist_ok=True) | |
if slurm_args.get('workplace') is not None: | |
train_log = os.path.join(slurm_args['workplace'], 'train.%A.out') | |
train_stderr = os.path.join(slurm_args['workplace'], 'train.%A.stderr.%j') | |
else: | |
train_log = train_stderr = None | |
nodes, gpus = slurm_args.get('nodes', 1), slurm_args.get('gpus', 8) | |
if not slurm_args.get('local', False): | |
assert (train_log is not None) and (train_stderr is not None) | |
# parse slurm | |
destination = "" | |
# if slurm_args.get('workplace', None) is not None: | |
# # Currently hash is just the current time in ISO format. | |
# # Remove colons since they cannot be escaped in POSIX PATH env vars. | |
# code_snapshot_hash = datetime.datetime.now().isoformat().replace(":", "_") | |
# destination = copy_all_python_files( | |
# ".", | |
# os.path.join(slurm_args['workplace'], "slurm_snapshot_code"), | |
# code_snapshot_hash, | |
# 'fairseq', | |
# ) | |
# os.environ["PYTHONPATH"] = destination + ":" + os.environ.get("PYTHONPATH", "") | |
# print('creat snapshot at {}'.format(destination)) | |
train_cmd = ['python', os.path.join(destination, 'run_train.py'), ] | |
train_cmd.extend([f'gpus={nodes * gpus}']) | |
train_cmd.extend([f'port={get_random_port()}']) | |
train_cmd += model_args | |
base_srun_cmd = [ | |
'srun', | |
'--job-name', jobname, | |
'--output', train_log, | |
'--error', train_stderr, | |
'--open-mode', 'append', | |
'--unbuffered', | |
] | |
srun_cmd = base_srun_cmd + train_cmd | |
srun_cmd_str = ' '.join(map(shlex.quote, srun_cmd)) | |
srun_cmd_str = srun_cmd_str + ' &' | |
sbatch_cmd = [ | |
'sbatch', | |
'--job-name', jobname, | |
'--partition', slurm_args.get('partition', 'learnfair'), | |
'--gres', 'gpu:volta:{}'.format(gpus), | |
'--nodes', str(nodes), | |
'--ntasks-per-node', '1', | |
'--cpus-per-task', '20', | |
'--output', train_log, | |
'--error', train_stderr, | |
'--open-mode', 'append', | |
'--signal', 'B:USR1@180', | |
'--time', slurm_args.get('time', '4320'), | |
'--mem', slurm_args.get('mem', '500gb'), | |
'--exclusive', | |
'--exclude', 'learnfair5035,learnfair5289,learnfair5088,learnfair5028,learnfair5032,learnfair5033,learnfair5056,learnfair5098,learnfair5122,learnfair5124,learnfair5156,learnfair5036,learnfair5258,learnfair5205,learnfair5201,learnfair5240,learnfair5087,learnfair5119,learnfair5246,learnfair7474,learnfair7585,learnfair5150,learnfair5166,learnfair5215,learnfair5142,learnfair5070,learnfair5236,learnfair7523' | |
] | |
if 'constraint' in slurm_args: | |
sbatch_cmd += ['-C', slurm_args.get('constraint')] | |
if 'comment' in slurm_args: | |
sbatch_cmd += ['--comment', slurm_args.get('comment')] | |
wrapped_cmd = requeue_support() + '\n' + srun_cmd_str + ' \n wait $! \n sleep 610 & \n wait $!' | |
sbatch_cmd += ['--wrap', wrapped_cmd] | |
sbatch_cmd_str = ' '.join(map(shlex.quote, sbatch_cmd)) | |
# start training | |
env = os.environ.copy() | |
env['OMP_NUM_THREADS'] = '2' | |
env['NCCL_SOCKET_IFNAME'] = '' | |
if env.get('SLURM_ARGS', None) is not None: | |
del env['SLURM_ARGS'] | |
if nodes > 1: | |
env['NCCL_SOCKET_IFNAME'] = '^docker0,lo' | |
env['NCCL_DEBUG'] = 'INFO' | |
if slurm_args.get('dry-run', False): | |
print(sbatch_cmd_str) | |
elif slurm_args.get('local', False): | |
assert nodes == 1, 'distributed training cannot be combined with local' | |
if 'CUDA_VISIBLE_DEVICES' not in env: | |
env['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(gpus))) | |
env['NCCL_DEBUG'] = 'INFO' | |
if train_log is not None: | |
train_proc = subprocess.Popen(train_cmd, env=env, stdout=subprocess.PIPE) | |
tee_proc = subprocess.Popen(['tee', '-a', train_log], stdin=train_proc.stdout) | |
train_proc.stdout.close() | |
train_proc.wait() | |
tee_proc.wait() | |
else: | |
train_proc = subprocess.Popen(train_cmd, env=env) | |
train_proc.wait() | |
else: | |
with open(train_log, 'a') as train_log_h: | |
print(f'running command: {sbatch_cmd_str}\n') | |
with subprocess.Popen(sbatch_cmd, stdout=subprocess.PIPE, env=env) as train_proc: | |
stdout = train_proc.stdout.read().decode('utf-8') | |
print(stdout, file=train_log_h) | |
try: | |
job_id = int(stdout.rstrip().split()[-1]) | |
return job_id | |
except IndexError: | |
return None | |
def launch(slurm_args, model_args): | |
job_id = launch_cluster(slurm_args, model_args) | |
if job_id is not None: | |
print('Launched {}'.format(job_id)) | |
else: | |
print('Failed.') | |
def requeue_support(): | |
return """ | |
trap_handler () { | |
echo "Caught signal: " $1 | |
# SIGTERM must be bypassed | |
if [ "$1" = "TERM" ]; then | |
echo "bypass sigterm" | |
else | |
# Submit a new job to the queue | |
echo "Requeuing " $SLURM_JOB_ID | |
scontrol requeue $SLURM_JOB_ID | |
fi | |
} | |
# Install signal handler | |
trap 'trap_handler USR1' USR1 | |
trap 'trap_handler TERM' TERM | |
""" | |
def get_random_port(): | |
old_state = random.getstate() | |
random.seed() | |
port = random.randint(10000, 20000) | |
random.setstate(old_state) | |
return port | |