davanstrien HF Staff commited on
Commit
b7b3c0d
·
1 Parent(s): 0a5527f

Switch to batch processing pattern from official run_dpsk_ocr_eval_batch.py

Browse files

- Use LLM class instead of AsyncLLMEngine (fixes segfault)
- Use ThreadPoolExecutor for parallel image preprocessing
- Single llm.generate() call for true batch processing
- Added max_num_seqs and num_workers parameters
- Mirrors official DeepSeek batch processing script

Files changed (1) hide show
  1. process_dataset.py +59 -58
process_dataset.py CHANGED
@@ -1,16 +1,15 @@
1
  #!/usr/bin/env python3
2
  """
3
  DeepSeek-OCR Dataset Processing
4
- Minimal adaptation of official run_dpsk_ocr_image.py for dataset processing
5
  """
6
 
7
  import argparse
8
- import asyncio
9
  import json
10
  import os
11
  import sys
12
- import time
13
  from datetime import datetime
 
14
 
15
  import torch
16
  if torch.version.cuda == '11.8':
@@ -18,13 +17,12 @@ if torch.version.cuda == '11.8':
18
 
19
  os.environ['VLLM_USE_V1'] = '0'
20
 
21
- from vllm import AsyncLLMEngine, SamplingParams
22
- from vllm.engine.arg_utils import AsyncEngineArgs
23
  from vllm.model_executor.models.registry import ModelRegistry
24
  from PIL import Image, ImageOps
25
  from tqdm.auto import tqdm
26
  from datasets import load_dataset
27
- from huggingface_hub import DatasetCard, login
28
 
29
  # Import DeepSeek-OCR modules (unchanged from original)
30
  from deepseek_ocr import DeepseekOCRForCausalLM
@@ -44,27 +42,19 @@ def check_cuda():
44
  print(f"Using GPU: {torch.cuda.get_device_name(0)}")
45
 
46
 
47
- async def process_single_image(engine, sampling_params, image_features, prompt):
48
- """Process a single image through the engine (unchanged from original)"""
49
- request_id = f"request-{int(time.time() * 1000000)}"
50
-
51
- if image_features and '<image>' in prompt:
52
- request = {
53
- "prompt": prompt,
54
- "multi_modal_data": {"image": image_features}
55
- }
56
- else:
57
- request = {"prompt": prompt}
58
-
59
- final_output = ""
60
- async for request_output in engine.generate(request, sampling_params, request_id):
61
- if request_output.outputs:
62
- final_output = request_output.outputs[0].text
63
-
64
- return final_output.strip()
65
 
66
 
67
- async def main_async(args):
68
  """Main processing function"""
69
  check_cuda()
70
 
@@ -95,23 +85,24 @@ async def main_async(args):
95
  dataset = dataset.select(range(min(args.max_samples, len(dataset))))
96
  print(f"Processing {len(dataset)} samples")
97
 
98
- # Initialize vLLM engine (UNCHANGED from original)
99
  print("Initializing vLLM engine...")
100
- engine_args = AsyncEngineArgs(
101
  model=MODEL_PATH,
102
  hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
103
  block_size=256,
104
- max_model_len=args.max_model_len,
105
  enforce_eager=False,
106
  trust_remote_code=True,
 
 
 
107
  tensor_parallel_size=1,
108
  gpu_memory_utilization=args.gpu_memory_utilization,
109
  )
110
- engine = AsyncLLMEngine.from_engine_args(engine_args)
111
 
112
- # Sampling params (UNCHANGED from original)
113
  logits_processors = [NoRepeatNGramLogitsProcessor(
114
- ngram_size=30, window_size=90, whitelist_token_ids={128821, 128822}
115
  )]
116
 
117
  sampling_params = SamplingParams(
@@ -121,37 +112,44 @@ async def main_async(args):
121
  skip_special_tokens=False,
122
  )
123
 
124
- # Process images
125
- print(f"Processing {len(dataset)} images...")
126
- all_markdown = []
127
- processor = DeepseekOCRProcessor()
128
-
129
- for idx in tqdm(range(len(dataset)), desc="OCR processing"):
130
  try:
131
- # Load image
132
  image = dataset[idx][args.image_column]
133
  if not isinstance(image, Image.Image):
134
  image = Image.open(image) if isinstance(image, str) else image
135
-
136
  image = ImageOps.exif_transpose(image.convert('RGB'))
137
-
138
- # Preprocess image (UNCHANGED from original)
139
- if '<image>' in PROMPT:
140
- image_features = processor.tokenize_with_images(
141
- images=[image], bos=True, eos=True, cropping=CROP_MODE
142
- )
143
- else:
144
- image_features = ''
145
-
146
- # Process
147
- result = await process_single_image(
148
- engine, sampling_params, image_features, PROMPT
149
- )
150
- all_markdown.append(result)
151
-
152
  except Exception as e:
153
- print(f"Error processing image {idx}: {e}")
154
- all_markdown.append("[OCR FAILED]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Add markdown column
157
  print("Adding markdown column...")
@@ -177,8 +175,9 @@ async def main_async(args):
177
  "max_tokens": args.max_tokens,
178
  "max_model_len": args.max_model_len,
179
  "gpu_memory_utilization": args.gpu_memory_utilization,
 
180
  "script": "process_dataset.py",
181
- "implementation": "vllm-async (official deepseek code)",
182
  }
183
  existing_info.append(new_info)
184
 
@@ -207,8 +206,10 @@ if __name__ == "__main__":
207
  parser.add_argument("--max-model-len", type=int, default=8192)
208
  parser.add_argument("--max-tokens", type=int, default=8192)
209
  parser.add_argument("--gpu-memory-utilization", type=float, default=0.75)
 
 
