Spaces:
Runtime error
Runtime error
import torch | |
from omegaconf import OmegaConf | |
from ldm.util import instantiate_from_config | |
def get_state_dict(d): | |
return d.get('state_dict', d) | |
def load_state_dict(ckpt_path, location='cpu'): | |
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) | |
print(f'Loaded state_dict from [{ckpt_path}]') | |
return state_dict | |
def create_model(config_path): | |
config = OmegaConf.load(config_path) | |
model = instantiate_from_config(config.model).cpu() | |
print(f'Loaded model config from [{config_path}]') | |
return model | |