MogensR commited on
Commit
39ebfbd
·
1 Parent(s): 05676b4

Create models/downloader.py

Browse files
Files changed (1) hide show
  1. models/downloader.py +521 -0
models/downloader.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model downloader for BackgroundFX Pro.
3
+ Handles downloading, caching, and verification of models.
4
+ """
5
+
6
+ import os
7
+ import shutil
8
+ import tempfile
9
+ import hashlib
10
+ import requests
11
+ from pathlib import Path
12
+ from typing import Optional, Callable, Dict, Any, List
13
+ from dataclasses import dataclass
14
+ from enum import Enum
15
+ import time
16
+ import threading
17
+ from urllib.parse import urlparse
18
+ from concurrent.futures import ThreadPoolExecutor, Future
19
+ import logging
20
+
21
+ from .registry import ModelInfo, ModelStatus, ModelRegistry
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class DownloadStatus(Enum):
27
+ """Download status."""
28
+ PENDING = "pending"
29
+ DOWNLOADING = "downloading"
30
+ VERIFYING = "verifying"
31
+ EXTRACTING = "extracting"
32
+ COMPLETED = "completed"
33
+ FAILED = "failed"
34
+ CANCELLED = "cancelled"
35
+
36
+
37
+ @dataclass
38
+ class DownloadProgress:
39
+ """Download progress information."""
40
+ model_id: str
41
+ status: DownloadStatus
42
+ current_bytes: int = 0
43
+ total_bytes: int = 0
44
+ speed_mbps: float = 0.0
45
+ eta_seconds: float = 0.0
46
+ error: Optional[str] = None
47
+
48
+ @property
49
+ def progress(self) -> float:
50
+ """Get progress percentage."""
51
+ if self.total_bytes > 0:
52
+ return (self.current_bytes / self.total_bytes) * 100
53
+ return 0.0
54
+
55
+
56
+ class ModelDownloader:
57
+ """Handle model downloading with progress tracking and resume support."""
58
+
59
+ def __init__(self,
60
+ registry: ModelRegistry,
61
+ max_workers: int = 3,
62
+ chunk_size: int = 8192,
63
+ timeout: int = 30,
64
+ max_retries: int = 3):
65
+ """
66
+ Initialize model downloader.
67
+
68
+ Args:
69
+ registry: Model registry instance
70
+ max_workers: Maximum concurrent downloads
71
+ chunk_size: Download chunk size in bytes
72
+ timeout: Request timeout in seconds
73
+ max_retries: Maximum retry attempts
74
+ """
75
+ self.registry = registry
76
+ self.max_workers = max_workers
77
+ self.chunk_size = chunk_size
78
+ self.timeout = timeout
79
+ self.max_retries = max_retries
80
+
81
+ # Download management
82
+ self.downloads: Dict[str, DownloadProgress] = {}
83
+ self.executor = ThreadPoolExecutor(max_workers=max_workers)
84
+ self.futures: Dict[str, Future] = {}
85
+ self._stop_events: Dict[str, threading.Event] = {}
86
+
87
+ # Cache directory
88
+ self.cache_dir = registry.models_dir / ".cache"
89
+ self.cache_dir.mkdir(exist_ok=True)
90
+
91
+ def download_model(self,
92
+ model_id: str,
93
+ progress_callback: Optional[Callable[[DownloadProgress], None]] = None,
94
+ force: bool = False) -> bool:
95
+ """
96
+ Download a model.
97
+
98
+ Args:
99
+ model_id: Model ID to download
100
+ progress_callback: Optional progress callback
101
+ force: Force re-download even if exists
102
+
103
+ Returns:
104
+ True if download successful
105
+ """
106
+ # Get model info
107
+ model = self.registry.get_model(model_id)
108
+ if not model:
109
+ logger.error(f"Model not found: {model_id}")
110
+ return False
111
+
112
+ # Check if already downloaded
113
+ if not force and model.status == ModelStatus.AVAILABLE:
114
+ logger.info(f"Model already available: {model_id}")
115
+ return True
116
+
117
+ # Initialize progress
118
+ progress = DownloadProgress(
119
+ model_id=model_id,
120
+ status=DownloadStatus.PENDING,
121
+ total_bytes=model.file_size
122
+ )
123
+ self.downloads[model_id] = progress
124
+
125
+ # Create stop event
126
+ self._stop_events[model_id] = threading.Event()
127
+
128
+ # Submit download task
129
+ future = self.executor.submit(
130
+ self._download_model_task,
131
+ model,
132
+ progress,
133
+ progress_callback,
134
+ force
135
+ )
136
+ self.futures[model_id] = future
137
+
138
+ # Wait for completion
139
+ try:
140
+ return future.result()
141
+ except Exception as e:
142
+ logger.error(f"Download failed for {model_id}: {e}")
143
+ return False
144
+
145
+ def download_models_async(self,
146
+ model_ids: List[str],
147
+ progress_callback: Optional[Callable[[str, DownloadProgress], None]] = None,
148
+ force: bool = False) -> Dict[str, Future]:
149
+ """
150
+ Download multiple models asynchronously.
151
+
152
+ Args:
153
+ model_ids: List of model IDs
154
+ progress_callback: Optional progress callback with model_id
155
+ force: Force re-download
156
+
157
+ Returns:
158
+ Dictionary of futures
159
+ """
160
+ futures = {}
161
+
162
+ for model_id in model_ids:
163
+ model = self.registry.get_model(model_id)
164
+ if not model:
165
+ logger.warning(f"Model not found: {model_id}")
166
+ continue
167
+
168
+ # Skip if already available
169
+ if not force and model.status == ModelStatus.AVAILABLE:
170
+ continue
171
+
172
+ # Initialize progress
173
+ progress = DownloadProgress(
174
+ model_id=model_id,
175
+ status=DownloadStatus.PENDING,
176
+ total_bytes=model.file_size
177
+ )
178
+ self.downloads[model_id] = progress
179
+
180
+ # Create stop event
181
+ self._stop_events[model_id] = threading.Event()
182
+
183
+ # Wrapper for progress callback
184
+ def progress_wrapper(p):
185
+ if progress_callback:
186
+ progress_callback(model_id, p)
187
+
188
+ # Submit download task
189
+ future = self.executor.submit(
190
+ self._download_model_task,
191
+ model,
192
+ progress,
193
+ progress_wrapper,
194
+ force
195
+ )
196
+ futures[model_id] = future
197
+ self.futures[model_id] = future
198
+
199
+ return futures
200
+
201
+ def _download_model_task(self,
202
+ model: ModelInfo,
203
+ progress: DownloadProgress,
204
+ progress_callback: Optional[Callable],
205
+ force: bool) -> bool:
206
+ """
207
+ Download model task.
208
+
209
+ Args:
210
+ model: Model information
211
+ progress: Progress tracker
212
+ progress_callback: Progress callback
213
+ force: Force re-download
214
+
215
+ Returns:
216
+ True if successful
217
+ """
218
+ try:
219
+ # Update status
220
+ progress.status = DownloadStatus.DOWNLOADING
221
+ self._notify_progress(progress, progress_callback)
222
+
223
+ # Try primary URL first, then mirrors
224
+ urls = [model.url] + model.mirror_urls
225
+
226
+ for url in urls:
227
+ if self._stop_events[model.model_id].is_set():
228
+ progress.status = DownloadStatus.CANCELLED
229
+ self._notify_progress(progress, progress_callback)
230
+ return False
231
+
232
+ try:
233
+ # Download file
234
+ output_path = self.registry.models_dir / model.filename
235
+ success = self._download_file(
236
+ url,
237
+ output_path,
238
+ progress,
239
+ progress_callback,
240
+ model.model_id
241
+ )
242
+
243
+ if success:
244
+ # Verify file
245
+ progress.status = DownloadStatus.VERIFYING
246
+ self._notify_progress(progress, progress_callback)
247
+
248
+ if self._verify_download(output_path, model):
249
+ # Update registry
250
+ model.status = ModelStatus.AVAILABLE
251
+ model.local_path = str(output_path)
252
+ model.download_date = time.time()
253
+ self.registry._save_registry()
254
+
255
+ progress.status = DownloadStatus.COMPLETED
256
+ self._notify_progress(progress, progress_callback)
257
+
258
+ logger.info(f"Successfully downloaded: {model.model_id}")
259
+ return True
260
+ else:
261
+ # Verification failed
262
+ output_path.unlink(missing_ok=True)
263
+ logger.warning(f"Verification failed for {model.model_id}")
264
+
265
+ except Exception as e:
266
+ logger.warning(f"Download failed from {url}: {e}")
267
+ continue
268
+
269
+ # All attempts failed
270
+ progress.status = DownloadStatus.FAILED
271
+ progress.error = "All download attempts failed"
272
+ self._notify_progress(progress, progress_callback)
273
+ return False
274
+
275
+ except Exception as e:
276
+ progress.status = DownloadStatus.FAILED
277
+ progress.error = str(e)
278
+ self._notify_progress(progress, progress_callback)
279
+ logger.error(f"Download task failed: {e}")
280
+ return False
281
+
282
+ def _download_file(self,
283
+ url: str,
284
+ output_path: Path,
285
+ progress: DownloadProgress,
286
+ progress_callback: Optional[Callable],
287
+ model_id: str) -> bool:
288
+ """
289
+ Download file with resume support.
290
+
291
+ Args:
292
+ url: Download URL
293
+ output_path: Output file path
294
+ progress: Progress tracker
295
+ progress_callback: Progress callback
296
+ model_id: Model ID for stop event
297
+
298
+ Returns:
299
+ True if successful
300
+ """
301
+ # Check for partial download
302
+ temp_path = output_path.with_suffix('.part')
303
+ resume_pos = 0
304
+
305
+ if temp_path.exists():
306
+ resume_pos = temp_path.stat().st_size
307
+ logger.info(f"Resuming download from {resume_pos} bytes")
308
+
309
+ # Prepare headers for resume
310
+ headers = {}
311
+ if resume_pos > 0:
312
+ headers['Range'] = f'bytes={resume_pos}-'
313
+
314
+ # Start download
315
+ start_time = time.time()
316
+ bytes_downloaded = resume_pos
317
+
318
+ try:
319
+ response = requests.get(
320
+ url,
321
+ headers=headers,
322
+ stream=True,
323
+ timeout=self.timeout
324
+ )
325
+ response.raise_for_status()
326
+
327
+ # Get total size
328
+ if 'content-length' in response.headers:
329
+ total_size = int(response.headers['content-length']) + resume_pos
330
+ progress.total_bytes = total_size
331
+ else:
332
+ total_size = None
333
+
334
+ # Download with progress
335
+ mode = 'ab' if resume_pos > 0 else 'wb'
336
+ with open(temp_path, mode) as f:
337
+ for chunk in response.iter_content(chunk_size=self.chunk_size):
338
+ # Check for cancellation
339
+ if self._stop_events[model_id].is_set():
340
+ logger.info(f"Download cancelled: {model_id}")
341
+ return False
342
+
343
+ if chunk:
344
+ f.write(chunk)
345
+ bytes_downloaded += len(chunk)
346
+
347
+ # Update progress
348
+ progress.current_bytes = bytes_downloaded
349
+
350
+ # Calculate speed and ETA
351
+ elapsed = time.time() - start_time
352
+ if elapsed > 0:
353
+ speed_bps = (bytes_downloaded - resume_pos) / elapsed
354
+ progress.speed_mbps = (speed_bps * 8) / 1_000_000
355
+
356
+ if total_size and speed_bps > 0:
357
+ remaining = total_size - bytes_downloaded
358
+ progress.eta_seconds = remaining / speed_bps
359
+
360
+ self._notify_progress(progress, progress_callback)
361
+
362
+ # Move to final location
363
+ shutil.move(str(temp_path), str(output_path))
364
+ return True
365
+
366
+ except requests.exceptions.RequestException as e:
367
+ logger.error(f"Download error: {e}")
368
+ return False
369
+ except Exception as e:
370
+ logger.error(f"File write error: {e}")
371
+ return False
372
+
373
+ def _verify_download(self, file_path: Path, model: ModelInfo) -> bool:
374
+ """
375
+ Verify downloaded file.
376
+
377
+ Args:
378
+ file_path: Downloaded file path
379
+ model: Model information
380
+
381
+ Returns:
382
+ True if verification passed
383
+ """
384
+ # Check file exists
385
+ if not file_path.exists():
386
+ return False
387
+
388
+ # Check file size
389
+ actual_size = file_path.stat().st_size
390
+ if model.file_size > 0:
391
+ size_diff = abs(actual_size - model.file_size)
392
+ if size_diff > 1000: # Allow 1KB difference
393
+ logger.warning(f"Size mismatch: expected {model.file_size}, got {actual_size}")
394
+ return False
395
+
396
+ # Check SHA256 if available
397
+ if model.sha256:
398
+ try:
399
+ sha256 = self._calculate_sha256(file_path)
400
+ if sha256 != model.sha256:
401
+ logger.warning(f"SHA256 mismatch for {model.model_id}")
402
+ return False
403
+ except Exception as e:
404
+ logger.error(f"SHA256 calculation failed: {e}")
405
+ return False
406
+
407
+ return True
408
+
409
+ def _calculate_sha256(self, file_path: Path) -> str:
410
+ """Calculate SHA256 hash of file."""
411
+ sha256_hash = hashlib.sha256()
412
+ with open(file_path, "rb") as f:
413
+ for byte_block in iter(lambda: f.read(self.chunk_size), b""):
414
+ sha256_hash.update(byte_block)
415
+ return sha256_hash.hexdigest()
416
+
417
+ def _notify_progress(self, progress: DownloadProgress, callback: Optional[Callable]):
418
+ """Notify progress callback."""
419
+ if callback:
420
+ try:
421
+ callback(progress)
422
+ except Exception as e:
423
+ logger.error(f"Progress callback error: {e}")
424
+
425
+ def cancel_download(self, model_id: str) -> bool:
426
+ """
427
+ Cancel ongoing download.
428
+
429
+ Args:
430
+ model_id: Model ID to cancel
431
+
432
+ Returns:
433
+ True if cancelled
434
+ """
435
+ if model_id in self._stop_events:
436
+ self._stop_events[model_id].set()
437
+
438
+ # Wait for cancellation
439
+ if model_id in self.futures:
440
+ try:
441
+ self.futures[model_id].result(timeout=5)
442
+ except:
443
+ pass
444
+ del self.futures[model_id]
445
+
446
+ # Update progress
447
+ if model_id in self.downloads:
448
+ self.downloads[model_id].status = DownloadStatus.CANCELLED
449
+
450
+ logger.info(f"Download cancelled: {model_id}")
451
+ return True
452
+
453
+ return False
454
+
455
+ def get_progress(self, model_id: str) -> Optional[DownloadProgress]:
456
+ """Get download progress for model."""
457
+ return self.downloads.get(model_id)
458
+
459
+ def get_all_progress(self) -> Dict[str, DownloadProgress]:
460
+ """Get all download progress."""
461
+ return self.downloads.copy()
462
+
463
+ def cleanup_partial_downloads(self):
464
+ """Clean up partial download files."""
465
+ for file in self.registry.models_dir.glob("*.part"):
466
+ try:
467
+ file.unlink()
468
+ logger.info(f"Removed partial download: {file.name}")
469
+ except Exception as e:
470
+ logger.error(f"Failed to remove {file}: {e}")
471
+
472
+ def download_required_models(self,
473
+ task: str = None,
474
+ gpu_available: bool = True) -> bool:
475
+ """
476
+ Download all required models for a task.
477
+
478
+ Args:
479
+ task: Optional task filter
480
+ gpu_available: GPU availability
481
+
482
+ Returns:
483
+ True if all downloads successful
484
+ """
485
+ # Get required models
486
+ required = []
487
+
488
+ if task:
489
+ # Get best model for task
490
+ from .registry import ModelTask
491
+ task_enum = ModelTask(task)
492
+ model = self.registry.get_best_model(
493
+ task_enum,
494
+ require_gpu=gpu_available if gpu_available else False
495
+ )
496
+ if model:
497
+ required.append(model.model_id)
498
+ else:
499
+ # Get all essential models
500
+ essential = ['rmbg-1.4', 'u2netp', 'modnet']
501
+ for model_id in essential:
502
+ if self.registry.get_model(model_id):
503
+ required.append(model_id)
504
+
505
+ # Download models
506
+ if required:
507
+ logger.info(f"Downloading required models: {required}")
508
+ futures = self.download_models_async(required)
509
+
510
+ # Wait for completion
511
+ success = True
512
+ for model_id, future in futures.items():
513
+ try:
514
+ if not future.result():
515
+ success = False
516
+ except Exception:
517
+ success = False
518
+
519
+ return success
520
+
521
+ return True