hvaldez's picture
first commit
c18a21e verified
raw
history blame
1.09 kB
import yaml
from omegaconf import OmegaConf, DictConfig
def load_base_cfg():
with open('configs/base.yml', 'r') as fp:
cfg = yaml.load(fp, Loader=yaml.SafeLoader)
return cfg
def load_cfg(cfg_file):
cfg = load_base_cfg()
with open(cfg_file, 'r') as fp:
exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader)
cfg['model'].update(exp_cfg.get('model', {}))
cfg['data'].update(exp_cfg.get('data', {}))
dataset = cfg['data'].get('dataset')
return cfg
def convert_types(config):
"""Convert `'None'` (str) --> `None` (None). Only supports top-level"""
for k, v in config.items():
if isinstance(v, DictConfig):
setattr(config, k, convert_types(v))
# TODO convert types in ListConfig, right now they are ignored
# if isinstance(v, ListConfig):
# new_v = ListConfig()
if v in ["None", "none"]:
setattr(config, k, None)
return config
def setup_config(config_path):
yaml_config = OmegaConf.load(config_path)
config = convert_types(yaml_config)
return config