khang119966 commited on
Commit
d03ac84
·
verified ·
1 Parent(s): 9d5127a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -64
app.py CHANGED
@@ -4,10 +4,14 @@ from transformers import AutoModel, AutoTokenizer
4
  import spaces
5
  import os
6
  import tempfile
 
7
 
8
- # Load model and tokenizer
 
 
9
  model_name = "deepseek-ai/DeepSeek-OCR"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
11
  model = AutoModel.from_pretrained(
12
  model_name,
13
  _attn_implementation="flash_attention_2",
@@ -15,52 +19,63 @@ model = AutoModel.from_pretrained(
15
  use_safetensors=True,
16
  )
17
  model = model.eval()
 
18
 
19
 
 
20
  @spaces.GPU
21
- def process_image(image, model_size, task_type):
22
  """
23
- Process image with DeepSeek-OCR
24
-
25
  Args:
26
- image: PIL Image or file path
27
- model_size: Model size configuration
28
- task_type: OCR task type
 
29
  """
30
- # GPU 函数内部移动模型到 GPU
 
 
 
 
31
  model_gpu = model.cuda().to(torch.bfloat16)
 
32
 
33
- # Create temporary directory for output
34
  with tempfile.TemporaryDirectory() as output_path:
35
- # Set prompt based on task type
36
  if task_type == "Free OCR":
37
- prompt = "<image>\nFree OCR. "
38
  elif task_type == "Convert to Markdown":
39
- prompt = "<image>\n<|grounding|>Convert the document to markdown. "
 
 
 
 
 
 
 
40
  else:
41
- prompt = "<image>\nFree OCR. "
 
42
 
43
- # Save uploaded image temporarily
44
- temp_image_path = os.path.join(output_path, "temp_image.jpg")
45
  image.save(temp_image_path)
46
 
47
- # Configure model size parameters
48
  size_configs = {
49
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
50
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
51
  "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
52
  "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
53
- "Gundam (Recommended)": {
54
- "base_size": 1024,
55
- "image_size": 640,
56
- "crop_mode": True,
57
- },
58
  }
59
-
60
  config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
61
 
62
- # Run inference
63
- result = model_gpu.infer(
 
64
  tokenizer,
65
  prompt=prompt,
66
  image_file=temp_image_path,
@@ -68,38 +83,50 @@ def process_image(image, model_size, task_type):
68
  base_size=config["base_size"],
69
  image_size=config["image_size"],
70
  crop_mode=config["crop_mode"],
71
- save_results=True,
72
  test_compress=True,
73
  eval_mode=True,
74
  )
75
 
76
- print(f"====\nresult: {result}\n====\n")
77
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
- # Create Gradio interface
81
- with gr.Blocks(title="DeepSeek-OCR") as demo:
82
  gr.Markdown(
83
  """
84
- # DeepSeek-OCR Document Recognition
 
85
 
86
- Upload an image to extract text using DeepSeek-OCR model.
87
- Supports various document types and handwriting recognition.
88
-
89
- **Model Sizes:**
90
- - **Tiny**: Fastest, lower accuracy (512x512)
91
- - **Small**: Fast, good accuracy (640x640)
92
- - **Base**: Balanced performance (1024x1024)
93
- - **Large**: Best accuracy, slower (1280x1280)
94
- - **Gundam (Recommended)**: Optimized for documents (1024 base, 640 image, crop mode)
95
  """
96
  )
97
 
98
  with gr.Row():
99
- with gr.Column():
100
- image_input = gr.Image(
101
- type="pil", label="Upload Image", sources=["upload", "clipboard"]
102
- )
103
 
104
  model_size = gr.Dropdown(
105
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
@@ -108,37 +135,68 @@ with gr.Blocks(title="DeepSeek-OCR") as demo:
108
  )
109
 
110
  task_type = gr.Dropdown(
111
- choices=["Free OCR", "Convert to Markdown"],
112
  value="Convert to Markdown",
113
  label="Task Type",
114
  )
 
 
 
 
 
 
 
115
 
116
- submit_btn = gr.Button("Process Image", variant="primary")
117
 
118
- with gr.Column():
119
- output_text = gr.Textbox(
120
- label="OCR Result", lines=20, show_copy_button=True
121
- )
122
 
123
- # Examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  gr.Examples(
125
  examples=[
126
- ["examples/math.png", "Gundam (Recommended)", "Convert to Markdown"],
127
- ["examples/receipt.jpg", "Base", "Free OCR"],
 
 
 
128
  ],
129
- inputs=[image_input, model_size, task_type],
130
- outputs=output_text,
131
- fn=process_image,
132
- cache_examples=False,
133
- )
134
-
135
- submit_btn.click(
136
- fn=process_image,
137
- inputs=[image_input, model_size, task_type],
138
- outputs=output_text,
139
  )
