Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import json | |
| import re | |
| import os | |
| from typing import Dict, Any | |
| # System prompt (must match training) | |
| SYSTEM_PROMPT = """Determine if the message is a restaurant reservation request. | |
| If yes, extract the following three fields as strings: | |
| - "num_people": number of people (as a string, e.g., "4"). If not mentioned, use an empty string (""). | |
| - "reservation_date": the exact date/time phrase from the message (as a string, do not convert or interpret; e.g., keep "this Saturday at 7 PM" as is). If not mentioned, use an empty string (""). | |
| - "phone_num": the phone number (as a string, digits only, remove any hyphens or formatting; e.g., "0912345678"). If not mentioned, use an empty string (""). | |
| If the message is NOT a reservation request, return: | |
| ```json | |
| { | |
| "num_people": "", | |
| "reservation_date": "", | |
| "phone_num": "" | |
| } | |
| ``` | |
| Output must be valid JSON only, with exactly these three fields and no additional text, fields, or explanations. | |
| """ | |
| # Global variables for model caching | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the model and tokenizer with caching""" | |
| global model, tokenizer | |
| if model is not None and tokenizer is not None: | |
| return model, tokenizer | |
| try: | |
| print("Loading model...") | |
| model_name = "Luigi/gemma-3-270m-it-dinercall-ner" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # Set padding token if not set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Model loaded successfully!") | |
| return model, tokenizer | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None, None | |
| def validate_json(output: str) -> tuple: | |
| """Validate and extract JSON from model output - supports both plain JSON and code block formats""" | |
| try: | |
| # First, try to extract JSON from code blocks (new model version) | |
| json_match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', output) | |
| if json_match: | |
| json_str = json_match.group(1) | |
| else: | |
| # If no code block, look for JSON directly (old model version) | |
| json_match = re.search(r'\{[\s\S]*\}', output) | |
| if not json_match: | |
| return False, None, "No JSON found / 未找到JSON" | |
| json_str = json_match.group(0) | |
| # Fix common JSON issues for both formats | |
| # 1. Add quotes around phone numbers (they often start with 0) | |
| json_str = re.sub(r'("phone_num":\s*)(\d[-\d]*)', r'\1"\2"', json_str) | |
| # 2. Add quotes around num_people if it's a number | |
| json_str = re.sub(r'("num_people":\s*)(\d+)', r'\1"\2"', json_str) | |
| # 3. Fix trailing commas | |
| json_str = re.sub(r',\s*\}', '}', json_str) | |
| parsed = json.loads(json_str) | |
| return True, parsed, "Valid JSON / 有效的JSON" | |
| except json.JSONDecodeError: | |
| return False, None, "Invalid JSON format / 無效的JSON格式" | |
| except Exception: | |
| return False, None, "Error parsing JSON / 解析JSON時出錯" | |
| def extract_reservation_info(text: str): | |
| """Extract reservation information from text""" | |
| # Load model if not already loaded | |
| model, tokenizer = load_model() | |
| if model is None or tokenizer is None: | |
| return {"error": "Model not loaded, please refresh the page / 模型未加載成功,請刷新頁面重試"}, "" | |
| try: | |
| # Create chat template | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": text} | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=64, | |
| temperature=0.1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| do_sample=False, | |
| ) | |
| # Extract assistant's response | |
| prompt_length = len(inputs.input_ids[0]) | |
| assistant_output = tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True) | |
| # Validate and parse JSON | |
| is_valid, parsed, message = validate_json(assistant_output) | |
| if is_valid: | |
| return parsed, assistant_output | |
| else: | |
| return {"error": message}, assistant_output | |
| except Exception as e: | |
| return {"error": f"Processing error / 處理時出錯: {str(e)}"}, "" | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| chinese_examples = [ | |
| "你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678", | |
| "週六下午三點,兩位,電話0987654321", | |
| "預約下週三中午12點半,5人用餐,聯絡電話0912345678", | |
| "我要訂位,3個人,今天下午6點" | |
| ] | |
| english_examples = [ | |
| "Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678", | |
| "Saturday 3 PM, 2 people, phone 0987654321", | |
| "Reservation for next Wednesday at 12:30 PM, 5 people, contact number 0912345678", | |
| "I want to make a reservation, 3 people, today at 6 PM" | |
| ] | |
| # Language-specific text dictionaries | |
| text_en = { | |
| "title": "🍽️ Restaurant Reservation Info Extractor", | |
| "description": "Use AI to automatically extract reservation information from messages", | |
| "input_label": "Input reservation message", | |
| "input_placeholder": "e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678", | |
| "button_text": "Extract Information", | |
| "json_label": "Extracted Result", | |
| "raw_label": "Raw Output", | |
| "instructions_title": "ℹ️ Instructions", | |
| "instructions": """**Supported information:** | |
| - 👥 Number of people (num_people) | |
| - 📅 Reservation date/time (reservation_date) | |
| - 📞 Phone number (phone_num) | |
| **Notes:** | |
| - First-time model loading may take a few minutes | |
| - If you encounter errors, try refreshing the page | |
| - The model outputs results in JSON format""", | |
| "footer": "Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner", | |
| "examples_title": "Examples", | |
| "chinese_examples": "Chinese Examples", | |
| "english_examples": "English Examples" | |
| } | |
| text_zh = { | |
| "title": "🍽️ 餐廳訂位資訊提取器", | |
| "description": "使用AI從中文訊息中自動提取訂位資訊", | |
| "input_label": "輸入訂位訊息", | |
| "input_placeholder": "例如: 你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678", | |
| "button_text": "提取資訊", | |
| "json_label": "提取結果", | |
| "raw_label": "原始輸出", | |
| "instructions_title": "ℹ️ 使用說明", | |
| "instructions": """**支援提取的資訊:** | |
| - 👥 人數 (num_people) | |
| - 📅 預訂日期/時間 (reservation_date) | |
| - 📞 電話號碼 (phone_num) | |
| **注意事項:** | |
| - 首次加載模型可能需要幾分鐘時間 | |
| - 如果遇到錯誤,請嘗試刷新頁面 | |
| - 模型會輸出JSON格式的結果""", | |
| "footer": "由 [Together AI](https://together.ai) 提供技術支持 | 模型: Luigi/gemma-3-270m-it-dinercall-ner", | |
| "examples_title": "示例", | |
| "chinese_examples": "中文示例", | |
| "english_examples": "英文示例" | |
| } | |
| with gr.Blocks( | |
| title="Restaurant Reservation Info Extractor", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| # Language selector | |
| language = gr.Radio( | |
| choices=["English", "中文"], | |
| value="English", | |
| label="Language / 語言", | |
| interactive=True | |
| ) | |
| # Create components that will be updated based on language | |
| title_md = gr.Markdown("# 🍽️ Restaurant Reservation Info Extractor") | |
| description_md = gr.Markdown("Use AI to automatically extract reservation information from messages") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="Input reservation message", | |
| placeholder="e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("Extract Information", variant="primary") | |
| examples_title_md = gr.Markdown("### Examples") | |
| chinese_examples_title_md = gr.Markdown("### Chinese Examples") | |
| gr.Examples( | |
| examples=chinese_examples, | |
| inputs=input_text, | |
| label="Chinese Examples" | |
| ) | |
| english_examples_title_md = gr.Markdown("### English Examples") | |
| gr.Examples( | |
| examples=english_examples, | |
| inputs=input_text, | |
| label="English Examples" | |
| ) | |
| with gr.Column(): | |
| json_output = gr.JSON(label="Extracted Result") | |
| raw_output = gr.Textbox( | |
| label="Raw Output", | |
| interactive=False, | |
| lines=3 | |
| ) | |
| # Info panel - Create the Accordion but don't use it as an output | |
| with gr.Accordion("ℹ️ Instructions", open=False) as instructions_accordion: | |
| instructions_md = gr.Markdown("""**Supported information:** | |
| - 👥 Number of people (num_people) | |
| - 📅 Reservation date/time (reservation_date) | |
| - 📞 Phone number (phone_num) | |
| **Notes:** | |
| - First-time model loading may take a few minutes | |
| - If you encounter errors, try refreshing the page | |
| - The model outputs results in JSON format""") | |
| # Footer | |
| footer_md = gr.Markdown("Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner") | |
| # Function to update interface based on language selection | |
| def update_interface(language): | |
| texts = text_en if language == "English" else text_zh | |
| return [ | |
| f"# {texts['title']}", # title_md | |
| texts['description'], # description_md | |
| gr.update(label=texts['input_label'], placeholder=texts['input_placeholder']), # input_text | |
| texts['button_text'], # submit_btn | |
| gr.update(label=texts['json_label']), # json_output | |
| gr.update(label=texts['raw_label']), # raw_output | |
| texts['instructions'], # instructions_md | |
| texts['footer'], # footer_md | |
| f"### {texts['examples_title']}", # examples_title_md | |
| f"### {texts['chinese_examples']}", # chinese_examples_title_md | |
| f"### {texts['english_examples']}" # english_examples_title_md | |
| ] | |
| # Connect the function to the button | |
| submit_btn.click( | |
| fn=extract_reservation_info, | |
| inputs=input_text, | |
| outputs=[json_output, raw_output] | |
| ) | |
| # Connect language selector to update interface - REMOVE ACCORDION FROM OUTPUTS | |
| language.change( | |
| fn=update_interface, | |
| inputs=language, | |
| outputs=[ | |
| title_md, | |
| description_md, | |
| input_text, | |
| submit_btn, | |
| json_output, | |
| raw_output, | |
| instructions_md, # Update only the Markdown inside Accordion | |
| footer_md, | |
| examples_title_md, | |
| chinese_examples_title_md, | |
| english_examples_title_md | |
| ] | |
| ) | |
| return demo | |
| # Create and launch the interface | |
| if __name__ == "__main__": | |
| # Pre-load the model when the app starts | |
| load_model() | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |