File size: 6,437 Bytes
798fe27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50d0b90
 
 
 
 
 
 
798fe27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37e9a31
 
798fe27
37e9a31
 
 
 
 
 
 
 
 
798fe27
 
37e9a31
 
 
 
798fe27
 
 
37e9a31
798fe27
37e9a31
798fe27
 
 
 
 
 
 
 
37e9a31
 
 
 
798fe27
37e9a31
798fe27
 
 
37e9a31
 
798fe27
37e9a31
 
 
 
 
 
 
 
798fe27
37e9a31
 
798fe27
 
 
37e9a31
 
798fe27
37e9a31
798fe27
 
37e9a31
798fe27
19df612
 
 
 
 
 
 
 
 
 
37e9a31
798fe27
37e9a31
19df612
37e9a31
798fe27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import streamlit as st # for web UI creation 
from sentence_transformers import SentenceTransformer # this is for embedding queries into dense vectors 
from pinecone import Pinecone, ServerlessSpec # for accessing pinecone vector DB 
import os # for readhing environment variable 
from langchain_huggingface import HuggingFaceEndpoint # for accessing HuggingFace inference endpoint 
from langchain.prompts import PromptTemplate
import firebase_admin  # for access to firebase 
from firebase_admin import credentials, firestore
from dotenv import load_dotenv

# === Load environment variables ===
load_dotenv(".env.local")

# === CONFIG ===
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_ENV = os.getenv("PINECONE_ENV")
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
USER_ID = "demo_user" # static user for testing 

# === Require Authentication ===
if "user" not in st.session_state or "email" not in st.session_state["user"]:
    st.error("Please sign in with your Google account to access the Clinical Trial Discovery tool.")
    st.stop()

USER_ID = st.session_state["user"]["email"]  # use email as unique user ID

# === Firebase Setup ===
if not firebase_admin._apps:
    cred = credentials.Certificate("service-account-key.json")
    firebase_admin.initialize_app(cred)
db = firestore.client()

# === Pinecone Setup ===
pc = Pinecone(api_key=PINECONE_API_KEY)
INDEX_NAME = "clinical-trials-rag"
index = pc.Index(INDEX_NAME)

# === Embedding Model ===
embed_model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb") # BioBERT sentence transformer model

# === LLM Setup ===
llm = HuggingFaceEndpoint(
    endpoint_url="https://f9eftfrz5qna6j32.us-east-1.aws.endpoints.huggingface.cloud", # Inference Endpoint Built from Hugging Face. Pay per hour. 
    huggingfacehub_api_token=HF_TOKEN,
    temperature=0.7,
    max_new_tokens=256
)

# === Prompt Template ===
prompt_template = PromptTemplate.from_template(
    """
Context:
{context}

Conversation so far:
{history}

User: {question}
Bot:"""
)

# === Session State Setup ===
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# === Tabs ===
tab1, tab2 = st.tabs(["πŸ” Ask a Question", "⭐ Bookmarked Trials"])

# === TAB 1: Conversational Chatbot ===
with tab1:
    st.title("πŸ”Ž Clinical Trial Discovery Chatbot")

    st.markdown("""
    πŸ’‘ **Example question formats:**
    - What clinical trials are available for non-small cell lung cancer in California?
    - List phase 3 trials for Type 1 Diabetes recruiting in 2025.
    - What studies on immunotherapy for melanoma are active in Europe?
    - Are there trials targeting heart disease patients over 65?
    """)
    
    for q, a in st.session_state.chat_history:
        st.markdown(f"**User:** {q}")
        st.markdown(f"**Bot:** {a}")

    user_query = st.text_input("πŸ” Enter your clinical trial question below:") # actual query input part 

    if user_query: # triggers query upon user type action 
        with st.spinner("Retrieving relevant trials..."): # display spinner while pinecone DB being searched 
            vec = embed_model.encode(user_query).tolist() # embed query using the BioBERT sentence transforme
            results = index.query(vector=vec, top_k=5, include_metadata=True) # search pinecone vector DB. Look for 5 most similar vectors
            contexts = [r["metadata"]["text"] for r in results["matches"]]
            nct_ids = [r["metadata"].get("nct_id", "") for r in results["matches"]] 
        # Prep the prompt for the LLM
        joined_context = "\n".join(contexts) # joins the retrieved trial summary into one contextual block
        chat_history_text = "\n".join(f"User: {q}\nBot: {a}" for q, a in st.session_state.chat_history) # records chat session history 
        prompt = prompt_template.format(context=joined_context, question=user_query, history=chat_history_text) # fills prompt template

        # this part calls the LLM endpoint to generate the answer
        with st.spinner("Generating answer..."):
            answer = llm(prompt)
        st.session_state.chat_history.append((user_query, answer))

        st.subheader("🧠 Answer:") # display answer in UI 
        st.write(answer)

        st.markdown("---")
        st.subheader("πŸ“‹ Related Clinical Trials") # display the related trials under the answer 

        for i, match in enumerate(results["matches"]): # loop through pinecone search results and display them
            meta = match["metadata"]
            nct_id = meta.get("nct_id", f"chunk_{i}") # assigns fallback chuck ID if 'nct_id' is missing 
            chunk_text = meta.get("text", "")[:400] # shows the first 400 characters of the trial chunk
            with st.expander(f"Trial: {nct_id}"): # create an expandable block for each trial 
                # Fetch full trial details from Firestore
                trial_doc = db.collection("ClinicalTrials").document(nct_id).get()
                if trial_doc.exists:
                    trial_data = trial_doc.to_dict()
                    for k, v in trial_data.items():
                        st.markdown(f"**{k.replace('_', ' ').title()}:** {v}")
                else:
                    st.warning("⚠️ Full trial details not found in Firestore. Showing partial match.")
                    st.write(meta.get("text", "")[:400] + "...")
                # add bookmark button instead each expander. Book marks are saved to /users/demo_user/Bookmarks/{nct_id} 
                if st.button(f"⭐ Bookmark {nct_id}", key=f"bookmark_{i}"):
                    db.collection("Users").document(USER_ID).collection("Bookmarks").document(nct_id).set({
                        "nct_id": nct_id,
                        "text": meta.get("text", "")[:400]
                    })
                    st.success(f"Bookmarked {nct_id} to Firestore.")

# === TAB 2: Bookmarked Trials ===
with tab2:
    st.title("⭐ Your Bookmarked Trials")
    # retrieve bookmarks from firestore 
    docs = db.collection("Users").document(USER_ID).collection("Bookmarks").stream()
    bookmarks = [doc.to_dict() for doc in docs]

    # if no bookmarks, show message. 
    if not bookmarks:
        st.info("You haven't bookmarked any trials yet.")
    else: # otherwise display bookmarked trials in expanders 
        for b in bookmarks:
            with st.expander(f"{b['nct_id']}"):
                st.write(b["text"])