File size: 4,851 Bytes
aea73e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
# Base CLI to parse Arguments
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen

import argparse
import logging
from abc import ABC, abstractmethod
from typing import Tuple, Union

import yaml
from pydantic import BaseModel


class ABCParser(ABC):
    """Blueprint for Argument Parser"""

    @abstractmethod
    def __init__(self) -> None:
        pass

    @abstractmethod
    def get_config(self) -> Tuple[Union[BaseModel, dict], logging.Logger]:
        """Load configuration and create a logger

        Returns:
            Tuple[PreProcessingConfig, logging.Logger]: Configuration and Logger
        """
        pass

    @abstractmethod
    def store_config(self) -> None:
        """Store the config file in the logging directory to keep track of the configuration."""
        pass


class ExperimentBaseParser:
    """Configuration Parser for Machine Learning Experiments"""

    def __init__(self) -> None:
        parser = argparse.ArgumentParser(
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
            description="Start an experiment with given configuration file.",
        )
        requiredNamed = parser.add_argument_group("required named arguments")
        requiredNamed.add_argument(
            "--config", type=str, help="Path to a config file", required=True
        )
        parser.add_argument("--gpu", type=int, help="Cuda-GPU ID")
        group = parser.add_mutually_exclusive_group(required=False)
        group.add_argument(
            "--sweep",
            action="store_true",
            help="Starting a sweep. For this the configuration file must be structured according to WandB sweeping. "
            "Compare https://docs.wandb.ai/guides/sweeps and https://community.wandb.ai/t/nested-sweep-configuration/3369/3 "
            "for further information. This parameter cannot be set in the config file!",
        )
        group.add_argument(
            "--agent",
            type=str,
            help="Add a new agent to the sweep. "
            "Please pass the sweep ID as argument in the way entity/project/sweep_id, e.g., user1/test_project/v4hwbijh. "
            "The agent configuration can be found in the WandB dashboard for the running sweep in the sweep overview tab "
            "under launch agent. Just paste the entity/project/sweep_id given there. The provided config file must be a sweep config file."
            "This parameter cannot be set in the config file!",
        )
        group.add_argument(
            "--checkpoint",
            type=str,
            help="Path to a PyTorch checkpoint file. "
            "The file is loaded and continued to train with the provided settings. "
            "If this is passed, no sweeps are possible. "
            "This parameter cannot be set in the config file!",
        )

        self.parser = parser

    def parse_arguments(self) -> Tuple[Union[BaseModel, dict]]:
        """Parse the arguments from CLI and load yaml config

        Returns:
            Tuple[Union[BaseModel, dict]]: Parsed arguments
        """
        # parse the arguments
        opt = self.parser.parse_args()  #定义了一个opt变量,用来存储参数
        with open(opt.config, "r") as config_file:
            yaml_config = yaml.safe_load(config_file)
            yaml_config_dict = dict(yaml_config)   #将yaml文件转换为字典
 
        opt_dict = vars(opt)   #将opt转换为字典
        # check for gpu to overwrite with cli argument
        if "gpu" in opt_dict:   #如果gpu在opt_dict中
            if opt_dict["gpu"] is not None:
                yaml_config_dict["gpu"] = opt_dict["gpu"]   #将opt_dict中的gpu值赋给yaml_config_dict中的gpu

        # check if either training, sweep, checkpoint or start agent should be called
        # first step: remove such keys from the config file
        if "run_sweep" in yaml_config_dict:  #如果yaml_config_dict中有run_sweep
            yaml_config_dict.pop("run_sweep")  #删除yaml_config_dict中的run_sweep
        if "agent" in yaml_config_dict:
            yaml_config_dict.pop("agent")
        if "checkpoint" in yaml_config_dict:
            yaml_config_dict.pop("checkpoint")

        # select one of the options
        if "sweep" in opt_dict and opt_dict["sweep"] is True:
            yaml_config_dict["run_sweep"] = True
        else:
            yaml_config_dict["run_sweep"] = False
        if "agent" in opt_dict:
            yaml_config_dict["agent"] = opt_dict["agent"]
        if "checkpoint" in opt_dict:
            if opt_dict["checkpoint"] is not None:
                yaml_config_dict["checkpoint"] = opt_dict["checkpoint"]

        self.config = yaml_config_dict  #将yaml_config_dict赋给self.config

        return self.config