runner
Browse files- runner/inference.py +9 -8
runner/inference.py
CHANGED
|
@@ -214,14 +214,15 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No
|
|
| 214 |
tos_url = URL[cache_name]
|
| 215 |
logger.info(f"Downloading data cache from\n {tos_url}...")
|
| 216 |
urllib.request.urlretrieve(tos_url, cache_path)
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
# checkpoint_path = configs.load_checkpoint_path
|
|
|
|
| 214 |
tos_url = URL[cache_name]
|
| 215 |
logger.info(f"Downloading data cache from\n {tos_url}...")
|
| 216 |
urllib.request.urlretrieve(tos_url, cache_path)
|
| 217 |
+
|
| 218 |
+
if not os.path.exists('./checkpoint.pt'):
|
| 219 |
+
# Google Drive file ID
|
| 220 |
+
file_id = '17zBIRed3xZM8ux0bq2hpf1oFC75Y7OEw'
|
| 221 |
+
# URL to download the file
|
| 222 |
+
url = f'https://drive.google.com/uc?id={file_id}'
|
| 223 |
+
|
| 224 |
+
# Download the file and save it as 'checkpoint.pt'
|
| 225 |
+
gdown.download(url, './checkpoint.pt', quiet=False)
|
| 226 |
|
| 227 |
|
| 228 |
# checkpoint_path = configs.load_checkpoint_path
|