khang119966 commited on
Commit
3b28ff1
Β·
verified Β·
1 Parent(s): c2f8f51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -29
app.py CHANGED
@@ -4,7 +4,8 @@ from transformers import AutoModel, AutoTokenizer
4
  import spaces
5
  import os
6
  import tempfile
7
- from PIL import Image
 
8
 
9
  # --- 1. Load Model and Tokenizer (Done only once at startup) ---
10
  print("Loading model and tokenizer...")
@@ -20,29 +21,32 @@ model = AutoModel.from_pretrained(
20
  model = model.eval()
21
  print("βœ… Model loaded successfully.")
22
 
23
-
24
- # --- 2. Main Processing Function ---
 
 
 
 
 
 
 
 
 
 
25
  @spaces.GPU
26
  def process_ocr_task(image, model_size, task_type, ref_text):
27
  """
28
  Processes an image with DeepSeek-OCR for all supported tasks.
29
- Args:
30
- image (PIL.Image): The input image.
31
- model_size (str): The model size configuration.
32
- task_type (str): The type of OCR task to perform.
33
- ref_text (str): The reference text for the 'Locate' task.
34
  """
35
  if image is None:
36
  return "Please upload an image first.", None
37
 
38
- # Move the model to GPU and use bfloat16 for better performance
39
  print("πŸš€ Moving model to GPU...")
40
  model_gpu = model.cuda().to(torch.bfloat16)
41
  print("βœ… Model is on GPU.")
42
 
43
- # Create a temporary directory to store files
44
  with tempfile.TemporaryDirectory() as output_path:
45
- # --- Build the prompt based on the selected task type ---
46
  if task_type == "πŸ“ Free OCR":
47
  prompt = "<image>\nFree OCR."
48
  elif task_type == "πŸ“„ Convert to Markdown":
@@ -52,16 +56,14 @@ def process_ocr_task(image, model_size, task_type, ref_text):
52
  elif task_type == "πŸ” Locate Object by Reference":
53
  if not ref_text or ref_text.strip() == "":
54
  raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
55
- # Use an f-string to embed the user's reference text into the prompt
56
  prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
57
  else:
58
- prompt = "<image>\nFree OCR." # Default fallback
59
 
60
- # Save the uploaded image to the temporary path
61
  temp_image_path = os.path.join(output_path, "temp_image.png")
62
  image.save(temp_image_path)
63
 
64
- # Configure model size parameters
65
  size_configs = {
66
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
67
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
@@ -72,7 +74,6 @@ def process_ocr_task(image, model_size, task_type, ref_text):
72
  config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
73
 
74
  print(f"πŸƒ Running inference with prompt: {prompt}")
75
- # --- Run the model's inference method ---
76
  text_result = model_gpu.infer(
77
  tokenizer,
78
  prompt=prompt,
@@ -81,26 +82,51 @@ def process_ocr_task(image, model_size, task_type, ref_text):
81
  base_size=config["base_size"],
82
  image_size=config["image_size"],
83
  crop_mode=config["crop_mode"],
84
- save_results=True, # Important: Must be True to get the output image
85
  test_compress=True,
86
  eval_mode=True,
87
  )
88
 
89
  print(f"====\nπŸ“„ Text Result: {text_result}\n====")
90
 
91
- # --- Handle the output (both text and image) ---
92
- image_result_path = None
93
- # Tasks that generate a visual output usually create a 'grounding' or 'result' image
94
- if task_type in ["πŸ” Locate Object by Reference", "πŸ“„ Convert to Markdown", "πŸ“ˆ Parse Figure"]:
95
- # Find the result image in the output directory
96
- for filename in os.listdir(output_path):
97
- if "grounding" in filename or "result" in filename:
98
- image_result_path = os.path.join(output_path, filename)
99
- break
100
-
101
- # If an image was found, open it with PIL; otherwise, return None
102
- result_image_pil = Image.open(image_result_path) if image_result_path else None
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  return text_result, result_image_pil
105
 
106
 
 
4
  import spaces
5
  import os
6
  import tempfile
7
+ from PIL import Image, ImageDraw
8
+ import re # Import thΖ° viện regular expression
9
 
10
  # --- 1. Load Model and Tokenizer (Done only once at startup) ---
11
  print("Loading model and tokenizer...")
 
21
  model = model.eval()
22
  print("βœ… Model loaded successfully.")
23
 
24
+ # --- Helper function to find pre-generated result images ---
25
+ def find_result_image(path):
26
+ for filename in os.listdir(path):
27
+ if "grounding" in filename or "result" in filename:
28
+ try:
29
+ image_path = os.path.join(path, filename)
30
+ return Image.open(image_path)
31
+ except Exception as e:
32
+ print(f"Error opening result image {filename}: {e}")
33
+ return None
34
+
35
+ # --- 2. Main Processing Function (UPDATED) ---
36
  @spaces.GPU
37
  def process_ocr_task(image, model_size, task_type, ref_text):
38
  """
39
  Processes an image with DeepSeek-OCR for all supported tasks.
 
 
 
 
 
40
  """
41
  if image is None:
42
  return "Please upload an image first.", None
43
 
 
44
  print("πŸš€ Moving model to GPU...")
45
  model_gpu = model.cuda().to(torch.bfloat16)
46
  print("βœ… Model is on GPU.")
47
 
 
48
  with tempfile.TemporaryDirectory() as output_path:
49
+ # Build the prompt... (same as before)
50
  if task_type == "πŸ“ Free OCR":
51
  prompt = "<image>\nFree OCR."
52
  elif task_type == "πŸ“„ Convert to Markdown":
 
56
  elif task_type == "πŸ” Locate Object by Reference":
57
  if not ref_text or ref_text.strip() == "":
58
  raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
 
59
  prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
60
  else:
61
+ prompt = "<image>\nFree OCR."
62
 
 
63
  temp_image_path = os.path.join(output_path, "temp_image.png")
64
  image.save(temp_image_path)
65
 
66
+ # Configure model size... (same as before)
67
  size_configs = {
68
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
69
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
 
74
  config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
75
 
76
  print(f"πŸƒ Running inference with prompt: {prompt}")
 
77
  text_result = model_gpu.infer(
78
  tokenizer,
79
  prompt=prompt,
 
82
  base_size=config["base_size"],
83
  image_size=config["image_size"],
84
  crop_mode=config["crop_mode"],
85
+ save_results=True,
86
  test_compress=True,
87
  eval_mode=True,
88
  )
89
 
90
  print(f"====\nπŸ“„ Text Result: {text_result}\n====")
91
 
92
+ # --- NEW: Handle the output with custom bounding box drawing ---
93
+ result_image_pil = None
 
 
 
 
 
 
 
 
 
 
94
 
95
+ if task_type == "πŸ” Locate Object by Reference":
96
+ # Define the pattern to find coordinates like [[280, 15, 696, 997]]
97
+ pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>")
98
+ match = pattern.search(text_result)
99
+
100
+ if match:
101
+ print("βœ… Found bounding box coordinates. Drawing on the original image.")
102
+ # Extract coordinates as integers
103
+ coords_norm = [int(c) for c in match.groups()]
104
+ x1_norm, y1_norm, x2_norm, y2_norm = coords_norm
105
+
106
+ # Get the original image's dimensions
107
+ w, h = image.size
108
+
109
+ # Scale the normalized coordinates (from 1000x1000 space) to the image's actual size
110
+ x1 = int(x1_norm / 1000 * w)
111
+ y1 = int(y1_norm / 1000 * h)
112
+ x2 = int(x2_norm / 1000 * w)
113
+ y2 = int(y2_norm / 1000 * h)
114
+
115
+ # Create a copy of the original image to draw on
116
+ image_with_bbox = image.copy()
117
+ draw = ImageDraw.Draw(image_with_bbox)
118
+
119
+ # Draw the rectangle with a red outline, 3 pixels wide
120
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
121
+
122
+ result_image_pil = image_with_bbox
123
+ else:
124
+ print("⚠️ Could not parse bbox from text. Falling back to searching for a result image.")
125
+ result_image_pil = find_result_image(output_path)
126
+ else:
127
+ # For other tasks, use the old method of finding the generated image
128
+ result_image_pil = find_result_image(output_path)
129
+
130
  return text_result, result_image_pil
131
 
132