ppsingh commited on
Commit
056423f
·
1 Parent(s): 31cb392

refactor the generator code

Browse files
app.py CHANGED
@@ -1,176 +1,14 @@
1
  import streamlit as st
2
  from utils.retriever import retrieve_paragraphs
 
3
  import ast
4
  import time
5
  import asyncio
6
  import logging
7
  import logging
8
  logging.basicConfig(level=logging.INFO)
9
- import os
10
- import configparser
11
 
12
 
13
- def getconfig(configfile_path: str):
14
- """
15
- Read the config file
16
- Params
17
- ----------------
18
- configfile_path: file path of .cfg file
19
- """
20
- config = configparser.ConfigParser()
21
- try:
22
- config.read_file(open(configfile_path))
23
- return config
24
- except:
25
- logging.warning("config file not found")
26
-
27
- # ---------------------------------------------------------------------
28
- # Provider-agnostic authentication and configuration
29
- # ---------------------------------------------------------------------
30
-
31
- def get_auth(provider: str) -> dict:
32
- """Get authentication configuration for different providers"""
33
- auth_configs = {
34
- "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
35
- "huggingface": {"api_key": os.getenv("HF_TOKEN")},
36
- "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
37
- "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
38
- }
39
-
40
- if provider not in auth_configs:
41
- raise ValueError(f"Unsupported provider: {provider}")
42
-
43
- auth_config = auth_configs[provider]
44
- api_key = auth_config.get("api_key")
45
-
46
- if not api_key:
47
- raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
48
-
49
- return auth_config
50
-
51
- # ---------------------------------------------------------------------
52
- # Model / client initialization (non exaustive list of providers)
53
- # ---------------------------------------------------------------------
54
-
55
- config = getconfig("model_params.cfg")
56
-
57
- PROVIDER = config.get("generator", "PROVIDER")
58
- MODEL = config.get("generator", "MODEL")
59
- MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
60
- TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
61
- INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
62
- ORGANIZATION = config.get("generator", "ORGANIZATION")
63
-
64
- # Set up authentication for the selected provider
65
- auth_config = get_auth(PROVIDER)
66
-
67
-
68
- from langchain_core.messages import SystemMessage, HumanMessage
69
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
70
-
71
- def build_messages(question: str, context: str) -> list:
72
- """
73
- Build messages in LangChain format.
74
-
75
- Args:
76
- question: The user's question
77
- context: The relevant context for answering
78
-
79
- Returns:
80
- List of LangChain message objects
81
- """
82
- system_content = (
83
- """
84
- You are an expert assistant. Your task is to generate accurate, helpful responses using only the
85
- information contained in the "CONTEXT" provided.
86
- Instructions:
87
- - Answer based only on provided context: Use only the information present in the retrieved_paragraphs below. Do not use any external knowledge or make assumptions beyond what is explicitly stated.
88
- - Language matching: Respond in the same language as the user's query.
89
- - Handle missing information: If the retrieved paragraphs do not contain sufficient information to answer the query, respond with "I don't know" or equivalent in the query language. If information is incomplete, state what you know and acknowledge limitations.
90
- - Be accurate and specific: When information is available, provide clear, specific answers. Include relevant details, useful facts, and numbers from the context.
91
- - Stay focused: Answer only what is asked. Do not provide additional information not requested.
92
- - Structure your response effectively:
93
- * Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
94
- * Use bullet points and lists when it makes sense to improve readability.
95
- * You do not need to use every passage. Only use the ones that help answer the question.
96
- - Format your response properly: Use markdown formatting (bullet points, numbered lists, headers) to make your response clear and easy to read. Example: <br> for linebreaks
97
-
98
- Input Format:
99
- - Query: {query}
100
- - Retrieved Paragraphs: {retrieved_paragraphs}
101
- Generate your response based on these guidelines.
102
- """
103
- )
104
-
105
- user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
106
-
107
- return [
108
- SystemMessage(content=system_content),
109
- HumanMessage(content=user_content)
110
- ]
111
- def get_chat_model():
112
- """Initialize the appropriate LangChain chat model based on provider"""
113
- common_params = {
114
- "temperature": TEMPERATURE,
115
- "max_tokens": MAX_TOKENS,
116
- }
117
-
118
- # if PROVIDER == "openai":
119
- # return ChatOpenAI(
120
- # model=MODEL,
121
- # openai_api_key=auth_config["api_key"],
122
- # **common_params
123
- # )
124
- # elif PROVIDER == "anthropic":
125
- # return ChatAnthropic(
126
- # model=MODEL,
127
- # anthropic_api_key=auth_config["api_key"],
128
- # **common_params
129
- # )
130
- # elif PROVIDER == "cohere":
131
- # return ChatCohere(
132
- # model=MODEL,
133
- # cohere_api_key=auth_config["api_key"],
134
- # **common_params
135
- # )
136
- if PROVIDER == "huggingface":
137
- # Initialize HuggingFaceEndpoint with explicit parameters
138
- llm = HuggingFaceEndpoint(
139
- repo_id=MODEL,
140
- huggingfacehub_api_token=auth_config["api_key"],
141
- task="text-generation",
142
- provider=INFERENCE_PROVIDER,
143
- server_kwargs={"bill_to": ORGANIZATION},
144
- temperature=TEMPERATURE,
145
- max_new_tokens=MAX_TOKENS
146
- )
147
- return ChatHuggingFace(llm=llm)
148
- else:
149
- raise ValueError(f"Unsupported provider: {PROVIDER}")
150
-
151
- # Initialize provider-agnostic chat model
152
- chat_model = get_chat_model()
153
-
154
- async def _call_llm(messages: list) -> str:
155
- """
156
- Provider-agnostic LLM call using LangChain.
157
-
158
- Args:
159
- messages: List of LangChain message objects
160
-
161
- Returns:
162
- Generated response content as string
163
- """
164
- try:
165
- # Use async invoke for better performance
166
- response = await chat_model.ainvoke(messages)
167
- logging.info(f"answer: {response.content}")
168
- return response.content
169
- #return response.content.strip()
170
- except Exception as e:
171
- logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
172
- raise
173
-
174
 
