Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List | |
| import uuid | |
| import chainlit as cl | |
| from chainlit.types import AskFileResponse | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_community.document_loaders import PyMuPDFLoader, TextLoader | |
| from langchain.prompts import MessagesPlaceholder | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.chains.history_aware_retriever import create_history_aware_retriever | |
| from langchain.chains.retrieval import create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_experimental.text_splitter import SemanticChunker | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain_core.documents import Document | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| # from chainlit.input_widget import Select, Switch, Slider | |
| from dotenv import load_dotenv | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors import LLMChainExtractor | |
| load_dotenv() | |
| BOR_FILE_PATH = "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf" | |
| NIST_FILE_PATH = "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf" | |
| SMALL_DOC = "https://arxiv.org/pdf/1908.10084" # 11 pages Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks | |
| documents_to_preload = [ | |
| BOR_FILE_PATH, | |
| NIST_FILE_PATH | |
| # SMALL_DOC | |
| ] | |
| collection_name = "ai-safety" | |
| welcome_message = """ | |
| Welcome to the chatbot to clarify all your AI Safety related queries.: | |
| Now preloading below documents: | |
| 1. Blueprint for an AI Bill of Rights | |
| 2. NIST AI Standards | |
| Please wait for a moment to load the documents. | |
| """ | |
| chat_model_name = "gpt-4o" | |
| embedding_model_name = "Snowflake/snowflake-arctic-embed-l" | |
| chat_model = ChatOpenAI(model=chat_model_name, temperature=0) | |
| async def connect_to_qdrant(): | |
| embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
| qdrant_url = os.environ["QDRANT_URL"] | |
| qdrant_api_key = os.environ["QDRANT_API_KEY"] | |
| collection_name = os.environ["COLLECTION_NAME"] | |
| qdrant_client = QdrantClient(url=qdrant_url,api_key=qdrant_api_key) | |
| vector_store = QdrantVectorStore( | |
| client=qdrant_client, | |
| collection_name=collection_name, | |
| embedding=embedding_model, | |
| ) | |
| return vector_store.as_retriever(search_type="similarity_score_threshold",search_kwargs={'k':10,'score_threshold': 0.8}) | |
| async def get_contextual_compressed_retriever(retriver): | |
| base_retriever = retriver | |
| compressor_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000) | |
| compressor = LLMChainExtractor.from_llm(compressor_llm) | |
| #Combine the retriever with the compressor | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=base_retriever | |
| ) | |
| return compression_retriever | |
| def initialize_vectorstore( | |
| collection_name: str, | |
| embedding_model, | |
| dimension, | |
| distance_metric: Distance = Distance.COSINE, | |
| ): | |
| client = QdrantClient(":memory:") | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams(size=dimension, distance=distance_metric), | |
| ) | |
| vector_store = QdrantVectorStore( | |
| client=client, | |
| collection_name=collection_name, | |
| embedding=embedding_model, | |
| ) | |
| return vector_store | |
| def get_text_splitter(strategy, embedding_model): | |
| if strategy == "semantic": | |
| return SemanticChunker( | |
| embedding_model, | |
| buffer_size=3, | |
| breakpoint_threshold_type="percentile", | |
| breakpoint_threshold_amount=90, | |
| ) | |
| def process_file(file: AskFileResponse, text_splitter): | |
| if file.type == "text/plain": | |
| Loader = TextLoader | |
| elif file.type == "application/pdf": | |
| Loader = PyMuPDFLoader | |
| loader = Loader(file.path) | |
| documents = loader.load() | |
| title = documents[0].metadata.get("title") | |
| docs = text_splitter.split_documents(documents) | |
| for i, doc in enumerate(docs): | |
| doc.metadata["source"] = f"source_{i}" | |
| doc.metadata["title"] = title | |
| return docs | |
| def populate_vectorstore(vector_store, docs: List[Document]): | |
| vector_store.add_documents(docs) | |
| return vector_store | |
| def create_history_aware_retriever_self(chat_model, retriever): | |
| contextualize_q_system_prompt = ( | |
| "Given a chat history and the latest user question which might reference context in the chat history, " | |
| "formulate a standalone question which can be understood without the chat history. Do NOT answer the question, " | |
| "just reformulate it if needed and otherwise return it as is." | |
| ) | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| return create_history_aware_retriever(chat_model, retriever, contextualize_q_prompt) | |
| def create_qa_chain(chat_model): | |
| qa_system_prompt = ( | |
| "You are an helpful assistant named 'Shield' and your task is to answer any questions related to AI Safety for the given context." | |
| "Use the following pieces of retrieved context to answer the question." | |
| # "If any questions asked outside AI Safety context, just say that you are a specialist in AI Safety and can't answer that." | |
| # f"When introducing you, just say that you are an AI assistant powered by embedding model {embedding_model_name} and chat model {chat_model_name} and your knowledge is limited to 'Blueprint for an AI Bill of Rights' and 'NIST AI Standards' documents." | |
| "If you don't know the answer, just say that you don't know.\n\n" | |
| "{context}" | |
| ) | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", qa_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| return create_stuff_documents_chain(chat_model, qa_prompt) | |
| def create_rag_chain(chat_model, retriever): | |
| history_aware_retriever = create_history_aware_retriever_self(chat_model, retriever) | |
| question_answer_chain = create_qa_chain(chat_model) | |
| return create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
| def create_session_id(): | |
| session_id = str(uuid.uuid4()) | |
| return session_id | |
| async def start(): | |
| msg = cl.Message(content=welcome_message) | |
| await msg.send() | |
| # Create a session id | |
| session_id = create_session_id() | |
| cl.user_session.set("session_id", session_id) | |
| retriever = await connect_to_qdrant() | |
| contextual_compressed_retriever = await get_contextual_compressed_retriever(retriever) | |
| rag_chain = create_rag_chain(chat_model, contextual_compressed_retriever) | |
| store = {} | |
| def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
| if session_id not in store: | |
| store[session_id] = ChatMessageHistory() | |
| return store[session_id] | |
| conversational_rag_chain = RunnableWithMessageHistory( | |
| rag_chain, | |
| get_session_history, | |
| input_messages_key="input", | |
| history_messages_key="chat_history", | |
| output_messages_key="answer", | |
| ) | |
| # Let the user know that the system is ready | |
| msg.content = msg.content + "\nReady to answer your questions!" | |
| await msg.update() | |
| cl.user_session.set("conversational_rag_chain", conversational_rag_chain) | |
| async def main(message: cl.Message): | |
| session_id = cl.user_session.get("session_id") | |
| conversational_rag_chain = cl.user_session.get("conversational_rag_chain") | |
| response = await conversational_rag_chain.ainvoke( | |
| {"input": message.content}, | |
| config={"configurable": {"session_id": session_id}, | |
| "callbacks":[cl.AsyncLangchainCallbackHandler()]}, | |
| ) | |
| answer = response["answer"] | |
| source_documents = response["context"] | |
| text_elements = [] | |
| unique_pages = set() | |
| if source_documents: | |
| for source_idx, source_doc in enumerate(source_documents): | |
| source_name = f"source_{source_idx+1}" | |
| page_number = source_doc.metadata['page'] | |
| #page_number = source_doc.metadata.get('page', "NA") # NA or any default value | |
| page = f"Page {page_number}" | |
| text_element_content = source_doc.page_content | |
| text_element_content = text_element_content if text_element_content != "" else "No Content" | |
| #text_elements.append(cl.Text(content=text_element_content, name=source_name)) | |
| if page not in unique_pages: | |
| unique_pages.add(page) | |
| text_elements.append(cl.Text(content=text_element_content, name=page)) | |
| #text_elements.append(cl.Text(content=text_element_content, name=page)) | |
| source_names = [text_el.name for text_el in text_elements] | |
| if source_names: | |
| answer += f"\n\n Sources:{', '.join(source_names)}" | |
| else: | |
| answer += "\n\n No sources found" | |
| await cl.Message(content=answer, elements=text_elements).send() | |
| if __name__ == "__main__": | |
| from chainlit.cli import run_chainlit | |
| run_chainlit(__file__) | |