import os import faiss from typing import List, Optional from langchain_community.vectorstores import FAISS from langchain_community.docstore.in_memory import InMemoryDocstore from langchain.schema import Document from app.models.model import Embedding_model_en, Embedding_model_de from app import config BASE_PATH = "vectorstore" VECTORSTORE_TYPES = { "English": "English_index", "Deutsch": "Deutsch_index" } def get_embedding_model(language): return Embedding_model_en if language == "English" else Embedding_model_de def create_new_vectorstore(embedding_model): dim = len(embedding_model.embed_query("hello world")) index = faiss.IndexFlatL2(dim) return FAISS( embedding_function=embedding_model, index=index, docstore=InMemoryDocstore(), index_to_docstore_id={} ) def load_vectorstore(store_type: str) -> FAISS: assert store_type in VECTORSTORE_TYPES, "Invalid vectorstore type." path = os.path.join(BASE_PATH, VECTORSTORE_TYPES[store_type]) print(f"Load vectorstore from language {store_type}") if os.path.exists(os.path.join(path, 'index.faiss')): print("Reload existing faiss") return FAISS.load_local(path, get_embedding_model(store_type), allow_dangerous_deserialization=True) else: print("Create new faiss") vs = create_new_vectorstore(get_embedding_model(store_type)) save_vectorstore(vs, store_type) return vs def save_vectorstore(vectorstore: FAISS, store_type: str): path = os.path.join(BASE_PATH, VECTORSTORE_TYPES[store_type]) vectorstore.save_local(path) def add_document(content: str, metadata: dict, store_type: str): assert store_type == metadata.get("type") vs = load_vectorstore(store_type) doc = Document(page_content=content, metadata=metadata) vs.add_documents([doc]) save_vectorstore(vs, store_type) def add_multi_documents(processed_docs: list, store_type: str): vs = load_vectorstore(store_type) vs.add_documents(processed_docs) save_vectorstore(vs, store_type) def get_relevant_documents(store_type, query: str, top_k: int = 10) -> List[Document]: vs = load_vectorstore(store_type) return vs.similarity_search(query, k=top_k)