Spaces:
Running
Running
| """ | |
| HuggingFace Space for dots.ocr (GOT-OCR2_0) | |
| 高精度OCRモデルをAPIとして提供 | |
| """ | |
| import gradio as gr | |
| import torch | |
| import os | |
| import io | |
| import base64 | |
| import json | |
| import time | |
| from PIL import Image | |
| from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig | |
| import logging | |
| # ロギング設定 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # GPU使用可能性チェック | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"使用デバイス: {device}") | |
| # グローバル変数 | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """dots.ocrモデルを読み込み""" | |
| global model, tokenizer | |
| try: | |
| logger.info("dots.ocr (GOT-OCR2_0) モデルを読み込み中...") | |
| # 8bit量子化設定 | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_compute_dtype=torch.float16 | |
| ) | |
| # モデルとトークナイザーを読み込み(最大メモリ効率化) | |
| model = AutoModel.from_pretrained( | |
| 'ucaslcl/GOT-OCR2_0', | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| device_map='auto', | |
| use_safetensors=True, | |
| torch_dtype=torch.float16, # メモリ使用量を半減 | |
| quantization_config=quantization_config, # 現代的な量子化設定 | |
| pad_token_id=151643 | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| 'ucaslcl/GOT-OCR2_0', | |
| trust_remote_code=True | |
| ) | |
| logger.info("モデル読み込み完了") | |
| return True | |
| except Exception as e: | |
| logger.error(f"モデル読み込みエラー: {e}") | |
| return False | |
| def process_image(image, ocr_type="ocr", ocr_box="", ocr_color=""): | |
| """ | |
| 画像をOCR処理 | |
| Args: | |
| image: PIL Image または画像パス | |
| ocr_type: OCRタイプ("ocr", "format", "fine-grained") | |
| ocr_box: OCRボックス座標(オプション) | |
| ocr_color: OCR色指定(オプション) | |
| Returns: | |
| dict: OCR結果 | |
| """ | |
| global model, tokenizer | |
| start_time = time.time() | |
| try: | |
| # モデル未読み込みの場合は読み込み | |
| if model is None or tokenizer is None: | |
| if not load_model(): | |
| raise Exception("モデルの読み込みに失敗しました") | |
| # 画像処理 | |
| if isinstance(image, str): | |
| # Base64文字列の場合 | |
| if image.startswith('data:image'): | |
| image = image.split(',')[1] | |
| image_data = base64.b64decode(image) | |
| image = Image.open(io.BytesIO(image_data)) | |
| elif not isinstance(image, Image.Image): | |
| # その他の形式の場合はPIL Imageに変換 | |
| image = Image.open(image) | |
| # PIL ImageをRGB形式に変換 | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| logger.info(f"画像サイズ: {image.size}") | |
| # OCR処理実行 | |
| with torch.no_grad(): | |
| result = model.chat( | |
| tokenizer, | |
| image, | |
| ocr_type=ocr_type, | |
| ocr_box=ocr_box, | |
| ocr_color=ocr_color | |
| ) | |
| processing_time = time.time() - start_time | |
| logger.info(f"OCR処理完了: {processing_time:.2f}秒, 結果長: {len(result)}文字") | |
| return { | |
| "text": result, | |
| "confidence": 0.95, # dots.ocrは高精度なので固定値 | |
| "processing_time": processing_time, | |
| "model_used": "ucaslcl/GOT-OCR2_0", | |
| "device": str(device), | |
| "image_size": list(image.size) | |
| } | |
| except Exception as e: | |
| logger.error(f"OCR処理エラー: {e}") | |
| processing_time = time.time() - start_time | |
| return { | |
| "text": f"[エラー] OCR処理でエラーが発生しました: {str(e)}", | |
| "confidence": 0.0, | |
| "processing_time": processing_time, | |
| "model_used": "error", | |
| "device": str(device), | |
| "error": str(e) | |
| } | |
| def gradio_interface(image, ocr_type="ocr"): | |
| """Gradio用のインターフェース関数""" | |
| result = process_image(image, ocr_type=ocr_type) | |
| # 結果を整形して返す | |
| output_text = result["text"] | |
| # メタデータ情報を追加 | |
| metadata = f""" | |
| 処理時間: {result['processing_time']:.2f}秒 | |
| 信頼度: {result['confidence']:.1%} | |
| 使用モデル: {result['model_used']} | |
| デバイス: {result['device']} | |
| """ | |
| if 'image_size' in result: | |
| metadata += f"画像サイズ: {result['image_size'][0]}x{result['image_size'][1]}" | |
| return output_text, metadata, json.dumps(result, ensure_ascii=False, indent=2) | |
| def api_interface(image): | |
| """API用のインターフェース関数(JSON返却)""" | |
| result = process_image(image) | |
| return result | |
| # Gradio インターフェース設定 | |
| with gr.Blocks( | |
| title="dots.ocr (GOT-OCR2_0) - 高精度OCR API", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🔍 dots.ocr (GOT-OCR2_0) - 高精度OCR API | |
| 最先端の視覚言語モデルによる高精度OCR処理 | |
| - **多言語対応**: 日本語、英語、中国語など80以上の言語 | |
| - **レイアウト検出**: テキスト、テーブル、図表の構造認識 | |
| - **高精度**: 95%以上の認識精度 | |
| ## 使用方法 | |
| 1. 画像をアップロード | |
| 2. OCRタイプを選択 | |
| 3. 「処理開始」ボタンをクリック | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # 入力部分 | |
| image_input = gr.Image( | |
| type="pil", | |
| label="📷 画像をアップロード", | |
| height=400 | |
| ) | |
| ocr_type = gr.Dropdown( | |
| choices=["ocr", "format", "fine-grained"], | |
| value="ocr", | |
| label="🔧 OCRタイプ", | |
| info="ocr: 基本OCR, format: フォーマット保持, fine-grained: 詳細解析" | |
| ) | |
| process_btn = gr.Button("🚀 処理開始", variant="primary") | |
| with gr.Column(scale=2): | |
| # 出力部分 | |
| with gr.Tab("📄 テキスト結果"): | |
| text_output = gr.Textbox( | |
| label="抽出されたテキスト", | |
| lines=15, | |
| placeholder="ここに抽出されたテキストが表示されます..." | |
| ) | |
| with gr.Tab("📊 処理情報"): | |
| metadata_output = gr.Textbox( | |
| label="処理メタデータ", | |
| lines=8, | |
| placeholder="処理時間、信頼度などの情報が表示されます..." | |
| ) | |
| with gr.Tab("🔧 JSON結果"): | |
| json_output = gr.Code( | |
| label="完全なJSON結果", | |
| language="json" | |
| ) | |
| # 処理ボタンのイベント設定 | |
| process_btn.click( | |
| fn=gradio_interface, | |
| inputs=[image_input, ocr_type], | |
| outputs=[text_output, metadata_output, json_output] | |
| ) | |
| # API用のシンプルなエンドポイント(独立したInterface) | |
| with gr.Row(): | |
| gr.Markdown("# API Endpoint") | |
| with gr.Row(): | |
| gr.Markdown("このエンドポイントはプログラムからの呼び出し用です") | |
| # API専用のInterface | |
| api_image = gr.Image(type="pil", label="image") | |
| api_submit = gr.Button("Submit") | |
| api_output = gr.JSON(label="output") | |
| # API用の関数 | |
| api_submit.click( | |
| fn=api_interface, | |
| inputs=[api_image], | |
| outputs=[api_output], | |
| api_name="predict" | |
| ) | |
| # アプリケーション起動時にモデルを読み込み | |
| if __name__ == "__main__": | |
| logger.info("アプリケーション起動中...") | |
| # 環境情報表示 | |
| logger.info(f"PyTorch version: {torch.__version__}") | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"CUDA version: {torch.version.cuda}") | |
| logger.info(f"GPU count: {torch.cuda.device_count()}") | |
| for i in range(torch.cuda.device_count()): | |
| logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}") | |
| # モデル事前読み込み | |
| load_model() | |
| # Gradioアプリ起動 | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_api=True, | |
| show_error=True # エラー詳細表示を有効化 | |
| ) |