|
|
""" |
|
|
Model downloader for BackgroundFX Pro. |
|
|
Handles downloading, caching, and verification of models. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
import hashlib |
|
|
import requests |
|
|
from pathlib import Path |
|
|
from typing import Optional, Callable, Dict, Any, List |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
import time |
|
|
import threading |
|
|
from urllib.parse import urlparse |
|
|
from concurrent.futures import ThreadPoolExecutor, Future |
|
|
import logging |
|
|
|
|
|
from .registry import ModelInfo, ModelStatus, ModelRegistry |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class DownloadStatus(Enum): |
|
|
"""Download status.""" |
|
|
PENDING = "pending" |
|
|
DOWNLOADING = "downloading" |
|
|
VERIFYING = "verifying" |
|
|
EXTRACTING = "extracting" |
|
|
COMPLETED = "completed" |
|
|
FAILED = "failed" |
|
|
CANCELLED = "cancelled" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DownloadProgress: |
|
|
"""Download progress information.""" |
|
|
model_id: str |
|
|
status: DownloadStatus |
|
|
current_bytes: int = 0 |
|
|
total_bytes: int = 0 |
|
|
speed_mbps: float = 0.0 |
|
|
eta_seconds: float = 0.0 |
|
|
error: Optional[str] = None |
|
|
|
|
|
@property |
|
|
def progress(self) -> float: |
|
|
"""Get progress percentage.""" |
|
|
if self.total_bytes > 0: |
|
|
return (self.current_bytes / self.total_bytes) * 100 |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
class ModelDownloader: |
|
|
"""Handle model downloading with progress tracking and resume support.""" |
|
|
|
|
|
def __init__(self, |
|
|
registry: ModelRegistry, |
|
|
max_workers: int = 3, |
|
|
chunk_size: int = 8192, |
|
|
timeout: int = 30, |
|
|
max_retries: int = 3): |
|
|
""" |
|
|
Initialize model downloader. |
|
|
|
|
|
Args: |
|
|
registry: Model registry instance |
|
|
max_workers: Maximum concurrent downloads |
|
|
chunk_size: Download chunk size in bytes |
|
|
timeout: Request timeout in seconds |
|
|
max_retries: Maximum retry attempts |
|
|
""" |
|
|
self.registry = registry |
|
|
self.max_workers = max_workers |
|
|
self.chunk_size = chunk_size |
|
|
self.timeout = timeout |
|
|
self.max_retries = max_retries |
|
|
|
|
|
|
|
|
self.downloads: Dict[str, DownloadProgress] = {} |
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers) |
|
|
self.futures: Dict[str, Future] = {} |
|
|
self._stop_events: Dict[str, threading.Event] = {} |
|
|
|
|
|
|
|
|
self.cache_dir = registry.models_dir / ".cache" |
|
|
self.cache_dir.mkdir(exist_ok=True) |
|
|
|
|
|
def download_model(self, |
|
|
model_id: str, |
|
|
progress_callback: Optional[Callable[[DownloadProgress], None]] = None, |
|
|
force: bool = False) -> bool: |
|
|
""" |
|
|
Download a model. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID to download |
|
|
progress_callback: Optional progress callback |
|
|
force: Force re-download even if exists |
|
|
|
|
|
Returns: |
|
|
True if download successful |
|
|
""" |
|
|
|
|
|
model = self.registry.get_model(model_id) |
|
|
if not model: |
|
|
logger.error(f"Model not found: {model_id}") |
|
|
return False |
|
|
|
|
|
|
|
|
if not force and model.status == ModelStatus.AVAILABLE: |
|
|
logger.info(f"Model already available: {model_id}") |
|
|
return True |
|
|
|
|
|
|
|
|
progress = DownloadProgress( |
|
|
model_id=model_id, |
|
|
status=DownloadStatus.PENDING, |
|
|
total_bytes=model.file_size |
|
|
) |
|
|
self.downloads[model_id] = progress |
|
|
|
|
|
|
|
|
self._stop_events[model_id] = threading.Event() |
|
|
|
|
|
|
|
|
future = self.executor.submit( |
|
|
self._download_model_task, |
|
|
model, |
|
|
progress, |
|
|
progress_callback, |
|
|
force |
|
|
) |
|
|
self.futures[model_id] = future |
|
|
|
|
|
|
|
|
try: |
|
|
return future.result() |
|
|
except Exception as e: |
|
|
logger.error(f"Download failed for {model_id}: {e}") |
|
|
return False |
|
|
|
|
|
def download_models_async(self, |
|
|
model_ids: List[str], |
|
|
progress_callback: Optional[Callable[[str, DownloadProgress], None]] = None, |
|
|
force: bool = False) -> Dict[str, Future]: |
|
|
""" |
|
|
Download multiple models asynchronously. |
|
|
|
|
|
Args: |
|
|
model_ids: List of model IDs |
|
|
progress_callback: Optional progress callback with model_id |
|
|
force: Force re-download |
|
|
|
|
|
Returns: |
|
|
Dictionary of futures |
|
|
""" |
|
|
futures = {} |
|
|
|
|
|
for model_id in model_ids: |
|
|
model = self.registry.get_model(model_id) |
|
|
if not model: |
|
|
logger.warning(f"Model not found: {model_id}") |
|
|
continue |
|
|
|
|
|
|
|
|
if not force and model.status == ModelStatus.AVAILABLE: |
|
|
continue |
|
|
|
|
|
|
|
|
progress = DownloadProgress( |
|
|
model_id=model_id, |
|
|
status=DownloadStatus.PENDING, |
|
|
total_bytes=model.file_size |
|
|
) |
|
|
self.downloads[model_id] = progress |
|
|
|
|
|
|
|
|
self._stop_events[model_id] = threading.Event() |
|
|
|
|
|
|
|
|
def progress_wrapper(p): |
|
|
if progress_callback: |
|
|
progress_callback(model_id, p) |
|
|
|
|
|
|
|
|
future = self.executor.submit( |
|
|
self._download_model_task, |
|
|
model, |
|
|
progress, |
|
|
progress_wrapper, |
|
|
force |
|
|
) |
|
|
futures[model_id] = future |
|
|
self.futures[model_id] = future |
|
|
|
|
|
return futures |
|
|
|
|
|
def _download_model_task(self, |
|
|
model: ModelInfo, |
|
|
progress: DownloadProgress, |
|
|
progress_callback: Optional[Callable], |
|
|
force: bool) -> bool: |
|
|
""" |
|
|
Download model task. |
|
|
|
|
|
Args: |
|
|
model: Model information |
|
|
progress: Progress tracker |
|
|
progress_callback: Progress callback |
|
|
force: Force re-download |
|
|
|
|
|
Returns: |
|
|
True if successful |
|
|
""" |
|
|
try: |
|
|
|
|
|
progress.status = DownloadStatus.DOWNLOADING |
|
|
self._notify_progress(progress, progress_callback) |
|
|
|
|
|
|
|
|
urls = [model.url] + model.mirror_urls |
|
|
|
|
|
for url in urls: |
|
|
if self._stop_events[model.model_id].is_set(): |
|
|
progress.status = DownloadStatus.CANCELLED |
|
|
self._notify_progress(progress, progress_callback) |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
output_path = self.registry.models_dir / model.filename |
|
|
success = self._download_file( |
|
|
url, |
|
|
output_path, |
|
|
progress, |
|
|
progress_callback, |
|
|
model.model_id |
|
|
) |
|
|
|
|
|
if success: |
|
|
|
|
|
progress.status = DownloadStatus.VERIFYING |
|
|
self._notify_progress(progress, progress_callback) |
|
|
|
|
|
if self._verify_download(output_path, model): |
|
|
|
|
|
model.status = ModelStatus.AVAILABLE |
|
|
model.local_path = str(output_path) |
|
|
model.download_date = time.time() |
|
|
self.registry._save_registry() |
|
|
|
|
|
progress.status = DownloadStatus.COMPLETED |
|
|
self._notify_progress(progress, progress_callback) |
|
|
|
|
|
logger.info(f"Successfully downloaded: {model.model_id}") |
|
|
return True |
|
|
else: |
|
|
|
|
|
output_path.unlink(missing_ok=True) |
|
|
logger.warning(f"Verification failed for {model.model_id}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Download failed from {url}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
progress.status = DownloadStatus.FAILED |
|
|
progress.error = "All download attempts failed" |
|
|
self._notify_progress(progress, progress_callback) |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
progress.status = DownloadStatus.FAILED |
|
|
progress.error = str(e) |
|
|
self._notify_progress(progress, progress_callback) |
|
|
logger.error(f"Download task failed: {e}") |
|
|
return False |
|
|
|
|
|
def _download_file(self, |
|
|
url: str, |
|
|
output_path: Path, |
|
|
progress: DownloadProgress, |
|
|
progress_callback: Optional[Callable], |
|
|
model_id: str) -> bool: |
|
|
""" |
|
|
Download file with resume support. |
|
|
|
|
|
Args: |
|
|
url: Download URL |
|
|
output_path: Output file path |
|
|
progress: Progress tracker |
|
|
progress_callback: Progress callback |
|
|
model_id: Model ID for stop event |
|
|
|
|
|
Returns: |
|
|
True if successful |
|
|
""" |
|
|
|
|
|
temp_path = output_path.with_suffix('.part') |
|
|
resume_pos = 0 |
|
|
|
|
|
if temp_path.exists(): |
|
|
resume_pos = temp_path.stat().st_size |
|
|
logger.info(f"Resuming download from {resume_pos} bytes") |
|
|
|
|
|
|
|
|
headers = {} |
|
|
if resume_pos > 0: |
|
|
headers['Range'] = f'bytes={resume_pos}-' |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
bytes_downloaded = resume_pos |
|
|
|
|
|
try: |
|
|
response = requests.get( |
|
|
url, |
|
|
headers=headers, |
|
|
stream=True, |
|
|
timeout=self.timeout |
|
|
) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
if 'content-length' in response.headers: |
|
|
total_size = int(response.headers['content-length']) + resume_pos |
|
|
progress.total_bytes = total_size |
|
|
else: |
|
|
total_size = None |
|
|
|
|
|
|
|
|
mode = 'ab' if resume_pos > 0 else 'wb' |
|
|
with open(temp_path, mode) as f: |
|
|
for chunk in response.iter_content(chunk_size=self.chunk_size): |
|
|
|
|
|
if self._stop_events[model_id].is_set(): |
|
|
logger.info(f"Download cancelled: {model_id}") |
|
|
return False |
|
|
|
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
bytes_downloaded += len(chunk) |
|
|
|
|
|
|
|
|
progress.current_bytes = bytes_downloaded |
|
|
|
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
if elapsed > 0: |
|
|
speed_bps = (bytes_downloaded - resume_pos) / elapsed |
|
|
progress.speed_mbps = (speed_bps * 8) / 1_000_000 |
|
|
|
|
|
if total_size and speed_bps > 0: |
|
|
remaining = total_size - bytes_downloaded |
|
|
progress.eta_seconds = remaining / speed_bps |
|
|
|
|
|
self._notify_progress(progress, progress_callback) |
|
|
|
|
|
|
|
|
shutil.move(str(temp_path), str(output_path)) |
|
|
return True |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
logger.error(f"Download error: {e}") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.error(f"File write error: {e}") |
|
|
return False |
|
|
|
|
|
def _verify_download(self, file_path: Path, model: ModelInfo) -> bool: |
|
|
""" |
|
|
Verify downloaded file. |
|
|
|
|
|
Args: |
|
|
file_path: Downloaded file path |
|
|
model: Model information |
|
|
|
|
|
Returns: |
|
|
True if verification passed |
|
|
""" |
|
|
|
|
|
if not file_path.exists(): |
|
|
return False |
|
|
|
|
|
|
|
|
actual_size = file_path.stat().st_size |
|
|
if model.file_size > 0: |
|
|
size_diff = abs(actual_size - model.file_size) |
|
|
if size_diff > 1000: |
|
|
logger.warning(f"Size mismatch: expected {model.file_size}, got {actual_size}") |
|
|
return False |
|
|
|
|
|
|
|
|
if model.sha256: |
|
|
try: |
|
|
sha256 = self._calculate_sha256(file_path) |
|
|
if sha256 != model.sha256: |
|
|
logger.warning(f"SHA256 mismatch for {model.model_id}") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.error(f"SHA256 calculation failed: {e}") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _calculate_sha256(self, file_path: Path) -> str: |
|
|
"""Calculate SHA256 hash of file.""" |
|
|
sha256_hash = hashlib.sha256() |
|
|
with open(file_path, "rb") as f: |
|
|
for byte_block in iter(lambda: f.read(self.chunk_size), b""): |
|
|
sha256_hash.update(byte_block) |
|
|
return sha256_hash.hexdigest() |
|
|
|
|
|
def _notify_progress(self, progress: DownloadProgress, callback: Optional[Callable]): |
|
|
"""Notify progress callback.""" |
|
|
if callback: |
|
|
try: |
|
|
callback(progress) |
|
|
except Exception as e: |
|
|
logger.error(f"Progress callback error: {e}") |
|
|
|
|
|
def cancel_download(self, model_id: str) -> bool: |
|
|
""" |
|
|
Cancel ongoing download. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID to cancel |
|
|
|
|
|
Returns: |
|
|
True if cancelled |
|
|
""" |
|
|
if model_id in self._stop_events: |
|
|
self._stop_events[model_id].set() |
|
|
|
|
|
|
|
|
if model_id in self.futures: |
|
|
try: |
|
|
self.futures[model_id].result(timeout=5) |
|
|
except: |
|
|
pass |
|
|
del self.futures[model_id] |
|
|
|
|
|
|
|
|
if model_id in self.downloads: |
|
|
self.downloads[model_id].status = DownloadStatus.CANCELLED |
|
|
|
|
|
logger.info(f"Download cancelled: {model_id}") |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def get_progress(self, model_id: str) -> Optional[DownloadProgress]: |
|
|
"""Get download progress for model.""" |
|
|
return self.downloads.get(model_id) |
|
|
|
|
|
def get_all_progress(self) -> Dict[str, DownloadProgress]: |
|
|
"""Get all download progress.""" |
|
|
return self.downloads.copy() |
|
|
|
|
|
def cleanup_partial_downloads(self): |
|
|
"""Clean up partial download files.""" |
|
|
for file in self.registry.models_dir.glob("*.part"): |
|
|
try: |
|
|
file.unlink() |
|
|
logger.info(f"Removed partial download: {file.name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to remove {file}: {e}") |
|
|
|
|
|
def download_required_models(self, |
|
|
task: str = None, |
|
|
gpu_available: bool = True) -> bool: |
|
|
""" |
|
|
Download all required models for a task. |
|
|
|
|
|
Args: |
|
|
task: Optional task filter |
|
|
gpu_available: GPU availability |
|
|
|
|
|
Returns: |
|
|
True if all downloads successful |
|
|
""" |
|
|
|
|
|
required = [] |
|
|
|
|
|
if task: |
|
|
|
|
|
from .registry import ModelTask |
|
|
task_enum = ModelTask(task) |
|
|
model = self.registry.get_best_model( |
|
|
task_enum, |
|
|
require_gpu=gpu_available if gpu_available else False |
|
|
) |
|
|
if model: |
|
|
required.append(model.model_id) |
|
|
else: |
|
|
|
|
|
essential = ['rmbg-1.4', 'u2netp', 'modnet'] |
|
|
for model_id in essential: |
|
|
if self.registry.get_model(model_id): |
|
|
required.append(model_id) |
|
|
|
|
|
|
|
|
if required: |
|
|
logger.info(f"Downloading required models: {required}") |
|
|
futures = self.download_models_async(required) |
|
|
|
|
|
|
|
|
success = True |
|
|
for model_id, future in futures.items(): |
|
|
try: |
|
|
if not future.result(): |
|
|
success = False |
|
|
except Exception: |
|
|
success = False |
|
|
|
|
|
return success |
|
|
|
|
|
return True |