import json import torch import transformers from transformers.utils import WEIGHTS_NAME, CONFIG_NAME from transformers.utils.hub import cached_file def load_config_hf(model_name): resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) return json.load(open(resolved_archive_file)) def load_state_dict_hf(model_name, device="cpu"): resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) return torch.load(resolved_archive_file, map_location=device)