Commit
·
22bbe4e
1
Parent(s):
c43e786
reduce latency to 3s
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ from langchain_core.output_parsers import StrOutputParser
|
|
| 19 |
from langchain_core.prompts import ChatPromptTemplate
|
| 20 |
from langchain.schema.runnable.passthrough import RunnableAssign
|
| 21 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 22 |
-
from
|
| 23 |
from langchain_community.retrievers import BM25Retriever
|
| 24 |
from langchain_openai import ChatOpenAI
|
| 25 |
from langchain.output_parsers import PydanticOutputParser
|
|
@@ -105,7 +105,7 @@ knowledge_base = KnowledgeBase()
|
|
| 105 |
# repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
|
| 106 |
repharser_llm = ChatNVIDIA(model="microsoft/phi-3-mini-4k-instruct") | StrOutputParser()
|
| 107 |
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
|
| 108 |
-
relevance_llm = ChatNVIDIA(model="
|
| 109 |
answer_llm = ChatOpenAI(
|
| 110 |
model="gpt-4o",
|
| 111 |
temperature=0.3,
|
|
@@ -116,16 +116,17 @@ answer_llm = ChatOpenAI(
|
|
| 116 |
|
| 117 |
# Prompts
|
| 118 |
repharser_prompt = ChatPromptTemplate.from_template(
|
| 119 |
-
"You are a smart retrieval assistant
|
| 120 |
-
"
|
| 121 |
-
"
|
| 122 |
-
"
|
| 123 |
-
"-
|
| 124 |
-
"-
|
|
|
|
|
|
|
|
|
|
| 125 |
"Original Question:\n{query}\n\n"
|
| 126 |
-
"
|
| 127 |
-
"1.\n"
|
| 128 |
-
"2."
|
| 129 |
)
|
| 130 |
|
| 131 |
relevance_prompt = ChatPromptTemplate.from_template("""
|
|
@@ -240,7 +241,7 @@ parser_prompt = ChatPromptTemplate.from_template(
|
|
| 240 |
# Helper Functions
|
| 241 |
def parse_rewrites(raw_response: str) -> list[str]:
|
| 242 |
lines = raw_response.strip().split("\n")
|
| 243 |
-
return [line.strip("0123456789. ").strip() for line in lines if line.strip()][:
|
| 244 |
|
| 245 |
def hybrid_retrieve(inputs, exclude_terms=None):
|
| 246 |
bm25_retriever = inputs["bm25_retriever"]
|
|
@@ -393,7 +394,7 @@ select_and_prompt = RunnableLambda(lambda x:
|
|
| 393 |
answer_chain = (
|
| 394 |
prepare_answer_inputs
|
| 395 |
| select_and_prompt
|
| 396 |
-
|
|
| 397 |
)
|
| 398 |
|
| 399 |
def RExtract(pydantic_class: Type[BaseModel], llm, prompt):
|
|
@@ -463,19 +464,17 @@ def chat_interface(message, history):
|
|
| 463 |
"query": message,
|
| 464 |
"all_queries": [message],
|
| 465 |
"all_texts": all_chunks,
|
| 466 |
-
"k_per_query":
|
| 467 |
"alpha": 0.5,
|
| 468 |
"vectorstore": vectorstore,
|
| 469 |
"bm25_retriever": bm25_retriever,
|
| 470 |
}
|
| 471 |
full_response = ""
|
| 472 |
-
collected = None
|
| 473 |
|
| 474 |
# Stream the response to user
|
| 475 |
for chunk in full_pipeline.stream(inputs):
|
| 476 |
if isinstance(chunk, dict) and "answer" in chunk:
|
| 477 |
full_response += chunk["answer"]
|
| 478 |
-
collected = chunk
|
| 479 |
yield full_response
|
| 480 |
elif isinstance(chunk, str):
|
| 481 |
full_response += chunk
|
|
|
|
| 19 |
from langchain_core.prompts import ChatPromptTemplate
|
| 20 |
from langchain.schema.runnable.passthrough import RunnableAssign
|
| 21 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 22 |
+
from langchain_community.vectorstores import FAISS
|
| 23 |
from langchain_community.retrievers import BM25Retriever
|
| 24 |
from langchain_openai import ChatOpenAI
|
| 25 |
from langchain.output_parsers import PydanticOutputParser
|
|
|
|
| 105 |
# repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
|
| 106 |
repharser_llm = ChatNVIDIA(model="microsoft/phi-3-mini-4k-instruct") | StrOutputParser()
|
| 107 |
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
|
| 108 |
+
relevance_llm = ChatNVIDIA(model="nvidia/llama-3.1-nemotron-70b-instruct") | StrOutputParser()
|
| 109 |
answer_llm = ChatOpenAI(
|
| 110 |
model="gpt-4o",
|
| 111 |
temperature=0.3,
|
|
|
|
| 116 |
|
| 117 |
# Prompts
|
| 118 |
repharser_prompt = ChatPromptTemplate.from_template(
|
| 119 |
+
"You are a smart retrieval assistant helping a search engine understand user intent more precisely.\n\n"
|
| 120 |
+
"Given a user question, generate **1 diverse rewrite** that is semantically equivalent but phrased differently. \n"
|
| 121 |
+
"The rewrite should be optimized for **retrieval from a hybrid system** using BM25 (keyword match) and dense vector embeddings.\n\n"
|
| 122 |
+
"Guidelines:\n"
|
| 123 |
+
"- Expand abbreviations or implied intent when useful\n"
|
| 124 |
+
"- Add relevant technical terms, tools, frameworks, or synonyms (e.g., 'LLM', 'pipeline', 'project')\n"
|
| 125 |
+
"- Rephrase using different sentence structure or tone\n"
|
| 126 |
+
"- Use field-specific vocabulary (e.g., data science, ML, software, research) if it fits the query\n"
|
| 127 |
+
"- Prioritize clarity and retrievability over stylistic variation\n\n"
|
| 128 |
"Original Question:\n{query}\n\n"
|
| 129 |
+
"Rewrite:\n1."
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
|
| 132 |
relevance_prompt = ChatPromptTemplate.from_template("""
|
|
|
|
| 241 |
# Helper Functions
|
| 242 |
def parse_rewrites(raw_response: str) -> list[str]:
|
| 243 |
lines = raw_response.strip().split("\n")
|
| 244 |
+
return [line.strip("0123456789. ").strip() for line in lines if line.strip()][:1]
|
| 245 |
|
| 246 |
def hybrid_retrieve(inputs, exclude_terms=None):
|
| 247 |
bm25_retriever = inputs["bm25_retriever"]
|
|
|
|
| 394 |
answer_chain = (
|
| 395 |
prepare_answer_inputs
|
| 396 |
| select_and_prompt
|
| 397 |
+
| relevance_llm
|
| 398 |
)
|
| 399 |
|
| 400 |
def RExtract(pydantic_class: Type[BaseModel], llm, prompt):
|
|
|
|
| 464 |
"query": message,
|
| 465 |
"all_queries": [message],
|
| 466 |
"all_texts": all_chunks,
|
| 467 |
+
"k_per_query": 8,
|
| 468 |
"alpha": 0.5,
|
| 469 |
"vectorstore": vectorstore,
|
| 470 |
"bm25_retriever": bm25_retriever,
|
| 471 |
}
|
| 472 |
full_response = ""
|
|
|
|
| 473 |
|
| 474 |
# Stream the response to user
|
| 475 |
for chunk in full_pipeline.stream(inputs):
|
| 476 |
if isinstance(chunk, dict) and "answer" in chunk:
|
| 477 |
full_response += chunk["answer"]
|
|
|
|
| 478 |
yield full_response
|
| 479 |
elif isinstance(chunk, str):
|
| 480 |
full_response += chunk
|