210
  parser.add_argument("--hf-token", help="HF API token")
211
  parser.add_argument("--private", action="store_true", help="Make output private")
212
 
213
  args = parser.parse_args()
214
- asyncio.run(main_async(args))
 
1
  #!/usr/bin/env python3
2
  """
3
  DeepSeek-OCR Dataset Processing
4
+ Minimal adaptation of official run_dpsk_ocr_eval_batch.py for dataset processing
5
  """
6
 
7
  import argparse
 
8
  import json
9
  import os
10
  import sys
 
11
  from datetime import datetime
12
+ from concurrent.futures import ThreadPoolExecutor
13
 
14
  import torch
15
  if torch.version.cuda == '11.8':
 
17
 
18
  os.environ['VLLM_USE_V1'] = '0'
19
 
20
+ from vllm import LLM, SamplingParams
 
21
  from vllm.model_executor.models.registry import ModelRegistry
22
  from PIL import Image, ImageOps
23
  from tqdm.auto import tqdm
24
  from datasets import load_dataset
25
+ from huggingface_hub import login
26
 
27
  # Import DeepSeek-OCR modules (unchanged from original)
28
  from deepseek_ocr import DeepseekOCRForCausalLM
 
42
  print(f"Using GPU: {torch.cuda.get_device_name(0)}")
43
 
44
 
45
+ def process_single_image(image):
46
+ """Preprocess single image (unchanged from official batch script)"""
47
+ prompt_in = PROMPT
48
+ cache_item = {
49
+ "prompt": prompt_in,
50
+ "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(
51
+ images=[image], bos=True, eos=True, cropping=CROP_MODE
52
+ )},
53
+ }
54
+ return cache_item
 
 
 
 
 
 
 
 
55
 
56
 
57
+ def main(args):
58
  """Main processing function"""
59
  check_cuda()
60
 
 
85
  dataset = dataset.select(range(min(args.max_samples, len(dataset))))
86
  print(f"Processing {len(dataset)} samples")
87
 
88
+ # Initialize vLLM engine (UNCHANGED from official batch script)
89
  print("Initializing vLLM engine...")
90
+ llm = LLM(
91
  model=MODEL_PATH,
92
  hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
93
  block_size=256,
 
94
  enforce_eager=False,
95
  trust_remote_code=True,
96
+ max_model_len=args.max_model_len,
97
+ swap_space=0,
98
+ max_num_seqs=args.max_num_seqs,
99
  tensor_parallel_size=1,
100
  gpu_memory_utilization=args.gpu_memory_utilization,
101
  )
 
102
 
103
+ # Sampling params (UNCHANGED from official batch script)
104
  logits_processors = [NoRepeatNGramLogitsProcessor(
105
+ ngram_size=40, window_size=90, whitelist_token_ids={128821, 128822}
106
  )]
107
 
108
  sampling_params = SamplingParams(
 
112
  skip_special_tokens=False,
113
  )
114
 
115
+ # Load and preprocess images
116
+ print(f"Loading images from dataset...")
117
+ images = []
118
+ for idx in range(len(dataset)):
 
 
119
  try:
 
120
  image = dataset[idx][args.image_column]
121
  if not isinstance(image, Image.Image):
122
  image = Image.open(image) if isinstance(image, str) else image
 
123
  image = ImageOps.exif_transpose(image.convert('RGB'))
124
+ images.append(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
+ print(f"Error loading image {idx}: {e}")
127
+ images.append(None)
128
+
129
+ # Preprocess images in parallel (UNCHANGED from official batch script)
130
+ print(f"Preprocessing images...")
131
+ with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
132
+ batch_inputs = list(tqdm(
133
+ executor.map(lambda img: process_single_image(img) if img else None, images),
134
+ total=len(images),
135
+ desc="Pre-processing images"
136
+ ))
137
+
138
+ # Filter out None entries and track their indices
139
+ valid_indices = [i for i, inp in enumerate(batch_inputs) if inp is not None]
140
+ valid_batch_inputs = [inp for inp in batch_inputs if inp is not None]
141
+
142
+ # Batch inference (UNCHANGED from official batch script)
143
+ print(f"Running batch inference on {len(valid_batch_inputs)} images...")
144
+ outputs_list = llm.generate(
145
+ valid_batch_inputs,
146
+ sampling_params=sampling_params
147
+ )
148
+
149
+ # Extract results
150
+ all_markdown = ["[OCR FAILED]"] * len(dataset)
151
+ for idx, output in zip(valid_indices, outputs_list):
152
+ all_markdown[idx] = output.outputs[0].text.strip()
153
 
154
  # Add markdown column
155
  print("Adding markdown column...")
 
175
  "max_tokens": args.max_tokens,
176
  "max_model_len": args.max_model_len,
177
  "gpu_memory_utilization": args.gpu_memory_utilization,
178
+ "max_num_seqs": args.max_num_seqs,
179
  "script": "process_dataset.py",
180
+ "implementation": "vllm-batch (official deepseek batch code)",
181
  }
182
  existing_info.append(new_info)
183
 
 
206
  parser.add_argument("--max-model-len", type=int, default=8192)
207
  parser.add_argument("--max-tokens", type=int, default=8192)
208
  parser.add_argument("--gpu-memory-utilization", type=float, default=0.75)
209
+ parser.add_argument("--max-num-seqs", type=int, default=100, help="Max concurrent sequences")
210
+ parser.add_argument("--num-workers", type=int, default=64, help="Image preprocessing workers")
211
  parser.add_argument("--hf-token", help="HF API token")
212
  parser.add_argument("--private", action="store_true", help="Make output private")
213
 
214
  args = parser.parse_args()
215
+ main(args)