fciannella commited on
Commit
c63b8b6
·
1 Parent(s): 59ba41d

Added the healthcare agent

Browse files
examples/voice_agent_webrtc_langgraph/Dockerfile CHANGED
@@ -55,6 +55,7 @@ WORKDIR /app/examples/voice_agent_webrtc_langgraph
55
 
56
  # Dependencies
57
  RUN uv sync --frozen
 
58
  # Install all agent requirements recursively into the project's virtual environment
59
  # RUN if [ -d "agents" ]; then \
60
  # find agents -type f -name "requirements.txt" -print0 | xargs -0 -I {} uv pip install -r "{}"; \
 
55
 
56
  # Dependencies
57
  RUN uv sync --frozen
58
+ # RUN uv sync
59
  # Install all agent requirements recursively into the project's virtual environment
60
  # RUN if [ -d "agents" ]; then \
61
  # find agents -type f -name "requirements.txt" -print0 | xargs -0 -I {} uv pip install -r "{}"; \
examples/voice_agent_webrtc_langgraph/agents/healthcare-agent/logic.py CHANGED
@@ -620,3 +620,244 @@ def authenticate_user(session_id: str, name: Optional[str], dob_yyyy_mm_dd: Opti
620
  return resp
621
 
622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  return resp
621
 
622
 
623
+
624
+ # --- Healthcare demo logic (patients, triage, providers, pharmacies) ---
625
+
626
+ _HC_SESSIONS: Dict[str, Dict[str, Any]] = {}
627
+ _HC_APPOINTMENTS: List[Dict[str, Any]] = []
628
+ _HC_CALL_LOG: List[Dict[str, Any]] = []
629
+
630
+
631
+ def _hc_fixtures_dir() -> Path:
632
+ return Path(__file__).parent / "mock_data"
633
+
634
+
635
+ def _hc_load_fixture(name: str) -> Any:
636
+ # Use a separate cache key namespace to avoid collisions
637
+ key = f"hc::{name}"
638
+ if key in _FIXTURE_CACHE:
639
+ return _FIXTURE_CACHE[key]
640
+ p = _hc_fixtures_dir() / name
641
+ with p.open("r", encoding="utf-8") as f:
642
+ data = json.load(f)
643
+ _FIXTURE_CACHE[key] = data
644
+ return data
645
+
646
+
647
+ def _hc_get_patient_blob(patient_id: str) -> Dict[str, Any]:
648
+ data = _hc_load_fixture("patients.json")
649
+ return dict((data.get("patients") or {}).get(patient_id, {}))
650
+
651
+
652
+ def find_patient_by_name(first_name: str, last_name: str) -> Dict[str, Any]:
653
+ data = _hc_load_fixture("patients.json")
654
+ patients = data.get("patients", {})
655
+ fn = (first_name or "").strip().lower()
656
+ ln = (last_name or "").strip().lower()
657
+ for pid, blob in patients.items():
658
+ prof = blob.get("profile") if isinstance(blob, dict) else None
659
+ if isinstance(prof, dict):
660
+ pfn = str(prof.get("first_name") or "").strip().lower()
661
+ pln = str(prof.get("last_name") or "").strip().lower()
662
+ if fn == pfn and ln == pln:
663
+ return {"patient_id": pid, "profile": prof}
664
+ return {}
665
+
666
+
667
+ def find_patient_by_full_name(full_name: str) -> Dict[str, Any]:
668
+ data = _hc_load_fixture("patients.json")
669
+ patients = data.get("patients", {})
670
+ target = (full_name or "").strip().lower()
671
+ for pid, blob in patients.items():
672
+ prof = blob.get("profile") if isinstance(blob, dict) else None
673
+ if isinstance(prof, dict):
674
+ fn = f"{str(prof.get('first_name') or '').strip()} {str(prof.get('last_name') or '').strip()}".strip().lower()
675
+ ff = str(prof.get("full_name") or "").strip().lower()
676
+ if target and (target == fn or target == ff):
677
+ return {"patient_id": pid, "profile": prof}
678
+ return {}
679
+
680
+
681
+ def get_patient_profile(patient_id: str) -> Dict[str, Any]:
682
+ blob = _hc_get_patient_blob(patient_id)
683
+ if not blob:
684
+ return {}
685
+ prof = dict(blob.get("profile", {}))
686
+ return {
687
+ "profile": prof,
688
+ "allergies": list(blob.get("allergies", [])),
689
+ "medications": list(blob.get("medications", [])),
690
+ "conditions": list(blob.get("conditions", [])),
691
+ "recent_visits": list(blob.get("recent_visits", [])),
692
+ "vitals": dict(blob.get("vitals", {})),
693
+ }
694
+
695
+
696
+ def authenticate_patient(session_id: str, patient_id: Optional[str], full_name: Optional[str], dob_yyyy_mm_dd: Optional[str], mrn_last4: Optional[str], secret_answer: Optional[str]) -> Dict[str, Any]:
697
+ session = _HC_SESSIONS.get(session_id) or {"verified": False, "patient_id": patient_id, "name": full_name}
698
+ if isinstance(patient_id, str) and patient_id:
699
+ session["patient_id"] = patient_id
700
+ if isinstance(full_name, str) and full_name:
701
+ session["name"] = full_name
702
+ if isinstance(dob_yyyy_mm_dd, str) and dob_yyyy_mm_dd:
703
+ session["dob"] = _normalize_dob(dob_yyyy_mm_dd) or dob_yyyy_mm_dd
704
+ if isinstance(mrn_last4, str) and mrn_last4:
705
+ session["mrn_last4"] = mrn_last4
706
+ if isinstance(secret_answer, str) and secret_answer:
707
+ session["secret"] = secret_answer
708
+
709
+ ok = False
710
+ pid = session.get("patient_id")
711
+ if isinstance(pid, str):
712
+ prof = get_patient_profile(pid).get("profile", {})
713
+ user_dob_norm = _normalize_dob(session.get("dob"))
714
+ prof_dob_norm = _normalize_dob(prof.get("dob"))
715
+ dob_ok = (user_dob_norm is not None) and (user_dob_norm == prof_dob_norm)
716
+ mrn_ok = str(session.get("mrn_last4") or "") == str(prof.get("mrn_last4") or "")
717
+ def _norm(x: Optional[str]) -> str:
718
+ return (x or "").strip().lower()
719
+ secret_ok = _norm(session.get("secret")) == _norm(prof.get("secret_answer"))
720
+ if dob_ok and (mrn_ok or secret_ok):
721
+ ok = True
722
+ session["verified"] = ok
723
+ _HC_SESSIONS[session_id] = session
724
+ need: List[str] = []
725
+ if _normalize_dob(session.get("dob")) is None:
726
+ need.append("dob")
727
+ if not session.get("mrn_last4") and not session.get("secret"):
728
+ need.append("mrn_last4_or_secret")
729
+ if not session.get("patient_id"):
730
+ need.append("patient")
731
+ resp: Dict[str, Any] = {"session_id": session_id, "verified": ok, "needs": need, "profile": {"name": session.get("name")}}
732
+ try:
733
+ if isinstance(session.get("patient_id"), str):
734
+ prof = get_patient_profile(session.get("patient_id")).get("profile", {})
735
+ if isinstance(prof, dict) and prof.get("secret_question"):
736
+ resp["question"] = prof.get("secret_question")
737
+ except Exception:
738
+ pass
739
+ return resp
740
+
741
+
742
+ def get_preferred_pharmacy(patient_id: str) -> Dict[str, Any]:
743
+ prof = get_patient_profile(patient_id).get("profile", {})
744
+ ph_id = prof.get("preferred_pharmacy_id")
745
+ if not ph_id:
746
+ return {}
747
+ data = _hc_load_fixture("pharmacies.json")
748
+ ph = (data.get("pharmacies") or {}).get(ph_id) or {}
749
+ return {"pharmacy_id": ph_id, **ph}
750
+
751
+
752
+ def list_providers(specialty: Optional[str] = None) -> List[Dict[str, Any]]:
753
+ data = _hc_load_fixture("providers.json")
754
+ providers = data.get("providers", {})
755
+ out: List[Dict[str, Any]] = []
756
+ for pid, p in providers.items():
757
+ if specialty and str(p.get("specialty", "")).lower() != specialty.strip().lower():
758
+ continue
759
+ out.append({"provider_id": pid, **p})
760
+ return out
761
+
762
+
763
+ def get_provider_slots(provider_id: str, count: int = 3) -> List[str]:
764
+ data = _hc_load_fixture("providers.json")
765
+ providers = data.get("providers", {})
766
+ p = providers.get(provider_id) or {}
767
+ return list((p.get("next_available") or [])[:count])
768
+
769
+
770
+ def schedule_appointment(provider_id: str, slot_iso: str, patient_id: Optional[str]) -> Dict[str, Any]:
771
+ appt = {
772
+ "appointment_id": f"A-{uuid.uuid4().hex[:8]}",
773
+ "provider_id": provider_id,
774
+ "slot": slot_iso,
775
+ "patient_id": patient_id,
776
+ "created_at": datetime.utcnow().isoformat() + "Z",
777
+ "status": "booked",
778
+ }
779
+ _HC_APPOINTMENTS.append(appt)
780
+ return appt
781
+
782
+
783
+ def _patient_age_years(patient_id: Optional[str]) -> Optional[int]:
784
+ try:
785
+ if not patient_id:
786
+ return None
787
+ prof = get_patient_profile(patient_id).get("profile", {})
788
+ dob = _normalize_dob(prof.get("dob"))
789
+ if not dob:
790
+ return None
791
+ y, m, d = [int(x) for x in dob.split("-")]
792
+ today = datetime.utcnow().date()
793
+ age = today.year - y - ((today.month, today.day) < (m, d))
794
+ return age
795
+ except Exception:
796
+ return None
797
+
798
+
799
+ def triage_symptoms(patient_id: Optional[str], symptoms_text: str) -> Dict[str, Any]:
800
+ txt = (symptoms_text or "").lower()
801
+ rules = _hc_load_fixture("triage_rules.json").get("rules", [])
802
+ age = _patient_age_years(patient_id) or 0
803
+
804
+ def contains_any(needles: List[str]) -> bool:
805
+ for n in needles:
806
+ if n.lower() in txt:
807
+ return True
808
+ return False
809
+
810
+ chosen: Dict[str, Any] | None = None
811
+ red_flags_hit: List[str] = []
812
+
813
+ for r in rules:
814
+ matches = r.get("match", [])
815
+ if matches and not contains_any(matches):
816
+ continue
817
+ rflags = r.get("red_flags", [])
818
+ if rflags:
819
+ red_flags_hit = [rf for rf in rflags if rf.lower() in txt]
820
+ if red_flags_hit:
821
+ chosen = r
822
+ break
823
+ crit = r.get("criteria", [])
824
+ if crit:
825
+ if "age_over_50" in crit and age > 50:
826
+ chosen = r
827
+ break
828
+ if not r.get("red_flags") and not r.get("criteria"):
829
+ chosen = r
830
+ # do not break; prefer a more specific rule if later
831
+
832
+ if not chosen and rules:
833
+ chosen = rules[-1]
834
+
835
+ if not chosen:
836
+ return {"risk": "self_care", "advice": "If symptoms persist or worsen, contact us or seek care.", "red_flags": []}
837
+
838
+ return {
839
+ "risk": chosen.get("escalate", "self_care"),
840
+ "advice": chosen.get("advice", ""),
841
+ "red_flags": red_flags_hit,
842
+ "rule": chosen.get("name", "")
843
+ }
844
+
845
+
846
+ def log_call(session_id: str, patient_id: Optional[str], notes: Optional[str], triage: Optional[Dict[str, Any]]) -> Dict[str, Any]:
847
+ entry = {
848
+ "log_id": f"L-{uuid.uuid4().hex[:8]}",
849
+ "session_id": session_id,
850
+ "patient_id": patient_id,
851
+ "notes": notes or "",
852
+ "triage": triage or {},
853
+ "timestamp": datetime.utcnow().isoformat() + "Z",
854
+ }
855
+ _HC_CALL_LOG.append(entry)
856
+ try:
857
+ # Also mirror to app.log for visibility
858
+ logging.getLogger("HealthcareAgent").info("call_log: %s", json.dumps(entry)[:500])
859
+ except Exception:
860
+ pass
861
+ return {"logged": True, "log_id": entry["log_id"]}
862
+
863
+
examples/voice_agent_webrtc_langgraph/agents/healthcare-agent/react_agent.py CHANGED
@@ -18,90 +18,73 @@ from langchain_core.messages import (
18
  )
19
 
20
 
21
- # ---- Tools (wire-transfer) ----
22
 
23
  try:
24
- from . import tools as wire_tools # type: ignore
25
  except Exception:
26
  import importlib.util as _ilu
27
  _dir = os.path.dirname(__file__)
28
  _tools_path = os.path.join(_dir, "tools.py")
29
- _spec = _ilu.spec_from_file_location("wire_transfer_agent_tools", _tools_path)
30
- wire_tools = _ilu.module_from_spec(_spec) # type: ignore
31
  assert _spec and _spec.loader
32
- _spec.loader.exec_module(wire_tools) # type: ignore
33
 
34
  # Aliases for tool functions
35
- list_accounts = wire_tools.list_accounts
36
- get_customer_profile = wire_tools.get_customer_profile
37
- find_customer = wire_tools.find_customer
38
- find_account_by_last4 = wire_tools.find_account_by_last4
39
- verify_identity = wire_tools.verify_identity
40
- get_account_balance_tool = wire_tools.get_account_balance_tool
41
- get_exchange_rate_tool = wire_tools.get_exchange_rate_tool
42
- calculate_wire_fee_tool = wire_tools.calculate_wire_fee_tool
43
- check_wire_limits_tool = wire_tools.check_wire_limits_tool
44
- get_cutoff_and_eta_tool = wire_tools.get_cutoff_and_eta_tool
45
- get_country_requirements_tool = wire_tools.get_country_requirements_tool
46
- validate_beneficiary_tool = wire_tools.validate_beneficiary_tool
47
- save_beneficiary_tool = wire_tools.save_beneficiary_tool
48
- quote_wire_tool = wire_tools.quote_wire_tool
49
- generate_otp_tool = wire_tools.generate_otp_tool
50
- verify_otp_tool = wire_tools.verify_otp_tool
51
- wire_transfer_domestic = wire_tools.wire_transfer_domestic
52
- wire_transfer_international = wire_tools.wire_transfer_international
53
-
54
- find_customer_by_name = None # not used for wire agent; tools expose find_customer
55
 
56
 
57
  """ReAct agent entrypoint and system prompt."""
58
 
59
 
60
  SYSTEM_PROMPT = (
61
- "You are a warm, cheerful banking assistant helping a customer send a wire transfer (domestic or international). "
62
- "Start with a brief greeting and very short small talk. Then ask for the caller's full name. "
63
- "IDENTITY IS MANDATORY: Before ANY account lookups or wire questions, you MUST call verify_identity. Ask for date of birth (customer can use any format; you normalize) and EITHER SSN last-4 OR the secret answer. If verify_identity returns a secret question, read it verbatim and collect the answer. "
64
- "NEVER claim the customer is verified unless the verify_identity tool returned verified=true. If not verified, ask ONLY for the next missing field and call verify_identity again. Do NOT proceed to wire details until verified=true. "
65
- "AFTER VERIFIED: Ask ONE question at a time, in this order, waiting for the user's answer each time: (1) wire type (DOMESTIC or INTERNATIONAL); (2) source account (last-4 or picker); (3) amount (with source currency); (4) destination country/state; (5) destination currency preference; (6) who pays fees (OUR/SHA/BEN). Keep each turn to a single, concise prompt. Do NOT re-ask for fields already provided; instead, briefly summarize known details and ask only for the next missing field. "
66
- "If destination currency differs from source, call get_exchange_rate_tool and state the applied rate and converted amount. "
67
- "Collect beneficiary details next. Use get_country_requirements_tool and validate_beneficiary_tool; if fields are missing, ask for ONLY the next missing field (one per turn). "
68
- "Then check balance/limits via get_account_balance_tool and check_wire_limits_tool. Provide a pre-transfer quote using quote_wire_tool showing: FX rate, total fees, who pays what, net sent, net received, and ETA from get_cutoff_and_eta_tool. "
69
- "Before executing, generate an OTP (generate_otp_tool), collect it, verify via verify_otp_tool, then execute the appropriate transfer: wire_transfer_domestic or wire_transfer_international. Offer to save the beneficiary afterward. "
70
- "STYLE: Keep messages short (1–2 sentences), empathetic, and strictly ask one question per turn."
71
  )
72
 
73
 
74
  _MODEL_NAME = os.getenv("REACT_MODEL", os.getenv("CLARIFY_MODEL", "gpt-4o"))
75
- _LLM = ChatOpenAI(model=_MODEL_NAME, temperature=0.3)
 
 
76
  _TOOLS = [
77
- list_accounts,
78
- get_customer_profile,
79
- find_customer,
80
- find_account_by_last4,
81
  verify_identity,
82
- get_account_balance_tool,
83
- get_exchange_rate_tool,
84
- calculate_wire_fee_tool,
85
- check_wire_limits_tool,
86
- get_cutoff_and_eta_tool,
87
- get_country_requirements_tool,
88
- validate_beneficiary_tool,
89
- save_beneficiary_tool,
90
- quote_wire_tool,
91
- generate_otp_tool,
92
- verify_otp_tool,
93
- wire_transfer_domestic,
94
- wire_transfer_international,
95
  ]
96
  _LLM_WITH_TOOLS = _LLM.bind_tools(_TOOLS)
97
  _TOOLS_BY_NAME = {t.name: t for t in _TOOLS}
98
 
99
  # Simple per-run context storage (thread-safe enough for local dev worker)
100
  _CURRENT_THREAD_ID: str | None = None
101
- _CURRENT_CUSTOMER_ID: str | None = None
102
 
103
  # ---- Logger ----
104
- logger = logging.getLogger("WireTransferAgent")
105
  if not logger.handlers:
106
  _stream = logging.StreamHandler()
107
  _stream.setLevel(logging.INFO)
@@ -116,7 +99,7 @@ if not logger.handlers:
116
  except Exception:
117
  pass
118
  logger.setLevel(logging.INFO)
119
- _DEBUG = os.getenv("RBC_FEES_DEBUG", "0") not in ("", "0", "false", "False")
120
 
121
  def _get_thread_id(config: Dict[str, Any] | None, messages: List[BaseMessage]) -> str:
122
  cfg = config or {}
@@ -256,24 +239,18 @@ def call_tool(tool_call: ToolCall) -> ToolMessage:
256
  """Execute a tool call and wrap result in a ToolMessage."""
257
  tool = _TOOLS_BY_NAME[tool_call["name"]]
258
  args = tool_call.get("args") or {}
259
- # Auto-inject session/customer context if missing for identity and other tools
260
  if tool.name == "verify_identity":
261
  if "session_id" not in args and _CURRENT_THREAD_ID:
262
  args["session_id"] = _CURRENT_THREAD_ID
263
- if "customer_id" not in args and _CURRENT_CUSTOMER_ID:
264
- args["customer_id"] = _CURRENT_CUSTOMER_ID
265
- if tool.name == "list_accounts":
266
- if "customer_id" not in args and _CURRENT_CUSTOMER_ID:
267
- args["customer_id"] = _CURRENT_CUSTOMER_ID
268
- # Gate non-identity tools until verified=true
269
- try:
270
- if tool.name not in ("verify_identity", "find_customer"):
271
- # Look back through recent messages for the last verify_identity result
272
- # The runtime passes messages separately; we cannot access here, so rely on LLM prompt discipline.
273
- # As an extra guard, if the tool is attempting a wire action before identity, return a friendly error.
274
- pass
275
- except Exception:
276
- pass
277
  if _DEBUG:
278
  try:
279
  logger.info("call_tool: name=%s args_keys=%s", tool.name, list(args.keys()))
@@ -328,12 +305,12 @@ def agent(messages: List[BaseMessage], previous: List[BaseMessage] | None, confi
328
  convo = _sanitize_conversation(convo)
329
  thread_id = _get_thread_id(config, new_list)
330
  logger.info("agent start: thread_id=%s total_in=%s (prev=%s, new=%s)", thread_id, len(convo), len(prev_list), len(new_list))
331
- # Establish default customer from config (or fallback to cust_test)
332
  conf = (config or {}).get("configurable", {}) if isinstance(config, dict) else {}
333
- default_customer = conf.get("customer_id") or conf.get("user_email") or "cust_test"
334
 
335
- # Heuristic: infer customer_id from latest human name if provided (e.g., "I am Alice Stone")
336
- inferred_customer: str | None = None
337
  try:
338
  recent_humans = [m for m in reversed(new_list) if (getattr(m, "type", None) == "human" or getattr(m, "role", None) == "user" or (isinstance(m, dict) and m.get("type") == "human"))]
339
  text = None
@@ -343,22 +320,15 @@ def agent(messages: List[BaseMessage], previous: List[BaseMessage] | None, confi
343
  break
344
  if isinstance(text, str):
345
  tokens = [t for t in text.replace(',', ' ').split() if t.isalpha()]
346
- if len(tokens) >= 2 and find_customer_by_name is not None:
347
- # Try adjacent pairs as first/last
348
- for i in range(len(tokens) - 1):
349
- fn = tokens[i]
350
- ln = tokens[i + 1]
351
- found = find_customer_by_name(fn, ln) # type: ignore
352
- if isinstance(found, dict) and found.get("customer_id"):
353
- inferred_customer = found.get("customer_id")
354
- break
355
  except Exception:
356
  pass
357
 
358
  # Update module context
359
- global _CURRENT_THREAD_ID, _CURRENT_CUSTOMER_ID
360
  _CURRENT_THREAD_ID = thread_id
361
- _CURRENT_CUSTOMER_ID = inferred_customer or default_customer
362
 
363
  llm_response = call_llm(convo).result()
364
 
 
18
  )
19
 
20
 
21
+ # ---- Tools (healthcare) ----
22
 
23
  try:
24
+ from . import tools as hc_tools # type: ignore
25
  except Exception:
26
  import importlib.util as _ilu
27
  _dir = os.path.dirname(__file__)
28
  _tools_path = os.path.join(_dir, "tools.py")
29
+ _spec = _ilu.spec_from_file_location("healthcare_agent_tools", _tools_path)
30
+ hc_tools = _ilu.module_from_spec(_spec) # type: ignore
31
  assert _spec and _spec.loader
32
+ _spec.loader.exec_module(hc_tools) # type: ignore
33
 
34
  # Aliases for tool functions
35
+ find_patient = hc_tools.find_patient
36
+ get_patient_profile_tool = hc_tools.get_patient_profile_tool
37
+ verify_identity = hc_tools.verify_identity
38
+ get_preferred_pharmacy_tool = hc_tools.get_preferred_pharmacy_tool
39
+ list_providers_tool = hc_tools.list_providers_tool
40
+ get_provider_slots_tool = hc_tools.get_provider_slots_tool
41
+ schedule_appointment_tool = hc_tools.schedule_appointment_tool
42
+ triage_symptoms_tool = hc_tools.triage_symptoms_tool
43
+ log_call_tool = hc_tools.log_call_tool
44
+
45
+ find_customer_by_name = None # not used
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  """ReAct agent entrypoint and system prompt."""
49
 
50
 
51
  SYSTEM_PROMPT = (
52
+ "You are a compassionate 24/7 telehealth nurse for existing patients. "
53
+ "Begin with a warm, concise greeting and ask for the caller's full name. "
54
+ "IDENTITY IS MANDATORY: Before accessing any records, verify identity using date of birth (any format; you normalize) and EITHER MRN last-4 OR the secret answer. If a secret question is available, read it verbatim and collect the answer. "
55
+ "NEVER claim the caller is verified unless verification returns verified=true. If not verified, ask ONLY for the next missing field and verify again. "
56
+ "AFTER VERIFIED: Ask ONE question at a time. Gather chief complaint and symptoms in plain language. Screen for common red flags (severe/worst-ever, head injury, weakness/numbness, vision changes, seizure, stiff neck, high fever). If any red flag is present, clearly advise urgent evaluation. "
57
+ "Use a calm, empathetic tone and keep responses short (1–2 sentences). "
58
+ "If no red flags, provide brief self-care guidance (hydration, rest, acetaminophen dose guidance when appropriate) and offer to book a telehealth appointment with available providers. "
59
+ "Confirm preferred pharmacy for prescriptions if needed. "
60
+ "Always speak clearly and avoid medical jargon."
 
61
  )
62
 
63
 
64
  _MODEL_NAME = os.getenv("REACT_MODEL", os.getenv("CLARIFY_MODEL", "gpt-4o"))
65
+ _OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
66
+ _OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
67
+ _LLM = ChatOpenAI(model=_MODEL_NAME, temperature=0.3, base_url=_OPENAI_BASE_URL, api_key=_OPENAI_API_KEY)
68
  _TOOLS = [
69
+ find_patient,
70
+ get_patient_profile_tool,
 
 
71
  verify_identity,
72
+ triage_symptoms_tool,
73
+ list_providers_tool,
74
+ get_provider_slots_tool,
75
+ schedule_appointment_tool,
76
+ get_preferred_pharmacy_tool,
77
+ log_call_tool,
 
 
 
 
 
 
 
78
  ]
79
  _LLM_WITH_TOOLS = _LLM.bind_tools(_TOOLS)
80
  _TOOLS_BY_NAME = {t.name: t for t in _TOOLS}
81
 
82
  # Simple per-run context storage (thread-safe enough for local dev worker)
83
  _CURRENT_THREAD_ID: str | None = None
84
+ _CURRENT_PATIENT_ID: str | None = None
85
 
86
  # ---- Logger ----
87
+ logger = logging.getLogger("HealthcareAgent")
88
  if not logger.handlers:
89
  _stream = logging.StreamHandler()
90
  _stream.setLevel(logging.INFO)
 
99
  except Exception:
100
  pass
101
  logger.setLevel(logging.INFO)
102
+ _DEBUG = os.getenv("HC_DEBUG", "0") not in ("", "0", "false", "False")
103
 
104
  def _get_thread_id(config: Dict[str, Any] | None, messages: List[BaseMessage]) -> str:
105
  cfg = config or {}
 
239
  """Execute a tool call and wrap result in a ToolMessage."""
240
  tool = _TOOLS_BY_NAME[tool_call["name"]]
241
  args = tool_call.get("args") or {}
242
+ # Auto-inject session/patient context for identity and profile tools
243
  if tool.name == "verify_identity":
244
  if "session_id" not in args and _CURRENT_THREAD_ID:
245
  args["session_id"] = _CURRENT_THREAD_ID
246
+ if "patient_id" not in args and _CURRENT_PATIENT_ID:
247
+ args["patient_id"] = _CURRENT_PATIENT_ID
248
+ if tool.name in ("get_patient_profile_tool", "get_preferred_pharmacy_tool"):
249
+ if "patient_id" not in args and _CURRENT_PATIENT_ID:
250
+ args["patient_id"] = _CURRENT_PATIENT_ID
251
+ if tool.name == "triage_symptoms_tool":
252
+ if "patient_id" not in args:
253
+ args["patient_id"] = _CURRENT_PATIENT_ID
 
 
 
 
 
 
254
  if _DEBUG:
255
  try:
256
  logger.info("call_tool: name=%s args_keys=%s", tool.name, list(args.keys()))
 
305
  convo = _sanitize_conversation(convo)
306
  thread_id = _get_thread_id(config, new_list)
307
  logger.info("agent start: thread_id=%s total_in=%s (prev=%s, new=%s)", thread_id, len(convo), len(prev_list), len(new_list))
308
+ # Establish default patient from config (or fallback to pt_jmarshall)
309
  conf = (config or {}).get("configurable", {}) if isinstance(config, dict) else {}
310
+ default_patient = conf.get("patient_id") or conf.get("user_email") or "pt_jmarshall"
311
 
312
+ # Heuristic: infer patient_id from latest human name if provided (e.g., "I am John Marshall")
313
+ inferred_patient: str | None = None
314
  try:
315
  recent_humans = [m for m in reversed(new_list) if (getattr(m, "type", None) == "human" or getattr(m, "role", None) == "user" or (isinstance(m, dict) and m.get("type") == "human"))]
316
  text = None
 
320
  break
321
  if isinstance(text, str):
322
  tokens = [t for t in text.replace(',', ' ').split() if t.isalpha()]
323
+ if len(tokens) >= 2 and False:
324
+ pass
 
 
 
 
 
 
 
325
  except Exception:
326
  pass
327
 
328
  # Update module context
329
+ global _CURRENT_THREAD_ID, _CURRENT_PATIENT_ID
330
  _CURRENT_THREAD_ID = thread_id
331
+ _CURRENT_PATIENT_ID = inferred_patient or default_patient
332
 
333
  llm_response = call_llm(convo).result()
334
 
examples/voice_agent_webrtc_langgraph/agents/healthcare-agent/tools.py CHANGED
@@ -1,167 +1,93 @@
1
  import os
2
  import sys
3
  import json
4
- from typing import Any, Dict
5
 
6
  from langchain_core.tools import tool
7
 
8
- # Robust logic import to avoid crossing into other agent modules during hot reloads
9
  try:
10
- from . import logic as wt_logic # type: ignore
11
  except Exception:
12
  import importlib.util as _ilu
13
  _dir = os.path.dirname(__file__)
14
  _logic_path = os.path.join(_dir, "logic.py")
15
- _spec = _ilu.spec_from_file_location("wire_transfer_agent_logic", _logic_path)
16
- wt_logic = _ilu.module_from_spec(_spec) # type: ignore
17
  assert _spec and _spec.loader
18
- _spec.loader.exec_module(wt_logic) # type: ignore
19
-
20
- get_accounts = wt_logic.get_accounts
21
- get_profile = wt_logic.get_profile
22
- find_customer_by_name = wt_logic.find_customer_by_name
23
- find_customer_by_full_name = getattr(wt_logic, "find_customer_by_full_name", wt_logic.find_customer_by_name)
24
- get_account_balance = wt_logic.get_account_balance
25
- get_exchange_rate = wt_logic.get_exchange_rate
26
- calculate_wire_fee = wt_logic.calculate_wire_fee
27
- check_wire_limits = wt_logic.check_wire_limits
28
- get_cutoff_and_eta = wt_logic.get_cutoff_and_eta
29
- get_country_requirements = wt_logic.get_country_requirements
30
- validate_beneficiary = wt_logic.validate_beneficiary
31
- save_beneficiary = wt_logic.save_beneficiary
32
- generate_otp = wt_logic.generate_otp
33
- verify_otp = wt_logic.verify_otp
34
- authenticate_user_wire = wt_logic.authenticate_user_wire
35
- quote_wire = wt_logic.quote_wire
36
- wire_transfer_domestic_logic = wt_logic.wire_transfer_domestic
37
- wire_transfer_international_logic = wt_logic.wire_transfer_international
38
 
39
-
40
- @tool
41
- def list_accounts(customer_id: str) -> str:
42
- """List customer's accounts with masked numbers, balances, currency, and wire eligibility. Returns JSON string."""
43
- return json.dumps(get_accounts(customer_id))
44
-
45
-
46
- @tool
47
- def get_customer_profile(customer_id: str) -> str:
48
- """Fetch basic customer profile (full_name, dob, ssn_last4, secret question). Returns JSON string."""
49
- return json.dumps(get_profile(customer_id))
50
 
51
 
52
  @tool
53
- def find_customer(first_name: str | None = None, last_name: str | None = None, full_name: str | None = None) -> str:
54
- """Find a customer_id by name. Prefer full_name; otherwise use first and last name. Returns JSON with customer_id or {}."""
55
  if isinstance(full_name, str) and full_name.strip():
56
- return json.dumps(find_customer_by_full_name(full_name))
57
- return json.dumps(find_customer_by_name(first_name or "", last_name or ""))
58
 
59
 
60
  @tool
61
- def find_account_by_last4(customer_id: str, last4: str) -> str:
62
- """Find a customer's account by last 4 digits. Returns JSON with account or {} if not found."""
63
- accts = get_accounts(customer_id)
64
- for a in accts:
65
- num = str(a.get("account_number") or "")
66
- if num.endswith(str(last4)):
67
- return json.dumps(a)
68
- return json.dumps({})
69
 
70
 
71
  @tool
72
- def verify_identity(session_id: str, customer_id: str | None = None, full_name: str | None = None, dob_yyyy_mm_dd: str | None = None, ssn_last4: str | None = None, secret_answer: str | None = None) -> str:
73
- """Verify user identity before wires. Provide any of: full_name, dob (YYYY-MM-DD), ssn_last4, secret_answer. Returns JSON with verified flag, needed fields, and optional secret question."""
74
- res = authenticate_user_wire(session_id, customer_id, full_name, dob_yyyy_mm_dd, ssn_last4, secret_answer)
75
  return json.dumps(res)
76
 
77
 
78
  @tool
79
- def get_account_balance_tool(account_id: str) -> str:
80
- """Get balance, currency, and wire limits for an account. Returns JSON."""
81
- return json.dumps(get_account_balance(account_id))
82
-
83
-
84
- @tool
85
- def get_exchange_rate_tool(from_currency: str, to_currency: str, amount: float) -> str:
86
- """Get exchange rate and converted amount for a given amount. Returns JSON."""
87
- return json.dumps(get_exchange_rate(from_currency, to_currency, amount))
88
 
89
 
90
  @tool
91
- def calculate_wire_fee_tool(kind: str, amount: float, from_currency: str, to_currency: str, payer: str) -> str:
92
- """Calculate wire fee breakdown and who pays (OUR/SHA/BEN). Returns JSON."""
93
- return json.dumps(calculate_wire_fee(kind, amount, from_currency, to_currency, payer))
94
 
95
 
96
  @tool
97
- def check_wire_limits_tool(account_id: str, amount: float) -> str:
98
- """Check sufficient funds and daily wire limit on an account. Returns JSON."""
99
- return json.dumps(check_wire_limits(account_id, amount))
100
 
101
 
102
  @tool
103
- def get_cutoff_and_eta_tool(kind: str, country: str) -> str:
104
- """Get cutoff time and estimated arrival window by type and country. Returns JSON."""
105
- return json.dumps(get_cutoff_and_eta(kind, country))
106
-
107
-
108
- @tool
109
- def get_country_requirements_tool(country_code: str) -> str:
110
- """Get required beneficiary fields for a country. Returns JSON array."""
111
- return json.dumps(get_country_requirements(country_code))
112
-
113
-
114
- @tool
115
- def validate_beneficiary_tool(country_code: str, beneficiary_json: str) -> str:
116
- """Validate beneficiary fields for a given country. Input is JSON dict string; returns {ok, missing}."""
117
- try:
118
- beneficiary = json.loads(beneficiary_json)
119
- except Exception:
120
- beneficiary = {}
121
- return json.dumps(validate_beneficiary(country_code, beneficiary))
122
 
123
 
124
  @tool
125
- def save_beneficiary_tool(customer_id: str, beneficiary_json: str) -> str:
126
- """Save a beneficiary for future use. Input is JSON dict string; returns {beneficiary_id}."""
127
- try:
128
- beneficiary = json.loads(beneficiary_json)
129
- except Exception:
130
- beneficiary = {}
131
- return json.dumps(save_beneficiary(customer_id, beneficiary))
132
 
133
 
134
  @tool
135
- def quote_wire_tool(kind: str, from_account_id: str, beneficiary_json: str, amount: float, from_currency: str, to_currency: str, payer: str) -> str:
136
- """Create a wire quote including FX, fees, limits, sanctions, eta; returns JSON with quote_id and totals."""
 
137
  try:
138
- beneficiary = json.loads(beneficiary_json)
139
  except Exception:
140
- beneficiary = {}
141
- return json.dumps(quote_wire(kind, from_account_id, beneficiary, amount, from_currency, to_currency, payer))
142
-
143
-
144
- @tool
145
- def generate_otp_tool(customer_id: str) -> str:
146
- """Generate a one-time passcode for wire authorization. Returns masked destination info."""
147
- return json.dumps(generate_otp(customer_id))
148
-
149
-
150
- @tool
151
- def verify_otp_tool(customer_id: str, otp: str) -> str:
152
- """Verify the one-time passcode for wire authorization. Returns {verified}."""
153
- return json.dumps(verify_otp(customer_id, otp))
154
-
155
-
156
- @tool
157
- def wire_transfer_domestic(quote_id: str, otp: str) -> str:
158
- """Execute a domestic wire with a valid quote_id and OTP. Returns confirmation."""
159
- return json.dumps(wire_transfer_domestic_logic(quote_id, otp))
160
-
161
-
162
- @tool
163
- def wire_transfer_international(quote_id: str, otp: str) -> str:
164
- """Execute an international wire with a valid quote_id and OTP. Returns confirmation."""
165
- return json.dumps(wire_transfer_international_logic(quote_id, otp))
166
 
167
 
 
1
  import os
2
  import sys
3
  import json
4
+ from typing import Any, Dict, Optional
5
 
6
  from langchain_core.tools import tool
7
 
8
+ # Robust logic import isolated to this agent
9
  try:
10
+ from . import logic as hc_logic # type: ignore
11
  except Exception:
12
  import importlib.util as _ilu
13
  _dir = os.path.dirname(__file__)
14
  _logic_path = os.path.join(_dir, "logic.py")
15
+ _spec = _ilu.spec_from_file_location("healthcare_agent_logic", _logic_path)
16
+ hc_logic = _ilu.module_from_spec(_spec) # type: ignore
17
  assert _spec and _spec.loader
18
+ _spec.loader.exec_module(hc_logic) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ find_patient_by_name = hc_logic.find_patient_by_name
21
+ find_patient_by_full_name = hc_logic.find_patient_by_full_name
22
+ get_patient_profile = hc_logic.get_patient_profile
23
+ authenticate_patient = hc_logic.authenticate_patient
24
+ get_preferred_pharmacy = hc_logic.get_preferred_pharmacy
25
+ list_providers = hc_logic.list_providers
26
+ get_provider_slots = hc_logic.get_provider_slots
27
+ schedule_appointment = hc_logic.schedule_appointment
28
+ triage_symptoms = hc_logic.triage_symptoms
29
+ log_call = hc_logic.log_call
 
30
 
31
 
32
  @tool
33
+ def find_patient(first_name: str | None = None, last_name: str | None = None, full_name: str | None = None) -> str:
34
+ """Find a patient_id by name. Prefer full_name; otherwise use first+last. Returns JSON with patient_id or {}."""
35
  if isinstance(full_name, str) and full_name.strip():
36
+ return json.dumps(find_patient_by_full_name(full_name))
37
+ return json.dumps(find_patient_by_name(first_name or "", last_name or ""))
38
 
39
 
40
  @tool
41
+ def get_patient_profile_tool(patient_id: str) -> str:
42
+ """Fetch patient profile, allergies, meds, visits, and vitals. Returns JSON string."""
43
+ return json.dumps(get_patient_profile(patient_id))
 
 
 
 
 
44
 
45
 
46
  @tool
47
+ def verify_identity(session_id: str, patient_id: str | None = None, full_name: str | None = None, dob_yyyy_mm_dd: str | None = None, mrn_last4: str | None = None, secret_answer: str | None = None) -> str:
48
+ """Verify identity before accessing records. Provide any of: full_name, dob (YYYY-MM-DD or free-form), MRN last-4, secret answer. Returns JSON with verified flag, needed fields, and optional secret question."""
49
+ res = authenticate_patient(session_id, patient_id, full_name, dob_yyyy_mm_dd, mrn_last4, secret_answer)
50
  return json.dumps(res)
51
 
52
 
53
  @tool
54
+ def get_preferred_pharmacy_tool(patient_id: str) -> str:
55
+ """Get the patient's preferred pharmacy details. Returns JSON."""
56
+ return json.dumps(get_preferred_pharmacy(patient_id))
 
 
 
 
 
 
57
 
58
 
59
  @tool
60
+ def list_providers_tool(specialty: str | None = None) -> str:
61
+ """List available providers. Optional filter by specialty. Returns JSON array."""
62
+ return json.dumps(list_providers(specialty))
63
 
64
 
65
  @tool
66
+ def get_provider_slots_tool(provider_id: str, count: int = 3) -> str:
67
+ """Get upcoming appointment slots for a provider. Returns JSON array of ISO datetimes."""
68
+ return json.dumps(get_provider_slots(provider_id, count))
69
 
70
 
71
  @tool
72
+ def schedule_appointment_tool(provider_id: str, slot_iso: str, patient_id: str | None = None) -> str:
73
+ """Schedule an appointment slot with a provider for a patient. Returns JSON with appointment_id."""
74
+ return json.dumps(schedule_appointment(provider_id, slot_iso, patient_id))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
  @tool
78
+ def triage_symptoms_tool(patient_id: str | None, symptoms_text: str) -> str:
79
+ """Run symptoms through triage rules. Returns {risk, advice, red_flags, rule}."""
80
+ return json.dumps(triage_symptoms(patient_id, symptoms_text))
 
 
 
 
81
 
82
 
83
  @tool
84
+ def log_call_tool(session_id: str, patient_id: str | None = None, notes: str | None = None, triage_json: str | None = None) -> str:
85
+ """Log the call details and triage outcome. triage_json is a JSON dict string. Returns {logged, log_id}."""
86
+ triage: Dict[str, Any] | None
87
  try:
88
+ triage = json.loads(triage_json or "null") if triage_json else None
89
  except Exception:
90
+ triage = None
91
+ return json.dumps(log_call(session_id, patient_id, notes, triage))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
examples/voice_agent_webrtc_langgraph/agents/requirements.txt CHANGED
@@ -10,4 +10,5 @@ pytz
10
  docling
11
  pymongo
12
  yt_dlp
13
- requests
 
 
10
  docling
11
  pymongo
12
  yt_dlp
13
+ requests
14
+ protobuf==6.31.1