Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from langdetect import detect | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig | |
| from langchain.vectorstores import Qdrant | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain.llms import HuggingFacePipeline | |
| from qdrant_client import QdrantClient | |
| # Get environment variables | |
| QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
| QDRANT_URL = os.getenv("QDRANT_URL") | |
| COLLECTION_NAME = "arabic_rag_collection" | |
| # Load model and tokenizer | |
| model_name = "FreedomIntelligence/Apollo-7B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Generation settings | |
| generation_config = GenerationConfig( | |
| max_new_tokens=150, | |
| temperature=0.2, | |
| top_k=20, | |
| do_sample=True, | |
| top_p=0.7, | |
| repetition_penalty=1.3, | |
| ) | |
| # Text generation pipeline | |
| llm_pipeline = pipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| task="text-generation", | |
| generation_config=generation_config, | |
| device=model.device.index if model.device.type == "cuda" else -1 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=llm_pipeline) | |
| # Connect to Qdrant + embedding | |
| embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1") | |
| qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
| vector_store = Qdrant( | |
| client=qdrant_client, | |
| collection_name=COLLECTION_NAME, | |
| embeddings=embedding | |
| ) | |
| retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
| # Set up RAG QA chain | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=retriever, | |
| chain_type="stuff" | |
| ) | |
| # FastAPI setup | |
| app = FastAPI(title="Apollo RAG Medical Chatbot") | |
| class Query(BaseModel): | |
| question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3) | |
| class TimeoutCallback(BaseCallbackHandler): | |
| def __init__(self, timeout_seconds: int = 60): | |
| self.timeout_seconds = timeout_seconds | |
| self.start_time = None | |
| async def on_llm_start(self, *args, **kwargs): | |
| self.start_time = asyncio.get_event_loop().time() | |
| async def on_llm_new_token(self, *args, **kwargs): | |
| if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds: | |
| raise TimeoutError("LLM processing timeout") | |
| # Prompt template | |
| def generate_prompt(question: str) -> str: | |
| lang = detect(question) | |
| if lang == "ar": | |
| return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. | |
| وتأكد من ان: | |
| - عدم تكرار أي نقطة أو عبارة أو كلمة | |
| - وضوح وسلاسة كل نقطة | |
| - تجنب الحشو والعبارات الزائدة | |
| السؤال: {question} | |
| الإجابة:""" | |
| else: | |
| return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If the context lacks information, rely on prior medical knowledge. | |
| Question: {question} | |
| Answer:""" | |
| # Input schema | |
| # class ChatRequest(BaseModel): | |
| # message: str | |
| # # Output endpoint | |
| # @app.post("/chat") | |
| # def chat_rag(req: ChatRequest): | |
| # prompt = generate_prompt(req.message) | |
| # response = qa_chain.run(prompt) | |
| # return {"response": response} | |
| # === ROUTES === # | |
| async def root(): | |
| return {"message": "Medical QA API is running!"} | |
| async def ask(query: Query): | |
| try: | |
| logger.debug(f"Received question: {query.question}") | |
| prompt = generate_prompt(query.question) | |
| timeout_callback = TimeoutCallback(timeout_seconds=60) | |
| loop = asyncio.get_event_loop() | |
| answer = await asyncio.wait_for( | |
| # qa_chain.run(prompt, callbacks=[timeout_callback]), | |
| loop.run_in_executor(None, qa_chain.run, prompt), | |
| timeout=360 | |
| ) | |
| if not answer: | |
| raise ValueError("Empty answer returned from model") | |
| if 'Answer:' in answer: | |
| response_text = answer.split('Answer:')[-1].strip() | |
| elif 'الإجابة:' in answer: | |
| response_text = answer.split('الإجابة:')[-1].strip() | |
| else: | |
| response_text = answer.strip() | |
| return { | |
| "status": "success", | |
| "response": response_text, | |
| "language": detect(query.question) | |
| } | |
| except TimeoutError as te: | |
| logger.error("Request timed out", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_504_GATEWAY_TIMEOUT, | |
| detail={"status": "error", "message": "Request timed out", "error": str(te)} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail={"status": "error", "message": "Internal server error", "error": str(e)} | |
| ) | |
| # === ENTRYPOINT === # | |
| if __name__ == "__main__": | |
| def handle_exit(signum, frame): | |
| print("Shutting down gracefully...") | |
| exit(0) | |
| signal.signal(signal.SIGINT, handle_exit) | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |