Spaces:
Running
Running
generate basic
Browse files- app.py +186 -42
- utils/__pycache__/generator.cpython-311.pyc +0 -0
- utils/__pycache__/retriever.cpython-311.pyc +0 -0
- utils/generator.py +0 -287
app.py
CHANGED
|
@@ -1,15 +1,178 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from utils.retriever import retrieve_paragraphs
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
"""Generate chat response based on method and inputs"""
|
| 14 |
|
| 15 |
try:
|
|
@@ -19,49 +182,30 @@ async def chat_response(query):
|
|
| 19 |
# Build list of only content, no metadata
|
| 20 |
context_retrieved_formatted = "||".join(doc['answer'] for doc in context_retrieved)
|
| 21 |
context_retrieved_lst = [doc['answer'] for doc in context_retrieved]
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# Generate response
|
| 30 |
-
response = await generate(query=query, context=context_retrieved_lst)
|
| 31 |
-
|
| 32 |
-
# Add disclaimer to the response
|
| 33 |
-
response_with_disclaimer = BEGINNING_TEXT + response
|
| 34 |
-
# Log the interaction
|
| 35 |
-
# try:
|
| 36 |
-
# chat_logger.log(
|
| 37 |
-
# query=query,
|
| 38 |
-
# answer=response,
|
| 39 |
-
# retrieved_content=context_retrieved_lst,
|
| 40 |
-
# request=request
|
| 41 |
-
# )
|
| 42 |
-
# except Exception as e:
|
| 43 |
-
# print(f"Logging error: {str(e)}")
|
| 44 |
-
|
| 45 |
|
| 46 |
-
# Stream response character by character
|
| 47 |
-
displayed_response = ""
|
| 48 |
-
for i, char in enumerate(response_with_disclaimer):
|
| 49 |
-
displayed_response += char
|
| 50 |
-
|
| 51 |
-
yield displayed_response
|
| 52 |
-
# Only add delay every few characters to avoid being too slow
|
| 53 |
-
if i % 3 == 0:
|
| 54 |
-
await asyncio.sleep(0.02)
|
| 55 |
|
| 56 |
except Exception as e:
|
| 57 |
error_message = f"Error processing request: {str(e)}"
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# 10.1. Question input
|
| 61 |
query = st.text_input(
|
| 62 |
label="Enter your question:",
|
| 63 |
key="query",
|
| 64 |
-
on_change=reset_page
|
| 65 |
)
|
| 66 |
|
| 67 |
# Only run search & display if user has entered something
|
|
@@ -69,4 +213,4 @@ if not query.strip():
|
|
| 69 |
st.info("Please enter a question to see results.")
|
| 70 |
st.stop()
|
| 71 |
else:
|
| 72 |
-
st.
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from utils.retriever import retrieve_paragraphs
|
| 3 |
+
import ast
|
| 4 |
+
import time
|
| 5 |
+
import asyncio
|
| 6 |
+
import logging
|
| 7 |
+
import logging
|
| 8 |
+
logging.basicConfig(level=logging.INFO)
|
| 9 |
+
import os
|
| 10 |
+
import configparser
|
| 11 |
|
| 12 |
+
|
| 13 |
+
def getconfig(configfile_path: str):
|
| 14 |
+
"""
|
| 15 |
+
Read the config file
|
| 16 |
+
Params
|
| 17 |
+
----------------
|
| 18 |
+
configfile_path: file path of .cfg file
|
| 19 |
+
"""
|
| 20 |
+
config = configparser.ConfigParser()
|
| 21 |
+
try:
|
| 22 |
+
config.read_file(open(configfile_path))
|
| 23 |
+
return config
|
| 24 |
+
except:
|
| 25 |
+
logging.warning("config file not found")
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------
|
| 28 |
+
# Provider-agnostic authentication and configuration
|
| 29 |
+
# ---------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
def get_auth(provider: str) -> dict:
|
| 32 |
+
"""Get authentication configuration for different providers"""
|
| 33 |
+
auth_configs = {
|
| 34 |
+
"openai": {"api_key": os.getenv("OPENAI_API_KEY")},
|
| 35 |
+
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
|
| 36 |
+
"anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
|
| 37 |
+
"cohere": {"api_key": os.getenv("COHERE_API_KEY")},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if provider not in auth_configs:
|
| 41 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 42 |
+
|
| 43 |
+
auth_config = auth_configs[provider]
|
| 44 |
+
api_key = auth_config.get("api_key")
|
| 45 |
+
|
| 46 |
+
if not api_key:
|
| 47 |
+
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
|
| 48 |
+
|
| 49 |
+
return auth_config
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------
|
| 52 |
+
# Model / client initialization (non exaustive list of providers)
|
| 53 |
+
# ---------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
config = getconfig("model_params.cfg")
|
| 56 |
+
|
| 57 |
+
PROVIDER = config.get("generator", "PROVIDER")
|
| 58 |
+
MODEL = config.get("generator", "MODEL")
|
| 59 |
+
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
|
| 60 |
+
TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
|
| 61 |
+
INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
|
| 62 |
+
ORGANIZATION = config.get("generator", "ORGANIZATION")
|
| 63 |
+
|
| 64 |
+
# Set up authentication for the selected provider
|
| 65 |
+
auth_config = get_auth(PROVIDER)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 69 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
| 70 |
+
|
| 71 |
+
def build_messages(question: str, context: str) -> list:
|
| 72 |
+
"""
|
| 73 |
+
Build messages in LangChain format.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
question: The user's question
|
| 77 |
+
context: The relevant context for answering
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
List of LangChain message objects
|
| 81 |
+
"""
|
| 82 |
+
system_content = (
|
| 83 |
+
"""
|
| 84 |
+
You are an expert assistant. Your task is to generate accurate, helpful responses using only the
|
| 85 |
+
information contained in the "CONTEXT" provided.
|
| 86 |
+
Instructions:
|
| 87 |
+
- Answer based only on provided context: Use only the information present in the retrieved_paragraphs below. Do not use any external knowledge or make assumptions beyond what is explicitly stated.
|
| 88 |
+
- Language matching: Respond in the same language as the user's query.
|
| 89 |
+
- Handle missing information: If the retrieved paragraphs do not contain sufficient information to answer the query, respond with "I don't know" or equivalent in the query language. If information is incomplete, state what you know and acknowledge limitations.
|
| 90 |
+
- Be accurate and specific: When information is available, provide clear, specific answers. Include relevant details, useful facts, and numbers from the context.
|
| 91 |
+
- Stay focused: Answer only what is asked. Do not provide additional information not requested.
|
| 92 |
+
- Structure your response effectively:
|
| 93 |
+
* Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 94 |
+
* Use bullet points and lists when it makes sense to improve readability.
|
| 95 |
+
* You do not need to use every passage. Only use the ones that help answer the question.
|
| 96 |
+
- Format your response properly: Use markdown formatting (bullet points, numbered lists, headers) to make your response clear and easy to read. Example: <br> for linebreaks
|
| 97 |
+
|
| 98 |
+
Input Format:
|
| 99 |
+
- Query: {query}
|
| 100 |
+
- Retrieved Paragraphs: {retrieved_paragraphs}
|
| 101 |
+
Generate your response based on these guidelines.
|
| 102 |
+
"""
|
| 103 |
)
|
| 104 |
+
|
| 105 |
+
user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
|
| 106 |
+
|
| 107 |
+
return [
|
| 108 |
+
SystemMessage(content=system_content),
|
| 109 |
+
HumanMessage(content=user_content)
|
| 110 |
+
]
|
| 111 |
+
def get_chat_model():
|
| 112 |
+
"""Initialize the appropriate LangChain chat model based on provider"""
|
| 113 |
+
common_params = {
|
| 114 |
+
"temperature": TEMPERATURE,
|
| 115 |
+
"max_tokens": MAX_TOKENS,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# if PROVIDER == "openai":
|
| 119 |
+
# return ChatOpenAI(
|
| 120 |
+
# model=MODEL,
|
| 121 |
+
# openai_api_key=auth_config["api_key"],
|
| 122 |
+
# **common_params
|
| 123 |
+
# )
|
| 124 |
+
# elif PROVIDER == "anthropic":
|
| 125 |
+
# return ChatAnthropic(
|
| 126 |
+
# model=MODEL,
|
| 127 |
+
# anthropic_api_key=auth_config["api_key"],
|
| 128 |
+
# **common_params
|
| 129 |
+
# )
|
| 130 |
+
# elif PROVIDER == "cohere":
|
| 131 |
+
# return ChatCohere(
|
| 132 |
+
# model=MODEL,
|
| 133 |
+
# cohere_api_key=auth_config["api_key"],
|
| 134 |
+
# **common_params
|
| 135 |
+
# )
|
| 136 |
+
if PROVIDER == "huggingface":
|
| 137 |
+
# Initialize HuggingFaceEndpoint with explicit parameters
|
| 138 |
+
llm = HuggingFaceEndpoint(
|
| 139 |
+
repo_id=MODEL,
|
| 140 |
+
huggingfacehub_api_token=auth_config["api_key"],
|
| 141 |
+
task="text-generation",
|
| 142 |
+
provider=INFERENCE_PROVIDER,
|
| 143 |
+
server_kwargs={"bill_to": ORGANIZATION},
|
| 144 |
+
temperature=TEMPERATURE,
|
| 145 |
+
max_new_tokens=MAX_TOKENS
|
| 146 |
+
)
|
| 147 |
+
return ChatHuggingFace(llm=llm)
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Unsupported provider: {PROVIDER}")
|
| 150 |
+
|
| 151 |
+
# Initialize provider-agnostic chat model
|
| 152 |
+
chat_model = get_chat_model()
|
| 153 |
+
|
| 154 |
+
async def _call_llm(messages: list) -> str:
|
| 155 |
+
"""
|
| 156 |
+
Provider-agnostic LLM call using LangChain.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
messages: List of LangChain message objects
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Generated response content as string
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
# Use async invoke for better performance
|
| 166 |
+
response = await chat_model.ainvoke(messages)
|
| 167 |
+
logging.info(f"answer: {response.content}")
|
| 168 |
+
return response.content
|
| 169 |
+
#return response.content.strip()
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
|
| 172 |
+
raise
|
| 173 |
+
|
| 174 |
|
| 175 |
+
def chat_response(query):
|
| 176 |
"""Generate chat response based on method and inputs"""
|
| 177 |
|
| 178 |
try:
|
|
|
|
| 182 |
# Build list of only content, no metadata
|
| 183 |
context_retrieved_formatted = "||".join(doc['answer'] for doc in context_retrieved)
|
| 184 |
context_retrieved_lst = [doc['answer'] for doc in context_retrieved]
|
| 185 |
+
logging.info("Context Retrieval done")
|
| 186 |
+
|
| 187 |
+
messages = build_messages(query, context_retrieved_lst)
|
| 188 |
+
answer = asyncio.run(_call_llm(messages))
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
return answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
except Exception as e:
|
| 195 |
error_message = f"Error processing request: {str(e)}"
|
| 196 |
+
return error_message
|
| 197 |
+
|
| 198 |
+
col_title, col_about = st.columns([8, 2])
|
| 199 |
+
with col_title:
|
| 200 |
+
st.markdown(
|
| 201 |
+
"<h1 style='text-align:center;'> Montreal AI Decisions (MVP)</h1>",
|
| 202 |
+
unsafe_allow_html=True
|
| 203 |
+
)
|
| 204 |
|
| 205 |
# 10.1. Question input
|
| 206 |
query = st.text_input(
|
| 207 |
label="Enter your question:",
|
| 208 |
key="query",
|
|
|
|
| 209 |
)
|
| 210 |
|
| 211 |
# Only run search & display if user has entered something
|
|
|
|
| 213 |
st.info("Please enter a question to see results.")
|
| 214 |
st.stop()
|
| 215 |
else:
|
| 216 |
+
st.write(chat_response(query))
|
utils/__pycache__/generator.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/generator.cpython-311.pyc and b/utils/__pycache__/generator.cpython-311.pyc differ
|
|
|
utils/__pycache__/retriever.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/retriever.cpython-311.pyc and b/utils/__pycache__/retriever.cpython-311.pyc differ
|
|
|
utils/generator.py
DELETED
|
@@ -1,287 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import asyncio
|
| 3 |
-
import json
|
| 4 |
-
import ast
|
| 5 |
-
from typing import List, Dict, Any, Union
|
| 6 |
-
from dotenv import load_dotenv
|
| 7 |
-
|
| 8 |
-
# LangChain imports
|
| 9 |
-
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
| 10 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
| 11 |
-
|
| 12 |
-
import os
|
| 13 |
-
import configparser
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def getconfig(configfile_path: str):
|
| 17 |
-
"""
|
| 18 |
-
Read the config file
|
| 19 |
-
Params
|
| 20 |
-
----------------
|
| 21 |
-
configfile_path: file path of .cfg file
|
| 22 |
-
"""
|
| 23 |
-
config = configparser.ConfigParser()
|
| 24 |
-
try:
|
| 25 |
-
config.read_file(open(configfile_path))
|
| 26 |
-
return config
|
| 27 |
-
except:
|
| 28 |
-
logging.warning("config file not found")
|
| 29 |
-
|
| 30 |
-
# ---------------------------------------------------------------------
|
| 31 |
-
# Provider-agnostic authentication and configuration
|
| 32 |
-
# ---------------------------------------------------------------------
|
| 33 |
-
|
| 34 |
-
def get_auth(provider: str) -> dict:
|
| 35 |
-
"""Get authentication configuration for different providers"""
|
| 36 |
-
auth_configs = {
|
| 37 |
-
"openai": {"api_key": os.getenv("OPENAI_API_KEY")},
|
| 38 |
-
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
|
| 39 |
-
"anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
|
| 40 |
-
"cohere": {"api_key": os.getenv("COHERE_API_KEY")},
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
if provider not in auth_configs:
|
| 44 |
-
raise ValueError(f"Unsupported provider: {provider}")
|
| 45 |
-
|
| 46 |
-
auth_config = auth_configs[provider]
|
| 47 |
-
api_key = auth_config.get("api_key")
|
| 48 |
-
|
| 49 |
-
if not api_key:
|
| 50 |
-
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
|
| 51 |
-
|
| 52 |
-
return auth_config
|
| 53 |
-
|
| 54 |
-
# ---------------------------------------------------------------------
|
| 55 |
-
# Model / client initialization (non exaustive list of providers)
|
| 56 |
-
# ---------------------------------------------------------------------
|
| 57 |
-
|
| 58 |
-
config = getconfig("model_params.cfg")
|
| 59 |
-
|
| 60 |
-
PROVIDER = config.get("generator", "PROVIDER")
|
| 61 |
-
MODEL = config.get("generator", "MODEL")
|
| 62 |
-
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
|
| 63 |
-
TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
|
| 64 |
-
INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
|
| 65 |
-
ORGANIZATION = config.get("generator", "ORGANIZATION")
|
| 66 |
-
|
| 67 |
-
# Set up authentication for the selected provider
|
| 68 |
-
auth_config = get_auth(PROVIDER)
|
| 69 |
-
|
| 70 |
-
def get_chat_model():
|
| 71 |
-
"""Initialize the appropriate LangChain chat model based on provider"""
|
| 72 |
-
common_params = {
|
| 73 |
-
"temperature": TEMPERATURE,
|
| 74 |
-
"max_tokens": MAX_TOKENS,
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
# if PROVIDER == "openai":
|
| 78 |
-
# return ChatOpenAI(
|
| 79 |
-
# model=MODEL,
|
| 80 |
-
# openai_api_key=auth_config["api_key"],
|
| 81 |
-
# **common_params
|
| 82 |
-
# )
|
| 83 |
-
# elif PROVIDER == "anthropic":
|
| 84 |
-
# return ChatAnthropic(
|
| 85 |
-
# model=MODEL,
|
| 86 |
-
# anthropic_api_key=auth_config["api_key"],
|
| 87 |
-
# **common_params
|
| 88 |
-
# )
|
| 89 |
-
# elif PROVIDER == "cohere":
|
| 90 |
-
# return ChatCohere(
|
| 91 |
-
# model=MODEL,
|
| 92 |
-
# cohere_api_key=auth_config["api_key"],
|
| 93 |
-
# **common_params
|
| 94 |
-
# )
|
| 95 |
-
if PROVIDER == "huggingface":
|
| 96 |
-
# Initialize HuggingFaceEndpoint with explicit parameters
|
| 97 |
-
llm = HuggingFaceEndpoint(
|
| 98 |
-
repo_id=MODEL,
|
| 99 |
-
huggingfacehub_api_token=auth_config["api_key"],
|
| 100 |
-
task="text-generation",
|
| 101 |
-
provider=INFERENCE_PROVIDER,
|
| 102 |
-
server_kwargs={"bill_to": ORGANIZATION},
|
| 103 |
-
temperature=TEMPERATURE,
|
| 104 |
-
max_new_tokens=MAX_TOKENS
|
| 105 |
-
)
|
| 106 |
-
return ChatHuggingFace(llm=llm)
|
| 107 |
-
else:
|
| 108 |
-
raise ValueError(f"Unsupported provider: {PROVIDER}")
|
| 109 |
-
|
| 110 |
-
# Initialize provider-agnostic chat model
|
| 111 |
-
chat_model = get_chat_model()
|
| 112 |
-
|
| 113 |
-
# ---------------------------------------------------------------------
|
| 114 |
-
# Context processing - may need further refinement (i.e. to manage other data sources)
|
| 115 |
-
# ---------------------------------------------------------------------
|
| 116 |
-
# def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 117 |
-
# """
|
| 118 |
-
# Extract only relevant fields from retrieval results.
|
| 119 |
-
|
| 120 |
-
# Args:
|
| 121 |
-
# retrieval_results: List of JSON objects from retriever
|
| 122 |
-
|
| 123 |
-
# Returns:
|
| 124 |
-
# List of processed objects with only relevant fields
|
| 125 |
-
# """
|
| 126 |
-
|
| 127 |
-
# retrieval_results = ast.literal_eval(retrieval_results)
|
| 128 |
-
|
| 129 |
-
# processed_results = []
|
| 130 |
-
|
| 131 |
-
# for result in retrieval_results:
|
| 132 |
-
# # Extract the answer content
|
| 133 |
-
# answer = result.get('answer', '')
|
| 134 |
-
|
| 135 |
-
# # Extract document identification from metadata
|
| 136 |
-
# metadata = result.get('answer_metadata', {})
|
| 137 |
-
# doc_info = {
|
| 138 |
-
# 'answer': answer,
|
| 139 |
-
# 'filename': metadata.get('filename', 'Unknown'),
|
| 140 |
-
# 'page': metadata.get('page', 'Unknown'),
|
| 141 |
-
# 'year': metadata.get('year', 'Unknown'),
|
| 142 |
-
# 'source': metadata.get('source', 'Unknown'),
|
| 143 |
-
# 'document_id': metadata.get('_id', 'Unknown')
|
| 144 |
-
# }
|
| 145 |
-
|
| 146 |
-
# processed_results.append(doc_info)
|
| 147 |
-
|
| 148 |
-
# return processed_results
|
| 149 |
-
|
| 150 |
-
# def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str:
|
| 151 |
-
# """
|
| 152 |
-
# Format processed retrieval results into a context string for the LLM.
|
| 153 |
-
|
| 154 |
-
# Args:
|
| 155 |
-
# processed_results: List of processed objects with relevant fields
|
| 156 |
-
|
| 157 |
-
# Returns:
|
| 158 |
-
# Formatted context string
|
| 159 |
-
# """
|
| 160 |
-
# if not processed_results:
|
| 161 |
-
# return ""
|
| 162 |
-
|
| 163 |
-
# context_parts = []
|
| 164 |
-
|
| 165 |
-
# for i, result in enumerate(processed_results, 1):
|
| 166 |
-
# doc_reference = f"[Document {i}: {result['filename']}"
|
| 167 |
-
# if result['page'] != 'Unknown':
|
| 168 |
-
# doc_reference += f", Page {result['page']}"
|
| 169 |
-
# if result['year'] != 'Unknown':
|
| 170 |
-
# doc_reference += f", Year {result['year']}"
|
| 171 |
-
# doc_reference += "]"
|
| 172 |
-
|
| 173 |
-
# context_part = f"{doc_reference}\n{result['answer']}\n"
|
| 174 |
-
# context_parts.append(context_part)
|
| 175 |
-
|
| 176 |
-
# return "\n".join(context_parts)
|
| 177 |
-
|
| 178 |
-
# ---------------------------------------------------------------------
|
| 179 |
-
# Core generation function for both Gradio UI and MCP
|
| 180 |
-
# ---------------------------------------------------------------------
|
| 181 |
-
async def _call_llm(messages: list) -> str:
|
| 182 |
-
"""
|
| 183 |
-
Provider-agnostic LLM call using LangChain.
|
| 184 |
-
|
| 185 |
-
Args:
|
| 186 |
-
messages: List of LangChain message objects
|
| 187 |
-
|
| 188 |
-
Returns:
|
| 189 |
-
Generated response content as string
|
| 190 |
-
"""
|
| 191 |
-
try:
|
| 192 |
-
# Use async invoke for better performance
|
| 193 |
-
response = await chat_model.ainvoke(messages)
|
| 194 |
-
print(response)
|
| 195 |
-
return response.content
|
| 196 |
-
#return response.content.strip()
|
| 197 |
-
except Exception as e:
|
| 198 |
-
logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
|
| 199 |
-
raise
|
| 200 |
-
|
| 201 |
-
def build_messages(question: str, context: str) -> list:
|
| 202 |
-
"""
|
| 203 |
-
Build messages in LangChain format.
|
| 204 |
-
|
| 205 |
-
Args:
|
| 206 |
-
question: The user's question
|
| 207 |
-
context: The relevant context for answering
|
| 208 |
-
|
| 209 |
-
Returns:
|
| 210 |
-
List of LangChain message objects
|
| 211 |
-
"""
|
| 212 |
-
system_content = (
|
| 213 |
-
"""
|
| 214 |
-
You are an expert assistant. Your task is to generate accurate, helpful responses using only the
|
| 215 |
-
information contained in the "CONTEXT" provided.
|
| 216 |
-
Instructions:
|
| 217 |
-
- Answer based only on provided context: Use only the information present in the retrieved_paragraphs below. Do not use any external knowledge or make assumptions beyond what is explicitly stated.
|
| 218 |
-
- Language matching: Respond in the same language as the user's query.
|
| 219 |
-
- Handle missing information: If the retrieved paragraphs do not contain sufficient information to answer the query, respond with "I don't know" or equivalent in the query language. If information is incomplete, state what you know and acknowledge limitations.
|
| 220 |
-
- Be accurate and specific: When information is available, provide clear, specific answers. Include relevant details, useful facts, and numbers from the context.
|
| 221 |
-
- Stay focused: Answer only what is asked. Do not provide additional information not requested.
|
| 222 |
-
- Structure your response effectively:
|
| 223 |
-
* Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 224 |
-
* Use bullet points and lists when it makes sense to improve readability.
|
| 225 |
-
* You do not need to use every passage. Only use the ones that help answer the question.
|
| 226 |
-
- Format your response properly: Use markdown formatting (bullet points, numbered lists, headers) to make your response clear and easy to read. Example: <br> for linebreaks
|
| 227 |
-
|
| 228 |
-
Input Format:
|
| 229 |
-
- Query: {query}
|
| 230 |
-
- Retrieved Paragraphs: {retrieved_paragraphs}
|
| 231 |
-
Generate your response based on these guidelines.
|
| 232 |
-
"""
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
|
| 236 |
-
|
| 237 |
-
return [
|
| 238 |
-
SystemMessage(content=system_content),
|
| 239 |
-
HumanMessage(content=user_content)
|
| 240 |
-
]
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str:
|
| 244 |
-
"""
|
| 245 |
-
Generate an answer to a query using provided context through RAG.
|
| 246 |
-
|
| 247 |
-
This function takes a user query and relevant context, then uses a language model
|
| 248 |
-
to generate a comprehensive answer based on the provided information.
|
| 249 |
-
|
| 250 |
-
Args:
|
| 251 |
-
query (str): User query
|
| 252 |
-
context (list): List of retrieval result objects (dictionaries)
|
| 253 |
-
Returns:
|
| 254 |
-
str: The generated answer based on the query and context
|
| 255 |
-
"""
|
| 256 |
-
if not query.strip():
|
| 257 |
-
return "Error: Query cannot be empty"
|
| 258 |
-
|
| 259 |
-
# Handle both string context (for Gradio UI) and list context (from retriever)
|
| 260 |
-
if isinstance(context, list):
|
| 261 |
-
if not context:
|
| 262 |
-
return "Error: No retrieval results provided"
|
| 263 |
-
|
| 264 |
-
# # Process the retrieval results
|
| 265 |
-
# processed_results = extract_relevant_fields(context)
|
| 266 |
-
formatted_context = context
|
| 267 |
-
|
| 268 |
-
# if not formatted_context.strip():
|
| 269 |
-
# return "Error: No valid content found in retrieval results"
|
| 270 |
-
|
| 271 |
-
elif isinstance(context, str):
|
| 272 |
-
if not context.strip():
|
| 273 |
-
return "Error: Context cannot be empty"
|
| 274 |
-
formatted_context = context
|
| 275 |
-
|
| 276 |
-
else:
|
| 277 |
-
return "Error: Context must be either a string or list of retrieval results"
|
| 278 |
-
|
| 279 |
-
try:
|
| 280 |
-
messages = build_messages(query, formatted_context)
|
| 281 |
-
answer = await _call_llm(messages)
|
| 282 |
-
|
| 283 |
-
return answer
|
| 284 |
-
|
| 285 |
-
except Exception as e:
|
| 286 |
-
logging.exception("Generation failed")
|
| 287 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|