chrisc36 commited on
Commit
996e3b2
1 Parent(s): 8b8f7c1

Delete util.py

Browse files
Files changed (1) hide show
  1. util.py +0 -785
util.py DELETED
@@ -1,785 +0,0 @@
1
- import io
2
- import logging
3
- import os
4
- import re
5
- import socket
6
- import sys
7
- import time
8
- import warnings
9
- from datetime import datetime
10
- from enum import Enum
11
- from itertools import cycle, islice
12
- from pathlib import Path
13
- from queue import Queue
14
- from threading import Thread
15
- from typing import Any, Callable, Dict, Optional, Tuple, Union
16
-
17
- import boto3
18
- import botocore.exceptions as boto_exceptions
19
- import rich
20
- from botocore.config import Config
21
- from cached_path.schemes import SchemeClient, add_scheme_client
22
- from rich.console import Console, ConsoleRenderable
23
- from rich.highlighter import NullHighlighter
24
- from rich.progress import Progress
25
- from rich.text import Text
26
- from rich.traceback import Traceback
27
-
28
- from .aliases import PathOrStr
29
- from .exceptions import (
30
- OLMoCliError,
31
- OLMoEnvironmentError,
32
- OLMoError,
33
- OLMoNetworkError,
34
- OLMoThreadError,
35
- )
36
- # from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
37
-
38
- try:
39
- from functools import cache
40
- except ImportError:
41
- from functools import lru_cache as cache
42
-
43
-
44
- class StrEnum(str, Enum):
45
- """
46
- This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
47
- We include this here for compatibility with older version of Python.
48
- """
49
-
50
- def __str__(self) -> str:
51
- return self.value
52
-
53
- def __repr__(self) -> str:
54
- return f"'{str(self)}'"
55
-
56
-
57
- _log_extra_fields: Dict[str, Any] = {}
58
- log = logging.getLogger(__name__)
59
-
60
-
61
- class LogFilterType(StrEnum):
62
- rank0_only = "rank0_only"
63
- local_rank0_only = "local_rank0_only"
64
- all_ranks = "all_ranks"
65
-
66
-
67
- def log_extra_field(field_name: str, field_value: Any) -> None:
68
- global _log_extra_fields
69
- if field_value is None:
70
- if field_name in _log_extra_fields:
71
- del _log_extra_fields[field_name]
72
- else:
73
- _log_extra_fields[field_name] = field_value
74
-
75
-
76
- def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None:
77
- """
78
- :param rank0_only: INFO and below messages will only be emitted on the rank0 process.
79
- """
80
- log_extra_field("hostname", socket.gethostname())
81
- if is_distributed():
82
- log_extra_field("node_rank", get_node_rank())
83
- log_extra_field("local_rank", get_local_rank())
84
- log_extra_field("global_rank", get_global_rank())
85
- else:
86
- log_extra_field("node_rank", 0)
87
- log_extra_field("local_rank", 0)
88
- log_extra_field("global_rank", 0)
89
-
90
- old_log_record_factory = logging.getLogRecordFactory()
91
-
92
- def log_record_factory(*args, **kwargs) -> logging.LogRecord:
93
- record = old_log_record_factory(*args, **kwargs)
94
- for field_name, field_value in _log_extra_fields.items():
95
- setattr(record, field_name, field_value)
96
- return record
97
-
98
- logging.setLogRecordFactory(log_record_factory)
99
-
100
- handler: logging.Handler
101
- if (
102
- os.environ.get("OLMo_NONINTERACTIVE", False)
103
- or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive"
104
- or not sys.stdout.isatty()
105
- ):
106
- handler = logging.StreamHandler(sys.stdout)
107
- formatter = logging.Formatter(
108
- "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s"
109
- )
110
- formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
111
- formatter.default_msec_format = "%s.%03d"
112
- handler.setFormatter(formatter)
113
- else:
114
- handler = RichHandler()
115
-
116
- def rank0_filter(record: logging.LogRecord) -> int:
117
- if record.levelno > logging.INFO:
118
- return 1
119
- if getattr(record, "global_rank", 0) == 0:
120
- return 1
121
- else:
122
- return 0
123
-
124
- def local_rank0_filter(record: logging.LogRecord) -> int:
125
- if record.levelno > logging.INFO:
126
- return 1
127
- if getattr(record, "local_rank", 0) == 0:
128
- return 1
129
- else:
130
- return 0
131
-
132
- if log_filter_type == LogFilterType.rank0_only:
133
- filter = rank0_filter
134
- elif log_filter_type == LogFilterType.local_rank0_only:
135
- filter = local_rank0_filter # type: ignore
136
- elif log_filter_type == LogFilterType.all_ranks:
137
- filter = None
138
- else:
139
- raise ValueError(log_filter_type)
140
-
141
- if filter is not None:
142
- handler.addFilter(filter) # type: ignore
143
- logging.basicConfig(handlers=[handler], level=logging.INFO)
144
-
145
- logging.captureWarnings(True)
146
- logging.getLogger("urllib3").setLevel(logging.ERROR)
147
-
148
-
149
- def excepthook(exctype, value, traceback):
150
- """
151
- Used to patch `sys.excepthook` in order to log exceptions.
152
- """
153
- if issubclass(exctype, KeyboardInterrupt):
154
- sys.__excepthook__(exctype, value, traceback)
155
- elif issubclass(exctype, OLMoCliError):
156
- rich.get_console().print(f"[yellow]{value}[/]", highlight=False)
157
- elif issubclass(exctype, OLMoError):
158
- rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
159
- else:
160
- log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))
161
-
162
-
163
- def install_excepthook():
164
- sys.excepthook = excepthook
165
-
166
-
167
- def filter_warnings():
168
- # Filter internal deprecation warnings from torch
169
- warnings.filterwarnings(
170
- action="ignore",
171
- category=UserWarning,
172
- message="torch.distributed.*_base is a private function and will be deprecated.*",
173
- )
174
- warnings.filterwarnings(
175
- action="ignore",
176
- category=UserWarning,
177
- message="TypedStorage is deprecated.*",
178
- )
179
- warnings.filterwarnings(
180
- action="ignore",
181
- category=UserWarning,
182
- message="Please use DTensor instead.*",
183
- )
184
- # Torchvision warnings. We don't actually use torchvision.
185
- warnings.filterwarnings(
186
- action="ignore",
187
- message="failed to load.*",
188
- module="torchvision.io.image",
189
- )
190
-
191
-
192
- def set_env_variables():
193
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
194
-
195
-
196
- def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None):
197
- if log_filter_type is None:
198
- log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only"))
199
- rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True)
200
- setup_logging(log_filter_type=log_filter_type)
201
- install_excepthook()
202
- filter_warnings()
203
- set_env_variables()
204
-
205
-
206
- def clean_opt(arg: str) -> str:
207
- if "=" not in arg:
208
- arg = f"{arg}=True"
209
- name, val = arg.split("=", 1)
210
- name = name.strip("-").replace("-", "_")
211
- return f"{name}={val}"
212
-
213
-
214
- class RichHandler(logging.Handler):
215
- """
216
- A simplified version of rich.logging.RichHandler from
217
- https://github.com/Textualize/rich/blob/master/rich/logging.py
218
- """
219
-
220
- def __init__(
221
- self,
222
- *,
223
- level: Union[int, str] = logging.NOTSET,
224
- console: Optional[Console] = None,
225
- markup: bool = False,
226
- ) -> None:
227
- super().__init__(level=level)
228
- self.console = console or rich.get_console()
229
- self.highlighter = NullHighlighter()
230
- self.markup = markup
231
-
232
- def emit(self, record: logging.LogRecord) -> None:
233
- try:
234
- if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
235
- self.console.print(record.msg)
236
- else:
237
- msg: Any = record.msg
238
- if isinstance(record.msg, str):
239
- msg = self.render_message(record=record, message=record.getMessage())
240
- renderables = [
241
- self.get_time_text(record),
242
- self.get_level_text(record),
243
- self.get_location_text(record),
244
- msg,
245
- ]
246
- if record.exc_info is not None:
247
- tb = Traceback.from_exception(*record.exc_info) # type: ignore
248
- renderables.append(tb)
249
- self.console.print(*renderables)
250
- except Exception:
251
- self.handleError(record)
252
-
253
- def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable:
254
- use_markup = getattr(record, "markup", self.markup)
255
- message_text = Text.from_markup(message) if use_markup else Text(message)
256
-
257
- highlighter = getattr(record, "highlighter", self.highlighter)
258
- if highlighter:
259
- message_text = highlighter(message_text)
260
-
261
- return message_text
262
-
263
- def get_time_text(self, record: logging.LogRecord) -> Text:
264
- log_time = datetime.fromtimestamp(record.created)
265
- time_str = log_time.strftime("[%Y-%m-%d %X]")
266
- return Text(time_str, style="log.time", end=" ")
267
-
268
- def get_level_text(self, record: logging.LogRecord) -> Text:
269
- level_name = record.levelname
270
- level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
271
- level_text.style = "log.level"
272
- level_text.end = " "
273
- return level_text
274
-
275
- def get_location_text(self, record: logging.LogRecord) -> Text:
276
- name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root"
277
- text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore
278
- return Text(text, style="log.path")
279
-
280
-
281
- def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0):
282
- """Wait for the condition function to return True."""
283
- start_time = time.monotonic()
284
- while not condition():
285
- time.sleep(0.5)
286
- if time.monotonic() - start_time > timeout:
287
- raise TimeoutError(f"{description} timed out")
288
-
289
-
290
- def is_url(path: PathOrStr) -> bool:
291
- return re.match(r"[a-z0-9]+://.*", str(path)) is not None
292
-
293
-
294
- def dir_is_empty(dir: PathOrStr) -> bool:
295
- dir = Path(dir)
296
- if not dir.is_dir():
297
- return True
298
- try:
299
- next(dir.glob("*"))
300
- return False
301
- except StopIteration:
302
- return True
303
-
304
-
305
- def get_progress_bar() -> Progress:
306
- from cached_path import get_download_progress
307
-
308
- return get_download_progress()
309
-
310
-
311
- def resource_path(
312
- folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None
313
- ) -> Path:
314
- if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
315
- log.info(f"Found local cache of {fname} at {local_path}")
316
- return local_path
317
- else:
318
- from cached_path import cached_path
319
-
320
- return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress)
321
-
322
-
323
- def file_size(path: PathOrStr) -> int:
324
- """
325
- Get the size of a local or remote file in bytes.
326
- """
327
- if is_url(path):
328
- from urllib.parse import urlparse
329
-
330
- parsed = urlparse(str(path))
331
- if parsed.scheme == "gs":
332
- return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
333
- elif parsed.scheme in ("s3", "r2", "weka"):
334
- return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
335
- elif parsed.scheme in ("http", "https"):
336
- return _http_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
337
- elif parsed.scheme == "file":
338
- return file_size(str(path).replace("file://", "", 1))
339
- else:
340
- raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
341
- else:
342
- return os.stat(path).st_size
343
-
344
-
345
- def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
346
- """Upload source file to a target location on GCS or S3."""
347
- from urllib.parse import urlparse
348
-
349
- source = Path(source)
350
- assert source.is_file()
351
- parsed = urlparse(target)
352
- if parsed.scheme == "gs":
353
- _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
354
- elif parsed.scheme in ("s3", "r2", "weka"):
355
- _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
356
- else:
357
- raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
358
-
359
-
360
- def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
361
- if is_url(source):
362
- from urllib.parse import urlparse
363
-
364
- parsed = urlparse(str(source))
365
- if parsed.scheme == "gs":
366
- return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
367
- elif parsed.scheme in ("s3", "r2", "weka"):
368
- return _s3_get_bytes_range(
369
- parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
370
- )
371
- elif parsed.scheme in ("http", "https"):
372
- return _http_get_bytes_range(
373
- parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
374
- )
375
- elif parsed.scheme == "file":
376
- return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
377
- else:
378
- raise NotImplementedError(f"get bytes range not implemented for '{parsed.scheme}' files")
379
- else:
380
- with open(source, "rb") as f:
381
- f.seek(bytes_start)
382
- return f.read(num_bytes)
383
-
384
-
385
- def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
386
- if is_url(dir):
387
- from urllib.parse import urlparse
388
-
389
- parsed = urlparse(str(dir))
390
- if parsed.scheme == "gs":
391
- raise NotImplementedError
392
- elif parsed.scheme in ("s3", "r2", "weka"):
393
- return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
394
- elif parsed.scheme == "file":
395
- return find_latest_checkpoint(str(dir).replace("file://", "", 1))
396
- else:
397
- raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
398
- else:
399
- latest_step = 0
400
- latest_checkpoint: Optional[Path] = None
401
- for path in Path(dir).glob("step*"):
402
- if path.is_dir():
403
- try:
404
- step = int(path.name.replace("step", "").replace("-unsharded", ""))
405
- except ValueError:
406
- continue
407
- # We prioritize sharded checkpoints over unsharded checkpoints.
408
- if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
409
- latest_step = step
410
- latest_checkpoint = path
411
- return latest_checkpoint
412
-
413
-
414
- def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
415
- from google.cloud import storage as gcs
416
-
417
- storage_client = gcs.Client()
418
- bucket = storage_client.bucket(bucket_name)
419
- blob = bucket.blob(key)
420
- if not save_overwrite and blob.exists():
421
- raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
422
- blob.upload_from_filename(source)
423
-
424
-
425
- def _gcs_file_size(bucket_name: str, key: str) -> int:
426
- from google.api_core.exceptions import NotFound
427
- from google.cloud import storage as gcs
428
-
429
- storage_client = gcs.Client()
430
- bucket = storage_client.bucket(bucket_name)
431
- blob = bucket.blob(key)
432
- try:
433
- blob.reload()
434
- except NotFound:
435
- raise FileNotFoundError(f"gs://{bucket_name}/{key}")
436
- assert blob.size is not None
437
- return blob.size
438
-
439
-
440
- def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
441
- from google.api_core.exceptions import NotFound
442
- from google.cloud import storage as gcs
443
-
444
- storage_client = gcs.Client()
445
- bucket = storage_client.bucket(bucket_name)
446
- blob = bucket.blob(key)
447
- try:
448
- blob.reload()
449
- except NotFound:
450
- raise FileNotFoundError(f"gs://{bucket_name}/{key}")
451
- return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)
452
-
453
-
454
- def _get_s3_profile_name(scheme: str) -> Optional[str]:
455
- if scheme == "s3":
456
- # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
457
- return os.environ.get("S3_PROFILE")
458
- if scheme == "r2":
459
- profile_name = os.environ.get("R2_PROFILE")
460
- if profile_name is None:
461
- raise OLMoEnvironmentError(
462
- "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
463
- )
464
-
465
- return profile_name
466
- if scheme == "weka":
467
- profile_name = os.environ.get("WEKA_PROFILE")
468
- if profile_name is None:
469
- raise OLMoEnvironmentError(
470
- "Weka profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?"
471
- )
472
-
473
- return profile_name
474
-
475
- raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
476
-
477
-
478
- def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
479
- if scheme == "s3":
480
- return None
481
- if scheme == "r2":
482
- r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
483
- if r2_endpoint_url is None:
484
- raise OLMoEnvironmentError(
485
- "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
486
- )
487
-
488
- return r2_endpoint_url
489
- if scheme == "weka":
490
- weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL")
491
- if weka_endpoint_url is None:
492
- raise OLMoEnvironmentError(
493
- "Weka endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?"
494
- )
495
-
496
- return weka_endpoint_url
497
-
498
- raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
499
-
500
-
501
- @cache
502
- def _get_s3_client(scheme: str):
503
- session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
504
- return session.client(
505
- "s3",
506
- endpoint_url=_get_s3_endpoint_url(scheme),
507
- config=Config(retries={"max_attempts": 10, "mode": "standard"}),
508
- use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
509
- )
510
-
511
-
512
- def _wait_before_retry(attempt: int):
513
- time.sleep(min(0.5 * 2**attempt, 3.0))
514
-
515
-
516
- def _s3_upload(
517
- source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
518
- ):
519
- err: Optional[Exception] = None
520
- if not save_overwrite:
521
- for attempt in range(1, max_attempts + 1):
522
- try:
523
- _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
524
- raise FileExistsError(
525
- f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
526
- )
527
- except boto_exceptions.ClientError as e:
528
- if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
529
- err = None
530
- break
531
- err = e
532
-
533
- if attempt < max_attempts:
534
- log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
535
- _wait_before_retry(attempt)
536
-
537
- if err is not None:
538
- raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err
539
-
540
- try:
541
- _get_s3_client(scheme).upload_file(source, bucket_name, key)
542
- except boto_exceptions.ClientError as e:
543
- raise OLMoNetworkError(f"Failed to upload to {scheme}") from e
544
-
545
-
546
- def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
547
- err: Optional[Exception] = None
548
- for attempt in range(1, max_attempts + 1):
549
- try:
550
- return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
551
- except boto_exceptions.ClientError as e:
552
- if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
553
- raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
554
- err = e
555
-
556
- if attempt < max_attempts:
557
- log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
558
- _wait_before_retry(attempt)
559
-
560
- raise OLMoNetworkError(f"Failed to get {scheme} file size") from err
561
-
562
-
563
- def _s3_get_bytes_range(
564
- scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
565
- ) -> bytes:
566
- err: Optional[Exception] = None
567
- for attempt in range(1, max_attempts + 1):
568
- try:
569
- return (
570
- _get_s3_client(scheme)
571
- .get_object(
572
- Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
573
- )["Body"]
574
- .read()
575
- )
576
- except boto_exceptions.ClientError as e:
577
- if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
578
- raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
579
- err = e
580
- except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
581
- # ResponseStreamingError (subclass of HTTPClientError) can happen as
582
- # a result of a failed read from the stream (http.client.IncompleteRead).
583
- # Retrying can help in this case.
584
- err = e
585
-
586
- if attempt < max_attempts:
587
- log.warning(
588
- "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
589
- )
590
- _wait_before_retry(attempt)
591
-
592
- # When torch's DataLoader intercepts exceptions, it may try to re-raise them
593
- # by recalling their constructor with a single message arg. Torch has some
594
- # logic to deal with the absence of a single-parameter constructor, but it
595
- # doesn't gracefully handle other possible failures in calling such a constructor
596
- # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
597
- # in us losing the true exception info. To avoid this, we change the exception
598
- # to a type that has a single-parameter constructor.
599
- raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err
600
-
601
-
602
- def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
603
- if not prefix.endswith("/"):
604
- prefix = f"{prefix}/"
605
- response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
606
- assert not response["IsTruncated"] # need to handle this if it happens
607
- latest_step = 0
608
- latest_checkpoint: Optional[str] = None
609
- for item in response["CommonPrefixes"]:
610
- prefix = item["Prefix"].strip("/")
611
- checkpoint_name = os.path.split(prefix)[-1]
612
- if not checkpoint_name.startswith("step"):
613
- continue
614
- try:
615
- step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
616
- except ValueError:
617
- continue
618
- # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete
619
- # (upload might have have failed part way through).
620
- try:
621
- _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml")
622
- except FileNotFoundError:
623
- continue
624
- # We prioritize sharded checkpoints over unsharded ones.
625
- if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
626
- latest_step = step
627
- latest_checkpoint = f"{scheme}://{bucket_name}/{prefix}"
628
- return latest_checkpoint
629
-
630
-
631
- def _http_file_size(scheme: str, host_name: str, path: str) -> int:
632
- import requests
633
-
634
- response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True)
635
- return int(response.headers.get("content-length"))
636
-
637
-
638
- def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes:
639
- import requests
640
-
641
- response = requests.get(
642
- f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"}
643
- )
644
- result = response.content
645
- assert (
646
- len(result) == num_bytes
647
- ), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything
648
- return result
649
-
650
-
651
- def default_thread_count() -> int:
652
- return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))
653
-
654
-
655
- def pass_through_fn(fn, *args, **kwargs):
656
- return fn(*args, **kwargs)
657
-
658
-
659
- def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None):
660
- q: Queue = Queue(maxsize=maxsize)
661
-
662
- sentinel = object()
663
-
664
- def fill_queue():
665
- try:
666
- for value in g:
667
- q.put(value)
668
- except Exception as e:
669
- q.put(e)
670
- finally:
671
- q.put(sentinel)
672
-
673
- thread_name = thread_name or repr(g)
674
- thread = Thread(name=thread_name, target=fill_queue, daemon=True)
675
- thread.start()
676
-
677
- for x in iter(q.get, sentinel):
678
- if isinstance(x, Exception):
679
- raise OLMoThreadError(f"generator thread {thread_name} failed") from x
680
- else:
681
- yield x
682
-
683
-
684
- def split_dict_of_list(batch, split_size):
685
- out = None
686
- for key, val in batch.items():
687
- parts = split_list(val, split_size)
688
- if out is None:
689
- out = [{key: part} for part in parts]
690
- else:
691
- assert len(out) == len(parts)
692
- for out_dict, part in zip(out, parts):
693
- out_dict[key] = part
694
- return out
695
-
696
-
697
- def split_list(lst, split_size):
698
- assert len(lst) % split_size == 0
699
- n = len(lst) // split_size
700
- return [lst[i*split_size:(i+1)*split_size] for i in range(n)]
701
-
702
-
703
- def flatten_list(lst):
704
- return [x for xs in lst for x in xs]
705
-
706
-
707
- def roundrobin(*iterables):
708
- """
709
- Call the given iterables in a round-robin fashion. For example:
710
- ``roundrobin('ABC', 'D', 'EF') --> A D E B F C``
711
- """
712
- # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
713
- num_active = len(iterables)
714
- nexts = cycle(iter(it).__next__ for it in iterables)
715
- while num_active:
716
- try:
717
- for next in nexts:
718
- yield next()
719
- except StopIteration:
720
- # Remove the iterator we just exhausted from the cycle.
721
- num_active -= 1
722
- nexts = cycle(islice(nexts, num_active))
723
-
724
-
725
- def add_cached_path_clients():
726
- add_scheme_client(WekaClient)
727
-
728
-
729
- class WekaClient(SchemeClient):
730
- recoverable_errors = SchemeClient.recoverable_errors + (
731
- boto_exceptions.HTTPClientError,
732
- boto_exceptions.ConnectionError,
733
- )
734
-
735
- scheme = "weka"
736
-
737
- def __init__(self, resource: str) -> None:
738
- SchemeClient.__init__(self, resource)
739
- self.bucket_name, self.path = WekaClient._split_cloud_path(resource, "weka")
740
- self.s3 = _get_s3_client("weka")
741
- self.object_info = None
742
-
743
- @staticmethod
744
- def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]:
745
- """Split a full s3 path into the bucket name and path."""
746
- from urllib.parse import urlparse
747
-
748
- parsed = urlparse(url)
749
- if not parsed.netloc or not parsed.path:
750
- raise ValueError("bad {} path {}".format(provider, url))
751
- bucket_name = parsed.netloc
752
- provider_path = parsed.path
753
- # Remove '/' at beginning of path.
754
- if provider_path.startswith("/"):
755
- provider_path = provider_path[1:]
756
- return bucket_name, provider_path
757
-
758
- def _ensure_object_info(self):
759
- if self.object_info is None:
760
- try:
761
- self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path)
762
- except boto_exceptions.ClientError as e:
763
- if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
764
- raise FileNotFoundError(f"weka://{self.bucket_name}/{self.path}") from e
765
- raise e
766
-
767
- def get_etag(self) -> Optional[str]:
768
- self._ensure_object_info()
769
- assert self.object_info is not None
770
- return self.object_info.get("ETag")
771
-
772
- def get_size(self) -> Optional[int]:
773
- self._ensure_object_info()
774
- assert self.object_info is not None
775
- return self.object_info.get("ContentLength")
776
-
777
- def get_resource(self, temp_file: io.BufferedWriter) -> None:
778
- self.s3.download_fileobj(Fileobj=temp_file, Bucket=self.bucket_name, Key=self.path)
779
-
780
- def get_bytes_range(self, index: int, length: int) -> bytes:
781
- response = self.s3.get_object(
782
- Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index+length-1}"
783
- )
784
- return response["Body"].read()
785
-