Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration | |
| import re | |
| # Check if CUDA is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| # Check if Flash Attention 2 is available | |
| def is_flash_attention_available(): | |
| try: | |
| import flash_attn | |
| return True | |
| except ImportError: | |
| return False | |
| # Initialize models and processors lazily | |
| base_model = None | |
| base_processor = None | |
| chat_model = None | |
| chat_processor = None | |
| def load_base_model(): | |
| global base_model, base_processor | |
| if base_model is None: | |
| base_repo = "microsoft/kosmos-2.5" | |
| # Use Flash Attention 2 if available, otherwise use default attention | |
| model_kwargs = { | |
| "device_map": "cuda", | |
| "dtype": dtype, | |
| } | |
| if is_flash_attention_available(): | |
| model_kwargs["attn_implementation"] = "flash_attention_2" | |
| base_model = Kosmos2_5ForConditionalGeneration.from_pretrained( | |
| base_repo, | |
| **model_kwargs | |
| ) | |
| base_processor = AutoProcessor.from_pretrained(base_repo) | |
| return base_model, base_processor | |
| def load_chat_model(): | |
| global chat_model, chat_processor | |
| if chat_model is None: | |
| chat_repo = "microsoft/kosmos-2.5-chat" | |
| # Use Flash Attention 2 if available, otherwise use default attention | |
| model_kwargs = { | |
| "device_map": "cuda", | |
| "dtype": dtype, | |
| } | |
| if is_flash_attention_available(): | |
| model_kwargs["attn_implementation"] = "flash_attention_2" | |
| chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained( | |
| chat_repo, | |
| **model_kwargs | |
| ) | |
| chat_processor = AutoProcessor.from_pretrained(chat_repo) | |
| return chat_model, chat_processor | |
| def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"): | |
| y = y.replace(prompt, "") | |
| if "<md>" in prompt: | |
| return y | |
| pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>" | |
| bboxs_raw = re.findall(pattern, y) | |
| lines = re.split(pattern, y)[1:] | |
| bboxs = [re.findall(r"\d+", i) for i in bboxs_raw] | |
| bboxs = [[int(j) for j in i] for i in bboxs] | |
| info = "" | |
| for i in range(len(lines)): | |
| if i < len(bboxs): | |
| box = bboxs[i] | |
| x0, y0, x1, y1 = box | |
| if not (x0 >= x1 or y0 >= y1): | |
| x0 = int(x0 * scale_width) | |
| y0 = int(y0 * scale_height) | |
| x1 = int(x1 * scale_width) | |
| y1 = int(y1 * scale_height) | |
| info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n" | |
| return info.strip() | |
| def generate_markdown(image): | |
| if image is None: | |
| return "Please upload an image." | |
| model, processor = load_base_model() | |
| prompt = "<md>" | |
| inputs = processor(text=prompt, images=image, return_tensors="pt") | |
| height, width = inputs.pop("height"), inputs.pop("width") | |
| raw_width, raw_height = image.size | |
| scale_height = raw_height / height | |
| scale_width = raw_width / width | |
| inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} | |
| inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| result = generated_text[0].replace(prompt, "").strip() | |
| return result | |
| def generate_ocr(image): | |
| if image is None: | |
| return "Please upload an image.", None | |
| model, processor = load_base_model() | |
| prompt = "<ocr>" | |
| inputs = processor(text=prompt, images=image, return_tensors="pt") | |
| height, width = inputs.pop("height"), inputs.pop("width") | |
| raw_width, raw_height = image.size | |
| scale_height = raw_height / height | |
| scale_width = raw_width / width | |
| inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} | |
| inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| # Post-process OCR output | |
| output_text = post_process_ocr(generated_text[0], scale_height, scale_width) | |
| # Create visualization | |
| from PIL import ImageDraw | |
| vis_image = image.copy() | |
| draw = ImageDraw.Draw(vis_image) | |
| lines = output_text.split("\n") | |
| for line in lines: | |
| if not line.strip(): | |
| continue | |
| parts = line.split(",") | |
| if len(parts) >= 8: | |
| try: | |
| coords = list(map(int, parts[:8])) | |
| draw.polygon(coords, outline="red", width=2) | |
| except: | |
| continue | |
| return output_text, vis_image | |
| def generate_chat_response(image, question): | |
| if image is None: | |
| return "Please upload an image." | |
| if not question.strip(): | |
| return "Please ask a question." | |
| model, processor = load_chat_model() | |
| template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" | |
| prompt = template.format(question) | |
| inputs = processor(text=prompt, images=image, return_tensors="pt") | |
| height, width = inputs.pop("height"), inputs.pop("width") | |
| raw_width, raw_height = image.size | |
| scale_height = raw_height / height | |
| scale_width = raw_width / width | |
| inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} | |
| inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| result = generated_text[0] | |
| if "ASSISTANT:" in result: | |
| result = result.split("ASSISTANT:")[-1].strip() | |
| return result | |
| # Create Gradio interface | |
| with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # KOSMOS-2.5 Document AI Demo | |
| Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images! | |
| This demo showcases three capabilities: | |
| 1. **Markdown Generation**: Convert document images to markdown format | |
| 2. **OCR with Bounding Boxes**: Extract text with spatial coordinates | |
| 3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat | |
| Upload a document image (receipt, form, article, etc.) and try different tasks! | |
| """) | |
| with gr.Tabs(): | |
| # Markdown Generation Tab | |
| with gr.TabItem("π Markdown Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| md_image = gr.Image(type="pil", label="Upload Document Image") | |
| gr.Examples( | |
| examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], | |
| inputs=md_image | |
| ) | |
| md_button = gr.Button("Generate Markdown", variant="primary") | |
| with gr.Column(): | |
| md_output = gr.Textbox( | |
| label="Generated Markdown", | |
| lines=15, | |
| max_lines=20, | |
| show_copy_button=True | |
| ) | |
| # OCR Tab | |
| with gr.TabItem("π OCR with Bounding Boxes"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| ocr_image = gr.Image(type="pil", label="Upload Document Image") | |
| gr.Examples( | |
| examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], | |
| inputs=ocr_image | |
| ) | |
| ocr_button = gr.Button("Extract Text with Coordinates", variant="primary") | |
| with gr.Column(): | |
| with gr.Row(): | |
| ocr_text = gr.Textbox( | |
| label="Extracted Text with Coordinates", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)") | |
| # Chat Tab | |
| with gr.TabItem("π¬ Document Q&A (Chat)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| chat_image = gr.Image(type="pil", label="Upload Document Image") | |
| gr.Examples( | |
| examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], | |
| inputs=chat_image | |
| ) | |
| chat_question = gr.Textbox( | |
| label="Ask a question about the document", | |
| placeholder="e.g., What is the total amount on this receipt?", | |
| lines=2 | |
| ) | |
| gr.Examples( | |
| examples=["What is the total amount on this receipt?", "What items were purchased?", "When was this receipt issued?", "What is the subtotal?"], | |
| inputs=chat_question | |
| ) | |
| chat_button = gr.Button("Get Answer", variant="primary") | |
| with gr.Column(): | |
| chat_output = gr.Textbox( | |
| label="Answer", | |
| lines=8, | |
| show_copy_button=True | |
| ) | |
| # Event handlers | |
| md_button.click( | |
| fn=generate_markdown, | |
| inputs=[md_image], | |
| outputs=[md_output] | |
| ) | |
| ocr_button.click( | |
| fn=generate_ocr, | |
| inputs=[ocr_image], | |
| outputs=[ocr_text, ocr_vis] | |
| ) | |
| chat_button.click( | |
| fn=generate_chat_response, | |
| inputs=[chat_image, chat_question], | |
| outputs=[chat_output] | |
| ) | |
| # Examples section | |
| gr.Markdown(""" | |
| ## Example Use Cases: | |
| - **Receipts**: Extract itemized information or ask about totals | |
| - **Forms**: Convert to structured format or answer specific questions | |
| - **Articles**: Get markdown format or ask about content | |
| - **Screenshots**: Extract text or get information about specific elements | |
| ## Note: | |
| This is a generative model and may occasionally hallucinate. Results should be verified for accuracy. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |