Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import io | |
| from PIL import Image | |
| import gradio as gr | |
| from src.bot.bot import Medibot | |
| from bs4 import BeautifulSoup | |
| import markdown | |
| from src.auth.auth import register_user, login_user | |
| from src.auth.db import initialize_db | |
| from groq import Groq | |
| from src import config | |
| #====================================== | |
| #=============utils==================== | |
| #====================================== | |
| # Helper functions | |
| def markdown_to_plain_text(md_text: str) -> str: | |
| html = markdown.markdown(md_text) | |
| soup = BeautifulSoup(html, "html.parser") | |
| return soup.get_text() | |
| # Ensure base64 strings are properly formatted (no newlines/whitespace) | |
| def decode_base64_to_image(base64_string): | |
| # Clean the string before decoding | |
| base64_string = base64_string.replace("\n", "").replace(" ", "") | |
| image_data = base64.b64decode(base64_string) | |
| return Image.open(io.BytesIO(image_data)) | |
| # Step 1: API Key Validation Logic | |
| def validate_api_key(user_api_key): | |
| global api_key | |
| if not user_api_key: | |
| return "β Please enter your Groq Cloud API key.", gr.update(visible=True), gr.update(visible=False) | |
| try: | |
| client = Groq(api_key=user_api_key) | |
| response = client.chat.completions.create( | |
| messages=[{"role": "user", "content": "Hello"}], | |
| model="llama3-70b-8192" | |
| ) | |
| api_key = user_api_key | |
| os.environ["GROQ_API_KEY"] = api_key | |
| return "β API key is valid and saved!", gr.update(visible=False), gr.update(visible=True) | |
| except Exception as e: | |
| return f"β Invalid API key: {str(e)}", gr.update(visible=True), gr.update(visible=False) | |
| def handle_login(userid, password, user_api_key): | |
| if user_api_key: | |
| # Step 1: Validate API Key first | |
| try: | |
| client = Groq(api_key=user_api_key) | |
| response = client.chat.completions.create( | |
| messages=[{"role": "user", "content": "Hello"}], | |
| model="llama3-70b-8192" | |
| ) | |
| # If API key is valid, proceed to register | |
| success, msg = register_user(userid, password, user_api_key) | |
| if success: | |
| config.api_key = user_api_key | |
| os.environ["GROQ_API_KEY"] = user_api_key | |
| return "β API Key validated & registered!", gr.update(visible=False), gr.update(visible=True) | |
| else: | |
| return msg, gr.update(visible=True), gr.update(visible=False) | |
| except Exception as e: | |
| # API key invalid | |
| return f"β Invalid API Key: {str(e)}", gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| # User is trying to login | |
| success, saved_api_key = login_user(userid, password) | |
| if success: | |
| config.api_key = saved_api_key | |
| os.environ["GROQ_API_KEY"] = saved_api_key | |
| return "β Login successful!", gr.update(visible=False), gr.update(visible=True) | |
| else: | |
| return "β Incorrect userid or password.", gr.update(visible=True), gr.update(visible=False) | |
| #====================================== | |
| #=============Interface================ | |
| #====================================== | |
| class Interface: | |
| def __init__(self, config_path: str = "src/bot/configs/prompt.toml", | |
| metadata_database: str = "database/metadata.csv", | |
| faiss_database: str = "database/faiss_index"): | |
| self.bot = Medibot(config_path = config_path, | |
| metadata_database = metadata_database, | |
| faiss_database = faiss_database, | |
| ) | |
| def get_answer(self, question: str): | |
| try: | |
| answer_md, retrieved_docs, refered_tables, refered_images = self.bot.query(question) | |
| # Convert answer to markdown display | |
| answer_display = answer_md | |
| # Format referenced tables as markdown | |
| tables_display = "### Referenced Tables:\n\n" | |
| if refered_tables: | |
| for table_name, table_content in refered_tables.items(): | |
| tables_display += f"{table_content}\n\n" | |
| else: | |
| tables_display += "_No tables referenced._" | |
| # Decode images | |
| # Format images as markdown (base64) | |
| images_display = [] | |
| if refered_images: | |
| for image_name, base64_string in refered_images.items(): | |
| data_uri = f"data:image/png;base64,{base64_string}" | |
| images_display.append(f'') # Markdown embedding for images | |
| else: | |
| images_display = None | |
| # Combine retrieved document texts | |
| retrieved_display = "### Retrieved Documents:\n\n" | |
| if retrieved_docs: | |
| for i, doc in enumerate(retrieved_docs): | |
| retrieved_display += f"**Doc {i+1}:**\n{doc.page_content}\n\n" | |
| else: | |
| retrieved_display += "_No documents retrieved._" | |
| return answer_display, tables_display, images_display, retrieved_display | |
| except Exception as e: | |
| return f"Error: {str(e)}", "", [], "" | |