Spaces:
Build error
Build error
| import os | |
| from typing import List, Iterable, Any | |
| from dotenv import load_dotenv | |
| from langchain.memory import ChatMessageHistory | |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_core.documents import Document | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from basic_chain import get_model | |
| from rag_chain import make_rag_chain | |
| import logging | |
| store = {} | |
| import json | |
| def create_memory_chain(llm, base_chain): | |
| contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
| which might reference context in the chat history, formulate a standalone question \ | |
| which can be understood without the chat history. Do NOT answer the question, \ | |
| just reformulate it if needed and otherwise return it as is.""" | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{question}"), | |
| ] | |
| ) | |
| runnable = contextualize_q_prompt | llm | base_chain | |
| def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
| if session_id not in store: | |
| store[session_id] = ChatMessageHistory() | |
| return store[session_id] | |
| with_message_history = RunnableWithMessageHistory( | |
| runnable, | |
| get_session_history, | |
| input_messages_key="question", | |
| history_messages_key="chat_history", | |
| ) | |
| return with_message_history | |
| def clean_session_history(session_id): | |
| global store | |
| store[session_id] = ChatMessageHistory() | |
| class SimpleTextRetriever(BaseRetriever): | |
| docs: List[Document] | |
| """Documents.""" | |
| def from_texts( | |
| cls, | |
| texts: Iterable[str], | |
| **kwargs: Any, | |
| ): | |
| docs = [Document(page_content=t) for t in texts] | |
| return cls(docs=docs, **kwargs) | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| return self.docs | |
| def main(): | |
| load_dotenv() | |
| model = get_model("ChatGPT") | |
| chat_memory = ChatMessageHistory() | |
| system_prompt = "You are a helpful AI assistant for busy professionals trying to improve their health." | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{question}"), | |
| ] | |
| ) | |
| text_path = "examples/equity_faq.txt" # Updated path | |
| text = open(text_path, "r").read() | |
| retriever = SimpleTextRetriever.from_texts([text]) | |
| rag_chain = make_rag_chain(model, retriever, rag_prompt=None) | |
| chain = create_memory_chain(model, rag_chain, chat_memory) | StrOutputParser() | |
| queries = [ | |
| "What do I need to get from the grocery store besides milk?", | |
| "Which of these items can I find at a farmer's market?", | |
| ] | |
| for query in queries: | |
| print(f"\nQuestion: {query}") | |
| response = chain.invoke( | |
| {"question": query}, | |
| config={"configurable": {"session_id": "foo"}} | |
| ) | |
| print(f"Answer: {response}") | |
| if __name__ == "__main__": | |
| # this is to quite parallel tokenizers warning. | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| main() | |