File size: 7,499 Bytes
e61e934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
vector_store.py
-----------------------------------------------------
Maintains FAISS runtime index + metadata cache.

Features
--------
- Ensure local FAISS runtime index exists (download from HF if missing)
- FAISS semantic search and BM25 text access
- Automatic TTL reload
- Full cache clearing for Hugging Face Space
- Explicit "♻️ FAISS memory cache reset" logging on rebuild
"""

import os
import json
import time
import shutil
from typing import List, Dict, Any, Optional

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download


# ------------------------------------------------------------------
# 🔧 Paths & constants
# ------------------------------------------------------------------
PERSISTENT_DIR = "/home/user/app/persistent"
RUNTIME_DIR = "/home/user/app/runtime_faiss"
INDEX_NAME = "faiss.index"
META_NAME = "faiss.index.meta.json"
GLOSSARY_META = "glossary.json"
HF_INDEX_REPO = "essprasad/CT-Chat-Index"

EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBED_MODEL = None  # lazy loaded

# in-memory cache
_runtime_index: Optional[faiss.Index] = None
_runtime_meta: Optional[List[Dict[str, Any]]] = None
_meta_loaded_time = 0.0
_META_TTL_SECONDS = 300.0


# ------------------------------------------------------------------
# 🔹 Helpers
# ------------------------------------------------------------------
def _ensure_dirs():
    os.makedirs(PERSISTENT_DIR, exist_ok=True)
    os.makedirs(RUNTIME_DIR, exist_ok=True)


def _ensure_model():
    global EMBED_MODEL
    if EMBED_MODEL is None:
        print("📥 Loading embedding model for FAISS retrieval…")
        EMBED_MODEL = SentenceTransformer(EMBED_MODEL_NAME)
        print("✅ Embedding model loaded.")
    return EMBED_MODEL


# ------------------------------------------------------------------
# 🔹 Cache control
# ------------------------------------------------------------------
def clear_local_faiss():
    """Delete all local FAISS + glossary caches (safe in HF Space)."""
    for p in [
        os.path.join(PERSISTENT_DIR, INDEX_NAME),
        os.path.join(PERSISTENT_DIR, META_NAME),
        os.path.join(PERSISTENT_DIR, GLOSSARY_META),
        RUNTIME_DIR,
    ]:
        try:
            if os.path.isdir(p):
                shutil.rmtree(p, ignore_errors=True)
            elif os.path.exists(p):
                os.remove(p)
            print(f"🗑️ Cleared: {p}")
        except Exception as e:
            print(f"⚠️ Failed to clear {p}: {e}")
    print("♻️ FAISS memory cache reset (runtime + persistent cleared)")


# ------------------------------------------------------------------
# 🔹 Loaders
# ------------------------------------------------------------------
def _load_local_index() -> bool:
    """Load FAISS index + metadata from persistent into runtime."""
    global _runtime_index, _runtime_meta, _meta_loaded_time
    _ensure_dirs()
    idx_path = os.path.join(PERSISTENT_DIR, INDEX_NAME)
    meta_path = os.path.join(PERSISTENT_DIR, META_NAME)
    try:
        if not (os.path.exists(idx_path) and os.path.exists(meta_path)):
            return False
        os.makedirs(RUNTIME_DIR, exist_ok=True)
        shutil.copy2(idx_path, os.path.join(RUNTIME_DIR, INDEX_NAME))
        shutil.copy2(meta_path, os.path.join(RUNTIME_DIR, META_NAME))
        _runtime_index = faiss.read_index(os.path.join(RUNTIME_DIR, INDEX_NAME))
        with open(os.path.join(RUNTIME_DIR, META_NAME), "r", encoding="utf-8") as f:
            _runtime_meta = json.load(f)
        _meta_loaded_time = time.time()
        print(f"✅ Loaded FAISS index ({len(_runtime_meta)} vectors).")
        return True
    except Exception as e:
        print(f"⚠️ Could not load local FAISS index: {e}")
        _runtime_index = None
        _runtime_meta = None
        return False


def _download_index_from_hub() -> bool:
    """Download FAISS artifacts from Hugging Face dataset repo."""
    _ensure_dirs()
    try:
        print("☁️ Downloading FAISS artifacts from HF dataset…")
        idx = hf_hub_download(repo_id=HF_INDEX_REPO,
                              filename=f"persistent/{INDEX_NAME}",
                              repo_type="dataset")
        meta = hf_hub_download(repo_id=HF_INDEX_REPO,
                               filename=f"persistent/{META_NAME}",
                               repo_type="dataset")
        shutil.copy2(idx, os.path.join(PERSISTENT_DIR, INDEX_NAME))
        shutil.copy2(meta, os.path.join(PERSISTENT_DIR, META_NAME))
        print("✅ FAISS artifacts downloaded and stored persistently.")
        return _load_local_index()
    except Exception as e:
        print(f"⚠️ HF download failed: {e}")
        return False


def _ensure_faiss_index(force_refresh: bool = False) -> bool:
    """
    Ensure runtime FAISS is available.
    If force_refresh=True, clears runtime and reloads fresh.
    """
    global _runtime_index, _runtime_meta, _meta_loaded_time
    _ensure_dirs()

    if force_refresh:
        try:
            shutil.rmtree(RUNTIME_DIR, ignore_errors=True)
            _runtime_index = None
            _runtime_meta = None
            print("♻️ Forced FAISS runtime reload requested.")
        except Exception as e:
            print(f"⚠️ Force refresh failed: {e}")

    if _runtime_index is not None and (time.time() - _meta_loaded_time) < _META_TTL_SECONDS:
        return True

    if _load_local_index():
        return True
    if _download_index_from_hub():
        return True

    print("⚠️ No FAISS index found locally or remotely.")
    return False


# ------------------------------------------------------------------
# 🔹 Accessors
# ------------------------------------------------------------------
def load_all_text_chunks() -> List[Dict[str, Any]]:
    """Return metadata list for BM25 fallback or analysis."""
    global _runtime_meta, _meta_loaded_time
    if _runtime_meta is None:
        if not _ensure_faiss_index():
            return []
    if (time.time() - _meta_loaded_time) > _META_TTL_SECONDS:
        try:
            meta_path = os.path.join(RUNTIME_DIR, META_NAME)
            with open(meta_path, "r", encoding="utf-8") as f:
                _runtime_meta = json.load(f)
                _meta_loaded_time = time.time()
        except Exception:
            pass
    return _runtime_meta or []


# ------------------------------------------------------------------
# 🔹 Core Search
# ------------------------------------------------------------------
def search_index(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
    """Perform semantic FAISS search and return metadata hits."""
    if not _ensure_faiss_index():
        return []

    try:
        model = _ensure_model()
        q_emb = model.encode([query], convert_to_numpy=True).astype("float32")
        faiss.normalize_L2(q_emb)
        D, I = _runtime_index.search(q_emb, top_k)
        results = []
        for dist, idx in zip(D[0], I[0]):
            if idx < 0 or idx >= len(_runtime_meta):
                continue
            meta = dict(_runtime_meta[idx])
            meta["score"] = float(dist)
            meta["file"] = meta.get("file") or meta.get("source") or "unknown"
            meta["text"] = meta.get("text") or meta.get("definition", "")
            results.append(meta)
        return results
    except Exception as e:
        print(f"⚠️ FAISS search failed: {e}")
        return []