140
 
141
- # Launch the app
142
  if __name__ == "__main__":
 
 
 
 
 
 
143
  demo.queue(max_size=20)
144
- demo.launch()
 
4
  import spaces
5
  import os
6
  import tempfile
7
+ from PIL import Image
8
 
9
+ # --- Tải Model và Tokenizer (Chỉ một lần khi khởi động) ---
10
+ # Di chuyển việc tải model ra ngoài để tránh tải lại mỗi lần gọi hàm
11
+ print("Loading model and tokenizer...")
12
  model_name = "deepseek-ai/DeepSeek-OCR"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
+ # Tải model lên CPU trước, sau đó chuyển sang GPU trong hàm xử lý
15
  model = AutoModel.from_pretrained(
16
  model_name,
17
  _attn_implementation="flash_attention_2",
 
19
  use_safetensors=True,
20
  )
21
  model = model.eval()
22
+ print("Model loaded successfully.")
23
 
24
 
25
+ # --- Hàm xử lý chính ---
26
  @spaces.GPU
27
+ def process_ocr_task(image, model_size, task_type, ref_text):
28
  """
29
+ Xử hình ảnh với DeepSeek-OCR cho tất cả các tác vụ.
 
30
  Args:
31
+ image: Đối tượng PIL Image
32
+ model_size: Cấu hình kích thước model
33
+ task_type: Loại tác vụ OCR
34
+ ref_text: Văn bản tham chiếu cho tác vụ 'Locate'
35
  """
36
+ if image is None:
37
+ return "Please upload an image first.", None
38
+
39
+ # Chuyển model sang GPU và định dạng bfloat16 để tối ưu hiệu suất
40
+ print("Moving model to GPU...")
41
  model_gpu = model.cuda().to(torch.bfloat16)
42
+ print("Model on GPU.")
43
 
44
+ # Tạo thư mục tạm thời để lưu trữ đầu ra
45
  with tempfile.TemporaryDirectory() as output_path:
46
+ # --- Xây dựng prompt dựa trên loại tác vụ ---
47
  if task_type == "Free OCR":
48
+ prompt = "<image>\nFree OCR."
49
  elif task_type == "Convert to Markdown":
50
+ prompt = "<image>\n<|grounding|>Convert the document to markdown."
51
+ elif task_type == "Parse Figure":
52
+ prompt = "<image>\nParse the figure."
53
+ elif task_type == "Locate Object by Reference":
54
+ if not ref_text or ref_text.strip() == "":
55
+ raise gr.Error("For 'Locate' task, please provide the reference text to find.")
56
+ # Sử dụng f-string để chèn văn bản tham chiếu vào prompt
57
+ prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
58
  else:
59
+ # Mặc định là Free OCR
60
+ prompt = "<image>\nFree OCR."
61
 
62
+ # Lưu ảnh được tải lên vào thư mục tạm
63
+ temp_image_path = os.path.join(output_path, "temp_image.png")
64
  image.save(temp_image_path)
65
 
66
+ # Cấu hình các tham số kích thước model
67
  size_configs = {
68
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
69
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
70
  "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
71
  "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
72
+ "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
 
 
 
 
73
  }
 
74
  config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
75
 
76
+ print(f"Running inference with prompt: {prompt}")
77
+ # --- Chạy inference ---
78
+ text_result = model_gpu.infer(
79
  tokenizer,
80
  prompt=prompt,
81
  image_file=temp_image_path,
 
83
  base_size=config["base_size"],
84
  image_size=config["image_size"],
85
  crop_mode=config["crop_mode"],
86
+ save_results=True, # Quan trọng: phải lưu kết quả để lấy ảnh output
87
  test_compress=True,
88
  eval_mode=True,
89
  )
90
 
91
+ print(f"====\nText Result: {text_result}\n====")
92
+
93
+ # --- Xử lý output (văn bản và hình ảnh) ---
94
+ image_result_path = None
95
+ # Tác vụ 'Locate' và 'Markdown' thường tạo ra ảnh kết quả c�� chữ 'grounding'
96
+ if task_type in ["Locate Object by Reference", "Convert to Markdown", "Parse Figure"]:
97
+ # Tìm file ảnh kết quả trong thư mục output
98
+ for filename in os.listdir(output_path):
99
+ if "grounding" in filename or "result" in filename:
100
+ image_result_path = os.path.join(output_path, filename)
101
+ break
102
+
103
+ # Nếu tìm thấy ảnh, tải nó, nếu không trả về None
104
+ result_image_pil = Image.open(image_result_path) if image_result_path else None
105
+
106
+ return text_result, result_image_pil
107
 
108
 