175
  def chat_response(query):
176
  """Generate chat response based on method and inputs"""
@@ -186,8 +24,6 @@ def chat_response(query):
186
 
187
  messages = build_messages(query, context_retrieved_lst)
188
  answer = asyncio.run(_call_llm(messages))
189
-
190
-
191
  return answer
192
 
193
 
@@ -214,3 +50,4 @@ if not query.strip():
214
  st.stop()
215
  else:
216
  st.write(chat_response(query))
 
 
1
  import streamlit as st
2
  from utils.retriever import retrieve_paragraphs
3
+ from utils.generator import build_messages, _call_llm
4
  import ast
5
  import time
6
  import asyncio
7
  import logging
8
  import logging
9
  logging.basicConfig(level=logging.INFO)
 
 
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def chat_response(query):
14
  """Generate chat response based on method and inputs"""
 
24
 
25
  messages = build_messages(query, context_retrieved_lst)
26
  answer = asyncio.run(_call_llm(messages))
 
 
27
  return answer
28
 
29
 
 
50
  st.stop()
51
  else:
52
  st.write(chat_response(query))
53
+
utils/__pycache__/generator.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/generator.cpython-311.pyc and b/utils/__pycache__/generator.cpython-311.pyc differ
 
utils/__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.84 kB). View file
 
