Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| import spaces | |
| import os | |
| import tempfile | |
| from PIL import Image | |
| # --- 1. Load Model and Tokenizer (Done only once at startup) --- | |
| print("Loading model and tokenizer...") | |
| model_name = "deepseek-ai/DeepSeek-OCR" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| # Load the model to CPU first; it will be moved to GPU during processing | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| _attn_implementation="flash_attention_2", | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| ) | |
| model = model.eval() | |
| print("β Model loaded successfully.") | |
| # --- 2. Main Processing Function --- | |
| def process_ocr_task(image, model_size, task_type, ref_text): | |
| """ | |
| Processes an image with DeepSeek-OCR for all supported tasks. | |
| Args: | |
| image (PIL.Image): The input image. | |
| model_size (str): The model size configuration. | |
| task_type (str): The type of OCR task to perform. | |
| ref_text (str): The reference text for the 'Locate' task. | |
| """ | |
| if image is None: | |
| return "Please upload an image first.", None | |
| # Move the model to GPU and use bfloat16 for better performance | |
| print("π Moving model to GPU...") | |
| model_gpu = model.cuda().to(torch.bfloat16) | |
| print("β Model is on GPU.") | |
| # Create a temporary directory to store files | |
| with tempfile.TemporaryDirectory() as output_path: | |
| # --- Build the prompt based on the selected task type --- | |
| if task_type == "π Free OCR": | |
| prompt = "<image>\nFree OCR." | |
| elif task_type == "π Convert to Markdown": | |
| prompt = "<image>\n<|grounding|>Convert the document to markdown." | |
| elif task_type == "π Parse Figure": | |
| prompt = "<image>\nParse the figure." | |
| elif task_type == "π Locate Object by Reference": | |
| if not ref_text or ref_text.strip() == "": | |
| raise gr.Error("For the 'Locate' task, you must provide the reference text to find!") | |
| # Use an f-string to embed the user's reference text into the prompt | |
| prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image." | |
| else: | |
| prompt = "<image>\nFree OCR." # Default fallback | |
| # Save the uploaded image to the temporary path | |
| temp_image_path = os.path.join(output_path, "temp_image.png") | |
| image.save(temp_image_path) | |
| # Configure model size parameters | |
| size_configs = { | |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, | |
| "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, | |
| } | |
| config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) | |
| print(f"π Running inference with prompt: {prompt}") | |
| # --- Run the model's inference method --- | |
| text_result = model_gpu.infer( | |
| tokenizer, | |
| prompt=prompt, | |
| image_file=temp_image_path, | |
| output_path=output_path, | |
| base_size=config["base_size"], | |
| image_size=config["image_size"], | |
| crop_mode=config["crop_mode"], | |
| save_results=True, # Important: Must be True to get the output image | |
| test_compress=True, | |
| eval_mode=True, | |
| ) | |
| print(f"====\nπ Text Result: {text_result}\n====") | |
| # --- Handle the output (both text and image) --- | |
| image_result_path = None | |
| # Tasks that generate a visual output usually create a 'grounding' or 'result' image | |
| if task_type in ["π Locate Object by Reference", "π Convert to Markdown", "π Parse Figure"]: | |
| # Find the result image in the output directory | |
| for filename in os.listdir(output_path): | |
| if "grounding" in filename or "result" in filename: | |
| image_result_path = os.path.join(output_path, filename) | |
| break | |
| # If an image was found, open it with PIL; otherwise, return None | |
| result_image_pil = Image.open(image_result_path) if image_result_path else None | |
| return text_result, result_image_pil | |
| # --- 3. Build the Gradio Interface --- | |
| with gr.Blocks(title="π³DeepSeek-OCRπ³", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π³ Full Demo of DeepSeek-OCR π³ | |
| Upload an image to explore the document recognition and understanding capabilities of DeepSeek-OCR. | |
| **π‘ How to use:** | |
| 1. **Upload an image** using the upload box. | |
| 2. Select a **Model Size**. `Gundam` is recommended for most documents for a good balance of speed and accuracy. | |
| 3. Choose a **Task Type**: | |
| - **π Free OCR**: Extracts raw text from the image. Best for simple text extraction. | |
| - **π Convert to Markdown**: Converts the entire document into Markdown format, preserving structure like headers, lists, and tables. | |
| - **π Parse Figure**: Analyzes and extracts structured data from charts, graphs, and geometric figures. | |
| - **π Locate Object by Reference**: Finds a specific object or piece of text in the image. You **must** type what you're looking for into the **"Reference Text"** box that appears. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="πΌοΈ Upload Image", sources=["upload", "clipboard"]) | |
| model_size = gr.Dropdown( | |
| choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], | |
| value="Gundam (Recommended)", | |
| label="βοΈ Model Size", | |
| ) | |
| task_type = gr.Dropdown( | |
| choices=["π Free OCR", "π Convert to Markdown", "π Parse Figure", "π Locate Object by Reference"], | |
| value="π Convert to Markdown", | |
| label="π Task Type", | |
| ) | |
| ref_text_input = gr.Textbox( | |
| label="π Reference Text (for Locate task)", | |
| placeholder="e.g., the teacher, 11-2=, a red car...", | |
| visible=False, # Initially hidden | |
| ) | |
| submit_btn = gr.Button("Process Image", variant="primary") | |
| with gr.Column(scale=2): | |
| output_text = gr.Textbox(label="π Text Result", lines=15, show_copy_button=True) | |
| output_image = gr.Image(label="πΌοΈ Image Result (if any)", type="pil") | |
| # --- UI Interaction Logic --- | |
| def toggle_ref_text_visibility(task): | |
| # If the user selects the 'Locate' task, make the reference textbox visible | |
| if task == "π Locate Object by Reference": | |
| return gr.Textbox(visible=True) | |
| else: | |
| return gr.Textbox(visible=False) | |
| # When the 'task_type' dropdown changes, call the function to update the visibility | |
| task_type.change( | |
| fn=toggle_ref_text_visibility, | |
| inputs=task_type, | |
| outputs=ref_text_input, | |
| ) | |
| # Define what happens when the submit button is clicked | |
| submit_btn.click( | |
| fn=process_ocr_task, | |
| inputs=[image_input, model_size, task_type, ref_text_input], | |
| outputs=[output_text, output_image], | |
| ) | |
| # --- Example Images and Tasks --- | |
| gr.Examples( | |
| examples=[ | |
| ["doc_markdown.png", "Gundam (Recommended)", "π Convert to Markdown", ""], | |
| ["chart.png", "Gundam (Recommended)", "π Parse Figure", ""], | |
| ["teacher.jpg", "Base", "π Locate Object by Reference", "the teacher"], | |
| ["math_locate.jpg", "Small", "π Locate Object by Reference", "20-10"], | |
| ["receipt.jpg", "Base", "π Free OCR", ""], | |
| ], | |
| inputs=[image_input, model_size, task_type, ref_text_input], | |
| outputs=[output_text, output_image], | |
| fn=process_ocr_task, | |
| cache_examples=False, # Disable caching to ensure examples run every time | |
| ) | |
| # --- 4. Launch the App --- | |
| if __name__ == "__main__": | |
| # Create an 'examples' directory if it doesn't exist | |
| if not os.path.exists("examples"): | |
| os.makedirs("examples") | |
| # Please manually download the example images into the "examples" folder. | |
| # e.g., doc_markdown.png, chart.png, teacher.png, math_locate.png, receipt.jpg | |
| demo.queue(max_size=20) | |
| demo.launch(share=True) # Set share=True to create a public link |