Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# This script combines the 2 steps of | |
# 1. calling zero_to_fp32.py to reconsolidate the shared deepspeed checkpoint | |
# 2. then resaving it as HF checkpoint, which also takes care of sharding large checkpoints | |
# | |
# example usage: | |
# | |
# this will generate the converted checkpoint under save_dir/opt_step-40/unwrapped_model | |
# | |
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 | |
# | |
# or you can override the destination by passing an explicit target dir, e.g.: | |
# | |
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 save_dir/opt_step-40/output_dir | |
import argparse | |
import sys | |
from pathlib import Path | |
import torch | |
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint | |
# auto-append the repo path to load m4 modules from instead of needing to set PYTHONPATH | |
repodir = str(Path(__file__).resolve().parents[2]) | |
sys.path.insert(0, repodir) | |
import m4.models | |
from m4.testing_utils import read_json_file | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/to/opt_step-100" | |
) | |
parser.add_argument( | |
"output_dir", | |
type=str, | |
nargs="?", | |
help="path to pass to save_pretrained, defaults to 'unwrapped_model' relative to the checkpoint_dir argument", | |
) | |
args = parser.parse_args() | |
checkpoint_dir = Path(args.checkpoint_dir) | |
config_dir = checkpoint_dir / "unwrapped_model" | |
ds_checkpoint_dir = checkpoint_dir / "accelerator_state" | |
config_file_path = config_dir / "config.json" | |
if args.output_dir is None: | |
output_dir = checkpoint_dir / "unwrapped_model" | |
else: | |
output_dir = args.output_dir | |
config = read_json_file(config_file_path) | |
config_class = m4.models._SUPPORTED_MODELS.get(config["model_type"], None) | |
if config_class is None: | |
raise ValueError(f"{config['model_type']=} isn't supported by m4") | |
modeling_class = m4.models.model_type_to_modeling_class.get(config["model_type"], None) | |
print(f"Detected {config_class}") | |
print("Reconsolidating fp32 model from checkpoint shards (can take a long time)") | |
state_dict = get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) # already on cpu | |
# Keeping debug to use if you ever need to debug state dict | |
# print("Saved State Dict") | |
# for k, v in state_dict.items(): | |
# print(f"{k} {v.shape}") | |
kwargs = {} | |
print(f"Loading config from {config_dir}") | |
model_config = config_class.from_pretrained(config_dir) | |
print(f"Instantiating a {modeling_class} model in bf16") | |
model = modeling_class.from_pretrained( | |
None, config=model_config, state_dict=state_dict, torch_dtype=torch.bfloat16 | |
) | |
# Keeping debug to use if you ever need to debug state dict | |
# print("Model State Dict") | |
# for k, v in model.state_dict().items(): | |
# print(f"{k} {v.shape}") | |
print(f"Saving model to {output_dir}") | |
model.save_pretrained(output_dir) | |