codemo commited on
Commit
6290d3f
·
verified ·
1 Parent(s): 17d69f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -64
app.py CHANGED
@@ -1,70 +1,337 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
67
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import torch
4
+ import spaces
5
+ import os
6
+ import sys
7
+ import tempfile
8
+ import shutil
9
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
10
+ import fitz
11
+ import re
12
+ import warnings
13
+ import numpy as np
14
+ import base64
15
+ from io import StringIO, BytesIO
16
+
17
+ # 模型路径配置
18
+ # 方式1: 使用在线模型(默认)
19
+ MODEL_PATH = 'deepseek-ai/DeepSeek-OCR'
20
+
21
+ # 方式2: 使用本地下载的模型(推荐)
22
+ # 将模型下载到本地后,修改为本地路径,例如:
23
+ # MODEL_PATH = './models/DeepSeek-OCR' # 本地模型路径
24
+ # MODEL_PATH = 'E:/hugging_face/models/DeepSeek-OCR' # 或使用绝对路径
25
+
26
+ # 如果本地路径不存在,则使用在线模型
27
+ if not os.path.exists(MODEL_PATH):
28
+ print(f"本地模型路径不存在: {MODEL_PATH}")
29
+ print("将使用在线模型: deepseek-ai/DeepSeek-OCR")
30
+ MODEL_PATH = 'deepseek-ai/DeepSeek-OCR'
31
+ else:
32
+ print(f"使用本地模型: {MODEL_PATH}")
33
+
34
+ # Auto-detect device (GPU if available, else CPU)
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
+ print(f"使用设备: {device}, 数据类型: {torch_dtype}")
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
40
+ model = AutoModel.from_pretrained(
41
+ MODEL_PATH,
42
+ trust_remote_code=True,
43
+ use_safetensors=True,
44
+ torch_dtype=torch_dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
+ model = model.eval().to(device)
47
+
48
+ MODEL_CONFIGS = {
49
+ "⚡ Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
50
+ "🚀 Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
51
+ "📄 Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
52
+ "📊 Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
53
+ "🎯 Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}
54
+ }
55
+
56
+ TASK_PROMPTS = {
57
+ "📋 Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True},
58
+ "📝 Free OCR": {"prompt": "<image>\nFree OCR.", "has_grounding": False},
59
+ "📍 Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True},
60
+ "🔍 Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False},
61
+ "✏️ Custom": {"prompt": "", "has_grounding": False}
62
+ }
63
+
64
+
65
+ def extract_grounding_references(text):
66
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
67
+ return re.findall(pattern, text, re.DOTALL)
68
+
69
+
70
+ def draw_bounding_boxes(image, refs, extract_images=False):
71
+ img_w, img_h = image.size
72
+ img_draw = image.copy()
73
+ draw = ImageDraw.Draw(img_draw)
74
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
75
+ draw2 = ImageDraw.Draw(overlay)
76
+ font = ImageFont.load_default()
77
+ crops = []
78
+
79
+ for ref in refs:
80
+ label = ref[1]
81
+ coords = eval(ref[2])
82
+ color = (np.random.randint(50, 255), np.random.randint(
83
+ 50, 255), np.random.randint(50, 255))
84
+ color_a = color + (60,)
85
+
86
+ for box in coords:
87
+ x1, y1, x2, y2 = int(
88
+ box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h)
89
+
90
+ if extract_images and label == 'image':
91
+ crops.append(image.crop((x1, y1, x2, y2)))
92
+
93
+ width = 5 if label == 'title' else 3
94
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
95
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a)
96
+
97
+ text_bbox = draw.textbbox((0, 0), label, font=font)
98
+ tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
99
+ ty = max(0, y1 - 20)
100
+ draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
101
+ draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
102
+
103
+ img_draw.paste(overlay, (0, 0), overlay)
104
+ return img_draw, crops
105
+
106
+
107
+ def clean_output(text, include_images=False, remove_labels=False):
108
+ if not text:
109
+ return ""
110
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
111
+ matches = re.findall(pattern, text, re.DOTALL)
112
+ img_num = 0
113
+
114
+ for match in matches:
115
+ if '<|ref|>image<|/ref|>' in match[0]:
116
+ if include_images:
117
+ text = text.replace(
118
+ match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1)
119
+ img_num += 1
120
+ else:
121
+ text = text.replace(match[0], '', 1)
122
+ else:
123
+ if remove_labels:
124
+ text = text.replace(match[0], '', 1)
125
+ else:
126
+ text = text.replace(match[0], match[1], 1)
127
+
128
+ return text.strip()
129
+
130
+
131
+ def embed_images(markdown, crops):
132
+ if not crops:
133
+ return markdown
134
+ for i, img in enumerate(crops):
135
+ buf = BytesIO()
136
+ img.save(buf, format="PNG")
137
+ b64 = base64.b64encode(buf.getvalue()).decode()
138
+ markdown = markdown.replace(
139
+ f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
140
+ return markdown
141
+
142
+
143
+ @spaces.GPU(duration=60)
144
+ def process_image(image, mode, task, custom_prompt):
145
+ if image is None:
146
+ return " Error Upload image", "", "", None, []
147
+ if task in ["✏️ Custom", "📍 Locate"] and not custom_prompt.strip():
148
+ return "Enter prompt", "", "", None, []
149
+
150
+ if image.mode in ('RGBA', 'LA', 'P'):
151
+ image = image.convert('RGB')
152
+ image = ImageOps.exif_transpose(image)
153
+
154
+ config = MODEL_CONFIGS[mode]
155
+
156
+ if task == "✏️ Custom":
157
+ prompt = f"<image>\n{custom_prompt.strip()}"
158
+ has_grounding = '<|grounding|>' in custom_prompt
159
+ elif task == "📍 Locate":
160
+ prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
161
+ has_grounding = True
162
+ else:
163
+ prompt = TASK_PROMPTS[task]["prompt"]
164
+ has_grounding = TASK_PROMPTS[task]["has_grounding"]
165
+
166
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
167
+ image.save(tmp.name, 'JPEG', quality=95)
168
+ tmp.close()
169
+ out_dir = tempfile.mkdtemp()
170
+
171
+ stdout = sys.stdout
172
+ sys.stdout = StringIO()
173
+
174
+ model.infer(tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir,
175
+ base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"])
176
+
177
+ result = '\n'.join([l for l in sys.stdout.getvalue().split('\n')
178
+ if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip()
179
+ sys.stdout = stdout
180
+
181
+ os.unlink(tmp.name)
182
+ shutil.rmtree(out_dir, ignore_errors=True)
183
+
184
+ if not result:
185
+ return "No text", "", "", None, []
186
+
187
+ cleaned = clean_output(result, False, False)
188
+ markdown = clean_output(result, True, True)
189
+
190
+ img_out = None
191
+ crops = []
192
+
193
+ if has_grounding and '<|ref|>' in result:
194
+ refs = extract_grounding_references(result)
195
+ if refs:
196
+ img_out, crops = draw_bounding_boxes(image, refs, True)
197
+
198
+ markdown = embed_images(markdown, crops)
199
+
200
+ return cleaned, markdown, result, img_out, crops
201
+
202
+
203
+ @spaces.GPU(duration=300)
204
+ def process_pdf(path, mode, task, custom_prompt):
205
+ doc = fitz.open(path)
206
+ texts, markdowns, raws, all_crops = [], [], [], []
207
+
208
+ for i in range(len(doc)):
209
+ page = doc.load_page(i)
210
+ pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
211
+ img = Image.open(BytesIO(pix.tobytes("png")))
212
+
213
+ text, md, raw, _, crops = process_image(img, mode, task, custom_prompt)
214
+
215
+ if text and text != "No text":
216
+ texts.append(f"### Page {i + 1}\n\n{text}")
217
+ markdowns.append(f"### Page {i + 1}\n\n{md}")
218
+ raws.append(f"=== Page {i + 1} ===\n{raw}")
219
+ all_crops.extend(crops)
220
+
221
+ doc.close()
222
+
223
+ return ("\n\n---\n\n".join(texts) if texts else "No text in PDF",
224
+ "\n\n---\n\n".join(markdowns) if markdowns else "No text in PDF",
225
+ "\n\n".join(raws), None, all_crops)
226
+
227
+
228
+ def process_file(path, mode, task, custom_prompt):
229
+ if not path:
230
+ return "Error Upload file", "", "", None, []
231
+
232
+ if path.lower().endswith('.pdf'):
233
+ return process_pdf(path, mode, task, custom_prompt)
234
+ else:
235
+ return process_image(Image.open(path), mode, task, custom_prompt)
236
+
237
+
238
+ def toggle_prompt(task):
239
+ if task == "✏️ Custom":
240
+ return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
241
+ elif task == "📍 Locate":
242
+ return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
243
+ return gr.update(visible=False)
244
+
245
+
246
+ def load_image(file_path):
247
+ if not file_path:
248
+ return None
249
+ if file_path.lower().endswith('.pdf'):
250
+ doc = fitz.open(file_path)
251
+ page = doc.load_page(0)
252
+ pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
253
+ img = Image.open(BytesIO(pix.tobytes("png")))
254
+ doc.close()
255
+ return img
256
+ else:
257
+ return Image.open(file_path)
258
+
259
+
260
+ with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR") as demo:
261
+ gr.Markdown("""
262
+ # 🚀 DeepSeek-OCR Demo
263
+ **Convert documents to markdown, extract raw text, and locate specific content with bounding boxes. Check the info at the bottom of the page for more information.**
264
+
265
+ **Hope this tool was helpful! If so, a quick like ❤️ would mean a lot :)**
266
+ """)
267
+
268
+ with gr.Row():
269
+ with gr.Column(scale=1):
270
+ file_in = gr.File(label="Upload Image or PDF", file_types=[
271
+ "image", ".pdf"], type="filepath")
272
+ input_img = gr.Image(label="Input Image", type="pil", height=300)
273
+ mode = gr.Dropdown(list(MODEL_CONFIGS.keys()),
274
+ value="⚡ Gundam", label="Mode")
275
+ task = gr.Dropdown(list(TASK_PROMPTS.keys()),
276
+ value="📋 Markdown", label="Task")
277
+ prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
278
+ btn = gr.Button("Extract", variant="primary", size="lg")
279
+
280
+ with gr.Column(scale=2):
281
+ with gr.Tabs():
282
+ with gr.Tab("📝 Text"):
283
+ text_out = gr.Textbox(
284
+ lines=20, show_copy_button=True, show_label=False)
285
+ with gr.Tab("🎨 Markdown"):
286
+ md_out = gr.Markdown("")
287
+ with gr.Tab("🖼️ Boxes"):
288
+ img_out = gr.Image(
289
+ type="pil", height=500, show_label=False)
290
+ with gr.Tab("🖼️ Cropped Images"):
291
+ gallery = gr.Gallery(
292
+ show_label=False, columns=3, height=400)
293
+ with gr.Tab("🔍 Raw"):
294
+ raw_out = gr.Textbox(
295
+ lines=20, show_copy_button=True, show_label=False)
296
+
297
+ gr.Examples(
298
+ examples=[
299
+ ["examples/ocr.jpg", "⚡ Gundam", "📋 Markdown", ""],
300
+ ["examples/reachy-mini.jpg", "⚡ Gundam", "📍 Locate", "Robot"]
301
+ ],
302
+ inputs=[input_img, mode, task, prompt],
303
+ cache_examples=False
304
+ )
305
+
306
+ with gr.Accordion("ℹ️ Info", open=False):
307
+ gr.Markdown("""
308
+ ### Modes
309
+ - **Gundam**: 1024 base + 640 tiles with cropping - Best balance
310
+ - **Tiny**: 512×512, no crop - Fastest
311
+ - **Small**: 640×640, no crop - Quick
312
+ - **Base**: 1024×1024, no crop - Standard
313
+ - **Large**: 1280×1280, no crop - Highest quality
314
+
315
+ ### Tasks
316
+ - **Markdown**: Convert document to structured markdown (grounding ✅)
317
+ - **Free OCR**: Simple text extraction
318
+ - **Locate**: Find specific text in image (grounding ✅)
319
+ - **Describe**: General image description
320
+ - **Custom**: Your own prompt (add `<|grounding|>` for boxes)
321
+ """)
322
+
323
+ file_in.change(load_image, [file_in], [input_img])
324
+ task.change(toggle_prompt, [task], [prompt])
325
 
326
+ def run(image, file_path, mode, task, custom_prompt):
327
+ if image is not None:
328
+ return process_image(image, mode, task, custom_prompt)
329
+ if file_path:
330
+ return process_file(file_path, mode, task, custom_prompt)
331
+ return "Error uploading file or image", "", "", None, []
332
 
333
+ btn.click(run, [input_img, file_in, mode, task, prompt],
334
+ [text_out, md_out, raw_out, img_out, gallery])
335
 
336
  if __name__ == "__main__":
337
+ demo.queue(max_size=20).launch()