Spaces:
Sleeping
Sleeping
| import json, torch | |
| from fastapi import Request | |
| from manager.agent_manager import agent_manager | |
| from models.fallback_model import generate_fallback_response | |
| from utils.context_parser import ContextParser | |
| from sentence_transformers import util | |
| def _short_history(context: dict, max_turns: int = 3) -> list: | |
| short_history = [] | |
| for h in context.get("dialogue_history", [])[-max_turns:]: | |
| if "player" in h and "npc" in h: | |
| short_history.append({"role": "player", "text": h["player"]}) | |
| short_history.append({"role": "npc", "text": h["npc"]}) | |
| return short_history | |
| # def _load_forbidden_trigger_data(npc_id: str) -> dict: | |
| # docs = retrieve(f"{npc_id}:forbidden_trigger_list", filters={"npc_id": npc_id}, top_k=1) | |
| # if not docs: | |
| # return {} | |
| # try: | |
| # return json.loads(docs[0]) if isinstance(docs[0], str) else docs[0] | |
| # except Exception: | |
| # return {} | |
| def _semantic_match_embedder(embedder, user_input: str, trigger_texts: list, threshold: float = 0.75): | |
| if not trigger_texts: | |
| return (False, 0.0, None) | |
| inp_emb = embedder.encode(user_input, convert_to_tensor=True) | |
| trg_embs = embedder.encode(trigger_texts, convert_to_tensor=True) | |
| cos_scores = util.cos_sim(inp_emb, trg_embs).squeeze(0) | |
| max_score, idx = torch.max(cos_scores, dim=0) | |
| score_val = float(max_score.item()) | |
| matched_text = trigger_texts[int(idx.item())] | |
| return (score_val >= threshold, score_val, matched_text) | |
| async def extract_emotion_via_fallback(request: Request, user_input: str) -> str: | |
| prompt = ( | |
| "๋ค์ ๋ฌธ์ฅ์ ํ์ ๊ฐ์ ์ ํ ๋จ์ด ๋๋ ์งง์ ๋ฌธ์ฅ์ผ๋ก ์ค๋ช ํ์์ค.\n\n" | |
| f"[๋ฌธ์ฅ]\n{user_input}\n\n" | |
| "์ง์:\n- ๊ฐ์ ์ ์ง์ ์ ์ผ๋ก ํํํ์ง ์์๋ ๋ฌธ๋งฅ์ ํตํด ์ถ๋ก ํ์์ค.\n" | |
| "- ๊ฐ๋ฅํ ๊ฒฝ์ฐ ๊ฐ์ ์ ๊ฐ๋๋ ๋์์ค๋ ๋ฐ์ํ์์ค.\n" | |
| "- ์: ๋ถ๋ ธ, ์ฌํ, ํผ๋, ๊ธฐ๋, ๋ฌด๊ด์ฌ, ์ด์กฐํจ ๋ฑ\n" | |
| "- ๋จ์ด ํ๋ ๋๋ ์งง์ ๋ฌธ์ฅ์ผ๋ก๋ง ์ถ๋ ฅํ์์ค.\n\n" | |
| "์ ๋ต:" | |
| ) | |
| response = await generate_fallback_response(request, prompt) | |
| return response.strip() | |
| async def _llm_trigger_check(request: Request, user_input: str, label_list: list) -> bool: | |
| if not label_list: | |
| return False | |
| criteria_block = "\n".join(f"- {c}" for c in label_list) | |
| prompt = ( | |
| "๋ค์์ ์๋ฏธ ๋น๊ต๋ฅผ ์ํ ํ๋จ ๊ธฐ์ค๊ณผ ๊ฒ์ฌ ๋์์ ๋๋ค.\n\n" | |
| "[CRITERIA]\n" | |
| f"{criteria_block}\n" | |
| "[/CRITERIA]\n\n" | |
| "[INPUT]\n" | |
| f"{user_input}\n" | |
| "[/INPUT]\n\n" | |
| "์ง์:\n" | |
| "- [INPUT] ๋ด์ฉ์ด [CRITERIA] ํญ๋ชฉ ์ค ํ๋์ ์๋ฏธ๊ฐ ๊ฐ๊ฑฐ๋ ์ ์ฌํ๋ฉด YES, ๊ทธ๋ ์ง ์์ผ๋ฉด NO๋ง ์ถ๋ ฅํ์์ค.\n" | |
| "- ๋จ์ด ๊ทธ๋๋ก ํฌํจ๋์ง ์์๋ ์๋ฏธ๊ฐ ์ ์ฌํ๋ฉด YES๋ก ๊ฐ์ฃผํ์์ค.\n" | |
| "- ํ์ ์ด ์๊ฑฐ๋ ํ๋จ์ด ์ ๋งคํ๋ฉด NO๋ฅผ ์ถ๋ ฅํ์์ค.\n\n" | |
| "์ ๋ต:" | |
| ) | |
| txt = await generate_fallback_response(request, prompt) | |
| ans = txt.strip().upper() | |
| normalized = ans.replace(".", "").replace("!", "").strip() | |
| return ( | |
| normalized == "YES" or | |
| normalized == "Y" or | |
| normalized.startswith("YES") or | |
| normalized.startswith("Y") or | |
| normalized.startswith("์") or | |
| normalized.startswith("๋ค") | |
| ) | |
| async def preprocess_input( | |
| request: Request, | |
| session_id: str, | |
| npc_id: str, | |
| user_input: str, | |
| context: dict | |
| ) -> dict: | |
| parser = ContextParser(context) | |
| emotion = await extract_emotion_via_fallback(request, user_input) | |
| require_items = context.get("require", {}).get("items", []) | |
| require_actions = context.get("require", {}).get("actions", []) | |
| require_game_state = context.get("require", {}).get("game_state", []) | |
| require_delta = context.get("require", {}).get("delta", {}) | |
| quest_stage = parser.game.get("quest_stage", "default") | |
| location = parser.game.get("location", context.get("location", "unknown")) | |
| # --- RAG bundle ๋ก๋ --- | |
| agent = agent_manager.get_agent(npc_id) | |
| bundle = agent.load_rag_bundle(quest_stage, location) | |
| # === 1์ฐจ ๊ฒ์ฌ: trigger_def ๊ธฐ๋ฐ === | |
| td_docs = bundle.get("trigger_def", []) | |
| if td_docs: | |
| td = td_docs[0] | |
| trig = td.get("trigger", {}) | |
| text_ok = not trig.get("required_text") or any(t in user_input for t in trig["required_text"]) | |
| items_ok = not trig.get("required_items", {}).get("mandatory") or set(trig["required_items"]["mandatory"]).issubset(set(require_items)) | |
| actions_ok = not trig.get("required_actions", {}).get("mandatory") or set(trig["required_actions"]["mandatory"]).issubset(set(require_actions)) | |
| gs_ok = not trig.get("required_game_state", {}).get("mandatory") or set(trig["required_game_state"]["mandatory"]).issubset(set(require_game_state)) | |
| delta_ok = all(require_delta.get(k, 0) >= v for k, v in trig.get("required_delta", {}).get("mandatory", {}).items()) | |
| if text_ok and items_ok and actions_ok and gs_ok and delta_ok: | |
| return { | |
| "session_id": session_id, | |
| "player_utterance": user_input, | |
| "npc_id": npc_id, | |
| "tags": parser.npc, | |
| "player_state": parser.player, | |
| "game_state": parser.game, | |
| "context": _short_history(context), | |
| "emotion": emotion, | |
| "triggers": trig, | |
| "is_valid": True, | |
| "additional_trigger": None, | |
| "rag_main_docs": ( | |
| td_docs | |
| + bundle.get("lore", []) | |
| + bundle.get("description", []) | |
| + bundle.get("npc_persona", []) | |
| + bundle.get("dialogue_turn", []) | |
| + bundle.get("flag_def", []) | |
| + bundle.get("main_res_validate", []) | |
| ), | |
| "rag_fallback_docs": bundle.get("fallback", []) + bundle.get("npc_persona", []), | |
| "trigger_meta": {} | |
| } | |
| # === 2์ฐจ ๊ฒ์ฌ: forbidden-trigger ๊ธฐ๋ฐ === | |
| forbidden_data = bundle.get("forbidden_trigger_list", [{}])[0] | |
| keywords = forbidden_data.get("triggers", {}).get("keywords", []) | |
| trigger_texts = forbidden_data.get("triggers", {}).get("text", []) | |
| embedder = request.app.state.embedder | |
| matched_key = None | |
| confidence = 0.0 | |
| kw_match = None | |
| txt_match = None | |
| # 1. keyword ์ ์ฌ๋ ๊ฒ์ฌ | |
| kw_hit, kw_score, kw_match = _semantic_match_embedder(embedder, user_input, keywords, threshold=0.75) | |
| # 2. text ์ ์ฌ๋ ๊ฒ์ฌ | |
| txt_hit, txt_score, txt_match = _semantic_match_embedder(embedder, user_input, trigger_texts, threshold=0.75) | |
| # 3. ์ ์ฌ๋ ๋์ ์ชฝ ์ ํ | |
| if kw_hit and (kw_score >= txt_score): | |
| matched_key = "keyword_match" | |
| confidence = kw_score | |
| elif txt_hit: | |
| matched_key = "text_match" | |
| confidence = txt_score | |
| elif max(kw_score, txt_score) >= 0.65: | |
| # ๊ฐ์ฅ ๊ฐ๊น์ด keyword์ text๋ง label ํ๋ณด๋ก ์ ๋ฌ | |
| label_candidates = [] | |
| if kw_match: | |
| label_candidates.append(kw_match) | |
| if txt_match: | |
| label_candidates.append(txt_match) | |
| if await _llm_trigger_check(request, user_input, label_candidates): | |
| matched_key = "semantic_match_llm" | |
| confidence = max(kw_score, txt_score) | |
| # === trigger_meta ๋งค์นญ ๋ณด์ === | |
| actual_trigger = None | |
| if matched_key: | |
| # kw_match๋ txt_match ๊ฐ์ด ์ค์ trigger_meta.trigger ๊ฐ๊ณผ ์ผ์นํ๋์ง ํ์ธ | |
| for tm in bundle.get("trigger_meta", []): | |
| if tm.get("trigger") in (kw_match, txt_match): | |
| actual_trigger = tm.get("trigger") | |
| break | |
| trigger_meta = {} | |
| if actual_trigger: | |
| trigger_meta = next((tm for tm in bundle.get("trigger_meta", []) if tm.get("trigger") == actual_trigger), {}) | |
| trigger_meta["confidence"] = confidence | |
| additional_trigger = bool(actual_trigger) | |
| return { | |
| "session_id": session_id, | |
| "player_utterance": user_input, | |
| "npc_id": npc_id, | |
| "tags": parser.npc, | |
| "player_state": parser.player, | |
| "game_state": parser.game, | |
| "context": _short_history(context), | |
| "emotion": emotion, | |
| "triggers": [], | |
| "is_valid": False, | |
| "additional_trigger": additional_trigger, | |
| "rag_main_docs": ( | |
| bundle.get("lore", []) | |
| + bundle.get("description", []) | |
| + bundle.get("npc_persona", []) | |
| + bundle.get("dialogue_turn", []) | |
| + bundle.get("flag_def", []) | |
| + bundle.get("main_res_validate", []) | |
| ), | |
| "rag_fallback_docs": bundle.get("fallback", []) + bundle.get("npc_persona", []), | |
| "trigger_meta": trigger_meta | |
| } | |