Spaces:
Sleeping
Sleeping
cdupland
Merge branch 'main' of https://huggingface.co/spaces/bziiit/VEGETALIS_AI_API into main
46438d2
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_pinecone import PineconeVectorStore | |
| from langchain_core.documents import Document | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_mistralai import ChatMistralAI | |
| from uuid import uuid4 | |
| from pydantic import BaseModel, Field | |
| from langchain_core.tools import tool | |
| import unicodedata | |
| class AddToKnowledgeBase(BaseModel): | |
| ''' Add information to the knowledge base if the user asks for it in his query''' | |
| information: str = Field(..., title="The information to add to the knowledge base") | |
| def detect_language(text:str): | |
| llm = ChatOpenAI(model="gpt-4o-mini",temperature=0) | |
| template = "détecte la langue du texte suivant: {text}. rassure-toi que ta reponse contient seulement le nom de la langue detectée" | |
| prompt = PromptTemplate.from_template(template) | |
| chain = prompt | llm | StrOutputParser() | |
| response = chain.invoke({"text": text}).strip().lower() | |
| print(response) | |
| return response | |
| def remove_non_standard_ascii(input_string: str) -> str: | |
| normalized_string = unicodedata.normalize('NFKD', input_string) | |
| return ''.join(char for char in normalized_string if 'a' <= char <= 'z' or 'A' <= char <= 'Z' or char.isdigit() or char in ' .,!?') | |
| def get_text_from_content_for_doc(content): | |
| text = "" | |
| for page in content: | |
| text += content[page]["texte"] | |
| return text | |
| def get_text_from_content_for_audio(content): | |
| return content["transcription"] | |
| def get_text_chunks(text): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, # the character length of the chunck | |
| chunk_overlap=100, # the character length of the overlap between chuncks | |
| length_function=len # the length function - in this case, character length (aka the python len() fn.) | |
| ) | |
| chunks = text_splitter.split_text(text) | |
| return chunks | |
| def get_vectorstore(text_chunks,filename, file_type,namespace,index,enterprise_name): | |
| try: | |
| embedding = OpenAIEmbeddings(model="text-embedding-3-large") | |
| vector_store = PineconeVectorStore(index=index, embedding=embedding,namespace=namespace) | |
| file_name = filename.split(".")[0].replace(" ","_").replace("-","_").replace(".","_").replace("/","_").replace("\\","_").strip() | |
| documents = [] | |
| uuids = [] | |
| for i, chunk in enumerate(text_chunks): | |
| clean_filename = remove_non_standard_ascii(file_name) | |
| document = Document( | |
| page_content=chunk, | |
| metadata={"filename":filename,"file_type":file_type, "filename_id":clean_filename, "entreprise_name":enterprise_name}, | |
| ) | |
| uuid = f"{clean_filename}_{i}" | |
| uuids.append(uuid) | |
| documents.append(document) | |
| vector_store.add_documents(documents=documents, ids=uuids) | |
| return {"filename_id":clean_filename} | |
| except Exception as e: | |
| print(e) | |
| return False | |
| def add_to_knowledge_base(enterprise_id,information,index,enterprise_name,user_id=""): | |
| ''' Add information to the knowledge base | |
| Args: | |
| enterprise_id (str): the enterprise id | |
| information (str): the information to add | |
| index (str): the index name | |
| ''' | |
| try: | |
| embedding = OpenAIEmbeddings(model="text-embedding-3-large") | |
| vector_store = PineconeVectorStore(index=index, embedding=embedding,namespace=enterprise_id) | |
| uuids = [] | |
| uuid = f"kb_{user_id}_{uuid4()}" | |
| document = Document( | |
| page_content=information, | |
| metadata={"filename":"knowledge_base","file_type":"text", "filename_id":uuid, "entreprise_name":enterprise_name, "user_id":user_id}, | |
| ) | |
| uuids.append(uuid) | |
| vector_store.add_documents(documents=[document], ids=uuids) | |
| return uuid | |
| except Exception as e: | |
| print(e) | |
| return False | |
| def get_retreive_answer(enterprise_id,prompt,index,common_id,user_id=""): | |
| try: | |
| print("common_id ",common_id) | |
| embedding = OpenAIEmbeddings(model="text-embedding-3-large") | |
| vector_store = PineconeVectorStore(index=index, embedding=embedding,namespace=enterprise_id) | |
| retriever = vector_store.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={"k": 3, "score_threshold": 0.6}, | |
| ) | |
| enterprise_context = retriever.invoke(prompt) | |
| user_memory = retriever.invoke(prompt,filters={"user_id":user_id}) | |
| if enterprise_context: | |
| print("found enterprise context") | |
| for chunk in enterprise_context: | |
| print(chunk.metadata) | |
| else: | |
| print("no enterprise context") | |
| if common_id: | |
| vector_store_commun = PineconeVectorStore(index=index, embedding=embedding,namespace=common_id) | |
| retriever_commun = vector_store_commun.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={"k": 5, "score_threshold": 0.1}, | |
| ) | |
| commun_context = retriever_commun.invoke(prompt) | |
| for chunk in commun_context: | |
| print(chunk.metadata) | |
| if commun_context: | |
| print("found commun context") | |
| else: | |
| print("no commun context") | |
| response = user_memory + enterprise_context + commun_context | |
| else: | |
| response = retriever.invoke(prompt) | |
| return response | |
| except Exception as e: | |
| print(e) | |
| return False | |
| def handle_calling_add_to_knowledge_base(query,enterprise_id = "",index = "",enterprise_name = "",user_id = "",llm = None): | |
| ''' Handle the calling of the add_to_knowledge_base function | |
| if the user, in his query, wants to add information to the knowledge base, the function will be called | |
| ''' | |
| template = """ | |
| You are an AI assistant that processes user queries. | |
| Determine if the user wants to add something to the knowledge base. | |
| - If the user wants to add something, extract the valuable information, reformulate and output 'add' followed by the information. | |
| - If the user does not want to add something, output 'no action'. | |
| Ensure your response is only 'add <content>' or 'no action'. | |
| User Query: "{query}" | |
| Response: | |
| """.strip() | |
| prompt = PromptTemplate.from_template(template) | |
| if not llm: | |
| llm = ChatOpenAI(model="gpt-4o",temperature=0) | |
| llm_with_tool = llm.bind_tools([AddToKnowledgeBase]) | |
| # template = "En tant qu'IA experte en marketing, tu travailles pour l'entreprise {enterprise}, si dans la question, il y a une demande d'ajout d'information à la base de connaissance, fait appel à la fonction add_to_knowledge_base en ajoutant l'information demandée, sinon, n'appelle pas la fonction. la question est la suivante: {query}" | |
| # prompt = PromptTemplate.from_template(template) | |
| chain = prompt | llm | StrOutputParser() | |
| response = chain.invoke({"query": query}).strip().lower() | |
| if response.startswith("add"): | |
| item = response[len("add"):].strip() | |
| if item: | |
| item_id = add_to_knowledge_base(enterprise_id,item,index,enterprise_name,user_id) | |
| print("added to knowledge base") | |
| print(item) | |
| return {"item_id":item_id,"item":item} | |
| print(response) | |
| return False | |
| def generate_response_via_langchain(query: str, stream: bool = False, model: str = "gpt-4o",context:str="",messages = [],style:str="formel",tonality:str="neutre",template:str = "",enterprise_name:str="",enterprise_id:str="",index:str=""): | |
| # Define the prompt template | |
| if template == "": | |
| template = "En tant qu'IA experte en marketing, réponds avec un style {style} et une tonalité {tonality} dans ta communcation pour l'entreprise {enterprise}, sachant le context suivant: {context}, et l'historique de la conversation, {messages}, {query}" | |
| # Initialize the OpenAI LLM with the specified model | |
| if model.startswith("gpt"): | |
| llm = ChatOpenAI(model=model,temperature=0) | |
| if model.startswith("mistral"): | |
| llm = ChatMistralAI(model=model,temperature=0) | |
| #handle_calling_add_to_knowledge_base(prompt.format(context=context,messages=messages,query=query,style=style,tonality=tonality,enterprise=enterprise_name)) | |
| # if handle_calling_add_to_knowledge_base(query,enterprise_id,index,enterprise_name): | |
| # template += " la base de connaissance a été mise à jour" | |
| language = detect_language(query) | |
| template += f" Reponds en {language}" | |
| # Create an LLM chain with the prompt and the LLM | |
| prompt = PromptTemplate.from_template(template) | |
| print(f"model: {model}") | |
| print(f"marque: {enterprise_name}") | |
| llm_chain = prompt | llm | StrOutputParser() | |
| print(f"language: {language}") | |
| if stream: | |
| # Return a generator that yields streamed responses | |
| return llm_chain.astream({ "query": query, "context": context, "messages": messages, "style": style, "tonality": tonality, "enterprise":enterprise_name }) | |
| # Invoke the LLM chain and return the result | |
| return llm_chain.invoke({"query": query, "context": context, "messages": messages, "style": style, "tonality": tonality, "enterprise":enterprise_name}) | |
| def setup_rag(file_type,content): | |
| if file_type == "pdf": | |
| text = get_text_from_content_for_doc(content) | |
| elif file_type == "audio": | |
| text = get_text_from_content_for_audio(content) | |
| chunks = get_text_chunks(text) | |
| vectorstore = get_vectorstore(chunks) | |
| return vectorstore | |
| def prompt_reformatting(prompt:str,context,query:str,style="formel",tonality="neutre",enterprise_name=""): | |
| if context == "": | |
| print("no context found for prompt reormatting") | |
| return prompt.format(context="Pas de contexte pertinent",messages="",query=query,style=style,tonality=tonality,enterprise=enterprise_name) | |
| docs_names = [] | |
| print("context found for prompt reormatting") | |
| for chunk in context: | |
| print(chunk.metadata) | |
| chunk_name = chunk.metadata["filename"] | |
| if chunk_name not in docs_names: | |
| docs_names.append(chunk_name) | |
| context = ", ".join(docs_names) | |
| prompt = prompt.format(context=context,messages="",query=query,style=style,tonality=tonality,enterprise=enterprise_name) | |
| return prompt | |