Spaces:
Running
Running
File size: 7,499 Bytes
e61e934 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
"""
vector_store.py
-----------------------------------------------------
Maintains FAISS runtime index + metadata cache.
Features
--------
- Ensure local FAISS runtime index exists (download from HF if missing)
- FAISS semantic search and BM25 text access
- Automatic TTL reload
- Full cache clearing for Hugging Face Space
- Explicit "♻️ FAISS memory cache reset" logging on rebuild
"""
import os
import json
import time
import shutil
from typing import List, Dict, Any, Optional
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
# ------------------------------------------------------------------
# 🔧 Paths & constants
# ------------------------------------------------------------------
PERSISTENT_DIR = "/home/user/app/persistent"
RUNTIME_DIR = "/home/user/app/runtime_faiss"
INDEX_NAME = "faiss.index"
META_NAME = "faiss.index.meta.json"
GLOSSARY_META = "glossary.json"
HF_INDEX_REPO = "essprasad/CT-Chat-Index"
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBED_MODEL = None # lazy loaded
# in-memory cache
_runtime_index: Optional[faiss.Index] = None
_runtime_meta: Optional[List[Dict[str, Any]]] = None
_meta_loaded_time = 0.0
_META_TTL_SECONDS = 300.0
# ------------------------------------------------------------------
# 🔹 Helpers
# ------------------------------------------------------------------
def _ensure_dirs():
os.makedirs(PERSISTENT_DIR, exist_ok=True)
os.makedirs(RUNTIME_DIR, exist_ok=True)
def _ensure_model():
global EMBED_MODEL
if EMBED_MODEL is None:
print("📥 Loading embedding model for FAISS retrieval…")
EMBED_MODEL = SentenceTransformer(EMBED_MODEL_NAME)
print("✅ Embedding model loaded.")
return EMBED_MODEL
# ------------------------------------------------------------------
# 🔹 Cache control
# ------------------------------------------------------------------
def clear_local_faiss():
"""Delete all local FAISS + glossary caches (safe in HF Space)."""
for p in [
os.path.join(PERSISTENT_DIR, INDEX_NAME),
os.path.join(PERSISTENT_DIR, META_NAME),
os.path.join(PERSISTENT_DIR, GLOSSARY_META),
RUNTIME_DIR,
]:
try:
if os.path.isdir(p):
shutil.rmtree(p, ignore_errors=True)
elif os.path.exists(p):
os.remove(p)
print(f"🗑️ Cleared: {p}")
except Exception as e:
print(f"⚠️ Failed to clear {p}: {e}")
print("♻️ FAISS memory cache reset (runtime + persistent cleared)")
# ------------------------------------------------------------------
# 🔹 Loaders
# ------------------------------------------------------------------
def _load_local_index() -> bool:
"""Load FAISS index + metadata from persistent into runtime."""
global _runtime_index, _runtime_meta, _meta_loaded_time
_ensure_dirs()
idx_path = os.path.join(PERSISTENT_DIR, INDEX_NAME)
meta_path = os.path.join(PERSISTENT_DIR, META_NAME)
try:
if not (os.path.exists(idx_path) and os.path.exists(meta_path)):
return False
os.makedirs(RUNTIME_DIR, exist_ok=True)
shutil.copy2(idx_path, os.path.join(RUNTIME_DIR, INDEX_NAME))
shutil.copy2(meta_path, os.path.join(RUNTIME_DIR, META_NAME))
_runtime_index = faiss.read_index(os.path.join(RUNTIME_DIR, INDEX_NAME))
with open(os.path.join(RUNTIME_DIR, META_NAME), "r", encoding="utf-8") as f:
_runtime_meta = json.load(f)
_meta_loaded_time = time.time()
print(f"✅ Loaded FAISS index ({len(_runtime_meta)} vectors).")
return True
except Exception as e:
print(f"⚠️ Could not load local FAISS index: {e}")
_runtime_index = None
_runtime_meta = None
return False
def _download_index_from_hub() -> bool:
"""Download FAISS artifacts from Hugging Face dataset repo."""
_ensure_dirs()
try:
print("☁️ Downloading FAISS artifacts from HF dataset…")
idx = hf_hub_download(repo_id=HF_INDEX_REPO,
filename=f"persistent/{INDEX_NAME}",
repo_type="dataset")
meta = hf_hub_download(repo_id=HF_INDEX_REPO,
filename=f"persistent/{META_NAME}",
repo_type="dataset")
shutil.copy2(idx, os.path.join(PERSISTENT_DIR, INDEX_NAME))
shutil.copy2(meta, os.path.join(PERSISTENT_DIR, META_NAME))
print("✅ FAISS artifacts downloaded and stored persistently.")
return _load_local_index()
except Exception as e:
print(f"⚠️ HF download failed: {e}")
return False
def _ensure_faiss_index(force_refresh: bool = False) -> bool:
"""
Ensure runtime FAISS is available.
If force_refresh=True, clears runtime and reloads fresh.
"""
global _runtime_index, _runtime_meta, _meta_loaded_time
_ensure_dirs()
if force_refresh:
try:
shutil.rmtree(RUNTIME_DIR, ignore_errors=True)
_runtime_index = None
_runtime_meta = None
print("♻️ Forced FAISS runtime reload requested.")
except Exception as e:
print(f"⚠️ Force refresh failed: {e}")
if _runtime_index is not None and (time.time() - _meta_loaded_time) < _META_TTL_SECONDS:
return True
if _load_local_index():
return True
if _download_index_from_hub():
return True
print("⚠️ No FAISS index found locally or remotely.")
return False
# ------------------------------------------------------------------
# 🔹 Accessors
# ------------------------------------------------------------------
def load_all_text_chunks() -> List[Dict[str, Any]]:
"""Return metadata list for BM25 fallback or analysis."""
global _runtime_meta, _meta_loaded_time
if _runtime_meta is None:
if not _ensure_faiss_index():
return []
if (time.time() - _meta_loaded_time) > _META_TTL_SECONDS:
try:
meta_path = os.path.join(RUNTIME_DIR, META_NAME)
with open(meta_path, "r", encoding="utf-8") as f:
_runtime_meta = json.load(f)
_meta_loaded_time = time.time()
except Exception:
pass
return _runtime_meta or []
# ------------------------------------------------------------------
# 🔹 Core Search
# ------------------------------------------------------------------
def search_index(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""Perform semantic FAISS search and return metadata hits."""
if not _ensure_faiss_index():
return []
try:
model = _ensure_model()
q_emb = model.encode([query], convert_to_numpy=True).astype("float32")
faiss.normalize_L2(q_emb)
D, I = _runtime_index.search(q_emb, top_k)
results = []
for dist, idx in zip(D[0], I[0]):
if idx < 0 or idx >= len(_runtime_meta):
continue
meta = dict(_runtime_meta[idx])
meta["score"] = float(dist)
meta["file"] = meta.get("file") or meta.get("source") or "unknown"
meta["text"] = meta.get("text") or meta.get("definition", "")
results.append(meta)
return results
except Exception as e:
print(f"⚠️ FAISS search failed: {e}")
return []
|