chrisc36 commited on
Commit
8d085c7
1 Parent(s): 996e3b2

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -195
utils.py DELETED
@@ -1,195 +0,0 @@
1
- import dataclasses
2
- import hashlib
3
- import sys
4
- import typing
5
- import warnings
6
- import socket
7
- from typing import Optional, Any, Dict
8
- import os
9
- import logging
10
- import absl.flags
11
- from flax.traverse_util import flatten_dict
12
-
13
- from ml_collections import ConfigDict, config_flags
14
- from ml_collections.config_dict import placeholder
15
- from mlxu import function_args_to_config
16
-
17
- _log_extra_fields: Dict[str, Any] = {}
18
-
19
-
20
- def is_float_printable(x):
21
- try:
22
- f"{x:0.2f}"
23
- return True
24
- except (ValueError, TypeError):
25
- return False
26
-
27
-
28
- def compute_hash(string: str) -> str:
29
- """Computes the hash of a string."""
30
- return hashlib.sha256(string.encode("utf-8")).hexdigest()
31
-
32
-
33
- def pop_metadata(data):
34
- meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")}
35
- return data, meta
36
-
37
-
38
- def setup_logging():
39
- handler: logging.Handler
40
- handler = logging.StreamHandler(sys.stdout)
41
- formatter = logging.Formatter(
42
- "[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s",
43
- datefmt="%H:%M:%S"
44
- )
45
- handler.setFormatter(formatter)
46
- logging.basicConfig(handlers=[handler], level=logging.INFO)
47
-
48
- logging.captureWarnings(True)
49
- logging.getLogger("urllib3").setLevel(logging.ERROR)
50
-
51
-
52
- def get_maybe_optional_type(field_type):
53
- if type(None) in typing.get_args(field_type):
54
- # Handle optional type
55
- args = [x for x in typing.get_args(field_type) if x != type(None)]
56
- assert len(args) == 1
57
- field_type = args[0]
58
- return field_type
59
-
60
-
61
- def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict:
62
- """Build a `ConfigDict` matching the possibly nested dataclass
63
-
64
- dataclass: A dataclass instance or a dataclass type, if an instance defaults
65
- will be set to the values in the class, if a class defaults will be
66
- set to the field defaults, or None if the field is required
67
- defaults_to_none: Make all defaults None
68
- """
69
- out = {}
70
- fields = dataclasses.fields(dataclass)
71
- for field in fields:
72
- if not field.init:
73
- continue
74
-
75
- if defaults_to_none:
76
- default = None
77
- elif hasattr(dataclass, field.name):
78
- default = getattr(dataclass, field.name)
79
- elif field.default is dataclasses.MISSING:
80
- default = None
81
- else:
82
- default = field.default
83
-
84
- field_type = get_maybe_optional_type(field.type)
85
-
86
- if hasattr(field_type, "__dataclass_fields__"):
87
- if not defaults_to_none and default is None:
88
- pass
89
- else:
90
- out[field.name] = config_from_dataclass(
91
- default or field.type, defaults_to_none=defaults_to_none)
92
- else:
93
- if default is None:
94
- assert not field_type == typing.Any
95
- origin = getattr(field_type, "__origin__", None)
96
- if origin is not None:
97
- field_type = origin
98
- out[field.name] = placeholder(field_type)
99
- else:
100
- out[field.name] = default
101
- return ConfigDict(out)
102
-
103
-
104
- def dataclass_with_none(cls):
105
- """Build an instance of possibly nested dataclass `cls` with all attributes None"""
106
- fields = dataclasses.fields(cls)
107
- args = {}
108
- for field in fields:
109
- if not field.init:
110
- pass
111
- elif dataclasses.is_dataclass(field.type):
112
- args[field.name] = dataclass_with_none(field.type)
113
- else:
114
- args[field.name] = None
115
- return cls(**args)
116
-
117
-
118
- def dataclass_from_config(cls, config: Dict):
119
- """Build an instance of `cls` with attributes from `config``"""
120
- fields = dataclasses.fields(cls)
121
- args = set(x.name for x in fields)
122
- for k in config.keys():
123
- if k not in args:
124
- raise ValueError(f"Config has unknown arg {k} fr {cls}")
125
- args = {}
126
- for field in fields:
127
- if not field.init:
128
- continue
129
-
130
- field_type = get_maybe_optional_type(field.type)
131
- if hasattr(field_type, "__dataclass_fields__"):
132
- if config.get(field.name) is None:
133
- args[field.name] = None
134
- elif hasattr(field_type, "from_dict"):
135
- src = config[field.name]
136
- if isinstance(src, ConfigDict):
137
- src = src.to_dict()
138
- args[field.name] = field_type.from_dict(src)
139
- else:
140
- args[field.name] = dataclass_from_config(field_type, config[field.name])
141
- elif field.name in config:
142
- if isinstance(config[field.name], ConfigDict):
143
- args[field.name] = config[field.name].to_dict()
144
- else:
145
- args[field.name] = config[field.name]
146
- return cls(**args)
147
-
148
-
149
- def update_dataclass(obj, updates):
150
- """Sets attributes in `obj` to match non-None fields in `updates`"""
151
- fields = dataclasses.fields(obj)
152
- for field in fields:
153
- if not field.init:
154
- continue
155
- update = updates.get(field.name)
156
- if update is None:
157
- continue
158
- current_value = getattr(obj, field.name)
159
- if dataclasses.is_dataclass(current_value):
160
- update_dataclass(current_value, update)
161
- else:
162
- if isinstance(update, (ConfigDict, dict)):
163
- assert all(x is None for x in flatten_dict(update).values())
164
- else:
165
- setattr(obj, field.name, update)
166
-
167
-
168
- def log_metrics_to_console(prefix: str, metrics: Dict[str, float]):
169
- # Stolen from the OLMo codebase
170
- def format_value(value: float) -> str:
171
- if isinstance(value, str):
172
- return value
173
- if value < 0.0001:
174
- return str(value) # scientific notation
175
- elif value > 1000:
176
- return f"{int(value):,d}"
177
- elif value > 100:
178
- return f"{value:.1f}"
179
- elif value > 10:
180
- return f"{value:.2f}"
181
- elif value > 1:
182
- return f"{value:.3f}"
183
- else:
184
- return f"{value:.4f}"
185
-
186
- logging.info(
187
- f"{prefix}\n"
188
- + "\n".join(
189
- [
190
- f" {name}={format_value(value)}"
191
- for name, value in metrics.items()
192
- if not name.startswith("optim/") # there's too many optimizer metrics
193
- ]
194
- )
195
- )