File size: 18,001 Bytes
39ebfbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
"""
Model downloader for BackgroundFX Pro.
Handles downloading, caching, and verification of models.
"""

import os
import shutil
import tempfile
import hashlib
import requests
from pathlib import Path
from typing import Optional, Callable, Dict, Any, List
from dataclasses import dataclass
from enum import Enum
import time
import threading
from urllib.parse import urlparse
from concurrent.futures import ThreadPoolExecutor, Future
import logging

from .registry import ModelInfo, ModelStatus, ModelRegistry

logger = logging.getLogger(__name__)


class DownloadStatus(Enum):
    """Download status."""
    PENDING = "pending"
    DOWNLOADING = "downloading"
    VERIFYING = "verifying"
    EXTRACTING = "extracting"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"


@dataclass
class DownloadProgress:
    """Download progress information."""
    model_id: str
    status: DownloadStatus
    current_bytes: int = 0
    total_bytes: int = 0
    speed_mbps: float = 0.0
    eta_seconds: float = 0.0
    error: Optional[str] = None
    
    @property
    def progress(self) -> float:
        """Get progress percentage."""
        if self.total_bytes > 0:
            return (self.current_bytes / self.total_bytes) * 100
        return 0.0


class ModelDownloader:
    """Handle model downloading with progress tracking and resume support."""
    
    def __init__(self, 
                 registry: ModelRegistry,
                 max_workers: int = 3,
                 chunk_size: int = 8192,
                 timeout: int = 30,
                 max_retries: int = 3):
        """
        Initialize model downloader.
        
        Args:
            registry: Model registry instance
            max_workers: Maximum concurrent downloads
            chunk_size: Download chunk size in bytes
            timeout: Request timeout in seconds
            max_retries: Maximum retry attempts
        """
        self.registry = registry
        self.max_workers = max_workers
        self.chunk_size = chunk_size
        self.timeout = timeout
        self.max_retries = max_retries
        
        # Download management
        self.downloads: Dict[str, DownloadProgress] = {}
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.futures: Dict[str, Future] = {}
        self._stop_events: Dict[str, threading.Event] = {}
        
        # Cache directory
        self.cache_dir = registry.models_dir / ".cache"
        self.cache_dir.mkdir(exist_ok=True)
    
    def download_model(self,
                       model_id: str,
                       progress_callback: Optional[Callable[[DownloadProgress], None]] = None,
                       force: bool = False) -> bool:
        """
        Download a model.
        
        Args:
            model_id: Model ID to download
            progress_callback: Optional progress callback
            force: Force re-download even if exists
            
        Returns:
            True if download successful
        """
        # Get model info
        model = self.registry.get_model(model_id)
        if not model:
            logger.error(f"Model not found: {model_id}")
            return False
        
        # Check if already downloaded
        if not force and model.status == ModelStatus.AVAILABLE:
            logger.info(f"Model already available: {model_id}")
            return True
        
        # Initialize progress
        progress = DownloadProgress(
            model_id=model_id,
            status=DownloadStatus.PENDING,
            total_bytes=model.file_size
        )
        self.downloads[model_id] = progress
        
        # Create stop event
        self._stop_events[model_id] = threading.Event()
        
        # Submit download task
        future = self.executor.submit(
            self._download_model_task,
            model,
            progress,
            progress_callback,
            force
        )
        self.futures[model_id] = future
        
        # Wait for completion
        try:
            return future.result()
        except Exception as e:
            logger.error(f"Download failed for {model_id}: {e}")
            return False
    
    def download_models_async(self,
                             model_ids: List[str],
                             progress_callback: Optional[Callable[[str, DownloadProgress], None]] = None,
                             force: bool = False) -> Dict[str, Future]:
        """
        Download multiple models asynchronously.
        
        Args:
            model_ids: List of model IDs
            progress_callback: Optional progress callback with model_id
            force: Force re-download
            
        Returns:
            Dictionary of futures
        """
        futures = {}
        
        for model_id in model_ids:
            model = self.registry.get_model(model_id)
            if not model:
                logger.warning(f"Model not found: {model_id}")
                continue
            
            # Skip if already available
            if not force and model.status == ModelStatus.AVAILABLE:
                continue
            
            # Initialize progress
            progress = DownloadProgress(
                model_id=model_id,
                status=DownloadStatus.PENDING,
                total_bytes=model.file_size
            )
            self.downloads[model_id] = progress
            
            # Create stop event
            self._stop_events[model_id] = threading.Event()
            
            # Wrapper for progress callback
            def progress_wrapper(p):
                if progress_callback:
                    progress_callback(model_id, p)
            
            # Submit download task
            future = self.executor.submit(
                self._download_model_task,
                model,
                progress,
                progress_wrapper,
                force
            )
            futures[model_id] = future
            self.futures[model_id] = future
        
        return futures
    
    def _download_model_task(self,
                            model: ModelInfo,
                            progress: DownloadProgress,
                            progress_callback: Optional[Callable],
                            force: bool) -> bool:
        """
        Download model task.
        
        Args:
            model: Model information
            progress: Progress tracker
            progress_callback: Progress callback
            force: Force re-download
            
        Returns:
            True if successful
        """
        try:
            # Update status
            progress.status = DownloadStatus.DOWNLOADING
            self._notify_progress(progress, progress_callback)
            
            # Try primary URL first, then mirrors
            urls = [model.url] + model.mirror_urls
            
            for url in urls:
                if self._stop_events[model.model_id].is_set():
                    progress.status = DownloadStatus.CANCELLED
                    self._notify_progress(progress, progress_callback)
                    return False
                
                try:
                    # Download file
                    output_path = self.registry.models_dir / model.filename
                    success = self._download_file(
                        url,
                        output_path,
                        progress,
                        progress_callback,
                        model.model_id
                    )
                    
                    if success:
                        # Verify file
                        progress.status = DownloadStatus.VERIFYING
                        self._notify_progress(progress, progress_callback)
                        
                        if self._verify_download(output_path, model):
                            # Update registry
                            model.status = ModelStatus.AVAILABLE
                            model.local_path = str(output_path)
                            model.download_date = time.time()
                            self.registry._save_registry()
                            
                            progress.status = DownloadStatus.COMPLETED
                            self._notify_progress(progress, progress_callback)
                            
                            logger.info(f"Successfully downloaded: {model.model_id}")
                            return True
                        else:
                            # Verification failed
                            output_path.unlink(missing_ok=True)
                            logger.warning(f"Verification failed for {model.model_id}")
                            
                except Exception as e:
                    logger.warning(f"Download failed from {url}: {e}")
                    continue
            
            # All attempts failed
            progress.status = DownloadStatus.FAILED
            progress.error = "All download attempts failed"
            self._notify_progress(progress, progress_callback)
            return False
            
        except Exception as e:
            progress.status = DownloadStatus.FAILED
            progress.error = str(e)
            self._notify_progress(progress, progress_callback)
            logger.error(f"Download task failed: {e}")
            return False
    
    def _download_file(self,
                      url: str,
                      output_path: Path,
                      progress: DownloadProgress,
                      progress_callback: Optional[Callable],
                      model_id: str) -> bool:
        """
        Download file with resume support.
        
        Args:
            url: Download URL
            output_path: Output file path
            progress: Progress tracker
            progress_callback: Progress callback
            model_id: Model ID for stop event
            
        Returns:
            True if successful
        """
        # Check for partial download
        temp_path = output_path.with_suffix('.part')
        resume_pos = 0
        
        if temp_path.exists():
            resume_pos = temp_path.stat().st_size
            logger.info(f"Resuming download from {resume_pos} bytes")
        
        # Prepare headers for resume
        headers = {}
        if resume_pos > 0:
            headers['Range'] = f'bytes={resume_pos}-'
        
        # Start download
        start_time = time.time()
        bytes_downloaded = resume_pos
        
        try:
            response = requests.get(
                url,
                headers=headers,
                stream=True,
                timeout=self.timeout
            )
            response.raise_for_status()
            
            # Get total size
            if 'content-length' in response.headers:
                total_size = int(response.headers['content-length']) + resume_pos
                progress.total_bytes = total_size
            else:
                total_size = None
            
            # Download with progress
            mode = 'ab' if resume_pos > 0 else 'wb'
            with open(temp_path, mode) as f:
                for chunk in response.iter_content(chunk_size=self.chunk_size):
                    # Check for cancellation
                    if self._stop_events[model_id].is_set():
                        logger.info(f"Download cancelled: {model_id}")
                        return False
                    
                    if chunk:
                        f.write(chunk)
                        bytes_downloaded += len(chunk)
                        
                        # Update progress
                        progress.current_bytes = bytes_downloaded
                        
                        # Calculate speed and ETA
                        elapsed = time.time() - start_time
                        if elapsed > 0:
                            speed_bps = (bytes_downloaded - resume_pos) / elapsed
                            progress.speed_mbps = (speed_bps * 8) / 1_000_000
                            
                            if total_size and speed_bps > 0:
                                remaining = total_size - bytes_downloaded
                                progress.eta_seconds = remaining / speed_bps
                        
                        self._notify_progress(progress, progress_callback)
            
            # Move to final location
            shutil.move(str(temp_path), str(output_path))
            return True
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Download error: {e}")
            return False
        except Exception as e:
            logger.error(f"File write error: {e}")
            return False
    
    def _verify_download(self, file_path: Path, model: ModelInfo) -> bool:
        """
        Verify downloaded file.
        
        Args:
            file_path: Downloaded file path
            model: Model information
            
        Returns:
            True if verification passed
        """
        # Check file exists
        if not file_path.exists():
            return False
        
        # Check file size
        actual_size = file_path.stat().st_size
        if model.file_size > 0:
            size_diff = abs(actual_size - model.file_size)
            if size_diff > 1000:  # Allow 1KB difference
                logger.warning(f"Size mismatch: expected {model.file_size}, got {actual_size}")
                return False
        
        # Check SHA256 if available
        if model.sha256:
            try:
                sha256 = self._calculate_sha256(file_path)
                if sha256 != model.sha256:
                    logger.warning(f"SHA256 mismatch for {model.model_id}")
                    return False
            except Exception as e:
                logger.error(f"SHA256 calculation failed: {e}")
                return False
        
        return True
    
    def _calculate_sha256(self, file_path: Path) -> str:
        """Calculate SHA256 hash of file."""
        sha256_hash = hashlib.sha256()
        with open(file_path, "rb") as f:
            for byte_block in iter(lambda: f.read(self.chunk_size), b""):
                sha256_hash.update(byte_block)
        return sha256_hash.hexdigest()
    
    def _notify_progress(self, progress: DownloadProgress, callback: Optional[Callable]):
        """Notify progress callback."""
        if callback:
            try:
                callback(progress)
            except Exception as e:
                logger.error(f"Progress callback error: {e}")
    
    def cancel_download(self, model_id: str) -> bool:
        """
        Cancel ongoing download.
        
        Args:
            model_id: Model ID to cancel
            
        Returns:
            True if cancelled
        """
        if model_id in self._stop_events:
            self._stop_events[model_id].set()
            
            # Wait for cancellation
            if model_id in self.futures:
                try:
                    self.futures[model_id].result(timeout=5)
                except:
                    pass
                del self.futures[model_id]
            
            # Update progress
            if model_id in self.downloads:
                self.downloads[model_id].status = DownloadStatus.CANCELLED
            
            logger.info(f"Download cancelled: {model_id}")
            return True
        
        return False
    
    def get_progress(self, model_id: str) -> Optional[DownloadProgress]:
        """Get download progress for model."""
        return self.downloads.get(model_id)
    
    def get_all_progress(self) -> Dict[str, DownloadProgress]:
        """Get all download progress."""
        return self.downloads.copy()
    
    def cleanup_partial_downloads(self):
        """Clean up partial download files."""
        for file in self.registry.models_dir.glob("*.part"):
            try:
                file.unlink()
                logger.info(f"Removed partial download: {file.name}")
            except Exception as e:
                logger.error(f"Failed to remove {file}: {e}")
    
    def download_required_models(self,
                                task: str = None,
                                gpu_available: bool = True) -> bool:
        """
        Download all required models for a task.
        
        Args:
            task: Optional task filter
            gpu_available: GPU availability
            
        Returns:
            True if all downloads successful
        """
        # Get required models
        required = []
        
        if task:
            # Get best model for task
            from .registry import ModelTask
            task_enum = ModelTask(task)
            model = self.registry.get_best_model(
                task_enum,
                require_gpu=gpu_available if gpu_available else False
            )
            if model:
                required.append(model.model_id)
        else:
            # Get all essential models
            essential = ['rmbg-1.4', 'u2netp', 'modnet']
            for model_id in essential:
                if self.registry.get_model(model_id):
                    required.append(model_id)
        
        # Download models
        if required:
            logger.info(f"Downloading required models: {required}")
            futures = self.download_models_async(required)
            
            # Wait for completion
            success = True
            for model_id, future in futures.items():
                try:
                    if not future.result():
                        success = False
                except Exception:
                    success = False
            
            return success
        
        return True