Spaces:
Running
Running
voice-agent-examples
/
examples
/voice_agent_webrtc_langgraph
/agents
/rbc-fees-agent
/react_agent.py
| import os | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| from langgraph.func import entrypoint, task | |
| from langgraph.graph import add_messages | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import ( | |
| SystemMessage, | |
| HumanMessage, | |
| AIMessage, | |
| BaseMessage, | |
| ToolCall, | |
| ToolMessage, | |
| ) | |
| # ---- Tools (wrap existing logic/fixtures) ---- | |
| try: | |
| from .tools import ( | |
| get_customer_profile, | |
| find_account_by_last4, | |
| parse_date_range, | |
| find_customer, | |
| check_upgrade_options, | |
| verify_identity, | |
| list_accounts, | |
| fetch_activity, | |
| detect_fees, | |
| explain_fee, | |
| check_dispute_eligibility, | |
| create_dispute, | |
| ) | |
| except ImportError: | |
| import sys | |
| sys.path.append(os.path.dirname(__file__)) | |
| from tools import ( # type: ignore | |
| get_customer_profile, | |
| find_account_by_last4, | |
| parse_date_range, | |
| find_customer, | |
| check_upgrade_options, | |
| verify_identity, | |
| list_accounts, | |
| fetch_activity, | |
| detect_fees, | |
| explain_fee, | |
| check_dispute_eligibility, | |
| create_dispute, | |
| ) | |
| # Also import a direct finder for name→customer_id in case the LLM doesn't call the tool before verification | |
| try: | |
| from .logic import find_customer_by_name # type: ignore | |
| except Exception: | |
| try: | |
| import sys as _sys, os as _os | |
| _sys.path.append(os.path.dirname(__file__)) | |
| from logic import find_customer_by_name # type: ignore | |
| except Exception: | |
| find_customer_by_name = None # type: ignore | |
| """ReAct agent entrypoint and system prompt.""" | |
| SYSTEM_PROMPT = ( | |
| "You are a warm, cheerful banking assistant on a phone call. " | |
| "Start with a brief greeting and small talk. If the caller's identity is unknown, politely ask for their full name (first and last). If they provide only a single given name, ask for their last name next. " | |
| "Before any account lookups or actions, you MUST verify the caller's identity using the verify_identity tool. " | |
| "Ask for date of birth (do not specify a format) and either last-4 of account or a secret answer. If the tool returns a secret question, read it back verbatim and ask for the answer. " | |
| "Normalize any provided DOB to YYYY-MM-DD before calling verify_identity. " | |
| "If the user provides first and last name, FIRST call find_customer to resolve customer_id and then include that customer_id in subsequent tool calls. " | |
| "Only after verified=true, re-use the authenticated account if the customer confirms it's the same; " | |
| "otherwise, ask for the last 4 of the other account and use find_account_by_last4. " | |
| "Then ASK THE CUSTOMER for the specific fee date or a date range (e.g., last 30/90 days). Do not assume a default window. " | |
| "After the customer provides a timeframe, first call parse_date_range. If it returns an error, ask for clarification and DO NOT proceed. Then call detect_fees. If detect_fees returns an error (invalid/future/no_fees), ask for clarification or suggest a wider range (e.g., last 90 days) and DO NOT invent a fee. Only once there are fee events, continue. FIRST, explain the relevant fee clearly (what it is and why it happened) using simple language. Do not mention your training data cutoff; rely on the provided tools and fixtures to answer. " | |
| "SECOND, confirm understanding or offer a brief clarification if needed. If the customer asks about a refund or relief, call check_dispute_eligibility; if eligible, ask permission and then call create_dispute; otherwise, suggest preventive tips. " | |
| "THIRD, ONLY AFTER explanation and any refund/relief handling, you MUST proactively consider upgrades: call check_upgrade_options with the recent fee events and propose ONE concise package (the highest estimated net benefit) even if the user doesn't ask. If net benefit is positive, emphasize savings; if not, present as optional convenience. " | |
| "Keep messages short (1–3 sentences), empathetic, and helpful. " | |
| "TTS SAFETY: Output must be plain text suitable for text-to-speech. Do not use markdown, bullets, asterisks, emojis, or special typography. Use only ASCII punctuation and straight quotes." | |
| ) | |
| _MODEL_NAME = os.getenv("REACT_MODEL", os.getenv("CLARIFY_MODEL", "gpt-4o")) | |
| _LLM = ChatOpenAI(model=_MODEL_NAME, temperature=0.3) | |
| _TOOLS = [ | |
| get_customer_profile, | |
| find_account_by_last4, | |
| parse_date_range, | |
| find_customer, | |
| check_upgrade_options, | |
| verify_identity, | |
| list_accounts, | |
| fetch_activity, | |
| detect_fees, | |
| explain_fee, | |
| check_dispute_eligibility, | |
| create_dispute, | |
| ] | |
| _LLM_WITH_TOOLS = _LLM.bind_tools(_TOOLS) | |
| _TOOLS_BY_NAME = {t.name: t for t in _TOOLS} | |
| # Simple per-run context storage (thread-safe enough for local dev worker) | |
| _CURRENT_THREAD_ID: str | None = None | |
| _CURRENT_CUSTOMER_ID: str | None = None | |
| # ---- Logger ---- | |
| logger = logging.getLogger("RBC_ReActFeesAgent") | |
| if not logger.handlers: | |
| _stream = logging.StreamHandler() | |
| _stream.setLevel(logging.INFO) | |
| _fmt = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| _stream.setFormatter(_fmt) | |
| logger.addHandler(_stream) | |
| try: | |
| _file = logging.FileHandler(str(Path(__file__).resolve().parents[2] / "app.log")) | |
| _file.setLevel(logging.INFO) | |
| _file.setFormatter(_fmt) | |
| logger.addHandler(_file) | |
| except Exception: | |
| pass | |
| logger.setLevel(logging.INFO) | |
| _DEBUG = os.getenv("RBC_FEES_DEBUG", "0") not in ("", "0", "false", "False") | |
| def _get_thread_id(config: Dict[str, Any] | None, messages: List[BaseMessage]) -> str: | |
| cfg = config or {} | |
| # Try dict-like and attribute-like access | |
| def _safe_get(container: Any, key: str, default: Any = None) -> Any: | |
| try: | |
| if isinstance(container, dict): | |
| return container.get(key, default) | |
| if hasattr(container, "get"): | |
| return container.get(key, default) | |
| if hasattr(container, key): | |
| return getattr(container, key, default) | |
| except Exception: | |
| return default | |
| return default | |
| try: | |
| conf = _safe_get(cfg, "configurable", {}) or {} | |
| for key in ("thread_id", "session_id", "thread"): | |
| val = _safe_get(conf, key) | |
| if isinstance(val, str) and val: | |
| return val | |
| except Exception: | |
| pass | |
| # Fallback: look for session_id on the latest human message additional_kwargs | |
| try: | |
| for m in reversed(messages or []): | |
| addl = getattr(m, "additional_kwargs", None) | |
| if isinstance(addl, dict) and isinstance(addl.get("session_id"), str) and addl.get("session_id"): | |
| return addl.get("session_id") | |
| if isinstance(m, dict): | |
| ak = m.get("additional_kwargs") or {} | |
| if isinstance(ak, dict) and isinstance(ak.get("session_id"), str) and ak.get("session_id"): | |
| return ak.get("session_id") | |
| except Exception: | |
| pass | |
| return "unknown" | |
| def _trim_messages(messages: List[BaseMessage], max_messages: int = 40) -> List[BaseMessage]: | |
| if len(messages) <= max_messages: | |
| return messages | |
| return messages[-max_messages:] | |
| def _sanitize_conversation(messages: List[BaseMessage]) -> List[BaseMessage]: | |
| """Ensure tool messages only follow an assistant message with tool_calls. | |
| Drops orphan tool messages that could cause OpenAI 400 errors. | |
| """ | |
| sanitized: List[BaseMessage] = [] | |
| pending_tool_ids: set[str] | None = None | |
| for m in messages: | |
| try: | |
| if isinstance(m, AIMessage): | |
| sanitized.append(m) | |
| tool_calls = getattr(m, "tool_calls", None) or [] | |
| ids: set[str] = set() | |
| for tc in tool_calls: | |
| # ToolCall can be mapping-like or object-like | |
| if isinstance(tc, dict): | |
| _id = tc.get("id") or tc.get("tool_call_id") | |
| else: | |
| _id = getattr(tc, "id", None) or getattr(tc, "tool_call_id", None) | |
| if isinstance(_id, str): | |
| ids.add(_id) | |
| pending_tool_ids = ids if ids else None | |
| continue | |
| if isinstance(m, ToolMessage): | |
| if pending_tool_ids and isinstance(getattr(m, "tool_call_id", None), str) and m.tool_call_id in pending_tool_ids: | |
| sanitized.append(m) | |
| # keep accepting subsequent tool messages for the same assistant turn | |
| continue | |
| # Orphan tool message: drop | |
| continue | |
| # Any other message resets expectation | |
| sanitized.append(m) | |
| pending_tool_ids = None | |
| except Exception: | |
| # On any unexpected shape, include as-is but reset to avoid pairing issues | |
| sanitized.append(m) | |
| pending_tool_ids = None | |
| # Ensure the conversation doesn't start with a ToolMessage | |
| while sanitized and isinstance(sanitized[0], ToolMessage): | |
| sanitized.pop(0) | |
| return sanitized | |
| def _today_string() -> str: | |
| override = os.getenv("RBC_FEES_TODAY_OVERRIDE") | |
| if isinstance(override, str) and override.strip(): | |
| try: | |
| datetime.strptime(override.strip(), "%Y-%m-%d") | |
| return override.strip() | |
| except Exception: | |
| pass | |
| return datetime.utcnow().strftime("%Y-%m-%d") | |
| def _system_messages() -> List[BaseMessage]: | |
| today = _today_string() | |
| return [ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| SystemMessage(content=( | |
| f"Today is {today} (UTC). When the user mentions any date or timeframe, first call parse_date_range. " | |
| "Do not claim a date is in the future unless it is strictly after today. " | |
| "Rely on tools/fixtures and do not mention training data cutoffs." | |
| )), | |
| ] | |
| def call_llm(messages: List[BaseMessage]) -> AIMessage: | |
| """LLM decides whether to call a tool or not.""" | |
| if _DEBUG: | |
| try: | |
| preview = [f"{getattr(m,'type', getattr(m,'role',''))}:{str(getattr(m,'content', m))[:80]}" for m in messages[-6:]] | |
| logger.info("call_llm: messages_count=%s preview=%s", len(messages), preview) | |
| except Exception: | |
| logger.info("call_llm: messages_count=%s", len(messages)) | |
| return _LLM_WITH_TOOLS.invoke(_system_messages() + messages) | |
| def call_tool(tool_call: ToolCall) -> ToolMessage: | |
| """Execute a tool call and wrap result in a ToolMessage.""" | |
| tool = _TOOLS_BY_NAME[tool_call["name"]] | |
| args = tool_call.get("args") or {} | |
| # Auto-inject session/customer context if missing for identity and other tools | |
| if tool.name == "verify_identity": | |
| if "session_id" not in args and _CURRENT_THREAD_ID: | |
| args["session_id"] = _CURRENT_THREAD_ID | |
| if "customer_id" not in args and _CURRENT_CUSTOMER_ID: | |
| args["customer_id"] = _CURRENT_CUSTOMER_ID | |
| if tool.name == "list_accounts": | |
| if "customer_id" not in args and _CURRENT_CUSTOMER_ID: | |
| args["customer_id"] = _CURRENT_CUSTOMER_ID | |
| if _DEBUG: | |
| try: | |
| logger.info("call_tool: name=%s args_keys=%s", tool.name, list(args.keys())) | |
| except Exception: | |
| logger.info("call_tool: name=%s", tool.name) | |
| result = tool.invoke(args) | |
| # Ensure string content | |
| content = result if isinstance(result, str) else json.dumps(result) | |
| return ToolMessage(content=content, tool_call_id=tool_call["id"], name=tool.name) | |
| def agent(messages: List[BaseMessage], previous: List[BaseMessage] | None, config: Dict[str, Any] | None = None): | |
| # Start from full conversation history (previous + new) | |
| prev_list = list(previous or []) | |
| new_list = list(messages or []) | |
| convo: List[BaseMessage] = prev_list + new_list | |
| # Trim to avoid context bloat | |
| convo = _trim_messages(convo, max_messages=int(os.getenv("RBC_FEES_MAX_MSGS", "40"))) | |
| # Sanitize to avoid orphan tool messages after trimming | |
| convo = _sanitize_conversation(convo) | |
| thread_id = _get_thread_id(config, new_list) | |
| logger.info("agent start: thread_id=%s total_in=%s (prev=%s, new=%s)", thread_id, len(convo), len(prev_list), len(new_list)) | |
| # Establish default customer from config (or fallback to cust_test) | |
| conf = (config or {}).get("configurable", {}) if isinstance(config, dict) else {} | |
| default_customer = conf.get("customer_id") or conf.get("user_email") or "cust_test" | |
| # Heuristic: infer customer_id from latest human name if provided (e.g., "I am Alice Stone") | |
| inferred_customer: str | None = None | |
| try: | |
| recent_humans = [m for m in reversed(new_list) if (getattr(m, "type", None) == "human" or getattr(m, "role", None) == "user" or (isinstance(m, dict) and m.get("type") == "human"))] | |
| text = None | |
| for m in recent_humans[:3]: | |
| text = (getattr(m, "content", None) if not isinstance(m, dict) else m.get("content")) or "" | |
| if isinstance(text, str) and text.strip(): | |
| break | |
| if isinstance(text, str): | |
| tokens = [t for t in text.replace(',', ' ').split() if t.isalpha()] | |
| if len(tokens) >= 2 and find_customer_by_name is not None: | |
| # Try adjacent pairs as first/last | |
| for i in range(len(tokens) - 1): | |
| fn = tokens[i] | |
| ln = tokens[i + 1] | |
| found = find_customer_by_name(fn, ln) # type: ignore | |
| if isinstance(found, dict) and found.get("customer_id"): | |
| inferred_customer = found.get("customer_id") | |
| break | |
| except Exception: | |
| pass | |
| # Update module context | |
| global _CURRENT_THREAD_ID, _CURRENT_CUSTOMER_ID | |
| _CURRENT_THREAD_ID = thread_id | |
| _CURRENT_CUSTOMER_ID = inferred_customer or default_customer | |
| llm_response = call_llm(convo).result() | |
| while True: | |
| tool_calls = getattr(llm_response, "tool_calls", None) or [] | |
| if not tool_calls: | |
| break | |
| # Execute tools (in parallel) and append results | |
| futures = [call_tool(tc) for tc in tool_calls] | |
| tool_results = [f.result() for f in futures] | |
| if _DEBUG: | |
| try: | |
| logger.info("tool_results: count=%s names=%s", len(tool_results), [tr.name for tr in tool_results]) | |
| except Exception: | |
| pass | |
| convo = add_messages(convo, [llm_response, *tool_results]) | |
| llm_response = call_llm(convo).result() | |
| # Append final assistant turn | |
| convo = add_messages(convo, [llm_response]) | |
| final_text = getattr(llm_response, "content", "") or "" | |
| ai = AIMessage(content=final_text if isinstance(final_text, str) else str(final_text)) | |
| logger.info("agent done: thread_id=%s total_messages=%s final_len=%s", thread_id, len(convo), len(ai.content)) | |
| # Save only the merged conversation (avoid duplicating previous) | |
| return entrypoint.final(value=ai, save=convo) | |