Spaces:
Running
Running
fix: tiktoken依存関係追加とAPI構造修正
Browse files- requirements.txtにtiktoken>=0.5.0を追加
- BitsAndBytesConfigで現代的な量子化設定に更新
- GradioのAPI構造を修正してocr_apiエンドポイントを適切に公開
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
- app.py +24 -8
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -11,7 +11,7 @@ import base64
|
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
from PIL import Image
|
| 14 |
-
from transformers import AutoModel, AutoTokenizer
|
| 15 |
import logging
|
| 16 |
|
| 17 |
# ロギング設定
|
|
@@ -33,6 +33,12 @@ def load_model():
|
|
| 33 |
try:
|
| 34 |
logger.info("dots.ocr (GOT-OCR2_0) モデルを読み込み中...")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# モデルとトークナイザーを読み込み(最大メモリ効率化)
|
| 37 |
model = AutoModel.from_pretrained(
|
| 38 |
'ucaslcl/GOT-OCR2_0',
|
|
@@ -41,7 +47,7 @@ def load_model():
|
|
| 41 |
device_map='auto',
|
| 42 |
use_safetensors=True,
|
| 43 |
torch_dtype=torch.float16, # メモリ使用量を半減
|
| 44 |
-
|
| 45 |
pad_token_id=151643
|
| 46 |
).eval()
|
| 47 |
|
|
@@ -226,13 +232,23 @@ with gr.Blocks(
|
|
| 226 |
outputs=[text_output, metadata_output, json_output]
|
| 227 |
)
|
| 228 |
|
| 229 |
-
# API
|
| 230 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
fn=api_interface,
|
| 232 |
-
inputs=
|
| 233 |
-
outputs=
|
| 234 |
-
title="API Endpoint",
|
| 235 |
-
description="このエンドポイントはプログラムからの呼び出し用です",
|
| 236 |
api_name="ocr_api"
|
| 237 |
)
|
| 238 |
|
|
|
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
from PIL import Image
|
| 14 |
+
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
|
| 15 |
import logging
|
| 16 |
|
| 17 |
# ロギング設定
|
|
|
|
| 33 |
try:
|
| 34 |
logger.info("dots.ocr (GOT-OCR2_0) モデルを読み込み中...")
|
| 35 |
|
| 36 |
+
# 8bit量子化設定
|
| 37 |
+
quantization_config = BitsAndBytesConfig(
|
| 38 |
+
load_in_8bit=True,
|
| 39 |
+
bnb_8bit_compute_dtype=torch.float16
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
# モデルとトークナイザーを読み込み(最大メモリ効率化)
|
| 43 |
model = AutoModel.from_pretrained(
|
| 44 |
'ucaslcl/GOT-OCR2_0',
|
|
|
|
| 47 |
device_map='auto',
|
| 48 |
use_safetensors=True,
|
| 49 |
torch_dtype=torch.float16, # メモリ使用量を半減
|
| 50 |
+
quantization_config=quantization_config, # 現代的な量子化設定
|
| 51 |
pad_token_id=151643
|
| 52 |
).eval()
|
| 53 |
|
|
|
|
| 232 |
outputs=[text_output, metadata_output, json_output]
|
| 233 |
)
|
| 234 |
|
| 235 |
+
# API用のシンプルなエンドポイント(独立したInterface)
|
| 236 |
+
with gr.Row():
|
| 237 |
+
gr.Markdown("# API Endpoint")
|
| 238 |
+
|
| 239 |
+
with gr.Row():
|
| 240 |
+
gr.Markdown("このエンドポイントはプログラムからの呼び出し用です")
|
| 241 |
+
|
| 242 |
+
# API専用のInterface
|
| 243 |
+
api_image = gr.Image(type="pil", label="image")
|
| 244 |
+
api_submit = gr.Button("Submit")
|
| 245 |
+
api_output = gr.JSON(label="output")
|
| 246 |
+
|
| 247 |
+
# API用の関数
|
| 248 |
+
api_submit.click(
|
| 249 |
fn=api_interface,
|
| 250 |
+
inputs=[api_image],
|
| 251 |
+
outputs=[api_output],
|
|
|
|
|
|
|
| 252 |
api_name="ocr_api"
|
| 253 |
)
|
| 254 |
|
requirements.txt
CHANGED
|
@@ -10,4 +10,5 @@ bitsandbytes>=0.41.0
|
|
| 10 |
scipy>=1.10.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
huggingface-hub>=0.17.0
|
| 13 |
-
verovio>=4.0.0
|
|
|
|
|
|
| 10 |
scipy>=1.10.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
huggingface-hub>=0.17.0
|
| 13 |
+
verovio>=4.0.0
|
| 14 |
+
tiktoken>=0.5.0
|