lit-demo-bv / gradio_helpers.py
andsteing's picture
Reformatted code a bit.
3cfc2e7
"""Gradio utilities.
Note that the optional `progress` parameter can be both a `tqdm` module or a
`gr.Progress` instance.
"""
import concurrent.futures
import contextlib
import glob
import hashlib
import logging
import os
import tempfile
import time
import urllib.request
import jax
import numpy as np
from tensorflow.io import gfile
@contextlib.contextmanager
def timed(name):
t0 = time.monotonic()
timing = dict(dt=None)
try:
yield timing
finally:
timing['secs'] = time.monotonic() - t0
logging.info('Timed %s: %.1f secs', name, timing['secs'])
def copy_file(
src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
):
"""Copies a file with progress bar.
Args:
src: Source file (readable by `tf.io.gfile`) or URL.
dst: Destination file. Path must be readable by `tf.io.gfile`.
progress: An object with a `.tqdm` attribute, or `None`.
block_size: Size of individual blocks to be read/written.
overwrite: If `True`, overwrite `dst` if it exists.
"""
if os.path.dirname(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
if os.path.exists(dst) and not overwrite:
return
if src.startswith('http://') or src.startswith('https://'):
opener = urllib.request.urlopen
request = urllib.request.Request(src, method='HEAD')
response = urllib.request.urlopen(request)
content_length = response.headers.get('Content-Length')
n = int(np.ceil(int(content_length) / block_size))
print('content_length', content_length)
else:
opener = lambda path: gfile.GFile(path, 'rb')
stats = gfile.stat(src)
n = int(np.ceil(stats.length / block_size))
if progress is None:
range_or_trange = range
else:
range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download')
with opener(src) as fin:
with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout:
for _ in range_or_trange(n):
fout.write(fin.read(block_size))
gfile.rename(f'{dst}-PARTIAL', dst)
_estimated_real = [(10, 10)]
_memory_cache = {}
def get_with_progress(getter, secs, progress, step=0.1):
"""Returns result from `getter` while showing a progress bar."""
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(getter)
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
if not future.done():
time.sleep(step)
return future.result()
def _get_array_sizes(tree):
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
def get_memory_cache(
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
):
"""Keeps cache below specified size by removing elements not last accessed."""
if key in _memory_cache:
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
return _memory_cache[key]
est, real = zip(*_estimated_real)
if estimated_secs is None:
estimated_secs = sum(est) / len(est)
with timed(f'loading {key}') as timing:
estimated_secs *= sum(real) / sum(est)
_memory_cache[key] = get_with_progress(getter, estimated_secs, progress)
_estimated_real.append((estimated_secs, timing['secs']))
sz = sum(_get_array_sizes(list(_memory_cache.values())))
logging.info('New memory cache size=%.1f MB', sz/1e6)
while sz > max_cache_size_bytes:
k, v = next(iter(_memory_cache.items()))
if k == key:
break
s = sum(_get_array_sizes(v))
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
_memory_cache.pop(k)
sz -= s
return _memory_cache[key]
def get_memory_cache_info():
"""Returns number of items and total size in bytes."""
sizes = _get_array_sizes(_memory_cache)
return len(_memory_cache), sum(sizes)
CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache')
def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None):
"""Keeps cache below specified size by removing elements not last accessed."""
fname = os.path.basename(path_or_url)
path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname
dst = os.path.join(CACHE_DIR, path_hash, fname)
if os.path.exists(dst):
return dst
os.makedirs(os.path.dirname(dst), exist_ok=True)
with timed(f'copying {path_or_url}'):
copy_file(path_or_url, dst, progress=progress)
atimes_sizes_paths = sorted([
(os.path.getatime(p), os.path.getsize(p), p)
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
if os.path.isfile(p)
])
sz = sum(sz for _, sz, _ in atimes_sizes_paths)
logging.info('New disk cache size=%.1f MB', sz/1e6)
while sz > max_cache_size_bytes:
_, s, path = atimes_sizes_paths.pop(0)
if path == dst:
break
logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6)
os.unlink(fname)
sz -= s
return dst
def get_disk_cache_info():
"""Returns number of items and total size in bytes."""
sizes = [
os.path.getsize(p)
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
]
return len(sizes), sum(sizes)