109
+ # --- Xây dựng giao diện Gradio ---
110
+ with gr.Blocks(title="DeepSeek-OCR", theme=gr.themes.Soft()) as demo:
111
  gr.Markdown(
112
  """
113
+ # Demo toàn diện DeepSeek-OCR
114
+ Tải lên một hình ảnh để thử nghiệm các khả năng nhận dạng và hiểu tài liệu của DeepSeek-OCR.
115
 
116
+ **Hướng dẫn:**
117
+ 1. Tải lên một hình ảnh.
118
+ 2. Chọn **Model Size** phù hợp (Gundam được khuyến nghị cho tài liệu).
119
+ 3. Chọn **Task Type**:
120
+ - **Free OCR**: Trích xuất văn bản thô.
121
+ - **Convert to Markdown**: Chuyển đổi tài liệu (giữ cấu trúc) sang định dạng Markdown.
122
+ - **Parse Figure**: Phân tích và trích xuất dữ liệu từ biểu đồ, hình vẽ.
123
+ - **Locate Object by Reference**: Tìm một đối tượng hoặc văn bản cụ thể trong ảnh. **Bạn cần nhập nội dung cần tìm vào ô "Reference Text" bên dưới.**
 
124
  """
125
  )
126
 
127
  with gr.Row():
128
+ with gr.Column(scale=1):
129
+ image_input = gr.Image(type="pil", label="Tải ảnh lên", sources=["upload", "clipboard"])
 
 
130
 
131
  model_size = gr.Dropdown(
132
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
 
135
  )
136
 
137
  task_type = gr.Dropdown(
138
+ choices=["Free OCR", "Convert to Markdown", "Parse Figure", "Locate Object by Reference"],
139
  value="Convert to Markdown",
140
  label="Task Type",
141
  )
142
+
143
+ # Ô nhập văn bản tham chiếu, ban đầu bị ẩn
144
+ ref_text_input = gr.Textbox(
145
+ label="Reference Text (cho tác vụ Locate)",
146
+ placeholder="Ví dụ: the teacher, 11-2=, a red car...",
147
+ visible=False, # Ban đầu ẩn đi
148
+ )
149
 
150
+ submit_btn = gr.Button("Xử ", variant="primary")
151
 
152
+ with gr.Column(scale=2):
153
+ output_text = gr.Textbox(label="Kết quả văn bản", lines=15, show_copy_button=True)
154
+ output_image = gr.Image(label="Kết quả hình ảnh (nếu có)", type="pil")
 
155
 
156
+ # --- Logic tương tác cho giao diện ---
157
+ def toggle_ref_text_visibility(task):
158
+ # Nếu người dùng chọn 'Locate', hiển thị ô nhập văn bản
159
+ if task == "Locate Object by Reference":
160
+ return gr.Textbox(visible=True)
161
+ else:
162
+ return gr.Textbox(visible=False)
163
+
164
+ # Khi dropdown 'task_type' thay đổi, gọi hàm để cập nhật trạng thái hiển thị của ô ref_text_input
165
+ task_type.change(
166
+ fn=toggle_ref_text_visibility,
167
+ inputs=task_type,
168
+ outputs=ref_text_input,
169
+ )
170
+
171
+ # Khi nhấn nút submit
172
+ submit_btn.click(
173
+ fn=process_ocr_task,
174
+ inputs=[image_input, model_size, task_type, ref_text_input],
175
+ outputs=[output_text, output_image],
176
+ )
177
+
178
+ # --- Các ví dụ minh họa ---
179
  gr.Examples(
180
  examples=[
181
+ ["./examples/doc_markdown.png", "Gundam (Recommended)", "Convert to Markdown", ""],
182
+ ["./examples/chart.png", "Gundam (Recommended)", "Parse Figure", ""],
183
+ ["./examples/teacher.png", "Base", "Locate Object by Reference", "the teacher"],
184
+ ["./examples/math_locate.png", "Small", "Locate Object by Reference", "11-2="],
185
+ ["./examples/receipt.jpg", "Base", "Free OCR", ""],
186
  ],
187
+ inputs=[image_input, model_size, task_type, ref_text_input],
188
+ outputs=[output_text, output_image],
189
+ fn=process_ocr_task,
190
+ cache_examples=False, # Tắt cache để đảm bảo chạy lại mỗi lần click
 
 
 
 
 
 
191
  )
192
 
193
+ # --- Khởi chạy ứng dụng ---
194
  if __name__ == "__main__":
195
+ # Tạo thư mục examples và tải ảnh ví dụ (nếu chưa có)
196
+ if not os.path.exists("examples"):
197
+ os.makedirs("examples")
198
+ # Bạn cần tự tải các file ảnh ví dụ vào thư mục "examples"
199
+ # Ví dụ: doc_markdown.png, chart.png, teacher.png, math_locate.png, receipt.jpg
200
+
201
  demo.queue(max_size=20)
202
+ demo.launch(share=True) # share=True để tạo link public