Sripriya16's picture
Update app.py
1895105 verified
raw
history blame
6.6 kB
# inference_api.py
import os
import fitz # PyMuPDF
import fasttext
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from IndicTransToolkit.processor import IndicProcessor
import google.generativeai as genai
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
import json
app = FastAPI()
# === CONFIGURATION ===
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
OCR_MODEL_ID = "microsoft/trocr-base-printed"
LANGUAGE_TO_TRANSLATE = "mal"
DEVICE = "cpu"
# --- Configure Gemini ---
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
else:
print("🔴 GEMINI_API_KEY not set.")
# --- Load Models ---
translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(
TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
).to(DEVICE)
ip = IndicProcessor(inference=True)
ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
lang_detect_model = fasttext.load_model(ft_model_path)
ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1)
# === HELPER FUNCTIONS ===
def classify_image_with_gemini(image: Image.Image):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
response = model.generate_content([prompt, image])
classification = response.text.strip().lower()
return "diagram" if "diagram" in classification else "document"
def summarize_diagram_with_gemini(image: Image.Image):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = "Describe the contents of this technical diagram in a concise summary."
response = model.generate_content([prompt, image])
return response.text.strip()
def extract_text_from_image(path):
image = Image.open(path).convert("RGB")
image_type = classify_image_with_gemini(image)
if image_type == "diagram":
return summarize_diagram_with_gemini(image)
else:
out = ocr_pipeline(image)
return out[0]["generated_text"] if out else ""
def extract_text_from_pdf(path):
doc = fitz.open(path)
return "".join(page.get_text("text") + "\n" for page in doc)
def read_text_from_txt(path):
with open(path, "r", encoding="utf-8") as f:
return f.read()
def detect_language(text_snippet):
s = text_snippet.replace("\n", " ").strip()
if not s: return None
preds = lang_detect_model.predict(s, k=1)
return preds[0][0].split("__")[-1] if preds and preds[0] else None
def translate_chunk(chunk):
batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return ip.postprocess_batch(decoded, lang="eng_Latn")[0]
def generate_structured_json(text_to_analyze):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}"
json_schema = {
"type": "OBJECT",
"properties": {
"summary": {"type": "STRING"},
"actions_required": {"type": "ARRAY", "items": {
"type": "OBJECT",
"properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}},
"required": ["action","priority","deadline","notes"]
}},
"departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}},
"cross_document_flags": {"type": "ARRAY", "items": {
"type": "OBJECT",
"properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}},
"required": ["related_document_type","related_issue"]
}}
},
"required": ["summary","actions_required","departments_to_notify","cross_document_flags"]
}
generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
response = model.generate_content(prompt, generation_config=generation_config)
return json.loads(response.text)
def check_relevance_with_gemini(summary_text):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}'
response = model.generate_content(prompt)
return "yes" in response.text.strip().lower()
# === API INPUT SCHEMA ===
class InputFile(BaseModel):
file_path: str
@app.post("/predict")
def predict(file: InputFile):
if not GEMINI_API_KEY:
return {"error": "Gemini API key not set."}
path = file.file_path
ext = os.path.splitext(path)[1].lower()
# Phase 1: Extract text
if ext == ".pdf":
original_text = extract_text_from_pdf(path)
elif ext == ".txt":
original_text = read_text_from_txt(path)
elif ext in [".png", ".jpg", ".jpeg"]:
original_text = extract_text_from_image(path)
else:
return {"error": "Unsupported file type."}
# Phase 2: Translate Malayalam if detected
lines = original_text.split("\n")
translated_lines = []
for ln in lines:
if not ln.strip(): continue
lang = detect_language(ln)
if lang == LANGUAGE_TO_TRANSLATE:
translated_lines.append(translate_chunk(ln))
else:
translated_lines.append(ln)
final_text = "\n".join(translated_lines)
# Phase 3: Gemini analysis
summary_data = generate_structured_json(final_text)
if not summary_data or "summary" not in summary_data:
return {"error": "Failed to generate analysis."}
is_relevant = check_relevance_with_gemini(summary_data["summary"])
if is_relevant:
return summary_data
else:
return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}