ClinicalTrialBasics / core /vector_sync.py
essprasad's picture
Upload 10 files
e61e934 verified
raw
history blame
7.89 kB
"""
vector_sync.py
Responsibilities:
- rebuild_faiss_from_glossary(glossary_path) -> builds a new faiss.Index + meta list
- _upload_to_dataset(index_path, meta_path, repo_id) -> upload via huggingface_hub
- safe helpers for creating normalized metadata entries
"""
import os
import re
import json
import shutil
from typing import Tuple, List, Dict, Any
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from huggingface_hub import upload_file
# default embedder (same model used elsewhere)
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
# directories
PERSISTENT_DIR = "/home/user/app/persistent"
TMP_DIR = "/home/user/app/tmp"
os.makedirs(PERSISTENT_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)
def _ensure_model():
"""Return global sentence-transformer model."""
return SentenceTransformer(EMBED_MODEL_NAME)
def _normalize_meta_row(row: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure consistent meta record fields."""
out = {
"term": row.get("term") or row.get("Term") or row.get("name") or "",
"text": row.get("text") or row.get("definition") or row.get("content") or "",
# keep both 'file' (local/basename) and full 'sources' list
"file": row.get("file") or row.get("source") or "",
"type": row.get("type") or "",
"sources": row.get("sources") if isinstance(row.get("sources"), list) else [row.get("source")] if row.get("source") else []
}
return out
# ==========================================================
# 🧠 Main Function: Rebuild FAISS from glossary.json
# ==========================================================
def rebuild_faiss_from_glossary(glossary_path: str):
"""
Build FAISS index + metadata from glossary JSON file.
Handles mixed entries (PDF, Excel, Web, Other).
Fully resilient against malformed or oversized rows.
"""
print(f"🧩 Building FAISS from glossary: {glossary_path}")
if not os.path.exists(glossary_path):
raise FileNotFoundError(f"Glossary not found: {glossary_path}")
# --- Load JSON safely
with open(glossary_path, "r", encoding="utf-8") as f:
try:
glossary_data = json.load(f)
except Exception as e:
raise RuntimeError(f"❌ Failed to load glossary JSON: {e}")
# Normalize structure
if isinstance(glossary_data, dict):
glossary_items = list(glossary_data.values())
elif isinstance(glossary_data, list):
glossary_items = glossary_data
else:
raise ValueError("Invalid glossary format — must be list or dict.")
model = SentenceTransformer(EMBED_MODEL_NAME)
entries, metas, bad_entries, long_entries = [], [], [], []
# helper: normalized type inference
def infer_type_from_source(src: str, declared_type: str = "") -> str:
src_l = (src or "").lower()
declared = (declared_type or "").lower()
if src_l.endswith(".pdf") or "pdf" in declared:
return "pdf"
if src_l.endswith((".xlsx", ".xls")) or "excel" in declared or "xls" in src_l:
return "excel"
if src_l.startswith("http") or declared == "web" or "http" in src_l:
return "web"
return "other"
# --- Process glossary items
for i, item in enumerate(glossary_items):
try:
if not isinstance(item, dict):
bad_entries.append(item)
continue
term = str(item.get("term") or item.get("Term") or item.get("name") or "").strip()
definition = str(item.get("definition") or item.get("text") or item.get("content") or "").strip()
# Normalize sources (keep list)
src_field = item.get("sources") or item.get("source") or item.get("file") or ""
if isinstance(src_field, list):
src_list = [str(s).strip() for s in src_field if s]
src = ", ".join(src_list)
else:
src_list = [str(src_field).strip()] if src_field else []
src = str(src_field).strip()
declared_type = str(item.get("type") or "").strip().lower()
entry_type = infer_type_from_source(src, declared_type)
# Clean up noisy HTML tags and whitespace
definition_clean = re.sub(r"<[^>]*>", "", definition)
definition_clean = re.sub(r"\s+", " ", definition_clean).strip()
# Skip if missing essentials
if not term or not definition_clean:
bad_entries.append(item)
continue
# Skip extremely long definitions (likely raw HTML or large web content)
if len(definition_clean) > 3000:
long_entries.append({
"term": term,
"len": len(definition_clean),
"source": src
})
continue
text = f"Definition of {term}: {definition_clean}"
entries.append(text)
metas.append({
"term": term,
"definition": definition_clean,
# preserve the original source list and file name
"sources": src_list if src_list else [src] if src else [],
"source": src,
"type": entry_type,
"file": os.path.basename(glossary_path)
})
except Exception as e:
bad_entries.append({
"index": i,
"error": str(e),
"raw": str(item)[:300]
})
continue
# --- Diagnostics
pdf_count = sum(1 for m in metas if m["type"].lower() == "pdf")
excel_count = sum(1 for m in metas if m["type"].lower() == "excel")
web_count = sum(1 for m in metas if m["type"].lower() == "web")
other_count = len(metas) - (pdf_count + excel_count + web_count)
print(f"🧠 Encoding {len(entries)} entries (PDF={pdf_count}, Excel={excel_count}, Web={web_count}, Other={other_count})…")
if bad_entries:
print(f"⚠️ {len(bad_entries)} malformed entries skipped.")
for b in bad_entries[:3]:
print(" →", json.dumps(b, ensure_ascii=False)[:300])
if long_entries:
print(f"⚠️ {len(long_entries)} very long entries (>3000 chars) skipped.")
for l in long_entries[:3]:
print(f" → Skipped {l['term']} ({l['len']} chars) from {l['source']}")
if not entries:
raise RuntimeError("❌ No valid glossary entries found after cleanup!")
# --- Encoding
embeddings = model.encode(entries, show_progress_bar=True, convert_to_numpy=True).astype("float32")
faiss.normalize_L2(embeddings)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
print(f"✅ Glossary vectors built ({len(entries)} total entries).")
# metas is list of dicts aligned with vectors — return exactly as before
return index, metas
# ==========================================================
# ☁️ Upload Helper
# ==========================================================
def _upload_to_dataset(index_path: str, meta_path: str, repo_id: str) -> None:
"""
Upload FAISS index and metadata JSON to Hugging Face dataset.
"""
try:
print(f"☁️ Uploading {index_path} and {meta_path} to {repo_id}...")
upload_file(
path_or_fileobj=index_path,
path_in_repo=f"persistent/{os.path.basename(index_path)}",
repo_id=repo_id,
repo_type="dataset"
)
upload_file(
path_or_fileobj=meta_path,
path_in_repo=f"persistent/{os.path.basename(meta_path)}",
repo_id=repo_id,
repo_type="dataset"
)
print("✅ Upload complete.")
except Exception as e:
print(f"⚠️ Upload failed: {e}")
raise