AlphaAngel444's picture
Add Google Auth Feature
50d0b90 verified
import streamlit as st # for web UI creation
from sentence_transformers import SentenceTransformer # this is for embedding queries into dense vectors
from pinecone import Pinecone, ServerlessSpec # for accessing pinecone vector DB
import os # for readhing environment variable
from langchain_huggingface import HuggingFaceEndpoint # for accessing HuggingFace inference endpoint
from langchain.prompts import PromptTemplate
import firebase_admin # for access to firebase
from firebase_admin import credentials, firestore
from dotenv import load_dotenv
# === Load environment variables ===
load_dotenv(".env.local")
# === CONFIG ===
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_ENV = os.getenv("PINECONE_ENV")
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
USER_ID = "demo_user" # static user for testing
# === Require Authentication ===
if "user" not in st.session_state or "email" not in st.session_state["user"]:
st.error("Please sign in with your Google account to access the Clinical Trial Discovery tool.")
st.stop()
USER_ID = st.session_state["user"]["email"] # use email as unique user ID
# === Firebase Setup ===
if not firebase_admin._apps:
cred = credentials.Certificate("service-account-key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
# === Pinecone Setup ===
pc = Pinecone(api_key=PINECONE_API_KEY)
INDEX_NAME = "clinical-trials-rag"
index = pc.Index(INDEX_NAME)
# === Embedding Model ===
embed_model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb") # BioBERT sentence transformer model
# === LLM Setup ===
llm = HuggingFaceEndpoint(
endpoint_url="https://f9eftfrz5qna6j32.us-east-1.aws.endpoints.huggingface.cloud", # Inference Endpoint Built from Hugging Face. Pay per hour.
huggingfacehub_api_token=HF_TOKEN,
temperature=0.7,
max_new_tokens=256
)
# === Prompt Template ===
prompt_template = PromptTemplate.from_template(
"""
Context:
{context}
Conversation so far:
{history}
User: {question}
Bot:"""
)
# === Session State Setup ===
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# === Tabs ===
tab1, tab2 = st.tabs(["πŸ” Ask a Question", "⭐ Bookmarked Trials"])
# === TAB 1: Conversational Chatbot ===
with tab1:
st.title("πŸ”Ž Clinical Trial Discovery Chatbot")
st.markdown("""
πŸ’‘ **Example question formats:**
- What clinical trials are available for non-small cell lung cancer in California?
- List phase 3 trials for Type 1 Diabetes recruiting in 2025.
- What studies on immunotherapy for melanoma are active in Europe?
- Are there trials targeting heart disease patients over 65?
""")
for q, a in st.session_state.chat_history:
st.markdown(f"**User:** {q}")
st.markdown(f"**Bot:** {a}")
user_query = st.text_input("πŸ” Enter your clinical trial question below:") # actual query input part
if user_query: # triggers query upon user type action
with st.spinner("Retrieving relevant trials..."): # display spinner while pinecone DB being searched
vec = embed_model.encode(user_query).tolist() # embed query using the BioBERT sentence transforme
results = index.query(vector=vec, top_k=5, include_metadata=True) # search pinecone vector DB. Look for 5 most similar vectors
contexts = [r["metadata"]["text"] for r in results["matches"]]
nct_ids = [r["metadata"].get("nct_id", "") for r in results["matches"]]
# Prep the prompt for the LLM
joined_context = "\n".join(contexts) # joins the retrieved trial summary into one contextual block
chat_history_text = "\n".join(f"User: {q}\nBot: {a}" for q, a in st.session_state.chat_history) # records chat session history
prompt = prompt_template.format(context=joined_context, question=user_query, history=chat_history_text) # fills prompt template
# this part calls the LLM endpoint to generate the answer
with st.spinner("Generating answer..."):
answer = llm(prompt)
st.session_state.chat_history.append((user_query, answer))
st.subheader("🧠 Answer:") # display answer in UI
st.write(answer)
st.markdown("---")
st.subheader("πŸ“‹ Related Clinical Trials") # display the related trials under the answer
for i, match in enumerate(results["matches"]): # loop through pinecone search results and display them
meta = match["metadata"]
nct_id = meta.get("nct_id", f"chunk_{i}") # assigns fallback chuck ID if 'nct_id' is missing
chunk_text = meta.get("text", "")[:400] # shows the first 400 characters of the trial chunk
with st.expander(f"Trial: {nct_id}"): # create an expandable block for each trial
# Fetch full trial details from Firestore
trial_doc = db.collection("ClinicalTrials").document(nct_id).get()
if trial_doc.exists:
trial_data = trial_doc.to_dict()
for k, v in trial_data.items():
st.markdown(f"**{k.replace('_', ' ').title()}:** {v}")
else:
st.warning("⚠️ Full trial details not found in Firestore. Showing partial match.")
st.write(meta.get("text", "")[:400] + "...")
# add bookmark button instead each expander. Book marks are saved to /users/demo_user/Bookmarks/{nct_id}
if st.button(f"⭐ Bookmark {nct_id}", key=f"bookmark_{i}"):
db.collection("Users").document(USER_ID).collection("Bookmarks").document(nct_id).set({
"nct_id": nct_id,
"text": meta.get("text", "")[:400]
})
st.success(f"Bookmarked {nct_id} to Firestore.")
# === TAB 2: Bookmarked Trials ===
with tab2:
st.title("⭐ Your Bookmarked Trials")
# retrieve bookmarks from firestore
docs = db.collection("Users").document(USER_ID).collection("Bookmarks").stream()
bookmarks = [doc.to_dict() for doc in docs]
# if no bookmarks, show message.
if not bookmarks:
st.info("You haven't bookmarked any trials yet.")
else: # otherwise display bookmarked trials in expanders
for b in bookmarks:
with st.expander(f"{b['nct_id']}"):
st.write(b["text"])