Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings.fastembed import FastEmbedEmbeddings | |
| from groq import Groq | |
| from langchain_groq import ChatGroq | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.output_parsers import JsonOutputParser, StrOutputParser | |
| import os | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from typing_extensions import TypedDict | |
| from typing import List | |
| from langchain.schema import Document | |
| from langgraph.graph import END, StateGraph | |
| # Environment setup | |
| os.environ['TAVILY_API_KEY'] = "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" | |
| # Model and embedding setup | |
| embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") | |
| llm = ChatGroq(temperature=0, model_name="Llama3-8b-8192", api_key="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") | |
| # Load documents from URLs | |
| urls = ["https://lilianweng.github.io/posts/2023-06-23-agent/", | |
| "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", | |
| "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/"] | |
| docs = [WebBaseLoader(url).load() for url in urls] | |
| docs_list = [item for sublist in docs for item in sublist] | |
| # Document splitting | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=512, chunk_overlap=0) | |
| doc_splits = text_splitter.split_documents(docs_list) | |
| # Vectorstore setup | |
| vectorstore = Chroma.from_documents(documents=doc_splits, embedding=embed_model, collection_name="local-rag") | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) | |
| # Prompt templates | |
| question_router_prompt = PromptTemplate( | |
| template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a | |
| user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents, | |
| prompt engineering, and adversarial attacks. Otherwise, use web-search. Give a binary choice 'web_search' | |
| or 'vectorstore' based on the question. Return a JSON with a single key 'datasource' and no preamble. | |
| Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", | |
| input_variables=["question"], | |
| ) | |
| question_router = question_router_prompt | llm | JsonOutputParser() | |
| rag_chain_prompt = PromptTemplate( | |
| template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. | |
| Use the following pieces of retrieved context to answer the question concisely. <|eot_id|><|start_header_id|>user<|end_header_id|> | |
| Question: {question} | |
| Context: {context} | |
| Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", | |
| input_variables=["question", "document"], | |
| ) | |
| # Chain | |
| rag_chain = rag_chain_prompt | llm | StrOutputParser() | |
| # Web search tool | |
| web_search_tool = TavilySearchResults(k=3) | |
| # Workflow functions | |
| def retrieve(state): | |
| question = state["question"] | |
| documents = retriever.invoke(question) | |
| return {"documents": documents, "question": question} | |
| def generate(state): | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = rag_chain.invoke({"context": documents, "question": question}) | |
| return {"documents": documents, "question": question, "generation": generation} | |
| def route_question(state): | |
| question = state["question"] | |
| source = question_router.invoke({"question": question}) | |
| return "websearch" if source['datasource'] == 'web_search' else "vectorstore" | |
| def web_search(state): | |
| question = state["question"] | |
| docs = web_search_tool.invoke({"query": question}) | |
| web_results = Document(page_content="\n".join([d["content"] for d in docs])) | |
| documents = state.get("documents", []) | |
| documents.append(web_results) | |
| return {"documents": documents, "question": question} | |
| workflow = StateGraph(TypedDict("GraphState", {"question": str, "generation": str, "documents": List[Document]})) | |
| # Define the nodes | |
| workflow.add_node("websearch", web_search) | |
| workflow.add_node("retrieve", retrieve) | |
| workflow.add_node("generate", generate) | |
| workflow.set_conditional_entry_point( | |
| route_question, | |
| { | |
| "websearch": "websearch", | |
| "vectorstore": "retrieve", | |
| }, | |
| ) | |
| workflow.add_edge("retrieve", "generate") | |
| workflow.add_edge("websearch", "generate") | |
| # Compile the app | |
| app = workflow.compile() | |
| # Gradio integration with Chatbot | |
| # Updated ask_question_conversation function | |
| def ask_question_conversation(history, question): | |
| inputs = {"question": question} | |
| generation_result = None | |
| # Run the workflow and get the generation result | |
| for output in app.stream(inputs): | |
| for key, value in output.items(): | |
| generation_result = value.get("generation", "No generation found") | |
| # Append the new question and response to the history | |
| history.append((question, generation_result)) | |
| # Return the updated history to chatbot and clear the question textbox | |
| return history, "" | |
| # Gradio conversation UI | |
| ''' | |
| with gr.Blocks() as demo: | |
| gr.Markdown("🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!") | |
| chatbot = gr.Chatbot(label="Chat with AI Assistant") | |
| question = gr.Textbox(label="Your Question", placeholder="Ask your question here...") | |
| clear = gr.Button("Clear Conversation") | |
| # Submit action for the question textbox | |
| question.submit(ask_question_conversation, [chatbot, question], [chatbot, question]) | |
| clear.click(lambda: [], None, chatbot) # Clear conversation history | |
| demo.launch() | |
| ''' | |
| with gr.Blocks(css=""" | |
| #title { | |
| font-size: 26px; | |
| font-weight: bold; | |
| text-align: center; | |
| color: #4A90E2; | |
| } | |
| #subtitle { | |
| font-size: 18px; | |
| text-align: center; | |
| margin-top: -15px; | |
| color: #7D7D7D; | |
| } | |
| .gr-chatbot, .gr-textbox, .gr-button { | |
| max-width: 600px; | |
| margin: 0 auto; | |
| } | |
| .gr-chatbot { | |
| height: 400px; | |
| } | |
| .gr-button { | |
| display: block; | |
| width: 100px; | |
| margin: 20px auto; | |
| background-color: #4A90E2; | |
| color: white; | |
| } | |
| """) as demo: | |
| gr.Markdown("<div id='title'>🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!</div>") | |
| chatbot = gr.Chatbot(label="Chat with AI Assistant") | |
| question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...", lines=1) | |
| clear = gr.Button("Clear Chat") | |
| # Submit action for the question textbox | |
| question.submit(ask_question_conversation, [chatbot, question], [chatbot, question]) | |
| clear.click(lambda: [], None, chatbot) # Clear conversation history | |
| demo.launch() |