import streamlit as st # for web UI creation from sentence_transformers import SentenceTransformer # this is for embedding queries into dense vectors from pinecone import Pinecone, ServerlessSpec # for accessing pinecone vector DB import os # for readhing environment variable from langchain_huggingface import HuggingFaceEndpoint # for accessing HuggingFace inference endpoint from langchain.prompts import PromptTemplate import firebase_admin # for access to firebase from firebase_admin import credentials, firestore from dotenv import load_dotenv # === Load environment variables === load_dotenv(".env.local") # === CONFIG === PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") PINECONE_ENV = os.getenv("PINECONE_ENV") HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") USER_ID = "demo_user" # static user for testing # === Require Authentication === if "user" not in st.session_state or "email" not in st.session_state["user"]: st.error("Please sign in with your Google account to access the Clinical Trial Discovery tool.") st.stop() USER_ID = st.session_state["user"]["email"] # use email as unique user ID # === Firebase Setup === if not firebase_admin._apps: cred = credentials.Certificate("service-account-key.json") firebase_admin.initialize_app(cred) db = firestore.client() # === Pinecone Setup === pc = Pinecone(api_key=PINECONE_API_KEY) INDEX_NAME = "clinical-trials-rag" index = pc.Index(INDEX_NAME) # === Embedding Model === embed_model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb") # BioBERT sentence transformer model # === LLM Setup === llm = HuggingFaceEndpoint( endpoint_url="https://f9eftfrz5qna6j32.us-east-1.aws.endpoints.huggingface.cloud", # Inference Endpoint Built from Hugging Face. Pay per hour. huggingfacehub_api_token=HF_TOKEN, temperature=0.7, max_new_tokens=256 ) # === Prompt Template === prompt_template = PromptTemplate.from_template( """ Context: {context} Conversation so far: {history} User: {question} Bot:""" ) # === Session State Setup === if "chat_history" not in st.session_state: st.session_state.chat_history = [] # === Tabs === tab1, tab2 = st.tabs(["🔍 Ask a Question", "⭐ Bookmarked Trials"]) # === TAB 1: Conversational Chatbot === with tab1: st.title("🔎 Clinical Trial Discovery Chatbot") st.markdown(""" 💡 **Example question formats:** - What clinical trials are available for non-small cell lung cancer in California? - List phase 3 trials for Type 1 Diabetes recruiting in 2025. - What studies on immunotherapy for melanoma are active in Europe? - Are there trials targeting heart disease patients over 65? """) for q, a in st.session_state.chat_history: st.markdown(f"**User:** {q}") st.markdown(f"**Bot:** {a}") user_query = st.text_input("🔍 Enter your clinical trial question below:") # actual query input part if user_query: # triggers query upon user type action with st.spinner("Retrieving relevant trials..."): # display spinner while pinecone DB being searched vec = embed_model.encode(user_query).tolist() # embed query using the BioBERT sentence transforme results = index.query(vector=vec, top_k=5, include_metadata=True) # search pinecone vector DB. Look for 5 most similar vectors contexts = [r["metadata"]["text"] for r in results["matches"]] nct_ids = [r["metadata"].get("nct_id", "") for r in results["matches"]] # Prep the prompt for the LLM joined_context = "\n".join(contexts) # joins the retrieved trial summary into one contextual block chat_history_text = "\n".join(f"User: {q}\nBot: {a}" for q, a in st.session_state.chat_history) # records chat session history prompt = prompt_template.format(context=joined_context, question=user_query, history=chat_history_text) # fills prompt template # this part calls the LLM endpoint to generate the answer with st.spinner("Generating answer..."): answer = llm(prompt) st.session_state.chat_history.append((user_query, answer)) st.subheader("🧠 Answer:") # display answer in UI st.write(answer) st.markdown("---") st.subheader("📋 Related Clinical Trials") # display the related trials under the answer for i, match in enumerate(results["matches"]): # loop through pinecone search results and display them meta = match["metadata"] nct_id = meta.get("nct_id", f"chunk_{i}") # assigns fallback chuck ID if 'nct_id' is missing chunk_text = meta.get("text", "")[:400] # shows the first 400 characters of the trial chunk with st.expander(f"Trial: {nct_id}"): # create an expandable block for each trial # Fetch full trial details from Firestore trial_doc = db.collection("ClinicalTrials").document(nct_id).get() if trial_doc.exists: trial_data = trial_doc.to_dict() for k, v in trial_data.items(): st.markdown(f"**{k.replace('_', ' ').title()}:** {v}") else: st.warning("⚠️ Full trial details not found in Firestore. Showing partial match.") st.write(meta.get("text", "")[:400] + "...") # add bookmark button instead each expander. Book marks are saved to /users/demo_user/Bookmarks/{nct_id} if st.button(f"⭐ Bookmark {nct_id}", key=f"bookmark_{i}"): db.collection("Users").document(USER_ID).collection("Bookmarks").document(nct_id).set({ "nct_id": nct_id, "text": meta.get("text", "")[:400] }) st.success(f"Bookmarked {nct_id} to Firestore.") # === TAB 2: Bookmarked Trials === with tab2: st.title("⭐ Your Bookmarked Trials") # retrieve bookmarks from firestore docs = db.collection("Users").document(USER_ID).collection("Bookmarks").stream() bookmarks = [doc.to_dict() for doc in docs] # if no bookmarks, show message. if not bookmarks: st.info("You haven't bookmarked any trials yet.") else: # otherwise display bookmarked trials in expanders for b in bookmarks: with st.expander(f"{b['nct_id']}"): st.write(b["text"])