Spaces:
Sleeping
Sleeping
"""Gradio utilities. | |
Note that the optional `progress` parameter can be both a `tqdm` module or a | |
`gr.Progress` instance. | |
""" | |
import concurrent.futures | |
import contextlib | |
import glob | |
import hashlib | |
import logging | |
import os | |
import tempfile | |
import time | |
import urllib.request | |
import jax | |
import numpy as np | |
from tensorflow.io import gfile | |
def timed(name): | |
t0 = time.monotonic() | |
timing = dict(dt=None) | |
try: | |
yield timing | |
finally: | |
timing['secs'] = time.monotonic() - t0 | |
logging.info('Timed %s: %.1f secs', name, timing['secs']) | |
def copy_file( | |
src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False | |
): | |
"""Copies a file with progress bar. | |
Args: | |
src: Source file (readable by `tf.io.gfile`) or URL. | |
dst: Destination file. Path must be readable by `tf.io.gfile`. | |
progress: An object with a `.tqdm` attribute, or `None`. | |
block_size: Size of individual blocks to be read/written. | |
overwrite: If `True`, overwrite `dst` if it exists. | |
""" | |
if os.path.dirname(dst): | |
os.makedirs(os.path.dirname(dst), exist_ok=True) | |
if os.path.exists(dst) and not overwrite: | |
return | |
if src.startswith('http://') or src.startswith('https://'): | |
opener = urllib.request.urlopen | |
request = urllib.request.Request(src, method='HEAD') | |
response = urllib.request.urlopen(request) | |
content_length = response.headers.get('Content-Length') | |
n = int(np.ceil(int(content_length) / block_size)) | |
print('content_length', content_length) | |
else: | |
opener = lambda path: gfile.GFile(path, 'rb') | |
stats = gfile.stat(src) | |
n = int(np.ceil(stats.length / block_size)) | |
if progress is None: | |
range_or_trange = range | |
else: | |
range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download') | |
with opener(src) as fin: | |
with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout: | |
for _ in range_or_trange(n): | |
fout.write(fin.read(block_size)) | |
gfile.rename(f'{dst}-PARTIAL', dst) | |
_estimated_real = [(10, 10)] | |
_memory_cache = {} | |
def get_with_progress(getter, secs, progress, step=0.1): | |
"""Returns result from `getter` while showing a progress bar.""" | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future = executor.submit(getter) | |
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): | |
if not future.done(): | |
time.sleep(step) | |
return future.result() | |
def _get_array_sizes(tree): | |
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] | |
def get_memory_cache( | |
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None | |
): | |
"""Keeps cache below specified size by removing elements not last accessed.""" | |
if key in _memory_cache: | |
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order | |
return _memory_cache[key] | |
est, real = zip(*_estimated_real) | |
if estimated_secs is None: | |
estimated_secs = sum(est) / len(est) | |
with timed(f'loading {key}') as timing: | |
estimated_secs *= sum(real) / sum(est) | |
_memory_cache[key] = get_with_progress(getter, estimated_secs, progress) | |
_estimated_real.append((estimated_secs, timing['secs'])) | |
sz = sum(_get_array_sizes(list(_memory_cache.values()))) | |
logging.info('New memory cache size=%.1f MB', sz/1e6) | |
while sz > max_cache_size_bytes: | |
k, v = next(iter(_memory_cache.items())) | |
if k == key: | |
break | |
s = sum(_get_array_sizes(v)) | |
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) | |
_memory_cache.pop(k) | |
sz -= s | |
return _memory_cache[key] | |
def get_memory_cache_info(): | |
"""Returns number of items and total size in bytes.""" | |
sizes = _get_array_sizes(_memory_cache) | |
return len(_memory_cache), sum(sizes) | |
CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache') | |
def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None): | |
"""Keeps cache below specified size by removing elements not last accessed.""" | |
fname = os.path.basename(path_or_url) | |
path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname | |
dst = os.path.join(CACHE_DIR, path_hash, fname) | |
if os.path.exists(dst): | |
return dst | |
os.makedirs(os.path.dirname(dst), exist_ok=True) | |
with timed(f'copying {path_or_url}'): | |
copy_file(path_or_url, dst, progress=progress) | |
atimes_sizes_paths = sorted([ | |
(os.path.getatime(p), os.path.getsize(p), p) | |
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*')) | |
if os.path.isfile(p) | |
]) | |
sz = sum(sz for _, sz, _ in atimes_sizes_paths) | |
logging.info('New disk cache size=%.1f MB', sz/1e6) | |
while sz > max_cache_size_bytes: | |
_, s, path = atimes_sizes_paths.pop(0) | |
if path == dst: | |
break | |
logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6) | |
os.unlink(fname) | |
sz -= s | |
return dst | |
def get_disk_cache_info(): | |
"""Returns number of items and total size in bytes.""" | |
sizes = [ | |
os.path.getsize(p) | |
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*')) | |
] | |
return len(sizes), sum(sizes) | |