ContentAgent / endpoint_utils.py
yetessam's picture
Update endpoint_utils.py
8aa8219 verified
raw
history blame
3.3 kB
# endpoint_utils.py
from __future__ import annotations
from typing import Optional, Tuple, Callable, Dict, Any
from urllib.parse import urlparse
import os, time, requests
def hf_headers():
tok = os.environ.get("HF_TOKEN", "").strip()
return {"Authorization": f"Bearer {tok}"} if tok else {}
def _valid_uri(uri: Optional[str]) -> bool:
if not uri:
return False
p = urlparse(uri)
return p.scheme in {"http", "https"} and bool(p.netloc)
def _detail(resp: requests.Response) -> str:
try:
j = resp.json()
return (j.get("error") or j.get("message") or "").strip()
except Exception:
return (resp.text or "").strip()
def wake_endpoint(
uri: Optional[str],
*,
token: Optional[str] = None,
max_wait: int = 600, # was 180 — bump to 10 minutes
poll_every: float = 5.0,
warm_payload: Optional[Dict[str, Any]] = None,
log: Callable[[str], None] = lambda _: None,
) -> Tuple[bool, Optional[str]]:
"""
Wake a scale-to-zero HF Inference Endpoint by nudging it, then polling until ready.
Returns (True, None) if ready; otherwise (False, "<last status/message>").
"""
if not _valid_uri(uri):
return False, "invalid or missing URI (expect http(s)://...)"
headers = hf_headers()
if not headers:
log("⚠️ HF_TOKEN not set — POST / will likely return 401/403.")
last = "no response"
# /health probe (auth included if required)
try:
hr = requests.get(f"{uri.rstrip('/')}/health", headers=headers, timeout=5)
if hr.ok:
log("✅ /health reports ready.")
return True, None
last = f"HTTP {hr.status_code}{_detail(hr) or 'warming?'}"
log(f"[health] {last}")
if hr.status_code in (401, 403):
return False, f"Unauthorized (check HF_TOKEN). {last}"
except requests.RequestException as e:
last = type(e).__name__
log(f"[health] {last}")
# warmup nudge
payload = warm_payload if warm_payload is not None else {"inputs": "wake"}
try:
requests.post(uri, headers=headers, json=payload, timeout=5)
except requests.RequestException:
pass
# poll
deadline = time.time() + max_wait
while time.time() < deadline:
try:
r = requests.post(uri, headers=headers, json={"inputs": "ping"}, timeout=8)
if r.ok:
log("✅ Endpoint is awake and responsive.")
return True, None
d = _detail(r)
last = f"HTTP {r.status_code}" + (f" – {d}" if d else "")
if r.status_code in (401, 403):
return False, f"Unauthorized (check HF_TOKEN, org access). {last}"
if r.status_code in (429, 503, 504):
log(f"[server] {d or 'warming up'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…")
else:
log(f"[server] {d or 'unexpected response'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…")
except requests.RequestException as e:
last = type(e).__name__
log(f"[client] {last}; retrying in {int(poll_every)}s…")
time.sleep(poll_every)
return False, f"Timed out after {max_wait}s — last status: {last}"