# endpoint_utils.py from __future__ import annotations from typing import Optional, Tuple, Callable, Dict, Any from urllib.parse import urlparse import os, time, requests def hf_headers(): tok = os.environ.get("HF_TOKEN", "").strip() return {"Authorization": f"Bearer {tok}"} if tok else {} def _valid_uri(uri: Optional[str]) -> bool: if not uri: return False p = urlparse(uri) return p.scheme in {"http", "https"} and bool(p.netloc) def _detail(resp: requests.Response) -> str: try: j = resp.json() return (j.get("error") or j.get("message") or "").strip() except Exception: return (resp.text or "").strip() def wake_endpoint( uri: Optional[str], *, token: Optional[str] = None, max_wait: int = 600, # was 180 — bump to 10 minutes poll_every: float = 5.0, warm_payload: Optional[Dict[str, Any]] = None, log: Callable[[str], None] = lambda _: None, ) -> Tuple[bool, Optional[str]]: """ Wake a scale-to-zero HF Inference Endpoint by nudging it, then polling until ready. Returns (True, None) if ready; otherwise (False, ""). """ if not _valid_uri(uri): return False, "invalid or missing URI (expect http(s)://...)" headers = hf_headers() if not headers: log("⚠️ HF_TOKEN not set — POST / will likely return 401/403.") last = "no response" # /health probe (auth included if required) try: hr = requests.get(f"{uri.rstrip('/')}/health", headers=headers, timeout=5) if hr.ok: log("✅ /health reports ready.") return True, None last = f"HTTP {hr.status_code} – {_detail(hr) or 'warming?'}" log(f"[health] {last}") if hr.status_code in (401, 403): return False, f"Unauthorized (check HF_TOKEN). {last}" except requests.RequestException as e: last = type(e).__name__ log(f"[health] {last}") # warmup nudge payload = warm_payload if warm_payload is not None else {"inputs": "wake"} try: requests.post(uri, headers=headers, json=payload, timeout=5) except requests.RequestException: pass # poll deadline = time.time() + max_wait while time.time() < deadline: try: r = requests.post(uri, headers=headers, json={"inputs": "ping"}, timeout=8) if r.ok: log("✅ Endpoint is awake and responsive.") return True, None d = _detail(r) last = f"HTTP {r.status_code}" + (f" – {d}" if d else "") if r.status_code in (401, 403): return False, f"Unauthorized (check HF_TOKEN, org access). {last}" if r.status_code in (429, 503, 504): log(f"[server] {d or 'warming up'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…") else: log(f"[server] {d or 'unexpected response'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…") except requests.RequestException as e: last = type(e).__name__ log(f"[client] {last}; retrying in {int(poll_every)}s…") time.sleep(poll_every) return False, f"Timed out after {max_wait}s — last status: {last}"