Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, UploadFile, File,Request,Depends,status | |
| from fastapi.security import OAuth2PasswordBearer | |
| from pydantic import BaseModel, Json | |
| from uuid import uuid4, UUID | |
| from typing import Optional | |
| import pymupdf | |
| from pinecone import Pinecone, ServerlessSpec | |
| import os | |
| from dotenv import load_dotenv | |
| from rag import * | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| from prompts import * | |
| from typing import Literal | |
| from models import * | |
| from fastapi.middleware.cors import CORSMiddleware | |
| load_dotenv() | |
| pinecone_api_key = os.environ.get("PINECONE_API_KEY") | |
| common_namespace = os.environ.get("COMMON_NAMESPACE") | |
| pc = Pinecone(api_key=pinecone_api_key) | |
| import time | |
| index_name = os.environ.get("INDEX_NAME") # change if desired | |
| existing_indexes = [index_info["name"] for index_info in pc.list_indexes()] | |
| if index_name not in existing_indexes: | |
| pc.create_index( | |
| name=index_name, | |
| dimension=3072, | |
| metric="cosine", | |
| spec=ServerlessSpec(cloud="aws", region="us-east-1"), | |
| ) | |
| while not pc.describe_index(index_name).status["ready"]: | |
| time.sleep(1) | |
| index = pc.Index(index_name) | |
| api_keys = [os.environ.get("FASTAPI_API_KEY")] | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # use token authentication | |
| def api_key_auth(api_key: str = Depends(oauth2_scheme)): | |
| if api_key not in api_keys: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Forbidden" | |
| ) | |
| app = FastAPI(dependencies=[Depends(api_key_auth)]) | |
| # FASTAPI_KEY_NAME = os.environ.get("FASTAPI_KEY_NAME") | |
| # FASTAPI_API_KEY = os.environ.get("FASTAPI_API_KEY") | |
| # @app.middleware("http") | |
| # async def api_key_middleware(request: Request, call_next): | |
| # if request.url.path not in ["/","/docs","/openapi.json"]: | |
| # api_key = request.headers.get(FASTAPI_KEY_NAME) | |
| # if api_key != FASTAPI_API_KEY: | |
| # raise HTTPException(status_code=403, detail="invalid API key :/") | |
| # response = await call_next(request) | |
| # return response | |
| class StyleWriter(BaseModel): | |
| style: Optional[str] = "neutral" | |
| tonality: Optional[str] = "formal" | |
| models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"] | |
| class UserInput(BaseModel): | |
| prompt: str | |
| enterprise_id: str | |
| user_id: Optional[str] = None | |
| stream: Optional[bool] = False | |
| messages: Optional[list[dict]] = [] | |
| style_tonality: Optional[StyleWriter] = None | |
| marque: Optional[str] = None | |
| model: Literal["gpt-4o","gpt-4o-mini","mistral-large-latest","o1-preview"] = "gpt-4o" | |
| class EnterpriseData(BaseModel): | |
| name: str | |
| id: Optional[str] = None | |
| filename: Optional[str] = None | |
| tasks = [] | |
| def greet_json(): | |
| return {"Hello": "World!"} | |
| async def upload_file(file: UploadFile, enterprise_data: Json[EnterpriseData]): | |
| try: | |
| # Read the uploaded file | |
| contents = await file.read() | |
| enterprise_name = enterprise_data.name.replace(" ","_").replace("-","_").replace(".","_").replace("/","_").replace("\\","_").strip() | |
| if enterprise_data.filename is not None: | |
| filename = enterprise_data.filename | |
| else: | |
| filename = file.filename | |
| # Assign a new UUID if id is not provided | |
| if enterprise_data.id is None: | |
| clean_name = remove_non_standard_ascii(enterprise_name) | |
| enterprise_data.id = f"{clean_name}_{uuid4()}" | |
| # Open the file with PyMuPDF | |
| pdf_document = pymupdf.open(stream=contents, filetype="pdf") | |
| # Extract all text from the document | |
| text = "" | |
| for page in pdf_document: | |
| text += page.get_text() | |
| # Split the text into chunks | |
| text_chunks = get_text_chunks(text) | |
| # Create a vector store | |
| vector_store = get_vectorstore(text_chunks, filename=filename, file_type="pdf", namespace=enterprise_data.id, index=index,enterprise_name=enterprise_name) | |
| if vector_store: | |
| return { | |
| "file_name":filename, | |
| "enterprise_id": enterprise_data.id, | |
| "number_of_chunks": len(text_chunks), | |
| "filename_id":vector_store["filename_id"], | |
| "enterprise_name":enterprise_name | |
| } | |
| else: | |
| raise HTTPException(status_code=500, detail="Could not create vector store") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| finally: | |
| await file.close() | |
| def get_documents(enterprise_id: str): | |
| try: | |
| docs_names = [] | |
| for ids in index.list(namespace=enterprise_id): | |
| for id in ids: | |
| name_doc = "_".join(id.split("_")[:-1]) | |
| if name_doc not in docs_names: | |
| docs_names.append(name_doc) | |
| return docs_names | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| def get_documents(enterprise_id: str, user_id: str): | |
| try: | |
| docs_names = [] | |
| for ids in index.list(prefix=f"kb_{user_id}_", namespace=enterprise_id): | |
| for id in ids: | |
| docs_names.append(id) | |
| return docs_names | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| def delete_document(enterprise_id: str, filename_id: str): | |
| try: | |
| for ids in index.list(prefix=f"{filename_id}_", namespace=enterprise_id): | |
| index.delete(ids=ids, namespace=enterprise_id) | |
| return {"message": "Document deleted", "chunks_deleted": ids} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| def delete_document(enterprise_id: str, user_id: str): | |
| try: | |
| for ids in index.list(prefix=f"kb_{user_id}_", namespace=enterprise_id): | |
| index.delete(ids=ids, namespace=enterprise_id) | |
| return {"message": "Document deleted", "chunks_deleted": ids} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| def delete_document(enterprise_id: str, user_id: str, info_id: str): | |
| try: | |
| all_ids = [] | |
| for ids in index.list(prefix=f"{info_id}", namespace=enterprise_id): | |
| # all_ids.extend(ids) | |
| print(ids) | |
| index.delete(ids=ids, namespace=enterprise_id) | |
| return {"message": "Document deleted", "chunks_deleted": all_ids} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| def delete_all_documents(enterprise_id: str): | |
| try: | |
| index.delete(namespace=enterprise_id,delete_all=True) | |
| return {"message": "All documents deleted"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| import async_timeout | |
| import asyncio | |
| GENERATION_TIMEOUT_SEC = 60 | |
| async def stream_generator(response, prompt, info_memoire): | |
| async with async_timeout.timeout(GENERATION_TIMEOUT_SEC): | |
| try: | |
| async for chunk in response: | |
| if isinstance(chunk, bytes): | |
| chunk = chunk.decode('utf-8') # Convert bytes to str if needed | |
| yield json.dumps({"prompt": prompt, "content": chunk, "info_memoire":info_memoire}) | |
| except asyncio.TimeoutError: | |
| raise HTTPException(status_code=504, detail="Stream timed out") | |
| def generate_answer(user_input: UserInput): | |
| try: | |
| print(user_input) | |
| prompt = user_input.prompt | |
| enterprise_id = user_input.enterprise_id | |
| template_prompt = base_template | |
| context = get_retreive_answer(enterprise_id, prompt, index, common_namespace, user_id=user_input.user_id) | |
| #final_prompt_simplified = prompt_formatting(prompt,template,context) | |
| infos_added_to_kb = handle_calling_add_to_knowledge_base(prompt,enterprise_id,index,getattr(user_input,"marque",""),user_id=getattr(user_input,"user_id","")) | |
| if infos_added_to_kb: | |
| prompt = prompt + "l'information a été ajoutée à la base de connaissance: " + infos_added_to_kb['item'] | |
| else: | |
| infos_added_to_kb = {} | |
| if not context: | |
| context = "" | |
| if user_input.style_tonality is None: | |
| prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque","")) | |
| answer = generate_response_via_langchain(prompt, | |
| model=getattr(user_input,"model","gpt-4o"), | |
| stream=user_input.stream,context = context , | |
| messages=user_input.messages, | |
| template=template_prompt, | |
| enterprise_name=getattr(user_input,"marque",""), | |
| enterprise_id=enterprise_id, | |
| index=index) | |
| else: | |
| prompt_formated = prompt_reformatting(template_prompt, | |
| context, | |
| prompt, | |
| style=getattr(user_input.style_tonality,"style","neutral"), | |
| tonality=getattr(user_input.style_tonality,"tonality","formal"), | |
| enterprise_name=getattr(user_input,"marque","")) | |
| answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"), | |
| stream=user_input.stream,context = context , | |
| messages=user_input.messages, | |
| style=getattr(user_input.style_tonality,"style","neutral"), | |
| tonality=getattr(user_input.style_tonality,"tonality","formal"), | |
| template=template_prompt, | |
| enterprise_name=getattr(user_input,"marque",""), | |
| enterprise_id=enterprise_id, | |
| index=index) | |
| if user_input.stream: | |
| return StreamingResponse(stream_generator(answer,prompt_formated,infos_added_to_kb), media_type="application/json") | |
| return { | |
| "prompt": prompt_formated, | |
| "answer": answer, | |
| "context": context, | |
| "info_memoire": infos_added_to_kb | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| async def stream_generator2(response, prompt, info_memoire): | |
| async with async_timeout.timeout(GENERATION_TIMEOUT_SEC): | |
| try: | |
| async for chunk in response: | |
| if isinstance(chunk, bytes): | |
| yield chunk | |
| except asyncio.TimeoutError: | |
| raise HTTPException(status_code=504, detail="Stream timed out") | |
| def generate_answer2(user_input: UserInput): | |
| try: | |
| print(user_input) | |
| prompt = user_input.prompt | |
| enterprise_id = user_input.enterprise_id | |
| template_prompt = base_template | |
| context = get_retreive_answer(enterprise_id, prompt, index, common_namespace, user_id=user_input.user_id) | |
| #final_prompt_simplified = prompt_formatting(prompt,template,context) | |
| infos_added_to_kb = handle_calling_add_to_knowledge_base(prompt,enterprise_id,index,getattr(user_input,"marque",""),user_id=getattr(user_input,"user_id","")) | |
| if infos_added_to_kb: | |
| prompt = prompt + "l'information a été ajoutée à la base de connaissance: " + infos_added_to_kb['item'] | |
| else: | |
| infos_added_to_kb = {} | |
| if not context: | |
| context = "" | |
| if user_input.style_tonality is None: | |
| prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque","")) | |
| answer = generate_response_via_langchain(prompt, | |
| model=getattr(user_input,"model","gpt-4o"), | |
| stream=user_input.stream,context = context , | |
| messages=user_input.messages, | |
| template=template_prompt, | |
| enterprise_name=getattr(user_input,"marque",""), | |
| enterprise_id=enterprise_id, | |
| index=index) | |
| else: | |
| prompt_formated = prompt_reformatting(template_prompt, | |
| context, | |
| prompt, | |
| style=getattr(user_input.style_tonality,"style","neutral"), | |
| tonality=getattr(user_input.style_tonality,"tonality","formal"), | |
| enterprise_name=getattr(user_input,"marque","")) | |
| answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"), | |
| stream=user_input.stream,context = context , | |
| messages=user_input.messages, | |
| style=getattr(user_input.style_tonality,"style","neutral"), | |
| tonality=getattr(user_input.style_tonality,"tonality","formal"), | |
| template=template_prompt, | |
| enterprise_name=getattr(user_input,"marque",""), | |
| enterprise_id=enterprise_id, | |
| index=index) | |
| if user_input.stream: | |
| return StreamingResponse(stream_generator2(answer,prompt_formated,infos_added_to_kb), media_type="application/json") | |
| return { | |
| "prompt": prompt_formated, | |
| "answer": answer, | |
| "context": context, | |
| "info_memoire": infos_added_to_kb | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| def get_models(): | |
| return {"models": models} | |