Spaces:
Build error
Build error
| import jax | |
| import flax | |
| import numpy as np | |
| from tqdm import tqdm | |
| import requests | |
| import os | |
| import tempfile | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def download(url, ckpt_dir=None): | |
| name = url[url.rfind('/') + 1 : url.rfind('?')] | |
| if ckpt_dir is None: | |
| ckpt_dir = tempfile.gettempdir() | |
| ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') | |
| ckpt_file = os.path.join(ckpt_dir, name) | |
| if not os.path.exists(ckpt_file): | |
| logger.info(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') | |
| if not os.path.exists(ckpt_dir): | |
| os.makedirs(ckpt_dir) | |
| response = requests.get(url, stream=True) | |
| total_size_in_bytes = int(response.headers.get('content-length', 0)) | |
| progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
| # first create temp file, in case the download fails | |
| ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') | |
| with open(ckpt_file_temp, 'wb') as file: | |
| for data in response.iter_content(chunk_size=1024): | |
| progress_bar.update(len(data)) | |
| file.write(data) | |
| progress_bar.close() | |
| if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
| logger.error('An error occured while downloading, please try again.') | |
| if os.path.exists(ckpt_file_temp): | |
| os.remove(ckpt_file_temp) | |
| else: | |
| # if download was successful, rename the temp file | |
| os.rename(ckpt_file_temp, ckpt_file) | |
| return ckpt_file | |
| def get(dictionary, key): | |
| if dictionary is None or key not in dictionary: | |
| return None | |
| return dictionary[key] | |
| def prefetch(dataset, n_prefetch): | |
| # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py | |
| ds_iter = iter(dataset) | |
| ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), | |
| ds_iter) | |
| if n_prefetch: | |
| ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) | |
| return ds_iter | |