Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
| os.environ["HF_HOME"] = "/tmp" | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import json | |
| import torch | |
| from transformers import MT5ForConditionalGeneration, MT5Tokenizer | |
| from sentence_transformers import SentenceTransformer, util | |
| # Load dataset | |
| with open("data/gpt2_ready_filtered.jsonl", "r", encoding="utf-8") as f: | |
| data = [json.loads(line) for line in f] | |
| texts = [item["text"] for item in data] | |
| # SomaliQA class | |
| class SomaliQA: | |
| def __init__(self, dataset_texts): | |
| self.texts = dataset_texts | |
| self.embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
| self.embeddings = self.embedder.encode(self.texts, convert_to_tensor=True) | |
| self.tokenizer = MT5Tokenizer.from_pretrained("nurfarah57/SQ-MT5") | |
| self.model = MT5ForConditionalGeneration.from_pretrained("nurfarah57/SQ-MT5") | |
| self.model.eval() | |
| def extract_qa(self, text): | |
| parts = text.split("\nJawaab:") | |
| if len(parts) == 2: | |
| return parts[0].replace("Su'aal:", "").strip(), parts[1].strip() | |
| return None, None | |
| def clean_text(self, text): | |
| return text.strip().lower().rstrip("?").replace("’", "'").replace(" ", " ") | |
| def generate_with_mt5(self, question): | |
| input_text = f"su'aal: {question}" | |
| inputs = self.tokenizer(input_text, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = self.model.generate(**inputs, max_length=128) | |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def answer(self, user_question): | |
| if not user_question.strip().endswith("?"): | |
| user_question += "?" | |
| user_clean = self.clean_text(user_question) | |
| # Exact match | |
| for text in self.texts: | |
| su_aal, jawaab = self.extract_qa(text) | |
| if su_aal and user_clean == self.clean_text(su_aal): | |
| return {"answer": jawaab, "source": "exact"} | |
| # Semantic match | |
| user_emb = self.embedder.encode(user_clean, convert_to_tensor=True) | |
| hits = util.semantic_search(user_emb, self.embeddings, top_k=1) | |
| if hits and len(hits[0]) > 0: | |
| idx = hits[0][0]['corpus_id'] | |
| su_aal, jawaab = self.extract_qa(self.texts[idx]) | |
| return {"answer": jawaab, "source": "semantic"} | |
| # Fallback to generation | |
| return {"answer": self.generate_with_mt5(user_question), "source": "generated"} | |
| # Init model | |
| qa_system = SomaliQA(texts) | |
| # FastAPI | |
| app = FastAPI( | |
| title="Somali QA API", | |
| description="Weydii su’aal oo hel jawaab sax ah laga helay dataset ama MT5 generation.", | |
| version="1.0" | |
| ) | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| def root(): | |
| return {"message": "✅ Somali QA API is running!"} | |
| def get_answer(req: QuestionRequest): | |
| if not req.question.strip(): | |
| raise HTTPException(status_code=400, detail="Su’aal lama helin.") | |
| return qa_system.answer(req.question) | |