m97j's picture
Initial commit
f2a7503
import json, torch
from fastapi import Request
from manager.agent_manager import agent_manager
from models.fallback_model import generate_fallback_response
from utils.context_parser import ContextParser
from sentence_transformers import util
def _short_history(context: dict, max_turns: int = 3) -> list:
short_history = []
for h in context.get("dialogue_history", [])[-max_turns:]:
if "player" in h and "npc" in h:
short_history.append({"role": "player", "text": h["player"]})
short_history.append({"role": "npc", "text": h["npc"]})
return short_history
# def _load_forbidden_trigger_data(npc_id: str) -> dict:
# docs = retrieve(f"{npc_id}:forbidden_trigger_list", filters={"npc_id": npc_id}, top_k=1)
# if not docs:
# return {}
# try:
# return json.loads(docs[0]) if isinstance(docs[0], str) else docs[0]
# except Exception:
# return {}
def _semantic_match_embedder(embedder, user_input: str, trigger_texts: list, threshold: float = 0.75):
if not trigger_texts:
return (False, 0.0, None)
inp_emb = embedder.encode(user_input, convert_to_tensor=True)
trg_embs = embedder.encode(trigger_texts, convert_to_tensor=True)
cos_scores = util.cos_sim(inp_emb, trg_embs).squeeze(0)
max_score, idx = torch.max(cos_scores, dim=0)
score_val = float(max_score.item())
matched_text = trigger_texts[int(idx.item())]
return (score_val >= threshold, score_val, matched_text)
async def extract_emotion_via_fallback(request: Request, user_input: str) -> str:
prompt = (
"๋‹ค์Œ ๋ฌธ์žฅ์˜ ํ™”์ž ๊ฐ์ •์„ ํ•œ ๋‹จ์–ด ๋˜๋Š” ์งง์€ ๋ฌธ์žฅ์œผ๋กœ ์„ค๋ช…ํ•˜์‹œ์˜ค.\n\n"
f"[๋ฌธ์žฅ]\n{user_input}\n\n"
"์ง€์‹œ:\n- ๊ฐ์ •์„ ์ง์ ‘์ ์œผ๋กœ ํ‘œํ˜„ํ•˜์ง€ ์•Š์•„๋„ ๋ฌธ๋งฅ์„ ํ†ตํ•ด ์ถ”๋ก ํ•˜์‹œ์˜ค.\n"
"- ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ๊ฐ์ •์˜ ๊ฐ•๋„๋‚˜ ๋‰˜์•™์Šค๋„ ๋ฐ˜์˜ํ•˜์‹œ์˜ค.\n"
"- ์˜ˆ: ๋ถ„๋…ธ, ์Šฌํ””, ํ˜ผ๋ž€, ๊ธฐ๋Œ€, ๋ฌด๊ด€์‹ฌ, ์ดˆ์กฐํ•จ ๋“ฑ\n"
"- ๋‹จ์–ด ํ•˜๋‚˜ ๋˜๋Š” ์งง์€ ๋ฌธ์žฅ์œผ๋กœ๋งŒ ์ถœ๋ ฅํ•˜์‹œ์˜ค.\n\n"
"์ •๋‹ต:"
)
response = await generate_fallback_response(request, prompt)
return response.strip()
async def _llm_trigger_check(request: Request, user_input: str, label_list: list) -> bool:
if not label_list:
return False
criteria_block = "\n".join(f"- {c}" for c in label_list)
prompt = (
"๋‹ค์Œ์€ ์˜๋ฏธ ๋น„๊ต๋ฅผ ์œ„ํ•œ ํŒ๋‹จ ๊ธฐ์ค€๊ณผ ๊ฒ€์‚ฌ ๋Œ€์ƒ์ž…๋‹ˆ๋‹ค.\n\n"
"[CRITERIA]\n"
f"{criteria_block}\n"
"[/CRITERIA]\n\n"
"[INPUT]\n"
f"{user_input}\n"
"[/INPUT]\n\n"
"์ง€์‹œ:\n"
"- [INPUT] ๋‚ด์šฉ์ด [CRITERIA] ํ•ญ๋ชฉ ์ค‘ ํ•˜๋‚˜์™€ ์˜๋ฏธ๊ฐ€ ๊ฐ™๊ฑฐ๋‚˜ ์œ ์‚ฌํ•˜๋ฉด YES, ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด NO๋งŒ ์ถœ๋ ฅํ•˜์‹œ์˜ค.\n"
"- ๋‹จ์–ด ๊ทธ๋Œ€๋กœ ํฌํ•จ๋˜์ง€ ์•Š์•„๋„ ์˜๋ฏธ๊ฐ€ ์œ ์‚ฌํ•˜๋ฉด YES๋กœ ๊ฐ„์ฃผํ•˜์‹œ์˜ค.\n"
"- ํ™•์‹ ์ด ์—†๊ฑฐ๋‚˜ ํŒ๋‹จ์ด ์• ๋งคํ•˜๋ฉด NO๋ฅผ ์ถœ๋ ฅํ•˜์‹œ์˜ค.\n\n"
"์ •๋‹ต:"
)
txt = await generate_fallback_response(request, prompt)
ans = txt.strip().upper()
normalized = ans.replace(".", "").replace("!", "").strip()
return (
normalized == "YES" or
normalized == "Y" or
normalized.startswith("YES") or
normalized.startswith("Y") or
normalized.startswith("์˜ˆ") or
normalized.startswith("๋„ค")
)
async def preprocess_input(
request: Request,
session_id: str,
npc_id: str,
user_input: str,
context: dict
) -> dict:
parser = ContextParser(context)
emotion = await extract_emotion_via_fallback(request, user_input)
require_items = context.get("require", {}).get("items", [])
require_actions = context.get("require", {}).get("actions", [])
require_game_state = context.get("require", {}).get("game_state", [])
require_delta = context.get("require", {}).get("delta", {})
quest_stage = parser.game.get("quest_stage", "default")
location = parser.game.get("location", context.get("location", "unknown"))
# --- RAG bundle ๋กœ๋“œ ---
agent = agent_manager.get_agent(npc_id)
bundle = agent.load_rag_bundle(quest_stage, location)
# === 1์ฐจ ๊ฒ€์‚ฌ: trigger_def ๊ธฐ๋ฐ˜ ===
td_docs = bundle.get("trigger_def", [])
if td_docs:
td = td_docs[0]
trig = td.get("trigger", {})
text_ok = not trig.get("required_text") or any(t in user_input for t in trig["required_text"])
items_ok = not trig.get("required_items", {}).get("mandatory") or set(trig["required_items"]["mandatory"]).issubset(set(require_items))
actions_ok = not trig.get("required_actions", {}).get("mandatory") or set(trig["required_actions"]["mandatory"]).issubset(set(require_actions))
gs_ok = not trig.get("required_game_state", {}).get("mandatory") or set(trig["required_game_state"]["mandatory"]).issubset(set(require_game_state))
delta_ok = all(require_delta.get(k, 0) >= v for k, v in trig.get("required_delta", {}).get("mandatory", {}).items())
if text_ok and items_ok and actions_ok and gs_ok and delta_ok:
return {
"session_id": session_id,
"player_utterance": user_input,
"npc_id": npc_id,
"tags": parser.npc,
"player_state": parser.player,
"game_state": parser.game,
"context": _short_history(context),
"emotion": emotion,
"triggers": trig,
"is_valid": True,
"additional_trigger": None,
"rag_main_docs": (
td_docs
+ bundle.get("lore", [])
+ bundle.get("description", [])
+ bundle.get("npc_persona", [])
+ bundle.get("dialogue_turn", [])
+ bundle.get("flag_def", [])
+ bundle.get("main_res_validate", [])
),
"rag_fallback_docs": bundle.get("fallback", []) + bundle.get("npc_persona", []),
"trigger_meta": {}
}
# === 2์ฐจ ๊ฒ€์‚ฌ: forbidden-trigger ๊ธฐ๋ฐ˜ ===
forbidden_data = bundle.get("forbidden_trigger_list", [{}])[0]
keywords = forbidden_data.get("triggers", {}).get("keywords", [])
trigger_texts = forbidden_data.get("triggers", {}).get("text", [])
embedder = request.app.state.embedder
matched_key = None
confidence = 0.0
kw_match = None
txt_match = None
# 1. keyword ์œ ์‚ฌ๋„ ๊ฒ€์‚ฌ
kw_hit, kw_score, kw_match = _semantic_match_embedder(embedder, user_input, keywords, threshold=0.75)
# 2. text ์œ ์‚ฌ๋„ ๊ฒ€์‚ฌ
txt_hit, txt_score, txt_match = _semantic_match_embedder(embedder, user_input, trigger_texts, threshold=0.75)
# 3. ์œ ์‚ฌ๋„ ๋†’์€ ์ชฝ ์„ ํƒ
if kw_hit and (kw_score >= txt_score):
matched_key = "keyword_match"
confidence = kw_score
elif txt_hit:
matched_key = "text_match"
confidence = txt_score
elif max(kw_score, txt_score) >= 0.65:
# ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด keyword์™€ text๋งŒ label ํ›„๋ณด๋กœ ์ „๋‹ฌ
label_candidates = []
if kw_match:
label_candidates.append(kw_match)
if txt_match:
label_candidates.append(txt_match)
if await _llm_trigger_check(request, user_input, label_candidates):
matched_key = "semantic_match_llm"
confidence = max(kw_score, txt_score)
# === trigger_meta ๋งค์นญ ๋ณด์ • ===
actual_trigger = None
if matched_key:
# kw_match๋‚˜ txt_match ๊ฐ’์ด ์‹ค์ œ trigger_meta.trigger ๊ฐ’๊ณผ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธ
for tm in bundle.get("trigger_meta", []):
if tm.get("trigger") in (kw_match, txt_match):
actual_trigger = tm.get("trigger")
break
trigger_meta = {}
if actual_trigger:
trigger_meta = next((tm for tm in bundle.get("trigger_meta", []) if tm.get("trigger") == actual_trigger), {})
trigger_meta["confidence"] = confidence
additional_trigger = bool(actual_trigger)
return {
"session_id": session_id,
"player_utterance": user_input,
"npc_id": npc_id,
"tags": parser.npc,
"player_state": parser.player,
"game_state": parser.game,
"context": _short_history(context),
"emotion": emotion,
"triggers": [],
"is_valid": False,
"additional_trigger": additional_trigger,
"rag_main_docs": (
bundle.get("lore", [])
+ bundle.get("description", [])
+ bundle.get("npc_persona", [])
+ bundle.get("dialogue_turn", [])
+ bundle.get("flag_def", [])
+ bundle.get("main_res_validate", [])
),
"rag_fallback_docs": bundle.get("fallback", []) + bundle.get("npc_persona", []),
"trigger_meta": trigger_meta
}