Sripriya16 commited on
Commit
8232535
·
verified ·
1 Parent(s): 6d0d98c

Delete inference_api.py

Browse files
Files changed (1) hide show
  1. inference_api.py +0 -161
inference_api.py DELETED
@@ -1,161 +0,0 @@
1
- # inference_api.py
2
- import os
3
- import fitz # PyMuPDF
4
- import fasttext
5
- import torch
6
- from PIL import Image
7
- from huggingface_hub import hf_hub_download
8
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
- from IndicTransToolkit.processor import IndicProcessor
10
- import google.generativeai as genai
11
- from fastapi import FastAPI
12
- from pydantic import BaseModel
13
- from typing import Optional
14
- import json
15
-
16
- app = FastAPI()
17
-
18
- # === CONFIGURATION ===
19
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
20
- TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
21
- OCR_MODEL_ID = "microsoft/trocr-base-printed"
22
- LANGUAGE_TO_TRANSLATE = "mal"
23
- DEVICE = "cpu"
24
-
25
- # --- Configure Gemini ---
26
- if GEMINI_API_KEY:
27
- genai.configure(api_key=GEMINI_API_KEY)
28
- else:
29
- print("🔴 GEMINI_API_KEY not set.")
30
-
31
- # --- Load Models ---
32
- translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
33
- translation_model = AutoModelForSeq2SeqLM.from_pretrained(
34
- TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
35
- ).to(DEVICE)
36
- ip = IndicProcessor(inference=True)
37
-
38
- ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
39
- lang_detect_model = fasttext.load_model(ft_model_path)
40
-
41
- ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1)
42
-
43
- # === HELPER FUNCTIONS ===
44
- def classify_image_with_gemini(image: Image.Image):
45
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
46
- prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
47
- response = model.generate_content([prompt, image])
48
- classification = response.text.strip().lower()
49
- return "diagram" if "diagram" in classification else "document"
50
-
51
- def summarize_diagram_with_gemini(image: Image.Image):
52
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
53
- prompt = "Describe the contents of this technical diagram in a concise summary."
54
- response = model.generate_content([prompt, image])
55
- return response.text.strip()
56
-
57
- def extract_text_from_image(path):
58
- image = Image.open(path).convert("RGB")
59
- image_type = classify_image_with_gemini(image)
60
- if image_type == "diagram":
61
- return summarize_diagram_with_gemini(image)
62
- else:
63
- out = ocr_pipeline(image)
64
- return out[0]["generated_text"] if out else ""
65
-
66
- def extract_text_from_pdf(path):
67
- doc = fitz.open(path)
68
- return "".join(page.get_text("text") + "\n" for page in doc)
69
-
70
- def read_text_from_txt(path):
71
- with open(path, "r", encoding="utf-8") as f:
72
- return f.read()
73
-
74
- def detect_language(text_snippet):
75
- s = text_snippet.replace("\n", " ").strip()
76
- if not s: return None
77
- preds = lang_detect_model.predict(s, k=1)
78
- return preds[0][0].split("__")[-1] if preds and preds[0] else None
79
-
80
- def translate_chunk(chunk):
81
- batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
82
- inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
83
- with torch.no_grad():
84
- generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
85
- decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
86
- return ip.postprocess_batch(decoded, lang="eng_Latn")[0]
87
-
88
- def generate_structured_json(text_to_analyze):
89
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
90
- prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}"
91
- json_schema = {
92
- "type": "OBJECT",
93
- "properties": {
94
- "summary": {"type": "STRING"},
95
- "actions_required": {"type": "ARRAY", "items": {
96
- "type": "OBJECT",
97
- "properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}},
98
- "required": ["action","priority","deadline","notes"]
99
- }},
100
- "departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}},
101
- "cross_document_flags": {"type": "ARRAY", "items": {
102
- "type": "OBJECT",
103
- "properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}},
104
- "required": ["related_document_type","related_issue"]
105
- }}
106
- },
107
- "required": ["summary","actions_required","departments_to_notify","cross_document_flags"]
108
- }
109
- generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
110
- response = model.generate_content(prompt, generation_config=generation_config)
111
- return json.loads(response.text)
112
-
113
- def check_relevance_with_gemini(summary_text):
114
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
115
- prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}'
116
- response = model.generate_content(prompt)
117
- return "yes" in response.text.strip().lower()
118
-
119
- # === API INPUT SCHEMA ===
120
- class InputFile(BaseModel):
121
- file_path: str
122
-
123
- @app.post("/predict")
124
- def predict(file: InputFile):
125
- if not GEMINI_API_KEY:
126
- return {"error": "Gemini API key not set."}
127
- path = file.file_path
128
- ext = os.path.splitext(path)[1].lower()
129
-
130
- # Phase 1: Extract text
131
- if ext == ".pdf":
132
- original_text = extract_text_from_pdf(path)
133
- elif ext == ".txt":
134
- original_text = read_text_from_txt(path)
135
- elif ext in [".png", ".jpg", ".jpeg"]:
136
- original_text = extract_text_from_image(path)
137
- else:
138
- return {"error": "Unsupported file type."}
139
-
140
- # Phase 2: Translate Malayalam if detected
141
- lines = original_text.split("\n")
142
- translated_lines = []
143
- for ln in lines:
144
- if not ln.strip(): continue
145
- lang = detect_language(ln)
146
- if lang == LANGUAGE_TO_TRANSLATE:
147
- translated_lines.append(translate_chunk(ln))
148
- else:
149
- translated_lines.append(ln)
150
- final_text = "\n".join(translated_lines)
151
-
152
- # Phase 3: Gemini analysis
153
- summary_data = generate_structured_json(final_text)
154
- if not summary_data or "summary" not in summary_data:
155
- return {"error": "Failed to generate analysis."}
156
-
157
- is_relevant = check_relevance_with_gemini(summary_data["summary"])
158
- if is_relevant:
159
- return summary_data
160
- else:
161
- return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}