Spaces:
Runtime error
Runtime error
File size: 6,554 Bytes
217780a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
#!/usr/bin/env python
#
# This tool converts any deepspeed checkpoints found at given path to hf format
#
# Example:
#
# ./convert-checkpoints.py checkpoints-path
#
import argparse
import subprocess
import sys
from pathlib import Path
import boto3
def check_s3_directory(directory_path):
s3 = boto3.client("s3")
# Add a trailing slash to the directory path
if not directory_path.endswith("/"):
directory_path += "/"
# Check if any objects exist with the given directory prefix
response = s3.list_objects_v2(Bucket="m4-exps", Prefix=directory_path)
# If any objects are found, the directory exists
if "Contents" in response:
return True
return False
def check_s3_file(file_key):
s3 = boto3.client("s3")
try:
s3.head_object(Bucket="m4-exps", Key=file_key)
return True
except Exception:
return False
def run_cmd(cmd, check=True):
try:
response = subprocess.run(
cmd,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=check,
encoding="utf-8",
).stdout.strip()
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)
return response
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("run_name", type=str, help="run name")
parser.add_argument("opt_step_num_list", nargs="+", help="list of opt-steps to download")
parser.add_argument("repo_path", type=str, help="repo path")
parser.add_argument("-f", "--force", action="store_true", help="force rebuilding of all checkpoints")
return parser.parse_args()
def exit(msg):
print(msg)
sys.exit()
def cmd_retry_loop(cmd, max_retries=5):
# s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
# ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
# you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
# XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=
tries = 0
while tries < max_retries:
tries += 1
try:
response = run_cmd(cmd)
print(response)
break
except EnvironmentError as e:
if "InvalidDigest" in str(e):
print(f"MD5 checksum failed, download retry {tries}")
continue
except Exception:
# some other possible failure?
raise
return response
def main():
args = get_args()
run_name = args.run_name
opt_step_num_list = args.opt_step_num_list
repo_path = Path(args.repo_path)
zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
bucket_name = "m4-exps"
opt_step_s3_file_keys = [f"{run_name}/opt_step-{opt_step_num}" for opt_step_num in opt_step_num_list]
check_s3_directory(run_name)
# Check each folder in real time to allow for overlapping jobs starting at different times
for opt_step_s3_file_key in opt_step_s3_file_keys:
print(f"\n*** Checking {opt_step_s3_file_key}")
if not check_s3_directory(opt_step_s3_file_key):
print(f"The checkpoint {opt_step_s3_file_key} does not exist - skipping")
continue
unwrapped_model_s3_file_key = f"{opt_step_s3_file_key}/unwrapped_model"
bin_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin"
index_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin.index.json"
is_not_converted = not check_s3_file(bin_s3_file_key) and not check_s3_file(index_s3_file_key)
if is_not_converted:
print(
f"The checkpoint hasn't been converted, launching download for {opt_step_s3_file_key} - it could take"
" a long time"
)
opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
cmd = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/* {cluster_opt_step_dir}".split()
download_response_opt_step_dir = cmd_retry_loop(cmd, max_retries=5)
print(f"download_response_opt_step_dir: {download_response_opt_step_dir}")
else:
print(
"The checkpoint has been converted already, downloading only the unwrapped checkpoint and"
" tokenizer dir"
)
opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
unwrapped_model_dir = f"{cluster_opt_step_dir}/unwrapped_model"
tokenizer_dir = f"{cluster_opt_step_dir}/tokenizer"
cmd_model = (
f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/* {unwrapped_model_dir}".split()
)
cmd_tokenizer = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/tokenizer/* {tokenizer_dir}".split()
download_response_model = cmd_retry_loop(cmd_model, max_retries=5)
print(f"download_response_model: {download_response_model}")
download_response_tokenizer = cmd_retry_loop(cmd_tokenizer, max_retries=5)
print(f"download_response_tokenizer: {download_response_tokenizer}")
print(f"opt_step_dirname: {opt_step_dirname} downloaded to cluster_opt_step_dir: {cluster_opt_step_dir}")
if is_not_converted:
print(f"Converting {cluster_opt_step_dir}")
convert_cmd = [zero_checkpoint_to_hf_path, cluster_opt_step_dir]
conversion_response = run_cmd(convert_cmd)
print(f"conversion_response: {conversion_response}")
print(f"upload converted checkpoint: {cluster_opt_step_dir}")
upload_cmd = (
f"s5cmd sync {cluster_opt_step_dir}/unwrapped_model/"
f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/ ".split()
)
upload_response = cmd_retry_loop(upload_cmd, max_retries=5)
print(f"upload_response: {upload_response}")
print(
f"Uploaded {cluster_opt_step_dir}/unwrapped_model to"
f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model"
)
if __name__ == "__main__":
main()
|