prithivMLmods commited on
Commit
41993a8
ยท
verified ยท
1 Parent(s): 3e612b3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Iterable, Optional, Tuple, Dict, Any, List
4
+ import hashlib
5
+ import spaces
6
+ import re
7
+ import time
8
+ import click
9
+ import gradio as gr
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ from loguru import logger
13
+ from pathlib import Path
14
+ import torch
15
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
16
+ from transformers.image_utils import load_image
17
+ import fitz
18
+ import html2text
19
+ import markdown
20
+ import tempfile
21
+
22
+ from gradio.themes import Soft
23
+ from gradio.themes.utils import colors, fonts, sizes
24
+
25
+ # --- Theme and CSS Definition ---
26
+
27
+ colors.steel_blue = colors.Color(
28
+ name="steel_blue",
29
+ c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
30
+ c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C",
31
+ c800="#2E5378", c900="#264364", c950="#1E3450",
32
+ )
33
+
34
+ class SteelBlueTheme(Soft):
35
+ def __init__(
36
+ self,
37
+ *,
38
+ primary_hue: colors.Color | str = colors.gray,
39
+ secondary_hue: colors.Color | str = colors.steel_blue,
40
+ neutral_hue: colors.Color | str = colors.slate,
41
+ text_size: sizes.Size | str = sizes.text_lg,
42
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
43
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
44
+ ),
45
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
46
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
47
+ ),
48
+ ):
49
+ super().__init__(
50
+ primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue,
51
+ text_size=text_size, font=font, font_mono=font_mono,
52
+ )
53
+ super().set(
54
+ background_fill_primary="*primary_50",
55
+ background_fill_primary_dark="*primary_900",
56
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
57
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
58
+ button_primary_text_color="white",
59
+ button_primary_text_color_hover="white",
60
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
61
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
62
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
63
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
64
+ slider_color="*secondary_500",
65
+ slider_color_dark="*secondary_600",
66
+ block_title_text_weight="600",
67
+ block_border_width="3px",
68
+ block_shadow="*shadow_drop_lg",
69
+ button_primary_shadow="*shadow_drop_lg",
70
+ button_large_padding="11px",
71
+ color_accent_soft="*primary_100",
72
+ block_label_background_fill="*primary_200",
73
+ )
74
+
75
+ steel_blue_theme = SteelBlueTheme()
76
+
77
+ # --- Model and App Logic ---
78
+
79
+ pdf_suffixes = [".pdf"]
80
+ image_suffixes = [".png", ".jpeg", ".jpg"]
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+
83
+ logger.info(f"Using device: {device}")
84
+
85
+ # Model 1: Logics-Parsing
86
+ MODEL_ID_1 = "Logics-MLLM/Logics-Parsing"
87
+ logger.info(f"Loading model 1: {MODEL_ID_1}")
88
+ processor_1 = AutoProcessor.from_pretrained(MODEL_ID_1, trust_remote_code=True)
89
+ model_1 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
+ MODEL_ID_1,
91
+ trust_remote_code=True,
92
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
93
+ ).to(device).eval()
94
+ logger.info(f"Model '{MODEL_ID_1}' loaded successfully.")
95
+
96
+ # Model 2: Gliese-OCR-7B-Post1.0
97
+ MODEL_ID_2 = "prithivMLmods/Gliese-OCR-7B-Post1.0"
98
+ logger.info(f"Loading model 2: {MODEL_ID_2}")
99
+ processor_2 = AutoProcessor.from_pretrained(MODEL_ID_2, trust_remote_code=True)
100
+ model_2 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
101
+ MODEL_ID_2,
102
+ trust_remote_code=True,
103
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
104
+ ).to(device).eval()
105
+ logger.info(f"Model '{MODEL_ID_2}' loaded successfully.")
106
+
107
+ # Model 3: olmOCR-7B-0825
108
+ MODEL_ID_3 = "allenai/olmOCR-7B-0825"
109
+ logger.info(f"Loading model 3: {MODEL_ID_3}")
110
+ processor_3 = AutoProcessor.from_pretrained(MODEL_ID_3, trust_remote_code=True)
111
+ model_3 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
112
+ MODEL_ID_3,
113
+ trust_remote_code=True,
114
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
115
+ ).to(device).eval()
116
+ logger.info(f"Model '{MODEL_ID_3}' loaded successfully.")
117
+
118
+ @spaces.GPU
119
+ def parse_page(image: Image.Image, model_name: str) -> str:
120
+ if model_name == "Logics-Parsing":
121
+ current_processor, current_model = processor_1, model_1
122
+ elif model_name == "Gliese-OCR-7B-Post1.0":
123
+ current_processor, current_model = processor_2, model_2
124
+ elif model_name == "olmOCR-7B-0825":
125
+ current_processor, current_model = processor_3, model_3
126
+ else:
127
+ raise ValueError(f"Unknown model choice: {model_name}")
128
+
129
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Parse this document page into a clean, structured HTML representation. Preserve the logical structure with appropriate tags for content blocks such as paragraphs (<p>), headings (<h1>-<h6>), tables (<table>), figures (<figure>), formulas (<formula>), and others. Include category tags, and filter out irrelevant elements like headers and footers."}]}]
130
+ prompt_full = current_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
131
+ inputs = current_processor(text=prompt_full, images=[image.convert("RGB")], return_tensors="pt").to(device)
132
+
133
+ with torch.no_grad():
134
+ generated_ids = current_model.generate(**inputs, max_new_tokens=2048, do_sample=False)
135
+
136
+ generated_ids = generated_ids[:, inputs['input_ids'].shape[1]:]
137
+ output_text = current_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
138
+ return output_text
139
+
140
+ def convert_file_to_images(file_path: str, dpi: int = 200) -> List[Image.Image]:
141
+ images = []
142
+ file_ext = Path(file_path).suffix.lower()
143
+
144
+ if file_ext in image_suffixes:
145
+ images.append(Image.open(file_path).convert("RGB"))
146
+ return images
147
+
148
+ if file_ext not in pdf_suffixes:
149
+ raise ValueError(f"Unsupported file type: {file_ext}")
150
+
151
+ try:
152
+ pdf_document = fitz.open(file_path)
153
+ zoom = dpi / 72.0
154
+ mat = fitz.Matrix(zoom, zoom)
155
+ for page_num in range(len(pdf_document)):
156
+ page = pdf_document.load_page(page_num)
157
+ pix = page.get_pixmap(matrix=mat)
158
+ img_data = pix.tobytes("png")
159
+ images.append(Image.open(BytesIO(img_data)).convert("RGB"))
160
+ pdf_document.close()
161
+ except Exception as e:
162
+ logger.error(f"Failed to convert PDF using PyMuPDF: {e}")
163
+ raise
164
+ return images
165
+
166
+ def get_initial_state() -> Dict[str, Any]:
167
+ return {"pages": [], "total_pages": 0, "current_page_index": 0, "page_results": []}
168
+
169
+ def load_and_preview_file(file_path: Optional[str]) -> Tuple[Optional[Image.Image], str, Dict[str, Any]]:
170
+ state = get_initial_state()
171
+ if not file_path:
172
+ return None, '<div class="page-info">No file loaded</div>', state
173
+
174
+ try:
175
+ pages = convert_file_to_images(file_path)
176
+ if not pages:
177
+ return None, '<div class="page-info">Could not load file</div>', state
178
+
179
+ state["pages"] = pages
180
+ state["total_pages"] = len(pages)
181
+ page_info_html = f'<div class="page-info">Page 1 / {state["total_pages"]}</div>'
182
+ return pages[0], page_info_html, state
183
+ except Exception as e:
184
+ logger.error(f"Failed to load and preview file: {e}")
185
+ return None, '<div class="page-info">Failed to load preview</div>', state
186
+
187
+ async def process_all_pages(state: Dict[str, Any], model_choice: str, progress=gr.Progress(track_tqdm=True)):
188
+ if not state or not state["pages"]:
189
+ error_msg = "<h3>Please upload a file first.</h3>"
190
+ return error_msg, "", "", None, "Error: No file to process", state
191
+
192
+ logger.info(f'Processing {state["total_pages"]} pages with model: {model_choice}')
193
+ start_time = time.time()
194
+
195
+ try:
196
+ page_results = []
197
+ for i, page_img in progress.tqdm(enumerate(state["pages"]), desc="Processing Pages"):
198
+ html_result = parse_page(page_img, model_choice)
199
+ page_results.append({'raw_html': html_result})
200
+
201
+ state["page_results"] = page_results
202
+
203
+ full_html_content = "\n\n".join([f'<!-- Page {i+1} -->\n{res["raw_html"]}' for i, res in enumerate(page_results)])
204
+ full_markdown = html2text.html2text(full_html_content)
205
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
206
+ f.write(full_markdown)
207
+ md_path = f.name
208
+
209
+ parsing_time = time.time() - start_time
210
+ cost_time_str = f'Total processing time: {parsing_time:.2f}s'
211
+
212
+ current_page_results = get_page_outputs(state)
213
+
214
+ return *current_page_results, md_path, cost_time_str, state
215
+
216
+ except Exception as e:
217
+ logger.error(f"Parsing failed: {e}", exc_info=True)
218
+ error_html = f"<h3>An error occurred during processing:</h3><p>{str(e)}</p>"
219
+ return error_html, "", "", None, f"Error: {str(e)}", state
220
+
221
+ def navigate_page(direction: str, state: Dict[str, Any]):
222
+ if not state or not state["pages"]:
223
+ return None, '<div class="page-info">No file loaded</div>', *get_page_outputs(state), state
224
+
225
+ current_index = state["current_page_index"]
226
+ total_pages = state["total_pages"]
227
+
228
+ if direction == "prev":
229
+ new_index = max(0, current_index - 1)
230
+ elif direction == "next":
231
+ new_index = min(total_pages - 1, current_index + 1)
232
+ else:
233
+ new_index = current_index
234
+
235
+ state["current_page_index"] = new_index
236
+
237
+ image_preview = state["pages"][new_index]
238
+ page_info_html = f'<div class="page-info">Page {new_index + 1} / {total_pages}</div>'
239
+
240
+ page_outputs = get_page_outputs(state)
241
+
242
+ return image_preview, page_info_html, *page_outputs, state
243
+
244
+ def get_page_outputs(state: Dict[str, Any]) -> Tuple[str, str, str]:
245
+ if not state or not state.get("page_results"):
246
+ return "<h3>Process the document to see results.</h3>", "", ""
247
+
248
+ index = state["current_page_index"]
249
+ if index >= len(state["page_results"]):
250
+ return "<h3>Result not available for this page.</h3>", "", ""
251
+
252
+ result = state["page_results"][index]
253
+ raw_html = result['raw_html']
254
+
255
+ md_source = html2text.html2text(raw_html)
256
+ md_render = markdown.markdown(md_source, extensions=['fenced_code', 'tables'])
257
+
258
+ return md_render, md_source, raw_html
259
+
260
+ def clear_all():
261
+ return None, None, "<h3>Results will be displayed here after processing.</h3>", "", "", None, "", '<div class="page-info">No file loaded</div>', get_initial_state()
262
+
263
+ @click.command()
264
+ def main():
265
+ css = """
266
+ .main-container { max-width: 1400px; margin: 0 auto; }
267
+ .header-text { text-align: center; margin-bottom: 20px; }
268
+ .page-info { text-align: center; padding: 8px 16px; font-weight: bold; margin: 10px 0; }
269
+ """
270
+ with gr.Blocks(theme=steel_blue_theme, css=css, title="Logics-Parsing Demo") as demo:
271
+ app_state = gr.State(value=get_initial_state())
272
+
273
+ gr.HTML("""
274
+ <div class="header-text">
275
+ <h1>๐Ÿ“„ Multimodal: VLM Parsing</h1>
276
+ <p style="font-size: 1.1em;">An advanced Vision Language Model to parse documents and images into clean Markdown (html)</p>
277
+ <div style="display: flex; justify-content: center; gap: 20px; margin: 15px 0;">
278
+ <a href="https://huggingface.co/collections/prithivMLmods/mm-vlm-parsing-68e33e52bfb9ae60b50602dc" target="_blank" style="text-decoration: none; font-weight: 500;">๐Ÿค— Model Info</a>
279
+ <a href="https://github.com/PRITHIVSAKTHIUR/VLM-Parsing" target="_blank" style="text-decoration: none; font-weight: 500;">๐Ÿ’ป GitHub</a>
280
+ <a href="https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending" target="_blank" style="text-decoration: none; font-weight: 500;">๐Ÿ“ Multimodal VLMs</a>
281
+ </div>
282
+ </div>
283
+ """)
284
+
285
+ with gr.Row(elem_classes=["main-container"]):
286
+ with gr.Column(scale=1):
287
+ model_choice = gr.Dropdown(choices=["Logics-Parsing", "Gliese-OCR-7B-Post1.0", "olmOCR-7B-0825"], label="Select Model", value="Logics-Parsing")
288
+ file_input = gr.File(label="Upload PDF or Image", file_types=[".pdf", ".jpg", ".jpeg", ".png"], type="filepath")
289
+
290
+ process_btn = gr.Button("๐Ÿš€Process Document", variant="primary", size="lg")
291
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear All", variant="secondary")
292
+
293
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=320)
294
+
295
+ with gr.Row():
296
+ prev_page_btn = gr.Button("โ—€ Previous")
297
+ page_info = gr.HTML('<div class="page-info">No file loaded</div>')
298
+ next_page_btn = gr.Button("Next โ–ถ")
299
+
300
+ example_root = "examples"
301
+ if os.path.exists(example_root) and os.path.isdir(example_root):
302
+ example_files = [os.path.join(example_root, f) for f in os.listdir(example_root) if f.endswith(tuple(pdf_suffixes + image_suffixes))]
303
+ if example_files:
304
+ gr.Examples(examples=example_files, inputs=file_input, label="Examples")
305
+
306
+ with gr.Accordion("Download & Details", open=False):
307
+ output_file = gr.File(label='Download Markdown Result', interactive=False)
308
+ cost_time = gr.Textbox(label='Time Cost', interactive=False)
309
+
310
+ with gr.Column(scale=2):
311
+ with gr.Tabs():
312
+ with gr.Tab("Markdown Source"):
313
+ md_source_output = gr.Code(language="markdown", label="Markdown Source")
314
+ with gr.Tab("Rendered Markdown"):
315
+ md_render_output = gr.Markdown(label='Markdown Rendering')
316
+ with gr.Tab("Generated HTML"):
317
+ raw_html_output = gr.Code(language="html", label="Generated HTML")
318
+
319
+ file_input.change(fn=load_and_preview_file, inputs=file_input, outputs=[image_preview, page_info, app_state], show_progress="full")
320
+
321
+ process_btn.click(fn=process_all_pages, inputs=[app_state, model_choice], outputs=[md_render_output, md_source_output, raw_html_output, output_file, cost_time, app_state], show_progress="full")
322
+
323
+ prev_page_btn.click(fn=lambda s: navigate_page("prev", s), inputs=app_state, outputs=[image_preview, page_info, md_render_output, md_source_output, raw_html_output, app_state])
324
+
325
+ next_page_btn.click(fn=lambda s: navigate_page("next", s), inputs=app_state, outputs=[image_preview, page_info, md_render_output, md_source_output, raw_html_output, app_state])
326
+
327
+ clear_btn.click(fn=clear_all, outputs=[file_input, image_preview, md_render_output, md_source_output, raw_html_output, output_file, cost_time, page_info, app_state])
328
+
329
+ demo.queue().launch(debug=True, show_error=True)
330
+
331
+ if __name__ == '__main__':
332
+ if not os.path.exists("examples"):
333
+ os.makedirs("examples")
334
+ logger.info("Created 'examples' directory. Please add some sample PDF/image files there.")
335
+ main()