mtyrrell's picture
cleanup; readme
454dc04
import gradio as gr
from .retriever import get_context, get_vectorstore
# Initialize vector store at startup
print("Initializing vector store connection...")
try:
vectorstore = get_vectorstore()
print("Vector store connection initialized successfully")
except Exception as e:
print(f"Failed to initialize vector store: {e}")
raise
# ---------------------------------------------------------------------
# MCP - returns raw dictionary format
# ---------------------------------------------------------------------
def retrieve_mcp(
query: str,
reports_filter: str = "",
sources_filter: str = "",
subtype_filter: str = "",
year_filter: str = ""
) -> list:
"""
Retrieve semantically similar documents from the vector database for MCP clients.
Args:
query (str): The search query text
reports_filter (str): Comma-separated list of specific report filenames (optional)
sources_filter (str): Filter by document source type (optional)
subtype_filter (str): Filter by document subtype (optional)
year_filter (str): Comma-separated list of years to filter by (optional)
Returns:
list: List of dictionaries containing document content, metadata, and scores
"""
# Parse filter inputs (convert empty strings to None or lists)
reports = [r.strip() for r in reports_filter.split(",") if r.strip()] if reports_filter else []
sources = sources_filter.strip() if sources_filter else None
subtype = subtype_filter.strip() if subtype_filter else None
year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
# Call retriever function and return raw results
results = get_context(
vectorstore=vectorstore,
query=query,
reports=reports,
sources=sources,
subtype=subtype,
year=year
)
return results
# ---------------------------------------------------------------------
# UI - returns formatted string
# ---------------------------------------------------------------------
def retrieve_ui(query, reports_filter="", sources_filter="", subtype_filter="", year_filter=""):
"""
Wrapper function for gradio interface to handle optional filter parameters
"""
# Parse filter inputs (convert empty strings to None or lists)
reports = [r.strip() for r in reports_filter.split(",") if r.strip()] if reports_filter else []
sources = sources_filter.strip() if sources_filter else None
subtype = subtype_filter.strip() if subtype_filter else None
year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
# Call retriever function
results = get_context(
vectorstore=vectorstore,
query=query,
reports=reports,
sources=sources,
subtype=subtype,
year=year
)
# Format results for display
formatted_results = []
for i, doc in enumerate(results, 1):
# Extract content and metadata using the correct keys from HF Spaces API
content = doc.get('answer', '')
metadata = doc.get('answer_metadata', {})
score = doc.get('score', 'N/A')
metadata_str = ", ".join([f"{k}: {v}" for k, v in metadata.items()])
formatted_results.append(
f"=== Result {i} (Score: {score}) ===\n"
f"Content: {content}\n"
f"Metadata: {metadata_str}\n"
)
return "\n".join(formatted_results)
# Create the Gradio interface with Blocks to support both UI and MCP
with gr.Blocks() as ui:
gr.Markdown("# ChatFed Retrieval/Reranker Module")
gr.Markdown("Retrieves semantically similar documents from vector database and reranks. Intended for use in RAG pipelines as an MCP server with other ChatFed modules.")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
lines=2,
placeholder="Enter your search query here",
info="The query to search for in the vector database"
)
reports_input = gr.Textbox(
label="Reports Filter (optional)",
lines=1,
placeholder="report1.pdf, report2.pdf",
info="Comma-separated list of specific report filenames to search within (leave empty for all)"
)
sources_input = gr.Textbox(
label="Sources Filter (optional)",
lines=1,
placeholder="annual_report",
info="Filter by document source type (leave empty for all)"
)
subtype_input = gr.Textbox(
label="Subtype Filter (optional)",
lines=1,
placeholder="financial",
info="Filter by document subtype (leave empty for all)"
)
year_input = gr.Textbox(
label="Year Filter (optional)",
lines=1,
placeholder="2023, 2024",
info="Comma-separated list of years to filter by (leave empty for all)"
)
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output = gr.Textbox(
label="Retrieved Context",
lines=10,
show_copy_button=True
)
# UI event handler
submit_btn.click(
fn=retrieve_ui,
inputs=[query_input, reports_input, sources_input, subtype_input, year_input],
outputs=output
)
# MCP endpoint
gr.api(retrieve_mcp)
# Launch with MCP server enabled
if __name__ == "__main__":
ui.launch(
server_name="0.0.0.0",
server_port=7860,
mcp_server=True,
show_error=True
)