Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,13 @@ from pathlib import Path
|
|
| 10 |
from typing import Optional, Dict, Any
|
| 11 |
import uuid, shutil, cv2, json, time, urllib.parse, sys
|
| 12 |
import threading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
import subprocess
|
|
|
|
| 14 |
import shutil as _shutil
|
| 15 |
# --- POINTEUR DE BACKEND (lit l'URL actuelle depuis une source externe) ------
|
| 16 |
import os
|
|
@@ -72,8 +78,103 @@ async def proxy_all(full_path: str, request: Request):
|
|
| 72 |
# -------------------------------------------------------------------------------
|
| 73 |
# Global progress dict (vid_stem -> {percent, logs, done})
|
| 74 |
progress_data: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# ---------- Helpers ----------
|
|
|
|
| 77 |
def _is_video(p: Path) -> bool:
|
| 78 |
return p.suffix.lower() in {".mp4", ".mov", ".mkv", ".webm"}
|
| 79 |
|
|
@@ -361,22 +462,36 @@ def window(vid: str, center: int = 0, count: int = 21):
|
|
| 361 |
center = max(0, min(int(center), max(0, frames-1)))
|
| 362 |
if frames <= 0:
|
| 363 |
print(f"[WINDOW] frames=0 for {vid}", file=sys.stdout)
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
# ----- Masques -----
|
|
|
|
| 380 |
@app.post("/mask", tags=["mask"])
|
| 381 |
async def save_mask(payload: Dict[str, Any] = Body(...)):
|
| 382 |
vid = payload.get("vid")
|
|
|
|
| 10 |
from typing import Optional, Dict, Any
|
| 11 |
import uuid, shutil, cv2, json, time, urllib.parse, sys
|
| 12 |
import threading
|
| 13 |
+
# <<<< HUGINFACE PATCH: WARMUP IMPORTS START >>>>
|
| 14 |
+
from typing import List
|
| 15 |
+
from huggingface_hub import snapshot_download
|
| 16 |
+
import traceback
|
| 17 |
+
# <<<< HUGINFACE PATCH: WARMUP IMPORTS END >>>>
|
| 18 |
import subprocess
|
| 19 |
+
|
| 20 |
import shutil as _shutil
|
| 21 |
# --- POINTEUR DE BACKEND (lit l'URL actuelle depuis une source externe) ------
|
| 22 |
import os
|
|
|
|
| 78 |
# -------------------------------------------------------------------------------
|
| 79 |
# Global progress dict (vid_stem -> {percent, logs, done})
|
| 80 |
progress_data: Dict[str, Dict[str, Any]] = {}
|
| 81 |
+
# <<<< HUGINFACE PATCH: WARMUP STATE+HELPERS START >>>>
|
| 82 |
+
# État global du warm-up (progression, logs, etc.)
|
| 83 |
+
warmup_state: Dict[str, Any] = {
|
| 84 |
+
"running": False,
|
| 85 |
+
"percent": 0,
|
| 86 |
+
"logs": [],
|
| 87 |
+
"done": False,
|
| 88 |
+
"current": None,
|
| 89 |
+
"total": 0,
|
| 90 |
+
"idx": 0,
|
| 91 |
+
"job_id": "default",
|
| 92 |
+
}
|
| 93 |
+
warmup_lock = threading.Lock()
|
| 94 |
+
warmup_stop = threading.Event()
|
| 95 |
+
|
| 96 |
+
def _default_model_list() -> List[str]:
|
| 97 |
+
"""
|
| 98 |
+
Liste par défaut lue depuis l'env WARMUP_MODELS (JSON array) sinon vide.
|
| 99 |
+
Exemple env:
|
| 100 |
+
WARMUP_MODELS=["runwayml/stable-diffusion-v1-5","facebook/sam2-hiera-base"]
|
| 101 |
+
"""
|
| 102 |
+
env = (os.getenv("WARMUP_MODELS") or "").strip()
|
| 103 |
+
if env:
|
| 104 |
+
try:
|
| 105 |
+
lst = json.loads(env)
|
| 106 |
+
if isinstance(lst, list):
|
| 107 |
+
return [str(x).strip() for x in lst if str(x).strip()]
|
| 108 |
+
except Exception:
|
| 109 |
+
pass
|
| 110 |
+
return []
|
| 111 |
|
| 112 |
+
def _log_warmup(msg: str):
|
| 113 |
+
print(f"[WARMUP] {msg}", file=sys.stdout)
|
| 114 |
+
with warmup_lock:
|
| 115 |
+
warmup_state["logs"].append(msg)
|
| 116 |
+
if len(warmup_state["logs"]) > 400:
|
| 117 |
+
warmup_state["logs"] = warmup_state["logs"][-400:]
|
| 118 |
+
|
| 119 |
+
def _download_one(repo_id: str, tries: int = 3) -> bool:
|
| 120 |
+
"""
|
| 121 |
+
Télécharge un repo HF en réessayant si besoin.
|
| 122 |
+
"""
|
| 123 |
+
cache_home = os.path.expanduser(os.getenv("HF_HOME", "/home/user/.cache/huggingface"))
|
| 124 |
+
local_dir = os.path.join(cache_home, "models", repo_id.replace("/", "__"))
|
| 125 |
+
for attempt in range(1, tries + 1):
|
| 126 |
+
if warmup_stop.is_set():
|
| 127 |
+
return False
|
| 128 |
+
try:
|
| 129 |
+
snapshot_download(
|
| 130 |
+
repo_id,
|
| 131 |
+
local_dir=local_dir,
|
| 132 |
+
local_dir_use_symlinks=False,
|
| 133 |
+
resume_download=True,
|
| 134 |
+
)
|
| 135 |
+
return True
|
| 136 |
+
except Exception as e:
|
| 137 |
+
_log_warmup(f"{repo_id} -> tentative {attempt}/{tries} échouée: {e}")
|
| 138 |
+
time.sleep(min(10, 2 * attempt))
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
def _warmup_thread(models: List[str]):
|
| 142 |
+
with warmup_lock:
|
| 143 |
+
warmup_state.update({
|
| 144 |
+
"running": True,
|
| 145 |
+
"percent": 0,
|
| 146 |
+
"done": False,
|
| 147 |
+
"logs": [],
|
| 148 |
+
"current": None,
|
| 149 |
+
"total": len(models),
|
| 150 |
+
"idx": 0,
|
| 151 |
+
})
|
| 152 |
+
warmup_stop.clear()
|
| 153 |
+
ok_count = 0
|
| 154 |
+
for i, repo in enumerate(models):
|
| 155 |
+
if warmup_stop.is_set():
|
| 156 |
+
_log_warmup("Arrêt demandé par l’utilisateur.")
|
| 157 |
+
break
|
| 158 |
+
with warmup_lock:
|
| 159 |
+
warmup_state["idx"] = i
|
| 160 |
+
warmup_state["current"] = repo
|
| 161 |
+
warmup_state["percent"] = int((i / max(1, len(models))) * 100)
|
| 162 |
+
_log_warmup(f"Téléchargement: {repo}")
|
| 163 |
+
ok = _download_one(repo)
|
| 164 |
+
if ok:
|
| 165 |
+
ok_count += 1
|
| 166 |
+
_log_warmup(f"OK: {repo}")
|
| 167 |
+
else:
|
| 168 |
+
_log_warmup(f"ÉCHEC: {repo}")
|
| 169 |
+
|
| 170 |
+
with warmup_lock:
|
| 171 |
+
warmup_state["percent"] = 100
|
| 172 |
+
warmup_state["done"] = True
|
| 173 |
+
warmup_state["running"] = False
|
| 174 |
+
_log_warmup(f"Terminé: {ok_count}/{len(models)} modèles.")
|
| 175 |
+
# <<<< HUGINFACE PATCH: WARMUP STATE+HELPERS END >>>>
|
| 176 |
# ---------- Helpers ----------
|
| 177 |
+
|
| 178 |
def _is_video(p: Path) -> bool:
|
| 179 |
return p.suffix.lower() in {".mp4", ".mov", ".mkv", ".webm"}
|
| 180 |
|
|
|
|
| 462 |
center = max(0, min(int(center), max(0, frames-1)))
|
| 463 |
if frames <= 0:
|
| 464 |
print(f"[WINDOW] frames=0 for {vid}", file=sys.stdout)
|
| 465 |
+
return {"vid": vid, "start": start, "count": n, "selected": sel, "items": items, "frames": frames}
|
| 466 |
+
# <<<< HUGINFACE PATCH: WARMUP ROUTES START >>>>
|
| 467 |
+
@app.post("/warmup/start", tags=["warmup"])
|
| 468 |
+
async def warmup_start(payload: Optional[Dict[str, Any]] = Body(None)):
|
| 469 |
+
"""
|
| 470 |
+
Lance un téléchargement séquentiel d'une liste de modèles HF.
|
| 471 |
+
Body JSON: {"models": ["org/modelA","org/modelB", ...]} (optionnel)
|
| 472 |
+
À défaut, lit WARMUP_MODELS (env).
|
| 473 |
+
"""
|
| 474 |
+
models = (payload or {}).get("models") or _default_model_list()
|
| 475 |
+
if not isinstance(models, list) or not models:
|
| 476 |
+
raise HTTPException(400, "Liste 'models' vide. Fournir JSON {models:[\"org/model\"]} ou variable d'environnement WARMUP_MODELS.")
|
| 477 |
+
if warmup_state.get("running"):
|
| 478 |
+
return {"started": False, "already_running": True, "state": warmup_state}
|
| 479 |
+
t = threading.Thread(target=_warmup_thread, args=(models,), daemon=True)
|
| 480 |
+
t.start()
|
| 481 |
+
return {"started": True, "count": len(models)}
|
| 482 |
+
|
| 483 |
+
@app.get("/warmup/status", tags=["warmup"])
|
| 484 |
+
def warmup_status():
|
| 485 |
+
with warmup_lock:
|
| 486 |
+
return dict(warmup_state)
|
| 487 |
|
| 488 |
+
@app.post("/warmup/stop", tags=["warmup"])
|
| 489 |
+
def warmup_stop_api():
|
| 490 |
+
warmup_stop.set()
|
| 491 |
+
return {"stopping": True}
|
| 492 |
+
# <<<< HUGINFACE PATCH: WARMUP ROUTES END >>>>
|
| 493 |
# ----- Masques -----
|
| 494 |
+
|
| 495 |
@app.post("/mask", tags=["mask"])
|
| 496 |
async def save_mask(payload: Dict[str, Any] = Body(...)):
|
| 497 |
vid = payload.get("vid")
|