Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import base64 | |
| from PIL import Image | |
| from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
| from vllm import LLM | |
| from vllm.sampling_params import SamplingParams | |
| # Hugging Face token from environment (optional) | |
| hf_token = os.getenv("HF_TOKEN") | |
| Image.MAX_IMAGE_PIXELS = None | |
| # Global placeholders (lazy-loaded later) | |
| llm = None | |
| ocr_model = None | |
| ocr_processor = None | |
| sampling_params = SamplingParams(max_tokens=5000) | |
| def load_prompt(): | |
| #with open("prompts/prompt.txt", "r", encoding="utf-8") as f: | |
| # return f.read() | |
| return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.") | |
| def try_extract_json(text): | |
| if not text or not text.strip(): | |
| return None | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| start = text.find('{') | |
| if start == -1: | |
| return None | |
| brace_count = 0 | |
| json_candidate = '' | |
| for i in range(start, len(text)): | |
| if text[i] == '{': | |
| brace_count += 1 | |
| elif text[i] == '}': | |
| brace_count -= 1 | |
| json_candidate += text[i] | |
| if brace_count == 0: | |
| break | |
| try: | |
| return json.loads(json_candidate) | |
| except json.JSONDecodeError: | |
| return None | |
| def encode_image_as_base64(pil_image): | |
| from io import BytesIO | |
| buffer = BytesIO() | |
| pil_image.save(buffer, format="JPEG") | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| def extract_all_text_pix2struct(image: Image.Image): | |
| global ocr_processor, ocr_model | |
| if ocr_processor is None or ocr_model is None: | |
| ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base") | |
| ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ocr_model = ocr_model.to(device) | |
| inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device) | |
| predictions = ocr_model.generate(**inputs, max_new_tokens=512) | |
| return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip() | |
| def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str): | |
| if not ocr_text or not json_data: | |
| return json_data | |
| def assign_best_guess(obj): | |
| if not obj.get("name") or obj["name"].strip() == "": | |
| obj["name"] = "(label unknown)" | |
| for evt in json_data.get("events", []): | |
| assign_best_guess(evt) | |
| for gw in json_data.get("gateways", []): | |
| assign_best_guess(gw) | |
| return json_data | |
| def run_model(image: Image.Image): | |
| global llm | |
| if llm is None: | |
| llm = LLM( | |
| model="mistralai/Pixtral-12B-2409", | |
| tokenizer_mode="mistral", | |
| dtype="bfloat16", | |
| max_model_len=30000, | |
| ) | |
| prompt = load_prompt() | |
| encoded_image = encode_image_as_base64(image) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}} | |
| ] | |
| } | |
| ] | |
| outputs = llm.chat(messages, sampling_params=sampling_params) | |
| raw_output = outputs[0].outputs[0].text | |
| parsed_json = try_extract_json(raw_output) | |
| # Apply OCR enrichment | |
| ocr_text = extract_all_text_pix2struct(image) | |
| parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text) | |
| return { | |
| "json": parsed_json, | |
| "raw": raw_output | |
| } | |