import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import json import re import os from typing import Dict, Any # System prompt (must match training) SYSTEM_PROMPT = """Determine if the message is a restaurant reservation request. If yes, extract the following three fields as strings: - "num_people": number of people (as a string, e.g., "4"). If not mentioned, use an empty string (""). - "reservation_date": the exact date/time phrase from the message (as a string, do not convert or interpret; e.g., keep "this Saturday at 7 PM" as is). If not mentioned, use an empty string (""). - "phone_num": the phone number (as a string, digits only, remove any hyphens or formatting; e.g., "0912345678"). If not mentioned, use an empty string (""). If the message is NOT a reservation request, return: ```json { "num_people": "", "reservation_date": "", "phone_num": "" } ``` Output must be valid JSON only, with exactly these three fields and no additional text, fields, or explanations. """ # Global variables for model caching model = None tokenizer = None def load_model(): """Load the model and tokenizer with caching""" global model, tokenizer if model is not None and tokenizer is not None: return model, tokenizer try: print("Loading model...") model_name = "Luigi/gemma-3-270m-it-dinercall-ner" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, device_map="auto", trust_remote_code=True ) # Set padding token if not set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Model loaded successfully!") return model, tokenizer except Exception as e: print(f"Error loading model: {e}") return None, None def validate_json(output: str) -> tuple: """Validate and extract JSON from model output - supports both plain JSON and code block formats""" try: # First, try to extract JSON from code blocks (new model version) json_match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', output) if json_match: json_str = json_match.group(1) else: # If no code block, look for JSON directly (old model version) json_match = re.search(r'\{[\s\S]*\}', output) if not json_match: return False, None, "No JSON found / 未找到JSON" json_str = json_match.group(0) # Fix common JSON issues for both formats # 1. Add quotes around phone numbers (they often start with 0) json_str = re.sub(r'("phone_num":\s*)(\d[-\d]*)', r'\1"\2"', json_str) # 2. Add quotes around num_people if it's a number json_str = re.sub(r'("num_people":\s*)(\d+)', r'\1"\2"', json_str) # 3. Fix trailing commas json_str = re.sub(r',\s*\}', '}', json_str) parsed = json.loads(json_str) return True, parsed, "Valid JSON / 有效的JSON" except json.JSONDecodeError: return False, None, "Invalid JSON format / 無效的JSON格式" except Exception: return False, None, "Error parsing JSON / 解析JSON時出錯" def extract_reservation_info(text: str): """Extract reservation information from text""" # Load model if not already loaded model, tokenizer = load_model() if model is None or tokenizer is None: return {"error": "Model not loaded, please refresh the page / 模型未加載成功,請刷新頁面重試"}, "" try: # Create chat template messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": text} ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", padding=True) # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=64, temperature=0.1, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=False, ) # Extract assistant's response prompt_length = len(inputs.input_ids[0]) assistant_output = tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True) # Validate and parse JSON is_valid, parsed, message = validate_json(assistant_output) if is_valid: return parsed, assistant_output else: return {"error": message}, assistant_output except Exception as e: return {"error": f"Processing error / 處理時出錯: {str(e)}"}, "" # Create Gradio interface def create_interface(): """Create the Gradio interface""" chinese_examples = [ "你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678", "週六下午三點,兩位,電話0987654321", "預約下週三中午12點半,5人用餐,聯絡電話0912345678", "我要訂位,3個人,今天下午6點" ] english_examples = [ "Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678", "Saturday 3 PM, 2 people, phone 0987654321", "Reservation for next Wednesday at 12:30 PM, 5 people, contact number 0912345678", "I want to make a reservation, 3 people, today at 6 PM" ] # Language-specific text dictionaries text_en = { "title": "🍽️ Restaurant Reservation Info Extractor", "description": "Use AI to automatically extract reservation information from messages", "input_label": "Input reservation message", "input_placeholder": "e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678", "button_text": "Extract Information", "json_label": "Extracted Result", "raw_label": "Raw Output", "instructions_title": "ℹ️ Instructions", "instructions": """**Supported information:** - 👥 Number of people (num_people) - 📅 Reservation date/time (reservation_date) - 📞 Phone number (phone_num) **Notes:** - First-time model loading may take a few minutes - If you encounter errors, try refreshing the page - The model outputs results in JSON format""", "footer": "Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner", "examples_title": "Examples", "chinese_examples": "Chinese Examples", "english_examples": "English Examples" } text_zh = { "title": "🍽️ 餐廳訂位資訊提取器", "description": "使用AI從中文訊息中自動提取訂位資訊", "input_label": "輸入訂位訊息", "input_placeholder": "例如: 你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678", "button_text": "提取資訊", "json_label": "提取結果", "raw_label": "原始輸出", "instructions_title": "ℹ️ 使用說明", "instructions": """**支援提取的資訊:** - 👥 人數 (num_people) - 📅 預訂日期/時間 (reservation_date) - 📞 電話號碼 (phone_num) **注意事項:** - 首次加載模型可能需要幾分鐘時間 - 如果遇到錯誤,請嘗試刷新頁面 - 模型會輸出JSON格式的結果""", "footer": "由 [Together AI](https://together.ai) 提供技術支持 | 模型: Luigi/gemma-3-270m-it-dinercall-ner", "examples_title": "示例", "chinese_examples": "中文示例", "english_examples": "英文示例" } with gr.Blocks( title="Restaurant Reservation Info Extractor", theme=gr.themes.Soft() ) as demo: # Language selector language = gr.Radio( choices=["English", "中文"], value="English", label="Language / 語言", interactive=True ) # Create components that will be updated based on language title_md = gr.Markdown("# 🍽️ Restaurant Reservation Info Extractor") description_md = gr.Markdown("Use AI to automatically extract reservation information from messages") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Input reservation message", placeholder="e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678", lines=3 ) submit_btn = gr.Button("Extract Information", variant="primary") examples_title_md = gr.Markdown("### Examples") chinese_examples_title_md = gr.Markdown("### Chinese Examples") gr.Examples( examples=chinese_examples, inputs=input_text, label="Chinese Examples" ) english_examples_title_md = gr.Markdown("### English Examples") gr.Examples( examples=english_examples, inputs=input_text, label="English Examples" ) with gr.Column(): json_output = gr.JSON(label="Extracted Result") raw_output = gr.Textbox( label="Raw Output", interactive=False, lines=3 ) # Info panel - Create the Accordion but don't use it as an output with gr.Accordion("ℹ️ Instructions", open=False) as instructions_accordion: instructions_md = gr.Markdown("""**Supported information:** - 👥 Number of people (num_people) - 📅 Reservation date/time (reservation_date) - 📞 Phone number (phone_num) **Notes:** - First-time model loading may take a few minutes - If you encounter errors, try refreshing the page - The model outputs results in JSON format""") # Footer footer_md = gr.Markdown("Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner") # Function to update interface based on language selection def update_interface(language): texts = text_en if language == "English" else text_zh return [ f"# {texts['title']}", # title_md texts['description'], # description_md gr.update(label=texts['input_label'], placeholder=texts['input_placeholder']), # input_text texts['button_text'], # submit_btn gr.update(label=texts['json_label']), # json_output gr.update(label=texts['raw_label']), # raw_output texts['instructions'], # instructions_md texts['footer'], # footer_md f"### {texts['examples_title']}", # examples_title_md f"### {texts['chinese_examples']}", # chinese_examples_title_md f"### {texts['english_examples']}" # english_examples_title_md ] # Connect the function to the button submit_btn.click( fn=extract_reservation_info, inputs=input_text, outputs=[json_output, raw_output] ) # Connect language selector to update interface - REMOVE ACCORDION FROM OUTPUTS language.change( fn=update_interface, inputs=language, outputs=[ title_md, description_md, input_text, submit_btn, json_output, raw_output, instructions_md, # Update only the Markdown inside Accordion footer_md, examples_title_md, chinese_examples_title_md, english_examples_title_md ] ) return demo # Create and launch the interface if __name__ == "__main__": # Pre-load the model when the app starts load_model() demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False )