Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import polars as pl | |
| from search import search | |
| from table import df_orig | |
| COLUMNS_MCP = [ | |
| "title", | |
| "authors", | |
| "abstract", | |
| "arxiv_id", | |
| "paper_page", | |
| "space_ids", | |
| "model_ids", | |
| "dataset_ids", | |
| "upvotes", | |
| "num_comments", | |
| "project_page", | |
| "github", | |
| "row_index", | |
| ] | |
| DEFAULT_COLUMNS_MCP = [ | |
| "title", | |
| "authors", | |
| "abstract", | |
| "arxiv_id", | |
| "project_page", | |
| "github", | |
| "row_index", | |
| ] | |
| df_mcp = df_orig.rename({"paper_id": "row_index"}).select(COLUMNS_MCP) | |
| def search_papers( | |
| search_query: str, | |
| candidate_pool_size: int, | |
| num_results: int, | |
| columns: list[str], | |
| ) -> list[dict]: | |
| """Searches NeurIPS 2025 papers relevant to a user query in English. | |
| This function performs a semantic search over NeurIPS 2025 papers. | |
| It uses a dual-stage retrieval process: | |
| - First, it retrieves `candidate_pool_size` papers using dense vector similarity. | |
| - Then, it re-ranks them with a cross-encoder model to select the top `num_results` most relevant papers. | |
| - The search results are returned as a list of dictionaries. | |
| Note: | |
| The search query must be written in English. Queries in other languages are not supported. | |
| Args: | |
| search_query (str): The natural language query input by the user. Must be in English. | |
| candidate_pool_size (int): Number of candidate papers to retrieve using the dense vector model. | |
| num_results (int): Final number of top-ranked papers to return after re-ranking. | |
| columns (list[str]): The columns to select from the DataFrame. | |
| Returns: | |
| list[dict]: A list of dictionaries of the top-ranked papers matching the query, sorted by relevance. | |
| """ | |
| if not search_query: | |
| raise ValueError("Search query cannot be empty") | |
| if num_results > candidate_pool_size: | |
| raise ValueError("Number of results must be less than or equal to candidate pool size") | |
| df = df_mcp.clone() | |
| results = search(search_query, candidate_pool_size, num_results) | |
| df = pl.DataFrame(results).rename({"paper_id": "row_index"}).join(df, on="row_index", how="inner") | |
| df = df.sort("ce_score", descending=True) | |
| return df.select(columns).to_dicts() | |
| def get_metadata(row_index: int) -> dict: | |
| """Returns a dictionary of metadata for a NeurIPS 2025 paper at the given table row index. | |
| Args: | |
| row_index (int): The index of the paper in the internal paper list table. | |
| Returns: | |
| dict: A dictionary containing metadata for the corresponding paper. | |
| """ | |
| return df_mcp.filter(pl.col("row_index") == row_index).to_dicts()[0] | |
| def get_table(columns: list[str]) -> list[dict]: | |
| """Returns a list of dictionaries of all NeurIPS 2025 papers. | |
| Args: | |
| columns (list[str]): The columns to select from the DataFrame. | |
| Returns: | |
| list[dict]: A list of dictionaries of all NeurIPS 2025 papers. | |
| """ | |
| return df_mcp.select(columns).to_dicts() | |
| with gr.Blocks() as demo: | |
| search_query = gr.Textbox(label="Search", submit_btn=True) | |
| candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=500, step=1, value=200) | |
| num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100) | |
| column_names = gr.CheckboxGroup(label="Columns", choices=COLUMNS_MCP, value=DEFAULT_COLUMNS_MCP) | |
| row_index = gr.Slider(label="Row Index", minimum=0, maximum=len(df_mcp) - 1, step=1, value=0) | |
| out = gr.JSON() | |
| search_papers_btn = gr.Button("Search Papers") | |
| get_metadata_btn = gr.Button("Get Metadata") | |
| get_table_btn = gr.Button("Get Table") | |
| search_papers_btn.click( | |
| fn=search_papers, | |
| inputs=[search_query, candidate_pool_size, num_results, column_names], | |
| outputs=out, | |
| ) | |
| get_metadata_btn.click( | |
| fn=get_metadata, | |
| inputs=row_index, | |
| outputs=out, | |
| ) | |
| get_table_btn.click( | |
| fn=get_table, | |
| inputs=column_names, | |
| outputs=out, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) | |