File size: 6,437 Bytes
798fe27 50d0b90 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 37e9a31 798fe27 19df612 37e9a31 798fe27 37e9a31 19df612 37e9a31 798fe27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import streamlit as st # for web UI creation
from sentence_transformers import SentenceTransformer # this is for embedding queries into dense vectors
from pinecone import Pinecone, ServerlessSpec # for accessing pinecone vector DB
import os # for readhing environment variable
from langchain_huggingface import HuggingFaceEndpoint # for accessing HuggingFace inference endpoint
from langchain.prompts import PromptTemplate
import firebase_admin # for access to firebase
from firebase_admin import credentials, firestore
from dotenv import load_dotenv
# === Load environment variables ===
load_dotenv(".env.local")
# === CONFIG ===
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_ENV = os.getenv("PINECONE_ENV")
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
USER_ID = "demo_user" # static user for testing
# === Require Authentication ===
if "user" not in st.session_state or "email" not in st.session_state["user"]:
st.error("Please sign in with your Google account to access the Clinical Trial Discovery tool.")
st.stop()
USER_ID = st.session_state["user"]["email"] # use email as unique user ID
# === Firebase Setup ===
if not firebase_admin._apps:
cred = credentials.Certificate("service-account-key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
# === Pinecone Setup ===
pc = Pinecone(api_key=PINECONE_API_KEY)
INDEX_NAME = "clinical-trials-rag"
index = pc.Index(INDEX_NAME)
# === Embedding Model ===
embed_model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb") # BioBERT sentence transformer model
# === LLM Setup ===
llm = HuggingFaceEndpoint(
endpoint_url="https://f9eftfrz5qna6j32.us-east-1.aws.endpoints.huggingface.cloud", # Inference Endpoint Built from Hugging Face. Pay per hour.
huggingfacehub_api_token=HF_TOKEN,
temperature=0.7,
max_new_tokens=256
)
# === Prompt Template ===
prompt_template = PromptTemplate.from_template(
"""
Context:
{context}
Conversation so far:
{history}
User: {question}
Bot:"""
)
# === Session State Setup ===
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# === Tabs ===
tab1, tab2 = st.tabs(["π Ask a Question", "β Bookmarked Trials"])
# === TAB 1: Conversational Chatbot ===
with tab1:
st.title("π Clinical Trial Discovery Chatbot")
st.markdown("""
π‘ **Example question formats:**
- What clinical trials are available for non-small cell lung cancer in California?
- List phase 3 trials for Type 1 Diabetes recruiting in 2025.
- What studies on immunotherapy for melanoma are active in Europe?
- Are there trials targeting heart disease patients over 65?
""")
for q, a in st.session_state.chat_history:
st.markdown(f"**User:** {q}")
st.markdown(f"**Bot:** {a}")
user_query = st.text_input("π Enter your clinical trial question below:") # actual query input part
if user_query: # triggers query upon user type action
with st.spinner("Retrieving relevant trials..."): # display spinner while pinecone DB being searched
vec = embed_model.encode(user_query).tolist() # embed query using the BioBERT sentence transforme
results = index.query(vector=vec, top_k=5, include_metadata=True) # search pinecone vector DB. Look for 5 most similar vectors
contexts = [r["metadata"]["text"] for r in results["matches"]]
nct_ids = [r["metadata"].get("nct_id", "") for r in results["matches"]]
# Prep the prompt for the LLM
joined_context = "\n".join(contexts) # joins the retrieved trial summary into one contextual block
chat_history_text = "\n".join(f"User: {q}\nBot: {a}" for q, a in st.session_state.chat_history) # records chat session history
prompt = prompt_template.format(context=joined_context, question=user_query, history=chat_history_text) # fills prompt template
# this part calls the LLM endpoint to generate the answer
with st.spinner("Generating answer..."):
answer = llm(prompt)
st.session_state.chat_history.append((user_query, answer))
st.subheader("π§ Answer:") # display answer in UI
st.write(answer)
st.markdown("---")
st.subheader("π Related Clinical Trials") # display the related trials under the answer
for i, match in enumerate(results["matches"]): # loop through pinecone search results and display them
meta = match["metadata"]
nct_id = meta.get("nct_id", f"chunk_{i}") # assigns fallback chuck ID if 'nct_id' is missing
chunk_text = meta.get("text", "")[:400] # shows the first 400 characters of the trial chunk
with st.expander(f"Trial: {nct_id}"): # create an expandable block for each trial
# Fetch full trial details from Firestore
trial_doc = db.collection("ClinicalTrials").document(nct_id).get()
if trial_doc.exists:
trial_data = trial_doc.to_dict()
for k, v in trial_data.items():
st.markdown(f"**{k.replace('_', ' ').title()}:** {v}")
else:
st.warning("β οΈ Full trial details not found in Firestore. Showing partial match.")
st.write(meta.get("text", "")[:400] + "...")
# add bookmark button instead each expander. Book marks are saved to /users/demo_user/Bookmarks/{nct_id}
if st.button(f"β Bookmark {nct_id}", key=f"bookmark_{i}"):
db.collection("Users").document(USER_ID).collection("Bookmarks").document(nct_id).set({
"nct_id": nct_id,
"text": meta.get("text", "")[:400]
})
st.success(f"Bookmarked {nct_id} to Firestore.")
# === TAB 2: Bookmarked Trials ===
with tab2:
st.title("β Your Bookmarked Trials")
# retrieve bookmarks from firestore
docs = db.collection("Users").document(USER_ID).collection("Bookmarks").stream()
bookmarks = [doc.to_dict() for doc in docs]
# if no bookmarks, show message.
if not bookmarks:
st.info("You haven't bookmarked any trials yet.")
else: # otherwise display bookmarked trials in expanders
for b in bookmarks:
with st.expander(f"{b['nct_id']}"):
st.write(b["text"])
|