Spaces:
Sleeping
Sleeping
| import os | |
| import toml | |
| from typing import Optional | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda | |
| import logging | |
| from langchain_groq import ChatGroq | |
| from src.bot.extract_metadata import Metadata | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| class Medibot: | |
| def __init__(self, config_path: str = "src/bot/configs/prompt.toml", | |
| metadata_database: str = "database/metadata.csv", | |
| faiss_database: str = "database/faiss_index" | |
| ): | |
| """Initialize Medibot with configuration and Groq client.""" | |
| # Load environment variables | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| logger.error("GROQ_API_KEY not found in environment variables") | |
| raise ValueError("GROQ_API_KEY is required") | |
| os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") | |
| # Load prompt configuration | |
| try: | |
| config = toml.load(config_path) | |
| system_prompt = config["rag_prompt"]["system_prompt"] | |
| user_prompt_template = config["rag_prompt"]["user_prompt_template"] | |
| except (FileNotFoundError, toml.TomlDecodeError) as e: | |
| logger.error(f"Failed to load config from {config_path}: {e}") | |
| raise | |
| # Initialize prompt template | |
| self.prompt_template = ChatPromptTemplate.from_messages([ | |
| ("system", system_prompt), | |
| ("user", user_prompt_template) | |
| ]) | |
| # initialize vector database | |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-large") | |
| vector_store = FAISS.load_local( | |
| faiss_database, embeddings, allow_dangerous_deserialization=True | |
| ) | |
| self.retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10}) | |
| # Initialize Groq client | |
| self.model = ChatGroq( | |
| model="llama-3.1-8b-instant", | |
| temperature=0.2, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| ) | |
| self.metadata_extactor = Metadata(metadata_database) | |
| def query(self, question: str) -> str: | |
| retrieved_docs = self.retriever.invoke(question) | |
| # RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) | |
| rag_chain = ( | |
| RunnableParallel({ | |
| "context": RunnableLambda(lambda _: retrieved_docs), # Reuse retrieved docs | |
| "question": RunnablePassthrough() | |
| }) | |
| | self.prompt_template | |
| | self.model | |
| | StrOutputParser() | |
| ) | |
| answer = rag_chain.invoke({"question": question}) | |
| refered_tables , refered_images = self.metadata_extactor.get_data_from_ref(retrieved_docs) | |
| return answer, retrieved_docs, refered_tables , refered_images |