File size: 6,284 Bytes
18652d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import dataclasses
import hashlib
import sys
import typing
import warnings
import socket
from typing import Optional, Any, Dict
import os
import logging
import absl.flags
from flax.traverse_util import flatten_dict

from ml_collections import ConfigDict, config_flags
from ml_collections.config_dict import placeholder
from mlxu import function_args_to_config

_log_extra_fields: Dict[str, Any] = {}


def is_float_printable(x):
    try:
        f"{x:0.2f}"
        return True
    except (ValueError, TypeError):
        return False


def compute_hash(string: str) -> str:
    """Computes the hash of a string."""
    return hashlib.sha256(string.encode("utf-8")).hexdigest()


def pop_metadata(data):
    meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")}
    return data, meta


def setup_logging():
    handler: logging.Handler
    handler = logging.StreamHandler(sys.stdout)
    formatter = logging.Formatter(
        "[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s",
        datefmt="%H:%M:%S"
    )
    handler.setFormatter(formatter)
    logging.basicConfig(handlers=[handler], level=logging.INFO)

    logging.captureWarnings(True)
    logging.getLogger("urllib3").setLevel(logging.ERROR)


def get_maybe_optional_type(field_type):
    if type(None) in typing.get_args(field_type):
        # Handle optional type
        args = [x for x in typing.get_args(field_type) if x != type(None)]
        assert len(args) == 1
        field_type = args[0]
    return field_type


def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict:
    """Build a `ConfigDict` matching the possibly nested dataclass

    dataclass: A dataclass instance or a dataclass type, if an instance defaults
               will be set to the values in the class, if a class defaults will be
               set to the field defaults, or None if the field is required
    defaults_to_none: Make all defaults None
    """
    out = {}
    fields = dataclasses.fields(dataclass)
    for field in fields:
        if not field.init:
            continue

        if defaults_to_none:
            default = None
        elif hasattr(dataclass, field.name):
            default = getattr(dataclass, field.name)
        elif field.default is dataclasses.MISSING:
            default = None
        else:
            default = field.default

        field_type = get_maybe_optional_type(field.type)

        if hasattr(field_type, "__dataclass_fields__"):
            if not defaults_to_none and default is None:
                pass
            else:
                out[field.name] = config_from_dataclass(
                    default or field.type, defaults_to_none=defaults_to_none)
        else:
            if default is None:
                assert not field_type == typing.Any
                origin = getattr(field_type, "__origin__", None)
                if origin is not None:
                    field_type = origin
                out[field.name] = placeholder(field_type)
            else:
                out[field.name] = default
    return ConfigDict(out)


def dataclass_with_none(cls):
    """Build an instance of possibly nested dataclass `cls` with all attributes None"""
    fields = dataclasses.fields(cls)
    args = {}
    for field in fields:
        if not field.init:
            pass
        elif dataclasses.is_dataclass(field.type):
            args[field.name] = dataclass_with_none(field.type)
        else:
            args[field.name] = None
    return cls(**args)


def dataclass_from_config(cls, config: Dict):
    """Build an instance of `cls` with attributes from `config``"""
    fields = dataclasses.fields(cls)
    args = set(x.name for x in fields)
    for k in config.keys():
        if k not in args:
            raise ValueError(f"Config has unknown arg {k} fr {cls}")
    args = {}
    for field in fields:
        if not field.init:
            continue

        field_type = get_maybe_optional_type(field.type)
        if hasattr(field_type, "__dataclass_fields__"):
            if config.get(field.name) is None:
                args[field.name] = None
            elif hasattr(field_type, "from_dict"):
                src = config[field.name]
                if isinstance(src, ConfigDict):
                    src = src.to_dict()
                args[field.name] = field_type.from_dict(src)
            else:
                args[field.name] = dataclass_from_config(field_type, config[field.name])
        elif field.name in config:
            if isinstance(config[field.name], ConfigDict):
                args[field.name] = config[field.name].to_dict()
            else:
                args[field.name] = config[field.name]
    return cls(**args)


def update_dataclass(obj, updates):
    """Sets attributes in `obj` to match non-None fields in `updates`"""
    fields = dataclasses.fields(obj)
    for field in fields:
        if not field.init:
            continue
        update = updates.get(field.name)
        if update is None:
            continue
        current_value = getattr(obj, field.name)
        if dataclasses.is_dataclass(current_value):
            update_dataclass(current_value, update)
        else:
            if isinstance(update, (ConfigDict, dict)):
                assert all(x is None for x in flatten_dict(update).values())
            else:
                setattr(obj, field.name, update)


def log_metrics_to_console(prefix: str, metrics: Dict[str, float]):
    # Stolen from the OLMo codebase
    def format_value(value: float) -> str:
        if isinstance(value, str):
            return value
        if value < 0.0001:
            return str(value)  # scientific notation
        elif value > 1000:
            return f"{int(value):,d}"
        elif value > 100:
            return f"{value:.1f}"
        elif value > 10:
            return f"{value:.2f}"
        elif value > 1:
            return f"{value:.3f}"
        else:
            return f"{value:.4f}"

    logging.info(
        f"{prefix}\n"
        + "\n".join(
            [
                f"    {name}={format_value(value)}"
                for name, value in metrics.items()
                if not name.startswith("optim/")  # there's too many optimizer metrics
            ]
        )
    )