|
|
import streamlit as st |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from pinecone import Pinecone, ServerlessSpec |
|
|
import os |
|
|
from langchain_huggingface import HuggingFaceEndpoint |
|
|
from langchain.prompts import PromptTemplate |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, firestore |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv(".env.local") |
|
|
|
|
|
|
|
|
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") |
|
|
PINECONE_ENV = os.getenv("PINECONE_ENV") |
|
|
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
USER_ID = "demo_user" |
|
|
|
|
|
|
|
|
if "user" not in st.session_state or "email" not in st.session_state["user"]: |
|
|
st.error("Please sign in with your Google account to access the Clinical Trial Discovery tool.") |
|
|
st.stop() |
|
|
|
|
|
USER_ID = st.session_state["user"]["email"] |
|
|
|
|
|
|
|
|
if not firebase_admin._apps: |
|
|
cred = credentials.Certificate("service-account-key.json") |
|
|
firebase_admin.initialize_app(cred) |
|
|
db = firestore.client() |
|
|
|
|
|
|
|
|
pc = Pinecone(api_key=PINECONE_API_KEY) |
|
|
INDEX_NAME = "clinical-trials-rag" |
|
|
index = pc.Index(INDEX_NAME) |
|
|
|
|
|
|
|
|
embed_model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb") |
|
|
|
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
|
endpoint_url="https://f9eftfrz5qna6j32.us-east-1.aws.endpoints.huggingface.cloud", |
|
|
huggingfacehub_api_token=HF_TOKEN, |
|
|
temperature=0.7, |
|
|
max_new_tokens=256 |
|
|
) |
|
|
|
|
|
|
|
|
prompt_template = PromptTemplate.from_template( |
|
|
""" |
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Conversation so far: |
|
|
{history} |
|
|
|
|
|
User: {question} |
|
|
Bot:""" |
|
|
) |
|
|
|
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
|
st.session_state.chat_history = [] |
|
|
|
|
|
|
|
|
tab1, tab2 = st.tabs(["π Ask a Question", "β Bookmarked Trials"]) |
|
|
|
|
|
|
|
|
with tab1: |
|
|
st.title("π Clinical Trial Discovery Chatbot") |
|
|
|
|
|
st.markdown(""" |
|
|
π‘ **Example question formats:** |
|
|
- What clinical trials are available for non-small cell lung cancer in California? |
|
|
- List phase 3 trials for Type 1 Diabetes recruiting in 2025. |
|
|
- What studies on immunotherapy for melanoma are active in Europe? |
|
|
- Are there trials targeting heart disease patients over 65? |
|
|
""") |
|
|
|
|
|
for q, a in st.session_state.chat_history: |
|
|
st.markdown(f"**User:** {q}") |
|
|
st.markdown(f"**Bot:** {a}") |
|
|
|
|
|
user_query = st.text_input("π Enter your clinical trial question below:") |
|
|
|
|
|
if user_query: |
|
|
with st.spinner("Retrieving relevant trials..."): |
|
|
vec = embed_model.encode(user_query).tolist() |
|
|
results = index.query(vector=vec, top_k=5, include_metadata=True) |
|
|
contexts = [r["metadata"]["text"] for r in results["matches"]] |
|
|
nct_ids = [r["metadata"].get("nct_id", "") for r in results["matches"]] |
|
|
|
|
|
joined_context = "\n".join(contexts) |
|
|
chat_history_text = "\n".join(f"User: {q}\nBot: {a}" for q, a in st.session_state.chat_history) |
|
|
prompt = prompt_template.format(context=joined_context, question=user_query, history=chat_history_text) |
|
|
|
|
|
|
|
|
with st.spinner("Generating answer..."): |
|
|
answer = llm(prompt) |
|
|
st.session_state.chat_history.append((user_query, answer)) |
|
|
|
|
|
st.subheader("π§ Answer:") |
|
|
st.write(answer) |
|
|
|
|
|
st.markdown("---") |
|
|
st.subheader("π Related Clinical Trials") |
|
|
|
|
|
for i, match in enumerate(results["matches"]): |
|
|
meta = match["metadata"] |
|
|
nct_id = meta.get("nct_id", f"chunk_{i}") |
|
|
chunk_text = meta.get("text", "")[:400] |
|
|
with st.expander(f"Trial: {nct_id}"): |
|
|
|
|
|
trial_doc = db.collection("ClinicalTrials").document(nct_id).get() |
|
|
if trial_doc.exists: |
|
|
trial_data = trial_doc.to_dict() |
|
|
for k, v in trial_data.items(): |
|
|
st.markdown(f"**{k.replace('_', ' ').title()}:** {v}") |
|
|
else: |
|
|
st.warning("β οΈ Full trial details not found in Firestore. Showing partial match.") |
|
|
st.write(meta.get("text", "")[:400] + "...") |
|
|
|
|
|
if st.button(f"β Bookmark {nct_id}", key=f"bookmark_{i}"): |
|
|
db.collection("Users").document(USER_ID).collection("Bookmarks").document(nct_id).set({ |
|
|
"nct_id": nct_id, |
|
|
"text": meta.get("text", "")[:400] |
|
|
}) |
|
|
st.success(f"Bookmarked {nct_id} to Firestore.") |
|
|
|
|
|
|
|
|
with tab2: |
|
|
st.title("β Your Bookmarked Trials") |
|
|
|
|
|
docs = db.collection("Users").document(USER_ID).collection("Bookmarks").stream() |
|
|
bookmarks = [doc.to_dict() for doc in docs] |
|
|
|
|
|
|
|
|
if not bookmarks: |
|
|
st.info("You haven't bookmarked any trials yet.") |
|
|
else: |
|
|
for b in bookmarks: |
|
|
with st.expander(f"{b['nct_id']}"): |
|
|
st.write(b["text"]) |
|
|
|