Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| import requests | |
| from torch.hub import download_url_to_file, get_dir | |
| from tqdm import tqdm | |
| from urllib.parse import urlparse | |
| from .misc import sizeof_fmt | |
| def download_file_from_google_drive(file_id, save_path): | |
| """Download files from google drive. | |
| Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive | |
| Args: | |
| file_id (str): File id. | |
| save_path (str): Save path. | |
| """ | |
| session = requests.Session() | |
| URL = 'https://docs.google.com/uc?export=download' | |
| params = {'id': file_id} | |
| response = session.get(URL, params=params, stream=True) | |
| token = get_confirm_token(response) | |
| if token: | |
| params['confirm'] = token | |
| response = session.get(URL, params=params, stream=True) | |
| # get file size | |
| response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) | |
| if 'Content-Range' in response_file_size.headers: | |
| file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) | |
| else: | |
| file_size = None | |
| save_response_content(response, save_path, file_size) | |
| def get_confirm_token(response): | |
| for key, value in response.cookies.items(): | |
| if key.startswith('download_warning'): | |
| return value | |
| return None | |
| def save_response_content(response, destination, file_size=None, chunk_size=32768): | |
| if file_size is not None: | |
| pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') | |
| readable_file_size = sizeof_fmt(file_size) | |
| else: | |
| pbar = None | |
| with open(destination, 'wb') as f: | |
| downloaded_size = 0 | |
| for chunk in response.iter_content(chunk_size): | |
| downloaded_size += chunk_size | |
| if pbar is not None: | |
| pbar.update(1) | |
| pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') | |
| if chunk: # filter out keep-alive new chunks | |
| f.write(chunk) | |
| if pbar is not None: | |
| pbar.close() | |
| def load_file_from_url(url, model_dir=None, progress=True, file_name=None): | |
| """Load file form http url, will download models if necessary. | |
| Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py | |
| Args: | |
| url (str): URL to be downloaded. | |
| model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. | |
| Default: None. | |
| progress (bool): Whether to show the download progress. Default: True. | |
| file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. | |
| Returns: | |
| str: The path to the downloaded file. | |
| """ | |
| if model_dir is None: # use the pytorch hub_dir | |
| hub_dir = get_dir() | |
| model_dir = os.path.join(hub_dir, 'checkpoints') | |
| os.makedirs(model_dir, exist_ok=True) | |
| parts = urlparse(url) | |
| filename = os.path.basename(parts.path) | |
| if file_name is not None: | |
| filename = file_name | |
| cached_file = os.path.abspath(os.path.join(model_dir, filename)) | |
| if not os.path.exists(cached_file): | |
| print(f'Downloading: "{url}" to {cached_file}\n') | |
| download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) | |
| return cached_file | |