Spaces:
Paused
Paused
| """Gradio helpers for caching, downloading etc.""" | |
| import concurrent.futures | |
| import contextlib | |
| import datetime | |
| import functools | |
| import logging | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| import unittest.mock | |
| import huggingface_hub | |
| import jax | |
| import numpy as np | |
| import psutil | |
| def _clone_git(url, destination_folder, commit_hash=None): | |
| subprocess.run([ | |
| 'git', 'clone', '--depth=1', | |
| url, destination_folder | |
| ], check=True) | |
| if commit_hash: | |
| subprocess.run( | |
| ['git', '-C', destination_folder, 'checkout', commit_hash], check=True | |
| ) | |
| def setup(): | |
| """Installs big_vision repo and mocks tensorflow_text.""" | |
| for url, dst_name, commit_hash in ( | |
| ( | |
| 'https://github.com/google-research/big_vision', | |
| 'big_vision_repo', | |
| None, | |
| ), | |
| ): | |
| dst_path = os.path.join(tempfile.gettempdir(), dst_name) | |
| if os.path.exists(dst_path): | |
| print('Found existing "%s" at "%s"' % (url, dst_path)) | |
| else: | |
| print('Cloning "%s" into "%s"' % (url, dst_path)) | |
| _clone_git(url, dst_path, commit_hash) | |
| if dst_path not in sys.path: | |
| sys.path.insert(0, dst_path) | |
| # Imported in `big_vision.pp.ops_text` but we don't use it. | |
| sys.modules['tensorflow_text'] = unittest.mock.MagicMock() | |
| # Must be run in main app before other BV imports: | |
| setup() | |
| def should_mock(): | |
| """Returns `True` if `MOCK_MODEL=yes` is set in environment.""" | |
| return os.environ.get('MOCK_MODEL') == 'yes' | |
| def timed(name, start_message=False): | |
| """Emits "Timed {name}: .1f secs" message to INFO logs.""" | |
| t0 = time.monotonic() | |
| timing = dict(dt=None) | |
| try: | |
| if start_message: | |
| logging.info('Timing %s...', name) | |
| yield timing | |
| finally: | |
| timing['secs'] = time.monotonic() - t0 | |
| logging.info('Timed %s: %.1f secs', name, timing['secs']) | |
| def synced(f): | |
| """Syncs calls to `f` with a `threading.Lock()`.""" | |
| lock = threading.Lock() | |
| def wrapper(*args, **kw): | |
| t0 = time.monotonic() | |
| with lock: | |
| lock_dt = time.monotonic() - t0 | |
| logging.info('synced wait: %.1f secs', lock_dt) | |
| return f(*args, **kw) | |
| return wrapper | |
| _warmed_up = set() | |
| _warmup_function = None | |
| def set_warmup_function(warmup_function): | |
| global _warmup_function | |
| _warmup_function = warmup_function | |
| _lock = threading.Lock() | |
| _scheduled = {} | |
| _download_secs = 0 | |
| _warmup_secs = 0 | |
| _loading_secs = 0 | |
| _done = {} | |
| _failed = {} | |
| def _do_download(): | |
| """Downloading files, to be started in background thread.""" | |
| global _download_secs | |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
| while True: | |
| if not _scheduled: | |
| time.sleep(1) | |
| continue | |
| name, (repo, filename, revision) = next(iter(_scheduled.items())) | |
| logging.info('Downloading "%s" %s/%s/%s...', name, repo, filename, revision) | |
| with timed(f'downloading {name}', True) as t: | |
| if should_mock(): | |
| logging.warning('Mocking loading') | |
| time.sleep(10.) | |
| _done[name] = None | |
| else: | |
| try: | |
| _done[name] = huggingface_hub.hf_hub_download( | |
| repo_id=repo, filename=filename, revision=revision) | |
| except Exception as e: # pylint: disable=broad-exception-caught | |
| logging.exception('Could not download "%s" from hub!', name) | |
| _failed[name] = str(e) | |
| with _lock: | |
| _scheduled.pop(name) | |
| continue | |
| if _warmup_function: | |
| def warmup(name): | |
| global _warmup_secs | |
| with timed(f'warming up {name}', True) as t: | |
| try: | |
| _warmup_function(name) | |
| _warmed_up.add(name) | |
| except Exception: # pylint: disable=broad-exception-caught | |
| logging.exception('Could not warmup "%s"!', name) | |
| _warmup_secs += t['secs'] | |
| executor.submit(warmup, name) | |
| _download_secs += t['secs'] | |
| with _lock: | |
| _scheduled.pop(name) | |
| def register_download(name, repo, filename, revision='main'): | |
| """Will cause download of `filename` from HF `repo` in background thread.""" | |
| with _lock: | |
| if name not in _scheduled: | |
| _scheduled[name] = (repo, filename, revision) | |
| def _hms(secs): | |
| """Formats `secs=3700` to `"01:01:40"`.""" | |
| secs = int(secs) | |
| h = secs // 3600 | |
| m = (secs - h * 3600) // 60 | |
| s = secs % 60 | |
| return (f'{h}:' if h else '') + f'{m:02}:{s:02}' | |
| def downloads_status(): | |
| """Returns string representation of download stats.""" | |
| done_t = remaining_t = '' | |
| if _done: | |
| done_t = f' in {_hms(_download_secs)}' | |
| remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}' | |
| status = f'Downloaded {len(_done)}{done_t}' | |
| if _scheduled: | |
| status += f', {len(_scheduled)}{remaining_t} remaining' | |
| if _warmup_function: | |
| status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}' | |
| if _failed: | |
| status += f', {len(_failed)} failed' | |
| return status | |
| def get_paths(): | |
| """Returns dictionary `name` to `path` from previous `register_download()`.""" | |
| return dict(_done) | |
| _download_thread = threading.Thread(target=_do_download) | |
| _download_thread.daemon = True | |
| _download_thread.start() | |
| _estimated_real = [(10, 10)] | |
| _memory_cache = {} | |
| def get_with_progress(getter, secs, progress, step=0.1): | |
| """Returns result from `getter` while showing a progress bar.""" | |
| if progress is None: | |
| return getter() | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(getter) | |
| for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): | |
| if not future.done(): | |
| time.sleep(step) | |
| return future.result() | |
| def _get_array_sizes(tree): | |
| return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] | |
| def get_memory_cache( | |
| key, getter, max_cache_size_bytes, progress=None, estimated_secs=None | |
| ): | |
| """Keeps cache below specified size by removing elements not last accessed.""" | |
| if key in _memory_cache: | |
| _memory_cache[key] = _memory_cache.pop(key) # Updates "last accessed" order | |
| return _memory_cache[key] | |
| est, real = zip(*_estimated_real) | |
| if estimated_secs is None: | |
| estimated_secs = sum(est) / len(est) | |
| with timed(f'loading {key}') as t: | |
| estimated_secs *= sum(real) / sum(est) | |
| value = get_with_progress(getter, estimated_secs, progress) | |
| _estimated_real.append((estimated_secs, t['secs'])) | |
| if not max_cache_size_bytes: | |
| return value | |
| _memory_cache[key] = value | |
| sz = sum(_get_array_sizes(list(_memory_cache.values()))) | |
| logging.info('New memory cache size=%.1f MB', sz/1e6) | |
| while sz > max_cache_size_bytes: | |
| k, v = next(iter(_memory_cache.items())) | |
| if k == key: | |
| break | |
| s = sum(_get_array_sizes(v)) | |
| logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) | |
| _memory_cache.pop(k) | |
| sz -= s | |
| return value | |
| def get_memory_cache_info(): | |
| """Returns number of items and total size in bytes.""" | |
| sizes = _get_array_sizes(_memory_cache) | |
| return len(_memory_cache), sum(sizes) | |
| def get_system_info(): | |
| """Returns string describing system's RAM/disk status.""" | |
| host_colocation = int(os.environ.get('HOST_COLOCATION', '1')) | |
| vm = psutil.virtual_memory() | |
| du = shutil.disk_usage('.') | |
| return ( | |
| f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, ' | |
| f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G' | |
| ) | |
| def get_status(include_system_info=True): | |
| """Returns string about download/memory/system status.""" | |
| mc_len, mc_sz = get_memory_cache_info() | |
| mc_t = _hms(sum(real for _, real in _estimated_real[1:])) | |
| return ( | |
| 'Timestamp: ' | |
| + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
| + ' – Model stats: ' | |
| + downloads_status() | |
| + ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' + | |
| (' – System: ' + get_system_info() if include_system_info else '') | |
| ) | |