Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| import spaces | |
| import os | |
| import tempfile | |
| from PIL import Image | |
| # Load model and tokenizer | |
| model_name = "deepseek-ai/DeepSeek-OCR" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| _attn_implementation="flash_attention_2", | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| ) | |
| model = model.eval() | |
| def process_image(image, model_size, task_type, is_eval_mode): | |
| """ | |
| Process image with DeepSeek-OCR and return multiple output formats. | |
| Args: | |
| image: PIL Image or file path | |
| model_size: Model size configuration | |
| task_type: OCR task type | |
| Returns: | |
| A tuple containing: | |
| - Path to the image with bounding boxes. | |
| - The content of the markdown result file. | |
| - The plain text OCR result. | |
| """ | |
| if image is None: | |
| return None, "Please upload an image first.", "Please upload an image first." | |
| model_gpu = model.cuda().to(torch.bfloat16) | |
| # Create temporary directory for output | |
| with tempfile.TemporaryDirectory() as output_path: | |
| # Set prompt based on task type | |
| if task_type == "Free OCR": | |
| prompt = "<image>\nFree OCR. " | |
| elif task_type == "Convert to Markdown": | |
| prompt = "<image>\n<|grounding|>Convert the document to markdown. " | |
| else: | |
| prompt = "<image>\nFree OCR. " | |
| # Save uploaded image temporarily | |
| temp_image_path = os.path.join(output_path, "temp_image.jpg") | |
| image.save(temp_image_path) | |
| # Configure model size parameters | |
| size_configs = { | |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, | |
| "Gundam (Recommended)": { | |
| "base_size": 1024, | |
| "image_size": 640, | |
| "crop_mode": True, | |
| }, | |
| } | |
| config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) | |
| # Run inference | |
| plain_text_result = model_gpu.infer( | |
| tokenizer, | |
| prompt=prompt, | |
| image_file=temp_image_path, | |
| output_path=output_path, | |
| base_size=config["base_size"], | |
| image_size=config["image_size"], | |
| crop_mode=config["crop_mode"], | |
| save_results=True, # Ensure results are saved to disk | |
| test_compress=True, | |
| eval_mode=is_eval_mode, | |
| ) | |
| # Define paths for the generated files | |
| image_result_path = os.path.join(output_path, "result_with_boxes.jpg") | |
| markdown_result_path = os.path.join(output_path, "result.mmd") | |
| # Read the markdown file content if it exists | |
| markdown_content = "" | |
| if os.path.exists(markdown_result_path): | |
| with open(markdown_result_path, "r", encoding="utf-8") as f: | |
| markdown_content = f.read() | |
| else: | |
| markdown_content = "Markdown result was not generated. This is expected for 'Free OCR' task." | |
| result_image = None | |
| # Check if the annotated image exists | |
| if os.path.exists(image_result_path): | |
| result_image = Image.open(image_result_path) | |
| result_image.load() | |
| # Return all three results. Gradio will handle the temporary file path for the image. | |
| text_result = plain_text_result if plain_text_result else markdown_content | |
| return result_image, markdown_content, text_result | |
| # Create Gradio interface | |
| with gr.Blocks(title="DeepSeek-OCR", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # DeepSeek-OCR Demo | |
| Upload an image to extract text using DeepSeek-OCR model. | |
| Supports various document types and handwriting recognition. | |
| **Model Sizes:** | |
| - **Tiny**: Fastest, lower accuracy (512x512) | |
| - **Small**: Fast, good accuracy (640x640) | |
| - **Base**: Balanced performance (1024x1024) | |
| - **Large**: Best accuracy, slower (1280x1280) | |
| - **Gundam (Recommended)**: Optimized for documents (1024 base, 640 image, crop mode) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", label="Upload Image", sources=["upload", "clipboard"] | |
| ) | |
| model_size = gr.Dropdown( | |
| choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], | |
| value="Gundam (Recommended)", | |
| label="Model Size", | |
| ) | |
| task_type = gr.Dropdown( | |
| choices=["Free OCR", "Convert to Markdown"], | |
| value="Convert to Markdown", | |
| label="Task Type", | |
| ) | |
| eval_mode_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Enable Evaluation Mode", | |
| info="Returns only plain text, but might be faster. Uncheck to get annotated image and markdown.", | |
| ) | |
| submit_btn = gr.Button("Process Image", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("Annotated Image"): | |
| output_image = gr.Image( | |
| interactive=False | |
| ) | |
| with gr.TabItem("Markdown Preview"): | |
| output_markdown = gr.Markdown() | |
| with gr.TabItem("Markdown Source(or Eval Output)"): | |
| output_text = gr.Textbox( | |
| lines=20, | |
| show_copy_button=True, | |
| interactive=False, | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["examples/math.png", "Gundam (Recommended)", "Convert to Markdown"], | |
| ["examples/receipt.jpg", "Base", "Convert to Markdown"], | |
| ["examples/receipt-2.png", "Base", "Convert to Markdown"], | |
| ], | |
| inputs=[image_input, model_size, task_type, eval_mode_checkbox], | |
| outputs=[output_image, output_markdown, output_text], | |
| fn=process_image, | |
| cache_examples=True, | |
| ) | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[image_input, model_size, task_type, eval_mode_checkbox], | |
| outputs=[output_image, output_markdown, output_text], | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20) | |
| demo.launch() | |