import os os.environ["TRANSFORMERS_CACHE"] = "/tmp" os.environ["HF_HOME"] = "/tmp" from fastapi import FastAPI, HTTPException from pydantic import BaseModel import json import torch from transformers import MT5ForConditionalGeneration, MT5Tokenizer from sentence_transformers import SentenceTransformer, util # Load dataset with open("data/gpt2_ready_filtered.jsonl", "r", encoding="utf-8") as f: data = [json.loads(line) for line in f] texts = [item["text"] for item in data] # SomaliQA class class SomaliQA: def __init__(self, dataset_texts): self.texts = dataset_texts self.embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") self.embeddings = self.embedder.encode(self.texts, convert_to_tensor=True) self.tokenizer = MT5Tokenizer.from_pretrained("nurfarah57/SQ-MT5") self.model = MT5ForConditionalGeneration.from_pretrained("nurfarah57/SQ-MT5") self.model.eval() def extract_qa(self, text): parts = text.split("\nJawaab:") if len(parts) == 2: return parts[0].replace("Su'aal:", "").strip(), parts[1].strip() return None, None def clean_text(self, text): return text.strip().lower().rstrip("?").replace("’", "'").replace(" ", " ") def generate_with_mt5(self, question): input_text = f"su'aal: {question}" inputs = self.tokenizer(input_text, return_tensors="pt", padding=True) with torch.no_grad(): outputs = self.model.generate(**inputs, max_length=128) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) def answer(self, user_question): if not user_question.strip().endswith("?"): user_question += "?" user_clean = self.clean_text(user_question) # Exact match for text in self.texts: su_aal, jawaab = self.extract_qa(text) if su_aal and user_clean == self.clean_text(su_aal): return {"answer": jawaab, "source": "exact"} # Semantic match user_emb = self.embedder.encode(user_clean, convert_to_tensor=True) hits = util.semantic_search(user_emb, self.embeddings, top_k=1) if hits and len(hits[0]) > 0: idx = hits[0][0]['corpus_id'] su_aal, jawaab = self.extract_qa(self.texts[idx]) return {"answer": jawaab, "source": "semantic"} # Fallback to generation return {"answer": self.generate_with_mt5(user_question), "source": "generated"} # Init model qa_system = SomaliQA(texts) # FastAPI app = FastAPI( title="Somali QA API", description="Weydii su’aal oo hel jawaab sax ah laga helay dataset ama MT5 generation.", version="1.0" ) class QuestionRequest(BaseModel): question: str @app.get("/") def root(): return {"message": "✅ Somali QA API is running!"} @app.post("/qa") def get_answer(req: QuestionRequest): if not req.question.strip(): raise HTTPException(status_code=400, detail="Su’aal lama helin.") return qa_system.answer(req.question)