|
|
|
import yaml |
|
from omegaconf import OmegaConf, DictConfig |
|
|
|
def load_base_cfg(): |
|
with open('configs/base.yml', 'r') as fp: |
|
cfg = yaml.load(fp, Loader=yaml.SafeLoader) |
|
return cfg |
|
|
|
def load_cfg(cfg_file): |
|
cfg = load_base_cfg() |
|
with open(cfg_file, 'r') as fp: |
|
exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader) |
|
|
|
cfg['model'].update(exp_cfg.get('model', {})) |
|
cfg['data'].update(exp_cfg.get('data', {})) |
|
dataset = cfg['data'].get('dataset') |
|
return cfg |
|
|
|
def convert_types(config): |
|
"""Convert `'None'` (str) --> `None` (None). Only supports top-level""" |
|
for k, v in config.items(): |
|
if isinstance(v, DictConfig): |
|
setattr(config, k, convert_types(v)) |
|
|
|
|
|
|
|
|
|
|
|
if v in ["None", "none"]: |
|
setattr(config, k, None) |
|
return config |
|
|
|
def setup_config(config_path): |
|
yaml_config = OmegaConf.load(config_path) |
|
config = convert_types(yaml_config) |
|
return config |