LLM_Powered_Legal_RAG / app /managers /vector_manager.py
nanfangwuyu21's picture
Deutsch is now available, language can be switched, use different embedding models, ingest resources from two languages, also propmts, vector retrival supported for both languages now.
db81bb8
import os
import faiss
from typing import List, Optional
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document
from app.models.model import Embedding_model_en, Embedding_model_de
from app import config
BASE_PATH = "vectorstore"
VECTORSTORE_TYPES = {
"English": "English_index",
"Deutsch": "Deutsch_index"
}
def get_embedding_model(language):
return Embedding_model_en if language == "English" else Embedding_model_de
def create_new_vectorstore(embedding_model):
dim = len(embedding_model.embed_query("hello world"))
index = faiss.IndexFlatL2(dim)
return FAISS(
embedding_function=embedding_model,
index=index,
docstore=InMemoryDocstore(),
index_to_docstore_id={}
)
def load_vectorstore(store_type: str) -> FAISS:
assert store_type in VECTORSTORE_TYPES, "Invalid vectorstore type."
path = os.path.join(BASE_PATH, VECTORSTORE_TYPES[store_type])
print(f"Load vectorstore from language {store_type}")
if os.path.exists(os.path.join(path, 'index.faiss')):
print("Reload existing faiss")
return FAISS.load_local(path, get_embedding_model(store_type), allow_dangerous_deserialization=True)
else:
print("Create new faiss")
vs = create_new_vectorstore(get_embedding_model(store_type))
save_vectorstore(vs, store_type)
return vs
def save_vectorstore(vectorstore: FAISS, store_type: str):
path = os.path.join(BASE_PATH, VECTORSTORE_TYPES[store_type])
vectorstore.save_local(path)
def add_document(content: str, metadata: dict, store_type: str):
assert store_type == metadata.get("type")
vs = load_vectorstore(store_type)
doc = Document(page_content=content, metadata=metadata)
vs.add_documents([doc])
save_vectorstore(vs, store_type)
def add_multi_documents(processed_docs: list, store_type: str):
vs = load_vectorstore(store_type)
vs.add_documents(processed_docs)
save_vectorstore(vs, store_type)
def get_relevant_documents(store_type, query: str, top_k: int = 10) -> List[Document]:
vs = load_vectorstore(store_type)
return vs.similarity_search(query, k=top_k)