import yaml from omegaconf import OmegaConf, DictConfig def load_base_cfg(): with open('configs/base.yml', 'r') as fp: cfg = yaml.load(fp, Loader=yaml.SafeLoader) return cfg def load_cfg(cfg_file): cfg = load_base_cfg() with open(cfg_file, 'r') as fp: exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader) cfg['model'].update(exp_cfg.get('model', {})) cfg['data'].update(exp_cfg.get('data', {})) dataset = cfg['data'].get('dataset') return cfg def convert_types(config): """Convert `'None'` (str) --> `None` (None). Only supports top-level""" for k, v in config.items(): if isinstance(v, DictConfig): setattr(config, k, convert_types(v)) # TODO convert types in ListConfig, right now they are ignored # if isinstance(v, ListConfig): # new_v = ListConfig() if v in ["None", "none"]: setattr(config, k, None) return config def setup_config(config_path): yaml_config = OmegaConf.load(config_path) config = convert_types(yaml_config) return config