Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Any, Optional, Tuple | |
| from fastapi import FastAPI, HTTPException, Request, Depends | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel, Field | |
| from pathlib import Path | |
| import numpy as np, json, os, time, uuid, pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| # optional engines | |
| try: | |
| from pyspark.sql import SparkSession, functions as F | |
| from pyspark.sql.types import StringType | |
| SPARK_AVAILABLE = True | |
| except Exception: | |
| SPARK_AVAILABLE = False | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| RERANK_AVAILABLE = True | |
| except Exception: | |
| RERANK_AVAILABLE = False | |
| APP_VERSION = "1.0.0" | |
| EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| DATA_DIR = Path("./data"); DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| INDEX_FP = DATA_DIR / "index.faiss" | |
| META_FP = DATA_DIR / "meta.jsonl" | |
| PARQ_FP = DATA_DIR / "meta.parquet" | |
| CFG_FP = DATA_DIR / "store.json" | |
| # --------- Schemas ---------- | |
| class EchoRequest(BaseModel): | |
| message: str | |
| class HealthResponse(BaseModel): | |
| status: str; version: str; index_size: int = 0; model: str = ""; spark: bool = False | |
| persisted: bool = False; rerank: bool = False; index_type: str = "flat" | |
| class EmbedRequest(BaseModel): | |
| texts: List[str] = Field(..., min_items=1); preview_n: int = Field(default=6, ge=0, le=32); normalize: bool = True | |
| class EmbedResponse(BaseModel): | |
| dim: int; count: int; preview: List[List[float]] | |
| class Doc(BaseModel): | |
| id: Optional[str] = None; text: str; meta: Dict[str, Any] = Field(default_factory=dict) | |
| class ChunkConfig(BaseModel): | |
| size: int = Field(default=800, gt=0); overlap: int = Field(default=120, ge=0) | |
| class IngestRequest(BaseModel): | |
| docs: List[Doc]; chunk: ChunkConfig = Field(default_factory=ChunkConfig); normalize: bool = True; use_spark: Optional[bool] = None | |
| class Match(BaseModel): | |
| id: str; score: float; text: Optional[str] = None; meta: Dict[str, Any] = Field(default_factory=dict) | |
| class QueryRequest(BaseModel): | |
| q: str; k: int = Field(default=5, ge=1, le=50); return_text: bool = True | |
| class QueryResponse(BaseModel): | |
| matches: List[Match] | |
| class ExplainMatch(Match): | |
| start: int; end: int; token_overlap: float | |
| class ExplainRequest(QueryRequest): pass | |
| class ExplainResponse(BaseModel): | |
| matches: List[ExplainMatch] | |
| class AnswerRequest(BaseModel): | |
| q: str; k: int = Field(default=5, ge=1, le=50); model: str = Field(default="mock") | |
| max_context_chars: int = Field(default=1600, ge=200, le=20000) | |
| return_contexts: bool = True; rerank: bool = False | |
| rerank_model: str = Field(default="cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| class AnswerResponse(BaseModel): | |
| answer: str; contexts: List[Match] = [] | |
| class ReindexParams(BaseModel): | |
| index_type: str = Field(default="flat", pattern="^(flat|ivf|hnsw)$") | |
| nlist: int = Field(default=64, ge=1, le=65536); M: int = Field(default=32, ge=4, le=128) | |
| # --------- Embeddings ---------- | |
| class LazyEmbedder: | |
| def __init__(self, model_name: str = EMBED_MODEL_NAME): | |
| self.model_name = model_name; self._model: Optional[SentenceTransformer] = None; self._dim: Optional[int] = None | |
| def _ensure(self): | |
| if self._model is None: | |
| self._model = SentenceTransformer(self.model_name) | |
| self._dim = int(self._model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # type: ignore | |
| def dim(self) -> int: | |
| self._ensure(); return int(self._dim) # type: ignore | |
| def encode(self, texts: List[str], normalize: bool = True) -> np.ndarray: | |
| self._ensure() | |
| vecs = self._model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True) # type: ignore | |
| if normalize: | |
| norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 | |
| vecs = vecs / norms | |
| return vecs.astype("float32") | |
| _embedder = LazyEmbedder() | |
| # --------- Reranker ---------- | |
| class LazyReranker: | |
| def __init__(self): self._model=None; self._name=None | |
| def ensure(self, name: str): | |
| if not RERANK_AVAILABLE: return | |
| if self._model is None or self._name != name: | |
| self._model = CrossEncoder(name); self._name = name | |
| def score(self, q: str, texts: List[str]) -> List[float]: | |
| if not RERANK_AVAILABLE or self._model is None: return [0.0]*len(texts) | |
| return [float(s) for s in self._model.predict([(q,t) for t in texts])] # type: ignore | |
| _reranker = LazyReranker() | |
| # --------- Chunking ---------- | |
| def chunk_text_py(text: str, size: int, overlap: int): | |
| t = " ".join((text or "").split()); n=len(t); out=[]; s=0 | |
| if overlap >= size: overlap = max(size - 1, 0) | |
| while s<n: | |
| e=min(s+size,n); out.append((t[s:e],(s,e))) | |
| if e==n: break | |
| s=max(e-overlap,0) | |
| return out | |
| def spark_clean_and_chunk(docs: List[Doc], size: int, overlap: int): | |
| if not SPARK_AVAILABLE: raise RuntimeError("Spark not available") | |
| spark = SparkSession.builder.appName("RAG-ETL").getOrCreate() | |
| import json as _j | |
| rows=[{"id":d.id or f"doc-{i}","text":d.text,"meta_json":_j.dumps(d.meta)} for i,d in enumerate(docs)] | |
| df=spark.createDataFrame(rows).withColumn("text",F.regexp_replace(F.col("text"),r"\s+"," ")).withColumn("text",F.trim(F.col("text"))).filter(F.length("text")>0) | |
| sz,ov=int(size),int(overlap); | |
| if ov>=sz: ov=max(sz-1,0) | |
| def chunk_udf(text: str, pid: str, meta_json: str) -> str: | |
| t=" ".join((text or "").split()); n=len(t); s=0; base=_j.loads(meta_json) if meta_json else {}; out=[] | |
| while s<n: | |
| e=min(s+sz,n); cid=f"{pid}::offset:{s}-{e}"; m=dict(base); m.update({"parent_id":pid,"start":s,"end":e}) | |
| out.append({"id":cid,"text":t[s:e],"meta":m}); | |
| if e==n: break | |
| s=max(e-ov,0) | |
| return _j.dumps(out) | |
| df=df.withColumn("chunks_json",chunk_udf(F.col("text"),F.col("id"),F.col("meta_json"))) | |
| exploded=df.select(F.explode(F.from_json("chunks_json","array<map<string,string>>")).alias("c")) | |
| out=exploded.select(F.col("c")["id"].alias("id"),F.col("c")["text"].alias("text"),F.col("c")["meta"].alias("meta_json")).collect() | |
| import json as _j2 | |
| return [{"id":r["id"],"text":r["text"],"meta":_j2.loads(r["meta_json"]) if r["meta_json"] else {}} for r in out] | |
| # --------- Vector index ---------- | |
| class VectorIndex: | |
| def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32): | |
| self.dim=dim; self.type=index_type; self.metric="ip"; self.nlist=nlist; self.M=M | |
| if index_type=="flat": | |
| self.index = faiss.IndexFlatIP(dim) | |
| elif index_type=="ivf": | |
| quant = faiss.IndexFlatIP(dim) | |
| self.index = faiss.IndexIVFFlat(quant, dim, max(1,nlist), faiss.METRIC_INNER_PRODUCT) | |
| elif index_type=="hnsw": | |
| self.index = faiss.IndexHNSWFlat(dim, max(4,M)); self.metric="l2" | |
| else: | |
| raise ValueError("bad index_type") | |
| def train(self, vecs: np.ndarray): | |
| if hasattr(self.index,"is_trained") and not self.index.is_trained: | |
| self.index.train(vecs) | |
| def add(self, vecs: np.ndarray): | |
| self.train(vecs); self.index.add(vecs) | |
| def search(self, qvec: np.ndarray, k: int): | |
| D,I = self.index.search(qvec,k) | |
| scores = (1.0 - 0.5*D[0]).tolist() if self.metric=="l2" else D[0].tolist() | |
| return I[0].tolist(), scores | |
| def save(self, fp: Path): faiss.write_index(self.index, str(fp)) | |
| def load(fp: Path) -> "VectorIndex": | |
| idx = faiss.read_index(str(fp)) | |
| vi = VectorIndex(idx.d, "flat"); vi.index = idx | |
| vi.metric = "ip" if isinstance(idx, faiss.IndexFlatIP) or "IVF" in str(type(idx)) else "l2" | |
| return vi | |
| # --------- Store ---------- | |
| class MemoryIndex: | |
| def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32): | |
| self.ids: List[str]=[]; self.texts: List[str]=[]; self.metas: List[Dict[str,Any]]=[] | |
| self.vindex = VectorIndex(dim, index_type=index_type, nlist=nlist, M=M) | |
| def add(self, vecs: np.ndarray, rows: List[Dict[str, Any]]): | |
| if vecs.shape[0]!=len(rows): raise ValueError("Vector count != row count") | |
| self.vindex.add(vecs) | |
| for r in rows: self.ids.append(r["id"]); self.texts.append(r["text"]); self.metas.append(r["meta"]) | |
| def size(self)->int: return self.vindex.index.ntotal | |
| def search(self, qvec: np.ndarray, k: int): return self.vindex.search(qvec,k) | |
| def save(self): | |
| self.vindex.save(INDEX_FP) | |
| with META_FP.open("w",encoding="utf-8") as f: | |
| for i in range(len(self.ids)): | |
| f.write(json.dumps({"id":self.ids[i],"text":self.texts[i],"meta":self.metas[i]})+"\n") | |
| try: | |
| df = pd.DataFrame({"id":self.ids,"text":self.texts,"meta_json":[json.dumps(m) for m in self.metas]}) | |
| df.to_parquet(PARQ_FP, index=False) | |
| except Exception: | |
| pass | |
| CFG_FP.write_text(json.dumps({"model":EMBED_MODEL_NAME,"dim":_embedder.dim,"index_type":self.vindex.type,"nlist":self.vindex.nlist,"M":self.vindex.M}),encoding="utf-8") | |
| def load_if_exists() -> Optional["MemoryIndex"]: | |
| if not INDEX_FP.exists() or not META_FP.exists(): return None | |
| cfg={"index_type":"flat","nlist":64,"M":32} | |
| if CFG_FP.exists(): | |
| try: cfg.update(json.loads(CFG_FP.read_text())) | |
| except Exception: pass | |
| vi = VectorIndex.load(INDEX_FP) | |
| store = MemoryIndex(dim=vi.dim, index_type=cfg.get("index_type","flat"), nlist=cfg.get("nlist",64), M=cfg.get("M",32)) | |
| store.vindex = vi | |
| ids,texts,metas=[],[],[] | |
| with META_FP.open("r",encoding="utf-8") as f: | |
| for line in f: | |
| rec=json.loads(line); ids.append(rec["id"]); texts.append(rec["text"]); metas.append(rec.get("meta",{})) | |
| store.ids,store.texts,store.metas=ids,texts,metas | |
| return store | |
| def reset_files(): | |
| for p in [INDEX_FP, META_FP, PARQ_FP, CFG_FP]: | |
| try: | |
| if p.exists(): p.unlink() | |
| except Exception: | |
| pass | |
| _mem_store: Optional[MemoryIndex] = MemoryIndex.load_if_exists() | |
| def require_store() -> MemoryIndex: | |
| if _mem_store is None or _mem_store.size()==0: | |
| raise HTTPException(status_code=400, detail="Index empty. Ingest documents first.") | |
| return _mem_store | |
| # --------- Helpers ---------- | |
| def _token_overlap(q: str, txt: str) -> float: | |
| qt={t for t in q.lower().split() if t}; tt={t for t in (txt or "").lower().split() if t} | |
| if not qt: return 0.0 | |
| return float(len(qt & tt))/float(len(qt)) | |
| def _topk(q: str, k: int) -> List[Match]: | |
| store=require_store(); qvec=_embedder.encode([q], normalize=True) | |
| idxs,scores=store.search(qvec,k); out=[] | |
| for i,s in zip(idxs,scores): | |
| if i==-1: continue | |
| out.append(Match(id=store.ids[i], score=float(s), text=store.texts[i], meta=store.metas[i])) | |
| return out | |
| def _compose_contexts(matches: List[Match], max_chars: int) -> str: | |
| buf=[]; total=0 | |
| for m in matches: | |
| t=m.text or ""; cut=min(len(t), max_chars-total) | |
| if cut<=0: break | |
| buf.append(t[:cut]); total+=cut | |
| if total>=max_chars: break | |
| return "\n\n".join(buf).strip() | |
| def _answer_with_mock(q: str, contexts: str) -> str: | |
| if not contexts: return "No indexed context available to answer the question." | |
| lines=[ln.strip() for ln in contexts.split("\n") if ln.strip()] | |
| hits=[ln for ln in lines if any(t in ln.lower() for t in q.lower().split())] | |
| if not hits: hits=lines[:2] | |
| return "Based on retrieved context, here’s a concise answer:\n- " + "\n- ".join(hits[:4]) | |
| def _maybe_rerank(q: str, matches: List[Match], enabled: bool, model_name: str) -> List[Match]: | |
| if not enabled: return matches | |
| try: | |
| _reranker.ensure(model_name) | |
| scores=_reranker.score(q, [m.text or "" for m in matches]) | |
| order=sorted(range(len(matches)), key=lambda i: scores[i], reverse=True) | |
| return [matches[i] for i in order] | |
| except Exception: | |
| return matches | |
| def _write_parquet_if_missing(): | |
| if not PARQ_FP.exists() and META_FP.exists(): | |
| try: | |
| rows=[json.loads(line) for line in META_FP.open("r",encoding="utf-8")] | |
| if rows: | |
| pd.DataFrame({"id":[r["id"] for r in rows], | |
| "text":[r["text"] for r in rows], | |
| "meta_json":[json.dumps(r.get("meta",{})) for r in rows]}).to_parquet(PARQ_FP,index=False) | |
| except Exception: | |
| pass | |
| # --------- Auth/limits/metrics ---------- | |
| API_KEY = os.getenv("API_KEY","") | |
| _rate = {"capacity":60,"refill_per_sec":1.0} | |
| _buckets: Dict[str, Dict[str, float]] = {} | |
| _metrics = {"requests":0,"by_endpoint":{}, "started": time.time()} | |
| def _allow(ip: str) -> bool: | |
| now=time.time(); b=_buckets.get(ip,{"tokens":_rate["capacity"],"ts":now}) | |
| tokens=min(b["tokens"]+(now-b["ts"])*_rate["refill_per_sec"], _rate["capacity"]) | |
| if tokens<1.0: | |
| _buckets[ip]={"tokens":tokens,"ts":now}; return False | |
| _buckets[ip]={"tokens":tokens-1.0,"ts":now}; return True | |
| async def guard(request: Request): | |
| if API_KEY and request.headers.get("x-api-key","")!=API_KEY: | |
| raise HTTPException(status_code=401, detail="invalid api key") | |
| ip=request.client.host if request.client else "local" | |
| if not _allow(ip): | |
| raise HTTPException(status_code=429, detail="rate limited") | |
| app = FastAPI(title="RAG-as-a-Service", version=APP_VERSION, description="Steps 10–13") | |
| async def req_meta(request: Request, call_next): | |
| rid=str(uuid.uuid4()); _metrics["requests"]+=1 | |
| ep=f"{request.method} {request.url.path}"; _metrics["by_endpoint"][ep]=_metrics["by_endpoint"].get(ep,0)+1 | |
| resp=await call_next(request) | |
| try: resp.headers["x-request-id"]=rid | |
| except Exception: pass | |
| return resp | |
| # --------- API ---------- | |
| def root(): | |
| return """<!doctype html><html><head><meta charset="utf-8"><title>RAG-as-a-Service</title></head> | |
| <body style="font-family:system-ui;margin:2rem;max-width:900px"> | |
| <h2>RAG-as-a-Service</h2> | |
| <input id="q" style="width:70%" placeholder="Ask a question"><button onclick="ask()">Ask</button> | |
| <pre id="out" style="background:#111;color:#eee;padding:1rem;border-radius:8px;white-space:pre-wrap"></pre> | |
| <script> | |
| async function ask(){ | |
| const q=document.getElementById('q').value; | |
| const res=await fetch('/answer',{method:'POST',headers:{'content-type':'application/json'},body:JSON.stringify({q, k:5, return_contexts:true})}); | |
| document.getElementById('out').textContent=JSON.stringify(await res.json(),null,2); | |
| } | |
| </script></body></html>""" | |
| def health() -> HealthResponse: | |
| size=_mem_store.size() if _mem_store is not None else 0 | |
| persisted=INDEX_FP.exists() and META_FP.exists() | |
| idx_type="flat" | |
| if CFG_FP.exists(): | |
| try: idx_type=json.loads(CFG_FP.read_text()).get("index_type","flat") | |
| except Exception: pass | |
| return HealthResponse(status="ok", version=APP_VERSION, index_size=size, model=EMBED_MODEL_NAME, spark=SPARK_AVAILABLE, persisted=persisted, rerank=RERANK_AVAILABLE, index_type=idx_type) | |
| def metrics(): | |
| up=time.time()-_metrics["started"] | |
| return {"requests":_metrics["requests"],"by_endpoint":_metrics["by_endpoint"],"uptime_sec":round(up,2)} | |
| def echo(payload: EchoRequest) -> Dict[str, str]: | |
| return {"echo": payload.message, "length": str(len(payload.message))} | |
| def embed(payload: EmbedRequest) -> EmbedResponse: | |
| vecs=_embedder.encode(payload.texts, normalize=payload.normalize) | |
| preview=[[float(round(v,5)) for v in row[:payload.preview_n]] for row in vecs] if payload.preview_n>0 else [] | |
| return EmbedResponse(dim=int(vecs.shape[1]), count=int(vecs.shape[0]), preview=preview) | |
| def ingest(req: IngestRequest) -> Dict[str, Any]: | |
| global _mem_store | |
| if _mem_store is None: | |
| cfg={"index_type":"flat","nlist":64,"M":32} | |
| if CFG_FP.exists(): | |
| try: cfg.update(json.loads(CFG_FP.read_text())) | |
| except Exception: pass | |
| _mem_store=MemoryIndex(dim=_embedder.dim, index_type=cfg["index_type"], nlist=cfg["nlist"], M=cfg["M"]) | |
| use_spark=SPARK_AVAILABLE if req.use_spark is None else bool(req.use_spark) | |
| rows=[] | |
| if use_spark: | |
| try: rows=spark_clean_and_chunk(req.docs, size=req.chunk.size, overlap=req.chunk.overlap) | |
| except Exception: rows=[] | |
| if not rows: | |
| for d in req.docs: | |
| pid=d.id or "doc" | |
| for ctext,(s,e) in chunk_text_py(d.text, size=req.chunk.size, overlap=req.chunk.overlap): | |
| meta=dict(d.meta); meta.update({"parent_id":pid,"start":s,"end":e}) | |
| rows.append({"id":f"{pid}::offset:{s}-{e}","text":ctext,"meta":meta}) | |
| if not rows: raise HTTPException(status_code=400, detail="No non-empty chunks produced") | |
| vecs=_embedder.encode([r["text"] for r in rows], normalize=req.normalize) | |
| _mem_store.add(vecs, rows); _mem_store.save(); | |
| if not PARQ_FP.exists(): | |
| try: | |
| pd.DataFrame({"id":[r["id"] for r in rows],"text":[r["text"] for r in rows],"meta_json":[json.dumps(r["meta"]) for r in rows]}).to_parquet(PARQ_FP,index=False) | |
| except Exception: pass | |
| return {"docs": len(req.docs), "chunks": len(rows), "index_size": _mem_store.size(), "engine": "spark" if use_spark else "python", "persisted": True} | |
| def query(req: QueryRequest) -> QueryResponse: | |
| matches=_topk(req.q, req.k) | |
| if not req.return_text: matches=[Match(id=m.id, score=m.score, text=None, meta=m.meta) for m in matches] | |
| return QueryResponse(matches=matches) | |
| def explain(req: ExplainRequest) -> ExplainResponse: | |
| matches=_topk(req.q, req.k); out=[] | |
| for m in matches: | |
| meta=m.meta; start=int(meta.get("start",0)); end=int(meta.get("end",0)) | |
| out.append(ExplainMatch(id=m.id, score=m.score, text=m.text if req.return_text else None, meta=meta, start=start, end=end, token_overlap=float(round(_token_overlap(req.q, m.text or ""),4)))) | |
| return ExplainResponse(matches=out) | |
| def answer(req: AnswerRequest) -> AnswerResponse: | |
| matches=_topk(req.q, req.k) | |
| matches=_maybe_rerank(req.q, matches, enabled=req.rerank, model_name=req.rerank_model) | |
| ctx=_compose_contexts(matches, req.max_context_chars) | |
| out=_answer_with_mock(req.q, ctx) if req.model=="mock" else _answer_with_mock(req.q, ctx) | |
| return AnswerResponse(answer=out, contexts=matches if req.return_contexts else []) | |
| def reindex(params: ReindexParams) -> Dict[str, Any]: | |
| global _mem_store | |
| if not META_FP.exists(): | |
| raise HTTPException(status_code=400, detail="no metadata on disk") | |
| rows = [json.loads(line) for line in META_FP.open("r", encoding="utf-8")] | |
| if not rows: | |
| raise HTTPException(status_code=400, detail="empty metadata") | |
| texts = [r["text"] for r in rows] | |
| vecs = _embedder.encode(texts, normalize=True) | |
| # Cap nlist to dataset size for IVF | |
| idx_type = params.index_type | |
| eff_nlist = params.nlist | |
| if idx_type == "ivf": | |
| eff_nlist = max(1, min(eff_nlist, len(rows))) | |
| try: | |
| _mem_store = MemoryIndex(dim=_embedder.dim, index_type=idx_type, nlist=eff_nlist, M=params.M) | |
| _mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows]) | |
| _mem_store.save() | |
| return { | |
| "reindexed": True, | |
| "index_type": idx_type, | |
| "index_size": _mem_store.size(), | |
| "nlist": eff_nlist, | |
| "M": params.M | |
| } | |
| except Exception as e: | |
| # Fallback to flat if IVF/HNSW training/add fails for any reason | |
| _mem_store = MemoryIndex(dim=_embedder.dim, index_type="flat") | |
| _mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows]) | |
| _mem_store.save() | |
| return { | |
| "reindexed": True, | |
| "index_type": "flat", | |
| "index_size": _mem_store.size(), | |
| "note": f"fallback due to: {str(e)[:120]}" | |
| } | |
| def reset() -> Dict[str, Any]: | |
| global _mem_store; _mem_store=None; MemoryIndex.reset_files(); return {"reset": True} | |
| def bulk_load_hf(repo: str, split: str = "train", text_field: str = "text", id_field: Optional[str]=None, meta_fields: Optional[List[str]]=None, chunk_size:int=800, overlap:int=120): | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset(repo, split=split) | |
| docs=[] | |
| for rec in ds: | |
| rid = str(rec[id_field]) if id_field and id_field in rec else None | |
| meta = {k: rec[k] for k in (meta_fields or []) if k in rec} | |
| docs.append(Doc(id=rid, text=str(rec[text_field]), meta=meta)) | |
| return ingest(IngestRequest(docs=docs, chunk=ChunkConfig(size=chunk_size, overlap=overlap), normalize=True)) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"bulk_load_hf failed: {e}") | |