Spaces:
Runtime error
Runtime error
| # inference_api.py | |
| import os | |
| import fitz # PyMuPDF | |
| import fasttext | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from IndicTransToolkit.processor import IndicProcessor | |
| import google.generativeai as genai | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import json | |
| app = FastAPI() | |
| # === CONFIGURATION === | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B" | |
| OCR_MODEL_ID = "microsoft/trocr-base-printed" | |
| LANGUAGE_TO_TRANSLATE = "mal" | |
| DEVICE = "cpu" | |
| # --- Configure Gemini --- | |
| if GEMINI_API_KEY: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| else: | |
| print("🔴 GEMINI_API_KEY not set.") | |
| # --- Load Models --- | |
| translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True) | |
| translation_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32 | |
| ).to(DEVICE) | |
| ip = IndicProcessor(inference=True) | |
| ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin") | |
| lang_detect_model = fasttext.load_model(ft_model_path) | |
| ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1) | |
| # === HELPER FUNCTIONS === | |
| def classify_image_with_gemini(image: Image.Image): | |
| model = genai.GenerativeModel('gemini-1.5-flash-latest') | |
| prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'." | |
| response = model.generate_content([prompt, image]) | |
| classification = response.text.strip().lower() | |
| return "diagram" if "diagram" in classification else "document" | |
| def summarize_diagram_with_gemini(image: Image.Image): | |
| model = genai.GenerativeModel('gemini-1.5-flash-latest') | |
| prompt = "Describe the contents of this technical diagram in a concise summary." | |
| response = model.generate_content([prompt, image]) | |
| return response.text.strip() | |
| def extract_text_from_image(path): | |
| image = Image.open(path).convert("RGB") | |
| image_type = classify_image_with_gemini(image) | |
| if image_type == "diagram": | |
| return summarize_diagram_with_gemini(image) | |
| else: | |
| out = ocr_pipeline(image) | |
| return out[0]["generated_text"] if out else "" | |
| def extract_text_from_pdf(path): | |
| doc = fitz.open(path) | |
| return "".join(page.get_text("text") + "\n" for page in doc) | |
| def read_text_from_txt(path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| def detect_language(text_snippet): | |
| s = text_snippet.replace("\n", " ").strip() | |
| if not s: return None | |
| preds = lang_detect_model.predict(s, k=1) | |
| return preds[0][0].split("__")[-1] if preds and preds[0] else None | |
| def translate_chunk(chunk): | |
| batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn") | |
| inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True) | |
| decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| return ip.postprocess_batch(decoded, lang="eng_Latn")[0] | |
| def generate_structured_json(text_to_analyze): | |
| model = genai.GenerativeModel('gemini-1.5-flash-latest') | |
| prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}" | |
| json_schema = { | |
| "type": "OBJECT", | |
| "properties": { | |
| "summary": {"type": "STRING"}, | |
| "actions_required": {"type": "ARRAY", "items": { | |
| "type": "OBJECT", | |
| "properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}}, | |
| "required": ["action","priority","deadline","notes"] | |
| }}, | |
| "departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}}, | |
| "cross_document_flags": {"type": "ARRAY", "items": { | |
| "type": "OBJECT", | |
| "properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}}, | |
| "required": ["related_document_type","related_issue"] | |
| }} | |
| }, | |
| "required": ["summary","actions_required","departments_to_notify","cross_document_flags"] | |
| } | |
| generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema) | |
| response = model.generate_content(prompt, generation_config=generation_config) | |
| return json.loads(response.text) | |
| def check_relevance_with_gemini(summary_text): | |
| model = genai.GenerativeModel('gemini-1.5-flash-latest') | |
| prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}' | |
| response = model.generate_content(prompt) | |
| return "yes" in response.text.strip().lower() | |
| # === API INPUT SCHEMA === | |
| class InputFile(BaseModel): | |
| file_path: str | |
| def predict(file: InputFile): | |
| if not GEMINI_API_KEY: | |
| return {"error": "Gemini API key not set."} | |
| path = file.file_path | |
| ext = os.path.splitext(path)[1].lower() | |
| # Phase 1: Extract text | |
| if ext == ".pdf": | |
| original_text = extract_text_from_pdf(path) | |
| elif ext == ".txt": | |
| original_text = read_text_from_txt(path) | |
| elif ext in [".png", ".jpg", ".jpeg"]: | |
| original_text = extract_text_from_image(path) | |
| else: | |
| return {"error": "Unsupported file type."} | |
| # Phase 2: Translate Malayalam if detected | |
| lines = original_text.split("\n") | |
| translated_lines = [] | |
| for ln in lines: | |
| if not ln.strip(): continue | |
| lang = detect_language(ln) | |
| if lang == LANGUAGE_TO_TRANSLATE: | |
| translated_lines.append(translate_chunk(ln)) | |
| else: | |
| translated_lines.append(ln) | |
| final_text = "\n".join(translated_lines) | |
| # Phase 3: Gemini analysis | |
| summary_data = generate_structured_json(final_text) | |
| if not summary_data or "summary" not in summary_data: | |
| return {"error": "Failed to generate analysis."} | |
| is_relevant = check_relevance_with_gemini(summary_data["summary"]) | |
| if is_relevant: | |
| return summary_data | |
| else: | |
| return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."} | |