utils/generator.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.utils import getconfig, get_auth
2
+ from langchain_core.messages import SystemMessage, HumanMessage
3
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
4
+ import logging
5
+ # ---------------------------------------------------------------------
6
+ # Model / client initialization (non exaustive list of providers)
7
+
8
+ config = getconfig("model_params.cfg")
9
+ # Reading Params
10
+ PROVIDER = config.get("generator", "PROVIDER")
11
+ MODEL = config.get("generator", "MODEL")
12
+ MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
13
+ TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
14
+ INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
15
+ ORGANIZATION = config.get("generator", "ORGANIZATION")
16
+
17
+ # Set up authentication for the selected provider
18
+ auth_config = get_auth(PROVIDER)
19
+
20
+
21
+ def get_chat_model():
22
+ """Initialize the appropriate LangChain chat model based on provider"""
23
+ common_params = {
24
+ "temperature": TEMPERATURE,
25
+ "max_tokens": MAX_TOKENS,
26
+ }
27
+
28
+ # if PROVIDER == "openai":
29
+ # return ChatOpenAI(
30
+ # model=MODEL,
31
+ # openai_api_key=auth_config["api_key"],
32
+ # **common_params
33
+ # )
34
+ # elif PROVIDER == "anthropic":
35
+ # return ChatAnthropic(
36
+ # model=MODEL,
37
+ # anthropic_api_key=auth_config["api_key"],
38
+ # **common_params
39
+ # )
40
+ # elif PROVIDER == "cohere":
41
+ # return ChatCohere(
42
+ # model=MODEL,
43
+ # cohere_api_key=auth_config["api_key"],
44
+ # **common_params
45
+ # )
46
+ if PROVIDER == "huggingface":
47
+ # Initialize HuggingFaceEndpoint with explicit parameters
48
+ llm = HuggingFaceEndpoint(
49
+ repo_id=MODEL,
50
+ huggingfacehub_api_token=auth_config["api_key"],
51
+ task="text-generation",
52
+ provider=INFERENCE_PROVIDER,
53
+ server_kwargs={"bill_to": ORGANIZATION},
54
+ temperature=TEMPERATURE,
55
+ max_new_tokens=MAX_TOKENS
56
+ )
57
+ return ChatHuggingFace(llm=llm)
58
+ else:
59
+ raise ValueError(f"Unsupported provider: {PROVIDER}")
60
+
61
+ # Initialize provider-agnostic chat model
62
+ chat_model = get_chat_model()
63
+
64
+ #------------------------------------Define Prompt -----------------------------------------
65
+
66
+
67
+ def build_messages(question: str, context: str) -> list:
68
+ """
69
+ Build messages in LangChain format.
70
+
71
+ Args:
72
+ question: The user's question
73
+ context: The relevant context for answering
74
+
75
+ Returns:
76
+ List of LangChain message objects
77
+ """
78
+ system_content = (
79
+ """
80
+ You are an expert assistant. Your task is to generate accurate, helpful responses using only the
81
+ information contained in the "CONTEXT" provided.
82
+ Instructions:
83
+ - Answer based only on provided context: Use only the information present in the retrieved_paragraphs below. Do not use any external knowledge or make assumptions beyond what is explicitly stated.
84
+ - Language matching: Respond in the same language as the user's query.
85
+ - Handle missing information: If the retrieved paragraphs do not contain sufficient information to answer the query, respond with "I don't know" or equivalent in the query language. If information is incomplete, state what you know and acknowledge limitations.
86
+ - Be accurate and specific: When information is available, provide clear, specific answers. Include relevant details, useful facts, and numbers from the context.
87
+ - Stay focused: Answer only what is asked. Do not provide additional information not requested.
88
+ - Structure your response effectively:
89
+ * Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
90
+ * Use bullet points and lists when it makes sense to improve readability.
91
+ * You do not need to use every passage. Only use the ones that help answer the question.
92
+ - Format your response properly: Use markdown formatting (bullet points, numbered lists, headers) to make your response clear and easy to read. Example: <br> for linebreaks
93
+
94
+ Input Format:
95
+ - Query: {query}
96
+ - Retrieved Paragraphs: {retrieved_paragraphs}
97
+ Generate your response based on these guidelines.
98
+ """
99
+ )
100
+
101
+ user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
102
+
103
+ return [
104
+ SystemMessage(content=system_content),
105
+ HumanMessage(content=user_content)
106
+ ]
107
+
108
+ #--------------------------------Get the async response ---------------------------------------------
109
+ async def _call_llm(messages: list) -> str:
110
+ """
111
+ Provider-agnostic LLM call using LangChain.
112
+
113
+ Args:
114
+ messages: List of LangChain message objects
115
+
116
+ Returns:
117
+ Generated response content as string
118
+ """
119
+ try:
120
+ # Use async invoke for better performance
121
+ response = await chat_model.ainvoke(messages)
122
+ logging.info(f"answer: {response.content}")
123
+ return response.content
124
+ #return response.content.strip()
125
+ except Exception as e:
126
+ logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
127
+ raise
128
+
utils/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import configparser
3
+
4
+
5
+ def getconfig(configfile_path: str):
6
+ """
7
+ Read the config file
8
+ Params
9
+ ----------------
10
+ configfile_path: file path of .cfg file
11
+ """
12
+ config = configparser.ConfigParser()
13
+ try:
14
+ config.read_file(open(configfile_path))
15
+ return config
16
+ except:
17
+ logging.warning("config file not found")
18
+
19
+ # ---------------------------------------------------------------------
20
+ # Provider-agnostic authentication and configuration
21
+ # ---------------------------------------------------------------------
22
+
23
+ def get_auth(provider: str) -> dict:
24
+ """Get authentication configuration for different providers"""
25
+ auth_configs = {
26
+ "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
27
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
28
+ "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
29
+ "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
30
+ }
31
+
32
+ if provider not in auth_configs:
33
+ raise ValueError(f"Unsupported provider: {provider}")
34
+
35
+ auth_config = auth_configs[provider]
36
+ api_key = auth_config.get("api_key")
37
+
38
+ if not api_key:
39
+ raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
40
+
41
+ return auth_config