Luigi's picture
back to f32
6fed942 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
import os
from typing import Dict, Any
# System prompt (must match training)
SYSTEM_PROMPT = """Determine if the message is a restaurant reservation request.
If yes, extract the following three fields as strings:
- "num_people": number of people (as a string, e.g., "4"). If not mentioned, use an empty string ("").
- "reservation_date": the exact date/time phrase from the message (as a string, do not convert or interpret; e.g., keep "this Saturday at 7 PM" as is). If not mentioned, use an empty string ("").
- "phone_num": the phone number (as a string, digits only, remove any hyphens or formatting; e.g., "0912345678"). If not mentioned, use an empty string ("").
If the message is NOT a reservation request, return:
```json
{
"num_people": "",
"reservation_date": "",
"phone_num": ""
}
```
Output must be valid JSON only, with exactly these three fields and no additional text, fields, or explanations.
"""
# Global variables for model caching
model = None
tokenizer = None
def load_model():
"""Load the model and tokenizer with caching"""
global model, tokenizer
if model is not None and tokenizer is not None:
return model, tokenizer
try:
print("Loading model...")
model_name = "Luigi/gemma-3-270m-it-dinercall-ner"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True
)
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded successfully!")
return model, tokenizer
except Exception as e:
print(f"Error loading model: {e}")
return None, None
def validate_json(output: str) -> tuple:
"""Validate and extract JSON from model output - supports both plain JSON and code block formats"""
try:
# First, try to extract JSON from code blocks (new model version)
json_match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', output)
if json_match:
json_str = json_match.group(1)
else:
# If no code block, look for JSON directly (old model version)
json_match = re.search(r'\{[\s\S]*\}', output)
if not json_match:
return False, None, "No JSON found / 未找到JSON"
json_str = json_match.group(0)
# Fix common JSON issues for both formats
# 1. Add quotes around phone numbers (they often start with 0)
json_str = re.sub(r'("phone_num":\s*)(\d[-\d]*)', r'\1"\2"', json_str)
# 2. Add quotes around num_people if it's a number
json_str = re.sub(r'("num_people":\s*)(\d+)', r'\1"\2"', json_str)
# 3. Fix trailing commas
json_str = re.sub(r',\s*\}', '}', json_str)
parsed = json.loads(json_str)
return True, parsed, "Valid JSON / 有效的JSON"
except json.JSONDecodeError:
return False, None, "Invalid JSON format / 無效的JSON格式"
except Exception:
return False, None, "Error parsing JSON / 解析JSON時出錯"
def extract_reservation_info(text: str):
"""Extract reservation information from text"""
# Load model if not already loaded
model, tokenizer = load_model()
if model is None or tokenizer is None:
return {"error": "Model not loaded, please refresh the page / 模型未加載成功,請刷新頁面重試"}, ""
try:
# Create chat template
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": text}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=64,
temperature=0.1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
)
# Extract assistant's response
prompt_length = len(inputs.input_ids[0])
assistant_output = tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True)
# Validate and parse JSON
is_valid, parsed, message = validate_json(assistant_output)
if is_valid:
return parsed, assistant_output
else:
return {"error": message}, assistant_output
except Exception as e:
return {"error": f"Processing error / 處理時出錯: {str(e)}"}, ""
# Create Gradio interface
def create_interface():
"""Create the Gradio interface"""
chinese_examples = [
"你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678",
"週六下午三點,兩位,電話0987654321",
"預約下週三中午12點半,5人用餐,聯絡電話0912345678",
"我要訂位,3個人,今天下午6點"
]
english_examples = [
"Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678",
"Saturday 3 PM, 2 people, phone 0987654321",
"Reservation for next Wednesday at 12:30 PM, 5 people, contact number 0912345678",
"I want to make a reservation, 3 people, today at 6 PM"
]
# Language-specific text dictionaries
text_en = {
"title": "🍽️ Restaurant Reservation Info Extractor",
"description": "Use AI to automatically extract reservation information from messages",
"input_label": "Input reservation message",
"input_placeholder": "e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678",
"button_text": "Extract Information",
"json_label": "Extracted Result",
"raw_label": "Raw Output",
"instructions_title": "ℹ️ Instructions",
"instructions": """**Supported information:**
- 👥 Number of people (num_people)
- 📅 Reservation date/time (reservation_date)
- 📞 Phone number (phone_num)
**Notes:**
- First-time model loading may take a few minutes
- If you encounter errors, try refreshing the page
- The model outputs results in JSON format""",
"footer": "Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner",
"examples_title": "Examples",
"chinese_examples": "Chinese Examples",
"english_examples": "English Examples"
}
text_zh = {
"title": "🍽️ 餐廳訂位資訊提取器",
"description": "使用AI從中文訊息中自動提取訂位資訊",
"input_label": "輸入訂位訊息",
"input_placeholder": "例如: 你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678",
"button_text": "提取資訊",
"json_label": "提取結果",
"raw_label": "原始輸出",
"instructions_title": "ℹ️ 使用說明",
"instructions": """**支援提取的資訊:**
- 👥 人數 (num_people)
- 📅 預訂日期/時間 (reservation_date)
- 📞 電話號碼 (phone_num)
**注意事項:**
- 首次加載模型可能需要幾分鐘時間
- 如果遇到錯誤,請嘗試刷新頁面
- 模型會輸出JSON格式的結果""",
"footer": "由 [Together AI](https://together.ai) 提供技術支持 | 模型: Luigi/gemma-3-270m-it-dinercall-ner",
"examples_title": "示例",
"chinese_examples": "中文示例",
"english_examples": "英文示例"
}
with gr.Blocks(
title="Restaurant Reservation Info Extractor",
theme=gr.themes.Soft()
) as demo:
# Language selector
language = gr.Radio(
choices=["English", "中文"],
value="English",
label="Language / 語言",
interactive=True
)
# Create components that will be updated based on language
title_md = gr.Markdown("# 🍽️ Restaurant Reservation Info Extractor")
description_md = gr.Markdown("Use AI to automatically extract reservation information from messages")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input reservation message",
placeholder="e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678",
lines=3
)
submit_btn = gr.Button("Extract Information", variant="primary")
examples_title_md = gr.Markdown("### Examples")
chinese_examples_title_md = gr.Markdown("### Chinese Examples")
gr.Examples(
examples=chinese_examples,
inputs=input_text,
label="Chinese Examples"
)
english_examples_title_md = gr.Markdown("### English Examples")
gr.Examples(
examples=english_examples,
inputs=input_text,
label="English Examples"
)
with gr.Column():
json_output = gr.JSON(label="Extracted Result")
raw_output = gr.Textbox(
label="Raw Output",
interactive=False,
lines=3
)
# Info panel - Create the Accordion but don't use it as an output
with gr.Accordion("ℹ️ Instructions", open=False) as instructions_accordion:
instructions_md = gr.Markdown("""**Supported information:**
- 👥 Number of people (num_people)
- 📅 Reservation date/time (reservation_date)
- 📞 Phone number (phone_num)
**Notes:**
- First-time model loading may take a few minutes
- If you encounter errors, try refreshing the page
- The model outputs results in JSON format""")
# Footer
footer_md = gr.Markdown("Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner")
# Function to update interface based on language selection
def update_interface(language):
texts = text_en if language == "English" else text_zh
return [
f"# {texts['title']}", # title_md
texts['description'], # description_md
gr.update(label=texts['input_label'], placeholder=texts['input_placeholder']), # input_text
texts['button_text'], # submit_btn
gr.update(label=texts['json_label']), # json_output
gr.update(label=texts['raw_label']), # raw_output
texts['instructions'], # instructions_md
texts['footer'], # footer_md
f"### {texts['examples_title']}", # examples_title_md
f"### {texts['chinese_examples']}", # chinese_examples_title_md
f"### {texts['english_examples']}" # english_examples_title_md
]
# Connect the function to the button
submit_btn.click(
fn=extract_reservation_info,
inputs=input_text,
outputs=[json_output, raw_output]
)
# Connect language selector to update interface - REMOVE ACCORDION FROM OUTPUTS
language.change(
fn=update_interface,
inputs=language,
outputs=[
title_md,
description_md,
input_text,
submit_btn,
json_output,
raw_output,
instructions_md, # Update only the Markdown inside Accordion
footer_md,
examples_title_md,
chinese_examples_title_md,
english_examples_title_md
]
)
return demo
# Create and launch the interface
if __name__ == "__main__":
# Pre-load the model when the app starts
load_model()
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)