ppsingh commited on
Commit
31cb392
·
1 Parent(s): 7f5ac0d

generate basic

Browse files
app.py CHANGED
@@ -1,15 +1,178 @@
1
  import streamlit as st
2
  from utils.retriever import retrieve_paragraphs
3
- from utils.generator import generate
 
 
 
 
 
 
 
4
 
5
- col_title, col_about = st.columns([8, 2])
6
- with col_title:
7
- st.markdown(
8
- "<h1 style='text-align:center;'> Montreal AI Decisions (MVP)</h1>",
9
- unsafe_allow_html=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- async def chat_response(query):
13
  """Generate chat response based on method and inputs"""
14
 
15
  try:
@@ -19,49 +182,30 @@ async def chat_response(query):
19
  # Build list of only content, no metadata
20
  context_retrieved_formatted = "||".join(doc['answer'] for doc in context_retrieved)
21
  context_retrieved_lst = [doc['answer'] for doc in context_retrieved]
22
-
23
- # # Prepare HTML for displaying source documents
24
- # docs_html = []
25
- # for i, d in enumerate(context_retrieved, 1):
26
- # docs_html.append(make_html_source(d, i))
27
- # docs_html = "".join(docs_html)
28
-
29
- # Generate response
30
- response = await generate(query=query, context=context_retrieved_lst)
31
-
32
- # Add disclaimer to the response
33
- response_with_disclaimer = BEGINNING_TEXT + response
34
- # Log the interaction
35
- # try:
36
- # chat_logger.log(
37
- # query=query,
38
- # answer=response,
39
- # retrieved_content=context_retrieved_lst,
40
- # request=request
41
- # )
42
- # except Exception as e:
43
- # print(f"Logging error: {str(e)}")
44
-
45
 
46
- # Stream response character by character
47
- displayed_response = ""
48
- for i, char in enumerate(response_with_disclaimer):
49
- displayed_response += char
50
-
51
- yield displayed_response
52
- # Only add delay every few characters to avoid being too slow
53
- if i % 3 == 0:
54
- await asyncio.sleep(0.02)
55
 
56
  except Exception as e:
57
  error_message = f"Error processing request: {str(e)}"
58
- yield error_message
 
 
 
 
 
 
 
59
 
60
  # 10.1. Question input
61
  query = st.text_input(
62
  label="Enter your question:",
63
  key="query",
64
- on_change=reset_page
65
  )
66
 
67
  # Only run search & display if user has entered something
@@ -69,4 +213,4 @@ if not query.strip():
69
  st.info("Please enter a question to see results.")
70
  st.stop()
71
  else:
72
- st.write_stream(chat_response(query))
 
1
  import streamlit as st
2
  from utils.retriever import retrieve_paragraphs
3
+ import ast
4
+ import time
5
+ import asyncio
6
+ import logging
7
+ import logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ import os
10
+ import configparser
11
 
12
+
13
+ def getconfig(configfile_path: str):
14
+ """
15
+ Read the config file
16
+ Params
17
+ ----------------
18
+ configfile_path: file path of .cfg file
19
+ """
20
+ config = configparser.ConfigParser()
21
+ try:
22
+ config.read_file(open(configfile_path))
23
+ return config
24
+ except:
25
+ logging.warning("config file not found")
26
+
27
+ # ---------------------------------------------------------------------
28
+ # Provider-agnostic authentication and configuration
29
+ # ---------------------------------------------------------------------
30
+
31
+ def get_auth(provider: str) -> dict:
32
+ """Get authentication configuration for different providers"""
33
+ auth_configs = {
34
+ "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
35
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
36
+ "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
37
+ "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
38
+ }
39
+
40
+ if provider not in auth_configs:
41
+ raise ValueError(f"Unsupported provider: {provider}")
42
+
43
+ auth_config = auth_configs[provider]
44
+ api_key = auth_config.get("api_key")
45
+
46
+ if not api_key:
47
+ raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
48
+
49
+ return auth_config
50
+
51
+ # ---------------------------------------------------------------------
52
+ # Model / client initialization (non exaustive list of providers)
53
+ # ---------------------------------------------------------------------
54
+
55
+ config = getconfig("model_params.cfg")
56
+
57
+ PROVIDER = config.get("generator", "PROVIDER")
58
+ MODEL = config.get("generator", "MODEL")
59
+ MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
60
+ TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
61
+ INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
62
+ ORGANIZATION = config.get("generator", "ORGANIZATION")
63
+
64
+ # Set up authentication for the selected provider
65
+ auth_config = get_auth(PROVIDER)
66
+
67
+
68
+ from langchain_core.messages import SystemMessage, HumanMessage
69
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
70
+
71
+ def build_messages(question: str, context: str) -> list:
72
+ """
73
+ Build messages in LangChain format.
74
+
75
+ Args:
76
+ question: The user's question
77
+ context: The relevant context for answering
78
+
79
+ Returns:
80
+ List of LangChain message objects
81
+ """
82
+ system_content = (
83
+ """
84
+ You are an expert assistant. Your task is to generate accurate, helpful responses using only the
85
+ information contained in the "CONTEXT" provided.
86
+ Instructions:
87
+ - Answer based only on provided context: Use only the information present in the retrieved_paragraphs below. Do not use any external knowledge or make assumptions beyond what is explicitly stated.
88
+ - Language matching: Respond in the same language as the user's query.
89
+ - Handle missing information: If the retrieved paragraphs do not contain sufficient information to answer the query, respond with "I don't know" or equivalent in the query language. If information is incomplete, state what you know and acknowledge limitations.
90
+ - Be accurate and specific: When information is available, provide clear, specific answers. Include relevant details, useful facts, and numbers from the context.
91
+ - Stay focused: Answer only what is asked. Do not provide additional information not requested.
92
+ - Structure your response effectively:
93
+ * Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
94
+ * Use bullet points and lists when it makes sense to improve readability.
95
+ * You do not need to use every passage. Only use the ones that help answer the question.
96
+ - Format your response properly: Use markdown formatting (bullet points, numbered lists, headers) to make your response clear and easy to read. Example: <br> for linebreaks
97
+
98
+ Input Format:
99
+ - Query: {query}
100
+ - Retrieved Paragraphs: {retrieved_paragraphs}
101
+ Generate your response based on these guidelines.
102
+ """
103
  )
104
+
105
+ user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
106
+
107
+ return [
108
+ SystemMessage(content=system_content),
109
+ HumanMessage(content=user_content)
110
+ ]
111
+ def get_chat_model():
112
+ """Initialize the appropriate LangChain chat model based on provider"""
113
+ common_params = {
114
+ "temperature": TEMPERATURE,
115
+ "max_tokens": MAX_TOKENS,
116
+ }
117
+
118
+ # if PROVIDER == "openai":
119
+ # return ChatOpenAI(
120
+ # model=MODEL,
121
+ # openai_api_key=auth_config["api_key"],
122
+ # **common_params
123
+ # )
124
+ # elif PROVIDER == "anthropic":
125
+ # return ChatAnthropic(
126
+ # model=MODEL,
127
+ # anthropic_api_key=auth_config["api_key"],
128
+ # **common_params
129
+ # )
130
+ # elif PROVIDER == "cohere":
131
+ # return ChatCohere(
132
+ # model=MODEL,
133
+ # cohere_api_key=auth_config["api_key"],
134
+ # **common_params
135
+ # )
136
+ if PROVIDER == "huggingface":
137
+ # Initialize HuggingFaceEndpoint with explicit parameters
138
+ llm = HuggingFaceEndpoint(
139
+ repo_id=MODEL,
140
+ huggingfacehub_api_token=auth_config["api_key"],
141
+ task="text-generation",
142
+ provider=INFERENCE_PROVIDER,
143
+ server_kwargs={"bill_to": ORGANIZATION},
144
+ temperature=TEMPERATURE,
145
+ max_new_tokens=MAX_TOKENS
146
+ )
147
+ return ChatHuggingFace(llm=llm)
148
+ else:
149
+ raise ValueError(f"Unsupported provider: {PROVIDER}")
150
+
151
+ # Initialize provider-agnostic chat model
152
+ chat_model = get_chat_model()
153
+
154
+ async def _call_llm(messages: list) -> str:
155
+ """
156
+ Provider-agnostic LLM call using LangChain.
157
+
158
+ Args:
159
+ messages: List of LangChain message objects
160
+
161
+ Returns:
162
+ Generated response content as string
163
+ """
164
+ try:
165
+ # Use async invoke for better performance
166
+ response = await chat_model.ainvoke(messages)
167
+ logging.info(f"answer: {response.content}")
168
+ return response.content
169
+ #return response.content.strip()
170
+ except Exception as e:
171
+ logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
172
+ raise
173
+
174
 
175
+ def chat_response(query):
176
  """Generate chat response based on method and inputs"""
177
 
178
  try:
 
182
  # Build list of only content, no metadata
183
  context_retrieved_formatted = "||".join(doc['answer'] for doc in context_retrieved)
184
  context_retrieved_lst = [doc['answer'] for doc in context_retrieved]
185
+ logging.info("Context Retrieval done")
186
+
187
+ messages = build_messages(query, context_retrieved_lst)
188
+ answer = asyncio.run(_call_llm(messages))
189
+
190
+
191
+ return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
 
 
 
 
 
 
 
 
 
193
 
194
  except Exception as e:
195
  error_message = f"Error processing request: {str(e)}"
196
+ return error_message
197
+
198
+ col_title, col_about = st.columns([8, 2])
199
+ with col_title:
200
+ st.markdown(
201
+ "<h1 style='text-align:center;'> Montreal AI Decisions (MVP)</h1>",
202
+ unsafe_allow_html=True
203
+ )
204
 
205
  # 10.1. Question input
206
  query = st.text_input(
207
  label="Enter your question:",
208
  key="query",
 
209
  )
210
 
211
  # Only run search & display if user has entered something
 
213
  st.info("Please enter a question to see results.")
214
  st.stop()
215
  else:
216
+ st.write(chat_response(query))
utils/__pycache__/generator.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/generator.cpython-311.pyc and b/utils/__pycache__/generator.cpython-311.pyc differ
 
utils/__pycache__/retriever.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/retriever.cpython-311.pyc and b/utils/__pycache__/retriever.cpython-311.pyc differ
 
utils/generator.py DELETED
@@ -1,287 +0,0 @@
1
- import logging
2
- import asyncio
3
- import json
4
- import ast
5
- from typing import List, Dict, Any, Union
6
- from dotenv import load_dotenv
7
-
8
- # LangChain imports
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
- from langchain_core.messages import SystemMessage, HumanMessage
11
-
12
- import os
13
- import configparser
14
-
15
-
16
- def getconfig(configfile_path: str):
17
- """
18
- Read the config file
19
- Params
20
- ----------------
21
- configfile_path: file path of .cfg file
22
- """
23
- config = configparser.ConfigParser()
24
- try:
25
- config.read_file(open(configfile_path))
26
- return config
27
- except:
28
- logging.warning("config file not found")
29
-
30
- # ---------------------------------------------------------------------
31
- # Provider-agnostic authentication and configuration
32
- # ---------------------------------------------------------------------
33
-
34
- def get_auth(provider: str) -> dict:
35
- """Get authentication configuration for different providers"""
36
- auth_configs = {
37
- "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
38
- "huggingface": {"api_key": os.getenv("HF_TOKEN")},
39
- "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
40
- "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
41
- }
42
-
43
- if provider not in auth_configs:
44
- raise ValueError(f"Unsupported provider: {provider}")
45
-
46
- auth_config = auth_configs[provider]
47
- api_key = auth_config.get("api_key")
48
-
49
- if not api_key:
50
- raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
51
-
52
- return auth_config
53
-
54
- # ---------------------------------------------------------------------
55
- # Model / client initialization (non exaustive list of providers)
56
- # ---------------------------------------------------------------------
57
-
58
- config = getconfig("model_params.cfg")
59
-
60
- PROVIDER = config.get("generator", "PROVIDER")
61
- MODEL = config.get("generator", "MODEL")
62
- MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
63
- TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
64
- INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
65
- ORGANIZATION = config.get("generator", "ORGANIZATION")
66
-
67
- # Set up authentication for the selected provider
68
- auth_config = get_auth(PROVIDER)
69
-
70
- def get_chat_model():
71
- """Initialize the appropriate LangChain chat model based on provider"""
72
- common_params = {
73
- "temperature": TEMPERATURE,
74
- "max_tokens": MAX_TOKENS,
75
- }
76
-
77
- # if PROVIDER == "openai":
78
- # return ChatOpenAI(
79
- # model=MODEL,
80
- # openai_api_key=auth_config["api_key"],
81
- # **common_params
82
- # )
83
- # elif PROVIDER == "anthropic":
84
- # return ChatAnthropic(
85
- # model=MODEL,
86
- # anthropic_api_key=auth_config["api_key"],
87
- # **common_params
88
- # )
89
- # elif PROVIDER == "cohere":
90
- # return ChatCohere(
91
- # model=MODEL,
92
- # cohere_api_key=auth_config["api_key"],
93
- # **common_params
94
- # )
95
- if PROVIDER == "huggingface":
96
- # Initialize HuggingFaceEndpoint with explicit parameters
97
- llm = HuggingFaceEndpoint(
98
- repo_id=MODEL,
99
- huggingfacehub_api_token=auth_config["api_key"],
100
- task="text-generation",
101
- provider=INFERENCE_PROVIDER,
102
- server_kwargs={"bill_to": ORGANIZATION},
103
- temperature=TEMPERATURE,
104
- max_new_tokens=MAX_TOKENS
105
- )
106
- return ChatHuggingFace(llm=llm)
107
- else:
108
- raise ValueError(f"Unsupported provider: {PROVIDER}")
109
-
110
- # Initialize provider-agnostic chat model
111
- chat_model = get_chat_model()
112
-
113
- # ---------------------------------------------------------------------
114
- # Context processing - may need further refinement (i.e. to manage other data sources)
115
- # ---------------------------------------------------------------------
116
- # def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
117
- # """
118
- # Extract only relevant fields from retrieval results.
119
-
120
- # Args:
121
- # retrieval_results: List of JSON objects from retriever
122
-
123
- # Returns:
124
- # List of processed objects with only relevant fields
125
- # """
126
-
127
- # retrieval_results = ast.literal_eval(retrieval_results)
128
-
129
- # processed_results = []
130
-
131
- # for result in retrieval_results:
132
- # # Extract the answer content
133
- # answer = result.get('answer', '')
134
-
135
- # # Extract document identification from metadata
136
- # metadata = result.get('answer_metadata', {})
137
- # doc_info = {
138
- # 'answer': answer,
139
- # 'filename': metadata.get('filename', 'Unknown'),
140
- # 'page': metadata.get('page', 'Unknown'),
141
- # 'year': metadata.get('year', 'Unknown'),
142
- # 'source': metadata.get('source', 'Unknown'),
143
- # 'document_id': metadata.get('_id', 'Unknown')
144
- # }
145
-
146
- # processed_results.append(doc_info)
147
-
148
- # return processed_results
149
-
150
- # def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str:
151
- # """
152
- # Format processed retrieval results into a context string for the LLM.
153
-
154
- # Args:
155
- # processed_results: List of processed objects with relevant fields
156
-
157
- # Returns:
158
- # Formatted context string
159
- # """
160
- # if not processed_results:
161
- # return ""
162
-
163
- # context_parts = []
164
-
165
- # for i, result in enumerate(processed_results, 1):
166
- # doc_reference = f"[Document {i}: {result['filename']}"
167
- # if result['page'] != 'Unknown':
168
- # doc_reference += f", Page {result['page']}"
169
- # if result['year'] != 'Unknown':
170
- # doc_reference += f", Year {result['year']}"
171
- # doc_reference += "]"
172
-
173
- # context_part = f"{doc_reference}\n{result['answer']}\n"
174
- # context_parts.append(context_part)
175
-
176
- # return "\n".join(context_parts)
177
-
178
- # ---------------------------------------------------------------------
179
- # Core generation function for both Gradio UI and MCP
180
- # ---------------------------------------------------------------------
181
- async def _call_llm(messages: list) -> str:
182
- """
183
- Provider-agnostic LLM call using LangChain.
184
-
185
- Args:
186
- messages: List of LangChain message objects
187
-
188
- Returns:
189
- Generated response content as string
190
- """
191
- try:
192
- # Use async invoke for better performance
193
- response = await chat_model.ainvoke(messages)
194
- print(response)
195
- return response.content
196
- #return response.content.strip()
197
- except Exception as e:
198
- logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
199
- raise
200
-
201
- def build_messages(question: str, context: str) -> list:
202
- """
203
- Build messages in LangChain format.
204
-
205
- Args:
206
- question: The user's question
207
- context: The relevant context for answering
208
-
209
- Returns:
210
- List of LangChain message objects
211
- """
212
- system_content = (
213
- """
214
- You are an expert assistant. Your task is to generate accurate, helpful responses using only the
215
- information contained in the "CONTEXT" provided.
216
- Instructions:
217
- - Answer based only on provided context: Use only the information present in the retrieved_paragraphs below. Do not use any external knowledge or make assumptions beyond what is explicitly stated.
218
- - Language matching: Respond in the same language as the user's query.
219
- - Handle missing information: If the retrieved paragraphs do not contain sufficient information to answer the query, respond with "I don't know" or equivalent in the query language. If information is incomplete, state what you know and acknowledge limitations.
220
- - Be accurate and specific: When information is available, provide clear, specific answers. Include relevant details, useful facts, and numbers from the context.
221
- - Stay focused: Answer only what is asked. Do not provide additional information not requested.
222
- - Structure your response effectively:
223
- * Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
224
- * Use bullet points and lists when it makes sense to improve readability.
225
- * You do not need to use every passage. Only use the ones that help answer the question.
226
- - Format your response properly: Use markdown formatting (bullet points, numbered lists, headers) to make your response clear and easy to read. Example: <br> for linebreaks
227
-
228
- Input Format:
229
- - Query: {query}
230
- - Retrieved Paragraphs: {retrieved_paragraphs}
231
- Generate your response based on these guidelines.
232
- """
233
- )
234
-
235
- user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
236
-
237
- return [
238
- SystemMessage(content=system_content),
239
- HumanMessage(content=user_content)
240
- ]
241
-
242
-
243
- async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str:
244
- """
245
- Generate an answer to a query using provided context through RAG.
246
-
247
- This function takes a user query and relevant context, then uses a language model
248
- to generate a comprehensive answer based on the provided information.
249
-
250
- Args:
251
- query (str): User query
252
- context (list): List of retrieval result objects (dictionaries)
253
- Returns:
254
- str: The generated answer based on the query and context
255
- """
256
- if not query.strip():
257
- return "Error: Query cannot be empty"
258
-
259
- # Handle both string context (for Gradio UI) and list context (from retriever)
260
- if isinstance(context, list):
261
- if not context:
262
- return "Error: No retrieval results provided"
263
-
264
- # # Process the retrieval results
265
- # processed_results = extract_relevant_fields(context)
266
- formatted_context = context
267
-
268
- # if not formatted_context.strip():
269
- # return "Error: No valid content found in retrieval results"
270
-
271
- elif isinstance(context, str):
272
- if not context.strip():
273
- return "Error: Context cannot be empty"
274
- formatted_context = context
275
-
276
- else:
277
- return "Error: Context must be either a string or list of retrieval results"
278
-
279
- try:
280
- messages = build_messages(query, formatted_context)
281
- answer = await _call_llm(messages)
282
-
283
- return answer
284
-
285
- except Exception as e:
286
- logging.exception("Generation failed")
287
- return f"Error: {str(e)}"