Spaces:
Build error
Build error
| import os | |
| from dotenv import load_dotenv | |
| from langchain import hub | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
| from langchain_core.messages.base import BaseMessage | |
| from basic_chain import basic_chain, get_model | |
| from remote_loader import load_wiki_articles # updated import | |
| from splitter import split_documents | |
| from vector_store import create_vector_db | |
| def find_similar(vs, query): | |
| docs = vs.similarity_search(query) | |
| return docs | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def get_question(input): | |
| if not input: | |
| return None | |
| elif isinstance(input,str): | |
| return input | |
| elif isinstance(input,dict) and 'question' in input: | |
| return input['question'] | |
| elif isinstance(input,BaseMessage): | |
| return input.content | |
| else: | |
| raise Exception("string or dict with 'question' key expected as RAG chain input.") | |
| def make_rag_chain(model, retriever, rag_prompt = None): | |
| # We will use a prompt template from langchain hub. | |
| if not rag_prompt: | |
| rag_prompt = hub.pull("rlm/rag-prompt") | |
| # And we will use the LangChain RunnablePassthrough to add some custom processing into our chain. | |
| rag_chain = ( | |
| { | |
| "context": RunnableLambda(get_question) | retriever | format_docs, | |
| "question": RunnablePassthrough() | |
| } | |
| | rag_prompt | |
| | model | |
| ) | |
| return rag_chain | |
| def main(): | |
| load_dotenv() | |
| model = get_model("ChatGPT") | |
| docs = load_wiki_articles(query="Bertrand Russell", load_max_docs=5) # Updated | |
| texts = split_documents(docs) | |
| vs = create_vector_db(texts) | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a professor who teaches philosophical concepts to beginners."), | |
| ("user", "{input}") | |
| ]) | |
| # Besides similarly search, you can also use maximal marginal relevance (MMR) for selecting results. | |
| # retriever = vs.as_retriever(search_type="mmr") | |
| retriever = vs.as_retriever() | |
| output_parser = StrOutputParser() | |
| chain = basic_chain(model, prompt) | |
| base_chain = chain | output_parser | |
| rag_chain = make_rag_chain(model, retriever) | output_parser | |
| questions = [ | |
| "What were the most important contributions of Bertrand Russell to philosophy?", | |
| "What was the first book Bertrand Russell published?", | |
| "What was most notable about \"An Essay on the Foundations of Geometry\"?", | |
| ] | |
| for q in questions: | |
| print("\n--- QUESTION: ", q) | |
| print("* BASE:\n", base_chain.invoke({"input": q})) | |
| print("* RAG:\n", rag_chain.invoke(q)) | |
| if __name__ == '__main__': | |
| # this is to quite parallel tokenizers warning. | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| main() | |