Princeaka commited on
Commit
2368b12
·
verified ·
1 Parent(s): fa9ff2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -316
app.py CHANGED
@@ -1,28 +1,16 @@
1
  #!/usr/bin/env python3
2
  """
3
- JusticeAI Backend — app.py
4
-
5
- Features & behavior:
6
- - Two databases:
7
- - DATABASE_URL (engine_user) stores personal user_memory ONLY and is never used to update global knowledge.
8
- - KNOWLEDGEDATABASE_URL (engine_knowledge) stores global knowledge rows used for replies.
9
- - /chat accepts {"message": "..."} (or {"text": "..."}) and:
10
- - infers topic (Ollama first if available, then embeddings, then keyword matching)
11
- - retrieves ONLY from knowledge rows in that topic (strict topic isolation)
12
- - composes a reply from topic-scoped knowledge (no automatic injection of user chats into knowledge)
13
- - returns the reply in the user's detected language (translation via language.py if present or Helsinki fallback if transformers available)
14
- - persists the user message and the reply into engine_user.user_memory and prunes to the last 10 messages per user
15
- - blocks storing toxic messages using the moderator pipeline (if available)
16
- - All endpoints included: /chat, /response, /add, /add-bulk, /leaderboard, /reembed, /model-status,
17
- /health, /metrics_stream, /metrics_recent, /verify-admin, /cleardatabase, / (frontend).
18
- - Ollama integration: uses HTTP (if ollama serve) or CLI (ollama run) to infer topic semantically if possible.
19
- - Optional models: SentenceTransformer for embeddings and transformers (Helsinki) for translation; code runs without them using fallbacks.
20
-
21
- Deployment notes:
22
- - Set DATABASE_URL and KNOWLEDGEDATABASE_URL environment variables.
23
- - Optionally install dependencies for better features:
24
- pip install sentence-transformers transformers torch langdetect emoji hf-cli
25
- - To enable Ollama model auto-pull at startup set OLLAMA_AUTO_PULL=1 and ensure ollama CLI exists.
26
  """
27
 
28
  from sqlalchemy.pool import NullPool
@@ -36,31 +24,45 @@ import subprocess
36
  import shutil
37
  import logging
38
  import random
 
 
 
39
  from datetime import datetime, timezone
40
  from collections import deque
41
  from typing import Optional, Dict, Any, List
42
 
43
- from fastapi import FastAPI, Request, Body, Query, Header
44
- from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse
45
  from sqlalchemy import create_engine, text as sql_text
46
 
47
  # external helpers
48
  import requests
49
 
50
- # Optional ML libs
 
 
 
 
 
51
  try:
52
  from sentence_transformers import SentenceTransformer
53
  except Exception:
54
  SentenceTransformer = None
55
 
56
  try:
57
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline as hf_pipeline
58
  except Exception:
59
  AutoTokenizer = None
60
  AutoModelForSeq2SeqLM = None
61
- AutoModelForCausalLM = None
62
  hf_pipeline = None
63
 
 
 
 
 
 
 
 
64
  # Optional local modules
65
  try:
66
  import language as language_module # type: ignore
@@ -83,7 +85,7 @@ try:
83
  except Exception:
84
  detect_lang = None
85
 
86
- # Moderator pipeline (text-classification) - optional
87
  moderator = None
88
  try:
89
  if hf_pipeline is not None:
@@ -91,7 +93,7 @@ try:
91
  except Exception:
92
  moderator = None
93
 
94
- # Config
95
  ADMIN_KEY = os.environ.get("ADMIN_KEY")
96
  DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:///justice_user.db")
97
  KNOWLEDGEDATABASE_URL = os.environ.get("KNOWLEDGEDATABASE_URL", DATABASE_URL)
@@ -104,15 +106,23 @@ OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "llama3")
104
  OLLAMA_HTTP_URL = os.environ.get("OLLAMA_HTTP_URL", "http://localhost:11434")
105
  OLLAMA_AUTO_PULL = os.environ.get("OLLAMA_AUTO_PULL", "0") in ("1", "true", "yes")
106
 
 
 
 
 
 
 
 
 
107
  # Logging
108
  logging.basicConfig(level=logging.INFO)
109
  logger = logging.getLogger("justicebrain")
110
 
111
- # Early heartbeat & start time
112
  last_heartbeat = {"time": datetime.utcnow().replace(tzinfo=timezone.utc).isoformat(), "ok": True}
113
  app_start_time = time.time()
114
 
115
- # Engines (user memory and knowledge separate)
116
  engine_user = create_engine(
117
  DATABASE_URL,
118
  poolclass=NullPool,
@@ -128,7 +138,7 @@ app = FastAPI(title="Justice Brain — Backend")
128
 
129
  # --- Database schema setup ---
130
  def ensure_tables():
131
- # knowledge table in knowledge DB
132
  dialect_k = engine_knowledge.dialect.name
133
  with engine_knowledge.begin() as conn:
134
  if dialect_k == "sqlite":
@@ -165,7 +175,7 @@ def ensure_tables():
165
  updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
166
  );
167
  """))
168
- # user memory table in user DB
169
  dialect_u = engine_user.dialect.name
170
  with engine_user.begin() as conn:
171
  if dialect_u == "sqlite":
@@ -209,7 +219,6 @@ def ensure_tables():
209
 
210
  ensure_tables()
211
 
212
- # add columns if missing (best-effort; uses engine_user but applied generically)
213
  def ensure_column_exists(table: str, column: str, col_def_sql: str):
214
  dialect = engine_user.dialect.name
215
  try:
@@ -292,7 +301,7 @@ def emoji_sentiment_score(emojis: List[str]) -> float:
292
  score += 0.1
293
  return max(-1.0, min(1.0, score / max(1, len(emojis))))
294
 
295
- # Language detection & translation (language.py preferred)
296
  _translation_model_cache: Dict[str, Any] = {}
297
 
298
  def detect_language_safe(text: str) -> str:
@@ -352,7 +361,6 @@ def translate_text(text: str, src: str, tgt: str) -> str:
352
  return out
353
  except Exception:
354
  pass
355
- # Helsinki fallback
356
  src_code = (src or "und").split("-")[0].lower()
357
  tgt_code = (tgt or "und").split("-")[0].lower()
358
  if not re.fullmatch(r"[a-z]{2,3}", src_code) or not re.fullmatch(r"[a-z]{2,3}", tgt_code):
@@ -391,7 +399,7 @@ def translate_from_english(text: str, tgt_lang: str) -> str:
391
  return text
392
  return translate_text(text, "en", tgt)
393
 
394
- # Embedding utilities (optional)
395
  embed_model = None
396
  def try_load_embed():
397
  global embed_model
@@ -414,71 +422,13 @@ def embed_to_bytes(text: str) -> Optional[bytes]:
414
  except Exception:
415
  return None
416
 
417
- # Boilerplate detection + reply synthesis
418
- def is_boilerplate_candidate(s: str) -> bool:
419
- s_low = (s or "").strip().lower()
420
- generic = ["i don't know", "not sure", "maybe", "perhaps", "justiceai is a unified intelligence dashboard"]
421
- if len(s_low) < 8:
422
- return True
423
- return any(g in s_low for g in generic)
424
-
425
- def generate_creative_reply(candidates: List[str]) -> str:
426
- all_sent = []
427
- seen = set()
428
- for c in candidates:
429
- for s in re.split(r'(?<=[.?!])\s+', c):
430
- st = s.strip()
431
- if not st or st in seen or is_boilerplate_candidate(st):
432
- continue
433
- seen.add(st)
434
- all_sent.append(st)
435
- if not all_sent:
436
- return "I don't have enough context yet — can you give more details?"
437
- return "\n".join(all_sent[:5])
438
-
439
- # Duplicate detection within topic
440
- def knowledge_text_exists_in_topic(text: str, topic: str, threshold: float = 0.92) -> bool:
441
- t = (text or "").strip()
442
- if not t:
443
- return False
444
- try:
445
- with engine_knowledge.begin() as conn:
446
- rows = conn.execute(sql_text("SELECT id, text FROM knowledge WHERE topic = :topic LIMIT 200"), {"topic": topic}).fetchall()
447
- for r in rows:
448
- existing = (r[1] or "").strip()
449
- if existing.lower() == t.lower():
450
- return True
451
- if embed_model is not None and rows:
452
- texts = [r[1] or "" for r in rows]
453
- embs = embed_model.encode(texts, convert_to_tensor=True)
454
- q_emb = embed_model.encode([t], convert_to_tensor=True)[0]
455
- import torch
456
- sims = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), embs)
457
- if float(torch.max(sims).item()) >= threshold:
458
- return True
459
- except Exception:
460
- pass
461
- return False
462
-
463
- # Topic inference fallback (embeddings/keywords)
464
- def infer_topic_from_message(msg: str, known_topics: List[str]) -> str:
465
- msg_low = (msg or "").lower()
466
- for topic in known_topics or []:
467
- if topic and topic.lower() in msg_low:
468
- return topic
469
- if embed_model is not None and known_topics:
470
- try:
471
- import torch
472
- topic_embs = embed_model.encode(known_topics, convert_to_tensor=True)
473
- q_emb = embed_model.encode([msg], convert_to_tensor=True)[0]
474
- sims = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), topic_embs)
475
- best_idx = int(torch.argmax(sims).item())
476
- return known_topics[best_idx]
477
- except Exception:
478
- pass
479
- return "general"
480
 
481
- # Ollama helpers
482
  def ollama_cli_available() -> bool:
483
  return shutil.which("ollama") is not None
484
 
@@ -489,19 +439,18 @@ def ollama_http_available() -> bool:
489
  except Exception:
490
  return False
491
 
492
- def call_ollama_http(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = 10) -> Optional[str]:
493
  try:
494
  url = f"{OLLAMA_HTTP_URL}/api/generate"
495
  payload = {"model": model, "prompt": prompt, "max_tokens": 256}
496
  headers = {"Content-Type": "application/json"}
497
- r = requests.post(url, json=payload, headers=headers, timeout=timeout_s)
498
  if r.status_code == 200:
499
  try:
500
  obj = r.json()
501
- if isinstance(obj, dict):
502
- for key in ("output", "text", "result", "generations"):
503
- if key in obj:
504
- return obj[key] if isinstance(obj[key], str) else json.dumps(obj[key])
505
  return r.text
506
  except Exception:
507
  return r.text
@@ -512,11 +461,11 @@ def call_ollama_http(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = 10
512
  logger.debug(f"ollama HTTP call failed: {e}")
513
  return None
514
 
515
- def call_ollama_cli(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = 15) -> Optional[str]:
516
  if not ollama_cli_available():
517
  return None
518
  try:
519
- proc = subprocess.run(["ollama", "run", model, "--prompt", prompt], capture_output=True, text=True, timeout=timeout_s)
520
  if proc.returncode == 0:
521
  return proc.stdout.strip() or proc.stderr.strip()
522
  else:
@@ -526,7 +475,7 @@ def call_ollama_cli(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = 15)
526
  logger.debug(f"ollama CLI call exception: {e}")
527
  return None
528
 
529
- def infer_topic_with_ollama(msg: str, topics: List[str], model: str = OLLAMA_MODEL, timeout_s: int = 8) -> Optional[str]:
530
  if not msg or not topics:
531
  return None
532
  topics_escaped = [t.replace('"','\\"') for t in topics]
@@ -580,12 +529,188 @@ def infer_topic_with_ollama(msg: str, topics: List[str], model: str = OLLAMA_MOD
580
  pass
581
  return None
582
 
583
- # Metrics & cache state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  recent_request_times = deque()
585
  recent_learning_timestamps = deque()
586
  response_time_ema: Optional[float] = None
587
  EMA_ALPHA = 0.2
588
- knowledge_version = 0
589
 
590
  def record_request(duration_s: float):
591
  global response_time_ema
@@ -604,14 +729,20 @@ def record_learn_event():
604
  while recent_learning_timestamps and recent_learning_timestamps[0] < ts - 3600:
605
  recent_learning_timestamps.popleft()
606
 
607
- # Startup tasks
608
  @app.on_event("startup")
609
  async def startup_event():
610
- logger.info("[JusticeAI] startup: attempting to load optional components")
611
- try:
612
- try_load_embed()
613
- except Exception as e:
614
- logger.warning(f"[startup] embed load issue: {e}")
 
 
 
 
 
 
615
  if OLLAMA_AUTO_PULL and ollama_cli_available():
616
  try:
617
  subprocess.run(["ollama", "pull", OLLAMA_MODEL], timeout=300)
@@ -620,8 +751,7 @@ async def startup_event():
620
  logger.debug(f"[startup] ollama pull failed: {e}")
621
  logger.info("[JusticeAI] startup complete")
622
 
623
- # --- Endpoints ---
624
-
625
  @app.post("/add")
626
  async def add_knowledge(data: dict = Body(...)):
627
  if not isinstance(data, dict):
@@ -642,7 +772,10 @@ async def add_knowledge(data: dict = Body(...)):
642
  return JSONResponse(status_code=400, content={"error": "translation failed"})
643
  emb_bytes = None
644
  if embed_model is not None:
645
- emb_bytes = embed_to_bytes(text_data)
 
 
 
646
  try:
647
  with engine_knowledge.begin() as conn:
648
  if emb_bytes:
@@ -655,13 +788,8 @@ async def add_knowledge(data: dict = Body(...)):
655
  "INSERT INTO knowledge (text, reply, language, category, topic, confidence, meta) "
656
  "VALUES (:t, :r, :lang, 'manual', :topic, :conf, :meta)"
657
  ), {"t": text_data, "r": reply, "lang": detected, "topic": topic, "conf": 0.9, "meta": json.dumps({"manual": True})})
658
- global knowledge_version
659
- knowledge_version += 1
660
  record_learn_event()
661
- res = {"status": "✅ Knowledge added", "text": text_data, "topic": topic, "language": detected}
662
- if not emb_bytes:
663
- res["note"] = "stored without embedding"
664
- return res
665
  except Exception as e:
666
  logger.exception("add failed")
667
  return JSONResponse(status_code=500, content={"error": "failed to store knowledge", "details": str(e)})
@@ -684,32 +812,132 @@ async def add_bulk(data: List[dict] = Body(...)):
684
  detected = detect_language_safe(text_data) or "und"
685
  if detected not in ("en", "eng", "und"):
686
  errors.append({"index": i, "error": "non-english; skip"}); continue
687
- emb_bytes = embed_to_bytes(text_data) if embed_model is not None else None
 
 
 
 
 
688
  with engine_knowledge.begin() as conn:
689
  if emb_bytes:
690
  conn.execute(sql_text(
691
- "INSERT INTO knowledge (text, reply, language, embedding, category, topic) "
692
- "VALUES (:t, :r, :lang, :e, 'manual', :topic)"
693
  ), {"t": text_data, "r": reply, "lang": "en", "e": emb_bytes, "topic": topic})
694
  else:
695
  conn.execute(sql_text(
696
- "INSERT INTO knowledge (text, reply, language, category, topic) "
697
- "VALUES (:t, :r, :lang, 'manual', :topic)"
698
  ), {"t": text_data, "r": reply, "lang": "en", "topic": topic})
699
  added += 1
700
  except Exception as e:
701
  logger.exception("add-bulk item error")
702
  errors.append({"index": i, "error": str(e)})
703
  if added:
704
- global knowledge_version
705
- knowledge_version += 1
706
  record_learn_event()
707
  return {"added": added, "errors": errors}
708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  @app.post("/chat")
710
  async def chat(request: Request, data: dict = Body(...)):
711
  t0 = time.time()
712
- # Accept both "message" and "text"
713
  if isinstance(data, dict):
714
  raw_msg = str(data.get("message", "") or data.get("text", "") or "").strip()
715
  else:
@@ -727,7 +955,7 @@ async def chat(request: Request, data: dict = Body(...)):
727
  detected_lang = detect_language_safe(raw_msg)
728
  reply_lang = detected_lang if detected_lang and detected_lang != "und" else "en"
729
 
730
- # Translate incoming to English for retrieval/synthesis if needed
731
  en_msg = raw_msg
732
  if detected_lang not in ("en", "eng", "", "und"):
733
  try:
@@ -735,7 +963,7 @@ async def chat(request: Request, data: dict = Body(...)):
735
  except Exception:
736
  en_msg = raw_msg
737
 
738
- # Infer topic (Ollama -> embed -> keyword)
739
  topic = "general"
740
  try:
741
  if not topic_hint:
@@ -758,7 +986,7 @@ async def chat(request: Request, data: dict = Body(...)):
758
  except Exception:
759
  topic = topic_hint or "general"
760
 
761
- # Moderation on incoming message
762
  flags = {}
763
  try:
764
  if moderator is not None:
@@ -771,43 +999,45 @@ async def chat(request: Request, data: dict = Body(...)):
771
  except Exception:
772
  pass
773
 
774
- # IMPORTANT: Do NOT auto-add incoming message to global knowledge.
775
- # We'll store into engine_user.user_memory only (personal).
776
-
777
- # Load knowledge entries for this topic only
778
  try:
779
  with engine_knowledge.begin() as conn:
780
- rows = conn.execute(sql_text(
781
- "SELECT id, text, reply, language, embedding FROM knowledge WHERE topic = :topic ORDER BY created_at DESC"
782
- ), {"topic": topic}).fetchall()
783
  except Exception as e:
784
  record_request(time.time() - t0)
785
  return JSONResponse(status_code=500, content={"error": "failed to read knowledge", "details": str(e)})
786
 
787
  knowledge_rows = [{"id": r[0], "text": r[1] or "", "reply": r[2] or "", "lang": r[3] or "und", "embedding": r[4]} for r in rows]
788
 
789
- # Retrieval (embedding-first then substring) restricted to topic
790
  matches: List[str] = []
791
  confidence = 0.0
792
  try:
793
  if embed_model is not None and knowledge_rows:
794
  texts = [kr["text"] for kr in knowledge_rows]
795
- embs = embed_model.encode(texts, convert_to_tensor=True)
796
- q_emb = embed_model.encode([en_msg], convert_to_tensor=True)[0]
797
- import torch
798
- scores = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), embs)
799
- cand = []
800
- for i in range(scores.shape[0]):
801
- s = float(scores[i])
802
- kr = knowledge_rows[i]
803
- candidate_text = (kr["reply"] or kr["text"]).strip()
804
- if is_boilerplate_candidate(candidate_text):
805
- continue
806
- if s >= 0.30:
807
- cand.append({"text": candidate_text, "lang": kr["lang"], "score": s})
808
- cand = sorted(cand, key=lambda x: -x["score"])
809
- matches = [c["text"] for c in cand]
810
- confidence = cand[0]["score"] if cand else 0.0
 
 
 
 
 
 
 
811
  else:
812
  cand = []
813
  for kr in knowledge_rows:
@@ -822,7 +1052,7 @@ async def chat(request: Request, data: dict = Body(...)):
822
  logger.warning(f"[retrieval] error: {e}")
823
  matches = []
824
 
825
- # Compose reply from topic-only knowledge
826
  if matches and confidence >= 0.6:
827
  reply_en = matches[0]
828
  elif matches:
@@ -835,7 +1065,6 @@ async def chat(request: Request, data: dict = Body(...)):
835
  except Exception:
836
  pass
837
  reply_final = base
838
- # Persist user memory (even when no confident match), skipping toxic
839
  try:
840
  if not flags.get('toxic', False):
841
  with engine_user.begin() as conn:
@@ -844,20 +1073,18 @@ async def chat(request: Request, data: dict = Body(...)):
844
  "VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
845
  ), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
846
  "mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
847
- # prune to last 10 per user
848
  conn.execute(sql_text(
849
- "DELETE FROM user_memory WHERE id NOT IN ("
850
- "SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
851
  ), {"uid": user_id})
852
  except Exception as e:
853
  logger.debug(f"user_memory store error: {e}")
854
  record_request(time.time() - t0)
855
- return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": "", "confidence": round(confidence, 2), "flags": flags}
856
 
857
- # Postprocess reply (dedupe)
858
  reply_en = dedupe_sentences(reply_en)
859
 
860
- # Ensure translation into user's language (robust)
861
  reply_final = reply_en
862
  lang_code = (reply_lang or "und").split("-")[0].lower()
863
  if lang_code not in ("en", "eng", "und", ""):
@@ -868,7 +1095,7 @@ async def chat(request: Request, data: dict = Body(...)):
868
  logger.warning(f"[translation] failed to translate reply_en -> {lang_code}: {exc}")
869
  reply_final = reply_en
870
 
871
- # Mood & emoji (non-intrusive)
872
  emoji = ""
873
  try:
874
  mood = detect_mood(raw_msg + " " + reply_final)
@@ -883,7 +1110,7 @@ async def chat(request: Request, data: dict = Body(...)):
883
  except Exception:
884
  emoji = ""
885
 
886
- # Persist user memory into DATABASE_URL only (engine_user) and prune to last 10
887
  try:
888
  if not flags.get('toxic', False):
889
  with engine_user.begin() as conn:
@@ -892,10 +1119,8 @@ async def chat(request: Request, data: dict = Body(...)):
892
  "VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
893
  ), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
894
  "mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
895
- # prune to last 10 per user
896
  conn.execute(sql_text(
897
- "DELETE FROM user_memory WHERE id NOT IN ("
898
- "SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
899
  ), {"uid": user_id})
900
  except Exception as e:
901
  logger.debug(f"user_memory persist error: {e}")
@@ -906,126 +1131,12 @@ async def chat(request: Request, data: dict = Body(...)):
906
  if include_steps:
907
  reply_final = f"{reply_final}\n\n[Debug: topic={topic} confidence={round(confidence,2)}]"
908
 
909
- return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": emoji, "confidence": round(confidence, 2), "flags": flags}
910
 
911
  @app.post("/response")
912
  async def response_wrapper(request: Request, data: dict = Body(...)):
913
  return await chat(request, data)
914
 
915
- @app.get("/leaderboard")
916
- async def leaderboard(topic: str = Query("general")):
917
- t = str(topic or "general").strip() or "general"
918
- try:
919
- with engine_knowledge.begin() as conn:
920
- rows = conn.execute(sql_text("""
921
- SELECT id, text, reply, language, category, confidence, created_at
922
- FROM knowledge
923
- WHERE topic = :topic
924
- ORDER BY confidence DESC, created_at DESC
925
- LIMIT 20
926
- """), {"topic": t}).fetchall()
927
- out = []
928
- for r in rows:
929
- text_en = r[1] or ""
930
- lang = r[3] or "und"
931
- display_text = text_en
932
- if lang and lang not in ("en", "eng", "", "und"):
933
- try:
934
- display_text = translate_to_english(text_en, lang)
935
- except Exception:
936
- display_text = text_en
937
- created_at = r[6]
938
- out.append({
939
- "id": r[0],
940
- "text": display_text,
941
- "reply": r[2],
942
- "language": lang,
943
- "category": r[4],
944
- "confidence": round(r[5] or 0.0, 2),
945
- "created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else str(created_at)
946
- })
947
- return {"topic": t, "top_20": out}
948
- except Exception as e:
949
- logger.exception("leaderboard failed")
950
- return JSONResponse(status_code=500, content={"error": "failed to fetch leaderboard", "details": str(e)})
951
-
952
- @app.post("/reembed")
953
- async def reembed_all(data: dict = Body(...), x_admin_key: str = Header(None, alias="X-Admin-Key")):
954
- if ADMIN_KEY is None:
955
- return JSONResponse(status_code=403, content={"error": "Server not configured for admin operations."})
956
- if x_admin_key != ADMIN_KEY:
957
- return JSONResponse(status_code=403, content={"error": "Invalid admin key."})
958
- if embed_model is None:
959
- return JSONResponse(status_code=503, content={"error": "Embedding model not ready."})
960
- confirm = str(data.get("confirm", "") or "").strip()
961
- if confirm != "REEMBED":
962
- return JSONResponse(status_code=400, content={"error": "confirm token required."})
963
- batch_size = int(data.get("batch_size", 100))
964
- try:
965
- with engine_knowledge.begin() as conn:
966
- rows = conn.execute(sql_text("SELECT id, text FROM knowledge ORDER BY id")).fetchall()
967
- ids_texts = [(r[0], r[1]) for r in rows]
968
- total = len(ids_texts)
969
- updated = 0
970
- for i in range(0, total, batch_size):
971
- batch = ids_texts[i:i+batch_size]
972
- texts = [t for _, t in batch]
973
- embs = embed_model.encode(texts, convert_to_tensor=True)
974
- for j, (kid, _) in enumerate(batch):
975
- emb_bytes = embs[j].cpu().numpy().tobytes()
976
- with engine_knowledge.begin() as conn:
977
- conn.execute(sql_text("UPDATE knowledge SET embedding = :e, updated_at = CURRENT_TIMESTAMP WHERE id = :id"), {"e": emb_bytes, "id": kid})
978
- updated += 1
979
- return {"status": "✅ Re-embed complete", "total_rows": total, "updated": updated}
980
- except Exception as e:
981
- logger.exception("reembed failed")
982
- return JSONResponse(status_code=500, content={"error": "reembed failed", "details": str(e)})
983
-
984
- @app.get("/model-status")
985
- async def model_status():
986
- return {
987
- "embed_loaded": embed_model is not None,
988
- "ollama_cli": ollama_cli_available(),
989
- "ollama_http": ollama_http_available(),
990
- "moderator": moderator is not None,
991
- "language_module": LANGUAGE_MODULE_AVAILABLE
992
- }
993
-
994
- @app.get("/health")
995
- async def health():
996
- try:
997
- with engine_knowledge.connect() as c:
998
- k = c.execute(sql_text("SELECT COUNT(*) FROM knowledge")).scalar() or 0
999
- except Exception:
1000
- k = -1
1001
- try:
1002
- with engine_user.connect() as c:
1003
- u = c.execute(sql_text("SELECT COUNT(*) FROM user_memory")).scalar() or 0
1004
- except Exception:
1005
- u = -1
1006
- return {"ok": True, "knowledge_count": int(k), "user_memory_count": int(u), "uptime_s": round(time.time() - app_start_time, 2), "heartbeat": last_heartbeat}
1007
-
1008
- async def metrics_producer():
1009
- while True:
1010
- try:
1011
- import psutil
1012
- cpu = psutil.cpu_percent(interval=None)
1013
- mem = psutil.virtual_memory()
1014
- mem_percent = mem.percent
1015
- except Exception:
1016
- cpu = 0.0; mem_percent = 0.0
1017
- payload = {"time": datetime.utcnow().isoformat(), "cpu_percent": cpu, "memory_percent": mem_percent}
1018
- yield f"data: {json.dumps(payload)}\n\n"
1019
- await asyncio.sleep(1.0)
1020
-
1021
- @app.get("/metrics_stream")
1022
- async def metrics_stream():
1023
- return StreamingResponse(metrics_producer(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
1024
-
1025
- @app.get("/metrics_recent")
1026
- async def metrics_recent(limit: int = Query(100, ge=1, le=600)):
1027
- return {"count": 0, "metrics": []}
1028
-
1029
  @app.post("/verify-admin")
1030
  async def verify_admin(x_admin_key: str = Header(None, alias="X-Admin-Key")):
1031
  if ADMIN_KEY is None:
@@ -1085,33 +1196,19 @@ async def frontend_dashboard():
1085
  html = html.replace("%%STARTUP_TIME%%", str(startup_time_local))
1086
  return HTMLResponse(html)
1087
 
1088
- # small helpers referenced above
1089
- def detect_mood(text: str) -> str:
1090
- lower = (text or "").lower()
1091
- positive = ["great", "thanks", "awesome", "happy", "love", "excellent", "cool", "yes", "good"]
1092
- negative = ["sad", "bad", "problem", "angry", "hate", "fail", "no", "error", "issue"]
1093
- if any(w in lower for w in positive):
1094
- return "positive"
1095
- if any(w in lower for w in negative):
1096
- return "negative"
1097
- return "neutral"
1098
-
1099
- def should_append_emoji(user_text: str, reply_text: str, mood: str, flags: Dict) -> str:
1100
- if flags.get("toxic"):
1101
- return ""
1102
- if EMOJIS_AVAILABLE:
1103
  try:
1104
- cat = get_category_for_mood(mood)
1105
- return get_emoji(cat, 0.6)
1106
  except Exception:
1107
- return ""
1108
- return ""
1109
-
1110
- if __name__ == "__main__":
1111
- try:
1112
- try_load_embed()
1113
- except Exception:
1114
- pass
1115
  app_start_time = time.time()
1116
  import uvicorn
1117
  port = int(os.environ.get("PORT", 7860))
 
1
  #!/usr/bin/env python3
2
  """
3
+ JusticeAI Backend — merged app.py
4
+
5
+ This file:
6
+ - Consolidates the JusticeAI backend (knowledge DB, user DB, /chat and other endpoints)
7
+ - Integrates Ollama topic inference (HTTP/CLI optional)
8
+ - Integrates optional embeddings (SentenceTransformer) and optional Helsinki translation models
9
+ - Adds a TTS /speak endpoint (voice cloning) using TTS.api (Coqui TTS) with optimizations for speed
10
+ - Keeps strict separation: user chat stored only in DATABASE_URL.user_memory and never used to mutate the global knowledge DB
11
+ - Prunes user_memory to the last 10 messages per user
12
+ - Attempts to minimize TTS latency by preloading, using GPU if available, using inference_mode / autocast,
13
+ and caching identical speaker samples by file hash
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
  from sqlalchemy.pool import NullPool
 
24
  import shutil
25
  import logging
26
  import random
27
+ import tempfile
28
+ import uuid
29
+ import asyncio
30
  from datetime import datetime, timezone
31
  from collections import deque
32
  from typing import Optional, Dict, Any, List
33
 
34
+ from fastapi import FastAPI, Request, Body, Query, Header, BackgroundTasks, File, UploadFile, Form, HTTPException, status
35
+ from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, FileResponse
36
  from sqlalchemy import create_engine, text as sql_text
37
 
38
  # external helpers
39
  import requests
40
 
41
+ # ML libs (optional)
42
+ try:
43
+ import torch
44
+ except Exception:
45
+ torch = None
46
+
47
  try:
48
  from sentence_transformers import SentenceTransformer
49
  except Exception:
50
  SentenceTransformer = None
51
 
52
  try:
53
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline as hf_pipeline
54
  except Exception:
55
  AutoTokenizer = None
56
  AutoModelForSeq2SeqLM = None
 
57
  hf_pipeline = None
58
 
59
+ # Optional TTS library (Coqui TTS)
60
+ try:
61
+ from TTS.api import TTS
62
+ TTS_AVAILABLE = True
63
+ except Exception:
64
+ TTS_AVAILABLE = False
65
+
66
  # Optional local modules
67
  try:
68
  import language as language_module # type: ignore
 
85
  except Exception:
86
  detect_lang = None
87
 
88
+ # Moderator pipeline (optional)
89
  moderator = None
90
  try:
91
  if hf_pipeline is not None:
 
93
  except Exception:
94
  moderator = None
95
 
96
+ # Config (env)
97
  ADMIN_KEY = os.environ.get("ADMIN_KEY")
98
  DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:///justice_user.db")
99
  KNOWLEDGEDATABASE_URL = os.environ.get("KNOWLEDGEDATABASE_URL", DATABASE_URL)
 
106
  OLLAMA_HTTP_URL = os.environ.get("OLLAMA_HTTP_URL", "http://localhost:11434")
107
  OLLAMA_AUTO_PULL = os.environ.get("OLLAMA_AUTO_PULL", "0") in ("1", "true", "yes")
108
 
109
+ # TTS configuration and speed options
110
+ TTS_MODEL_NAME = os.environ.get("TTS_MODEL_NAME", "tts_models/multilingual/multi-dataset/xtts_v2")
111
+ TTS_DEVICE = os.environ.get("TTS_DEVICE", "cuda" if (torch is not None and torch.cuda.is_available()) else "cpu")
112
+ TTS_USE_HALF = os.environ.get("TTS_USE_HALF", "1") in ("1", "true", "yes")
113
+
114
+ # Non-TTS operation timeout (for blocking calls we choose to limit)
115
+ MODEL_TIMEOUT = float(os.environ.get("MODEL_TIMEOUT", "10"))
116
+
117
  # Logging
118
  logging.basicConfig(level=logging.INFO)
119
  logger = logging.getLogger("justicebrain")
120
 
121
+ # Heartbeat & startup
122
  last_heartbeat = {"time": datetime.utcnow().replace(tzinfo=timezone.utc).isoformat(), "ok": True}
123
  app_start_time = time.time()
124
 
125
+ # Engines (separate DBs)
126
  engine_user = create_engine(
127
  DATABASE_URL,
128
  poolclass=NullPool,
 
138
 
139
  # --- Database schema setup ---
140
  def ensure_tables():
141
+ # knowledge table
142
  dialect_k = engine_knowledge.dialect.name
143
  with engine_knowledge.begin() as conn:
144
  if dialect_k == "sqlite":
 
175
  updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
176
  );
177
  """))
178
+ # user memory table
179
  dialect_u = engine_user.dialect.name
180
  with engine_user.begin() as conn:
181
  if dialect_u == "sqlite":
 
219
 
220
  ensure_tables()
221
 
 
222
  def ensure_column_exists(table: str, column: str, col_def_sql: str):
223
  dialect = engine_user.dialect.name
224
  try:
 
301
  score += 0.1
302
  return max(-1.0, min(1.0, score / max(1, len(emojis))))
303
 
304
+ # --- Language detection & translation ---
305
  _translation_model_cache: Dict[str, Any] = {}
306
 
307
  def detect_language_safe(text: str) -> str:
 
361
  return out
362
  except Exception:
363
  pass
 
364
  src_code = (src or "und").split("-")[0].lower()
365
  tgt_code = (tgt or "und").split("-")[0].lower()
366
  if not re.fullmatch(r"[a-z]{2,3}", src_code) or not re.fullmatch(r"[a-z]{2,3}", tgt_code):
 
399
  return text
400
  return translate_text(text, "en", tgt)
401
 
402
+ # --- Embeddings utilities ---
403
  embed_model = None
404
  def try_load_embed():
405
  global embed_model
 
422
  except Exception:
423
  return None
424
 
425
+ # --- Helpers for running blocking code with a timeout (for non-TTS operations) ---
426
+ async def run_blocking_with_timeout(func, *args, timeout: float = MODEL_TIMEOUT):
427
+ loop = asyncio.get_running_loop()
428
+ fut = loop.run_in_executor(None, lambda: func(*args))
429
+ return await asyncio.wait_for(fut, timeout=timeout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
+ # --- Ollama helpers (HTTP & CLI) ---
432
  def ollama_cli_available() -> bool:
433
  return shutil.which("ollama") is not None
434
 
 
439
  except Exception:
440
  return False
441
 
442
+ def call_ollama_http(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = MODEL_TIMEOUT) -> Optional[str]:
443
  try:
444
  url = f"{OLLAMA_HTTP_URL}/api/generate"
445
  payload = {"model": model, "prompt": prompt, "max_tokens": 256}
446
  headers = {"Content-Type": "application/json"}
447
+ r = requests.post(url, json=payload, headers=headers, timeout=min(timeout_s, MODEL_TIMEOUT))
448
  if r.status_code == 200:
449
  try:
450
  obj = r.json()
451
+ for key in ("output", "text", "result", "generations"):
452
+ if key in obj:
453
+ return obj[key] if isinstance(obj[key], str) else json.dumps(obj[key])
 
454
  return r.text
455
  except Exception:
456
  return r.text
 
461
  logger.debug(f"ollama HTTP call failed: {e}")
462
  return None
463
 
464
+ def call_ollama_cli(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = MODEL_TIMEOUT) -> Optional[str]:
465
  if not ollama_cli_available():
466
  return None
467
  try:
468
+ proc = subprocess.run(["ollama", "run", model, "--prompt", prompt], capture_output=True, text=True, timeout=min(timeout_s, MODEL_TIMEOUT))
469
  if proc.returncode == 0:
470
  return proc.stdout.strip() or proc.stderr.strip()
471
  else:
 
475
  logger.debug(f"ollama CLI call exception: {e}")
476
  return None
477
 
478
+ def infer_topic_with_ollama(msg: str, topics: List[str], model: str = OLLAMA_MODEL, timeout_s: int = MODEL_TIMEOUT) -> Optional[str]:
479
  if not msg or not topics:
480
  return None
481
  topics_escaped = [t.replace('"','\\"') for t in topics]
 
529
  pass
530
  return None
531
 
532
+ # --- Boilerplate detection & reply synthesis helpers ---
533
+ def is_boilerplate_candidate(s: str) -> bool:
534
+ s_low = (s or "").strip().lower()
535
+ generic = ["i don't know", "not sure", "maybe", "perhaps", "justiceai is a unified intelligence dashboard"]
536
+ if len(s_low) < 8:
537
+ return True
538
+ return any(g in s_low for g in generic)
539
+
540
+ def generate_creative_reply(candidates: List[str]) -> str:
541
+ all_sent = []
542
+ seen = set()
543
+ for c in candidates:
544
+ for s in re.split(r'(?<=[.?!])\s+', c):
545
+ st = s.strip()
546
+ if not st or st in seen or is_boilerplate_candidate(st):
547
+ continue
548
+ seen.add(st)
549
+ all_sent.append(st)
550
+ if not all_sent:
551
+ return "I don't have enough context yet — can you give more details?"
552
+ return "\n".join(all_sent[:5])
553
+
554
+ # --- TTS: optimized loader, caching speaker files ---
555
+ _tts_model = None
556
+ _tts_lock = threading.Lock()
557
+ _speaker_hash_cache: Dict[str, str] = {}
558
+ _tts_loaded_event = threading.Event()
559
+
560
+ def compute_file_sha256(path: str) -> str:
561
+ h = hashlib.sha256()
562
+ with open(path, "rb") as f:
563
+ while True:
564
+ b = f.read(8192)
565
+ if not b:
566
+ break
567
+ h.update(b)
568
+ return h.hexdigest()
569
+
570
+ def get_tts_model_blocking():
571
+ global _tts_model
572
+ if not TTS_AVAILABLE:
573
+ raise RuntimeError("TTS.api not available on server")
574
+ with _tts_lock:
575
+ if _tts_model is None:
576
+ model_name = os.environ.get("TTS_MODEL_NAME", TTS_MODEL_NAME)
577
+ device = os.environ.get("TTS_DEVICE", TTS_DEVICE)
578
+ logger.info(f"[TTS] Loading model {model_name} on device {device}")
579
+ _tts_model = TTS(model_name)
580
+ try:
581
+ if device and torch is not None:
582
+ if device.startswith("cuda") and torch.cuda.is_available():
583
+ try:
584
+ _tts_model.to(device)
585
+ except Exception:
586
+ pass
587
+ try:
588
+ torch.backends.cudnn.benchmark = True
589
+ except Exception:
590
+ pass
591
+ if TTS_USE_HALF:
592
+ try:
593
+ if hasattr(_tts_model, "model") and hasattr(_tts_model.model, "half"):
594
+ _tts_model.model.half()
595
+ except Exception:
596
+ pass
597
+ try:
598
+ torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "4")))
599
+ except Exception:
600
+ pass
601
+ else:
602
+ try:
603
+ torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "4")))
604
+ except Exception:
605
+ pass
606
+ except Exception as e:
607
+ logger.debug(f"[TTS] model device tuning warning: {e}")
608
+ logger.info("[TTS] model loaded")
609
+ _tts_loaded_event.set()
610
+ return _tts_model
611
+
612
+ def _save_upload_file_tmp(upload_file: UploadFile) -> str:
613
+ suffix = os.path.splitext(upload_file.filename)[1] or ".wav"
614
+ fd, tmp_path = tempfile.mkstemp(suffix=suffix, prefix="tts_speaker_")
615
+ os.close(fd)
616
+ with open(tmp_path, "wb") as f:
617
+ content = upload_file.file.read()
618
+ f.write(content)
619
+ return tmp_path
620
+
621
+ # Preload TTS in background at process start
622
+ if TTS_AVAILABLE:
623
+ threading.Thread(target=lambda: (get_tts_model_blocking()), daemon=True).start()
624
+
625
+ @app.post("/speak")
626
+ async def speak(
627
+ background_tasks: BackgroundTasks,
628
+ text: str = Form(...),
629
+ voice_wav: Optional[UploadFile] = File(None),
630
+ language: Optional[str] = Form(None),
631
+ ):
632
+ """
633
+ Generate speech for `text`. Optionally use an uploaded `voice_wav` (WAV) file as speaker sample.
634
+ This endpoint aims for speed by using a preloaded model and GPU/half precision if configured.
635
+ """
636
+ if not text or not text.strip():
637
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Field 'text' is required")
638
+ if not TTS_AVAILABLE:
639
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="TTS engine not available")
640
+
641
+ speaker_path = None
642
+ speaker_hash = None
643
+ if voice_wav is not None:
644
+ try:
645
+ speaker_path = _save_upload_file_tmp(voice_wav)
646
+ speaker_hash = compute_file_sha256(speaker_path)
647
+ cached = _speaker_hash_cache.get(speaker_hash)
648
+ if cached and os.path.exists(cached):
649
+ try:
650
+ os.remove(speaker_path)
651
+ except Exception:
652
+ pass
653
+ speaker_path = cached
654
+ else:
655
+ _speaker_hash_cache[speaker_hash] = speaker_path
656
+ except Exception as e:
657
+ logger.exception("Failed to save uploaded voice sample")
658
+ raise HTTPException(status_code=500, detail="Failed to process uploaded voice sample")
659
+
660
+ out_fd, out_path = tempfile.mkstemp(suffix=".wav", prefix="tts_out_")
661
+ os.close(out_fd)
662
+ background_tasks.add_task(lambda p: os.path.exists(p) and os.remove(p), out_path)
663
+
664
+ try:
665
+ tts = get_tts_model_blocking()
666
+ except Exception as e:
667
+ logger.exception("[TTS] model load failed")
668
+ try:
669
+ if os.path.exists(out_path):
670
+ os.remove(out_path)
671
+ except Exception:
672
+ pass
673
+ raise HTTPException(status_code=500, detail="Failed to load TTS model")
674
+
675
+ kwargs = {}
676
+ if speaker_path:
677
+ kwargs["speaker_wav"] = speaker_path
678
+ if language:
679
+ kwargs["language"] = language
680
+
681
+ try:
682
+ if torch is not None and torch.cuda.is_available() and TTS_USE_HALF:
683
+ try:
684
+ with torch.inference_mode():
685
+ with torch.cuda.amp.autocast():
686
+ tts.tts_to_file(text=text, file_path=out_path, **kwargs)
687
+ except Exception as e:
688
+ logger.debug(f"[TTS] autocast path failed: {e}, falling back")
689
+ with torch.inference_mode():
690
+ tts.tts_to_file(text=text, file_path=out_path, **kwargs)
691
+ else:
692
+ if torch is not None:
693
+ with torch.inference_mode():
694
+ tts.tts_to_file(text=text, file_path=out_path, **kwargs)
695
+ else:
696
+ tts.tts_to_file(text=text, file_path=out_path, **kwargs)
697
+ except Exception as e:
698
+ logger.exception("[TTS] synthesis failed")
699
+ try:
700
+ if os.path.exists(out_path):
701
+ os.remove(out_path)
702
+ except Exception:
703
+ pass
704
+ raise HTTPException(status_code=500, detail="TTS synthesis failed")
705
+
706
+ filename = f"speech-{uuid.uuid4().hex}.wav"
707
+ return FileResponse(path=out_path, filename=filename, media_type="audio/wav", background=background_tasks)
708
+
709
+ # --- Metrics & caches ---
710
  recent_request_times = deque()
711
  recent_learning_timestamps = deque()
712
  response_time_ema: Optional[float] = None
713
  EMA_ALPHA = 0.2
 
714
 
715
  def record_request(duration_s: float):
716
  global response_time_ema
 
729
  while recent_learning_timestamps and recent_learning_timestamps[0] < ts - 3600:
730
  recent_learning_timestamps.popleft()
731
 
732
+ # --- Startup event: warm up optional components ---
733
  @app.on_event("startup")
734
  async def startup_event():
735
+ logger.info("[JusticeAI] startup event beginning")
736
+ # Try to warmup embedding model quickly in background
737
+ if SentenceTransformer is not None:
738
+ def _warm_embed():
739
+ try:
740
+ try_load_embed()
741
+ logger.info("[startup] embed model warmup complete")
742
+ except Exception as e:
743
+ logger.debug(f"[startup] embed warmup issue: {e}")
744
+ threading.Thread(target=_warm_embed, daemon=True).start()
745
+ # Optionally attempt ollama pull (best-effort)
746
  if OLLAMA_AUTO_PULL and ollama_cli_available():
747
  try:
748
  subprocess.run(["ollama", "pull", OLLAMA_MODEL], timeout=300)
 
751
  logger.debug(f"[startup] ollama pull failed: {e}")
752
  logger.info("[JusticeAI] startup complete")
753
 
754
+ # --- Knowledge management endpoints ---
 
755
  @app.post("/add")
756
  async def add_knowledge(data: dict = Body(...)):
757
  if not isinstance(data, dict):
 
772
  return JSONResponse(status_code=400, content={"error": "translation failed"})
773
  emb_bytes = None
774
  if embed_model is not None:
775
+ try:
776
+ emb_bytes = await run_blocking_with_timeout(lambda: embed_to_bytes(text_data), timeout=MODEL_TIMEOUT)
777
+ except Exception:
778
+ emb_bytes = None
779
  try:
780
  with engine_knowledge.begin() as conn:
781
  if emb_bytes:
 
788
  "INSERT INTO knowledge (text, reply, language, category, topic, confidence, meta) "
789
  "VALUES (:t, :r, :lang, 'manual', :topic, :conf, :meta)"
790
  ), {"t": text_data, "r": reply, "lang": detected, "topic": topic, "conf": 0.9, "meta": json.dumps({"manual": True})})
 
 
791
  record_learn_event()
792
+ return {"status": "✅ Knowledge added", "text": text_data, "topic": topic, "language": detected}
 
 
 
793
  except Exception as e:
794
  logger.exception("add failed")
795
  return JSONResponse(status_code=500, content={"error": "failed to store knowledge", "details": str(e)})
 
812
  detected = detect_language_safe(text_data) or "und"
813
  if detected not in ("en", "eng", "und"):
814
  errors.append({"index": i, "error": "non-english; skip"}); continue
815
+ emb_bytes = None
816
+ if embed_model is not None:
817
+ try:
818
+ emb_bytes = await run_blocking_with_timeout(lambda: embed_to_bytes(text_data), timeout=MODEL_TIMEOUT)
819
+ except Exception:
820
+ emb_bytes = None
821
  with engine_knowledge.begin() as conn:
822
  if emb_bytes:
823
  conn.execute(sql_text(
824
+ "INSERT INTO knowledge (text, reply, language, embedding, category, topic) VALUES (:t, :r, :lang, :e, 'manual', :topic)"
 
825
  ), {"t": text_data, "r": reply, "lang": "en", "e": emb_bytes, "topic": topic})
826
  else:
827
  conn.execute(sql_text(
828
+ "INSERT INTO knowledge (text, reply, language, category, topic) VALUES (:t, :r, :lang, 'manual', :topic)"
 
829
  ), {"t": text_data, "r": reply, "lang": "en", "topic": topic})
830
  added += 1
831
  except Exception as e:
832
  logger.exception("add-bulk item error")
833
  errors.append({"index": i, "error": str(e)})
834
  if added:
 
 
835
  record_learn_event()
836
  return {"added": added, "errors": errors}
837
 
838
+ @app.get("/leaderboard")
839
+ async def leaderboard(topic: str = Query("general")):
840
+ t = str(topic or "general").strip() or "general"
841
+ try:
842
+ with engine_knowledge.begin() as conn:
843
+ rows = conn.execute(sql_text("""
844
+ SELECT id, text, reply, language, category, confidence, created_at
845
+ FROM knowledge
846
+ WHERE topic = :topic
847
+ ORDER BY confidence DESC, created_at DESC
848
+ LIMIT 20
849
+ """), {"topic": t}).fetchall()
850
+ out = []
851
+ for r in rows:
852
+ text_en = r[1] or ""
853
+ lang = r[3] or "und"
854
+ display_text = text_en
855
+ if lang and lang not in ("en", "eng", "", "und"):
856
+ try:
857
+ display_text = translate_to_english(text_en, lang)
858
+ except Exception:
859
+ display_text = text_en
860
+ created_at = r[6]
861
+ out.append({
862
+ "id": r[0],
863
+ "text": display_text,
864
+ "reply": r[2],
865
+ "language": lang,
866
+ "category": r[4],
867
+ "confidence": round(r[5] or 0.0, 2),
868
+ "created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else str(created_at)
869
+ })
870
+ return {"topic": t, "top_20": out}
871
+ except Exception as e:
872
+ logger.exception("leaderboard failed")
873
+ return JSONResponse(status_code=500, content={"error": "failed to fetch leaderboard", "details": str(e)})
874
+
875
+ @app.post("/reembed")
876
+ async def reembed_all(data: dict = Body(...), x_admin_key: str = Header(None, alias="X-Admin-Key")):
877
+ if ADMIN_KEY is None:
878
+ return JSONResponse(status_code=403, content={"error": "Server not configured for admin operations."})
879
+ if x_admin_key != ADMIN_KEY:
880
+ return JSONResponse(status_code=403, content={"error": "Invalid admin key."})
881
+ if embed_model is None:
882
+ return JSONResponse(status_code=503, content={"error": "Embedding model not ready."})
883
+ confirm = str(data.get("confirm", "") or "").strip()
884
+ if confirm != "REEMBED":
885
+ return JSONResponse(status_code=400, content={"error": "confirm token required."})
886
+ batch_size = int(data.get("batch_size", 100))
887
+ try:
888
+ with engine_knowledge.begin() as conn:
889
+ rows = conn.execute(sql_text("SELECT id, text FROM knowledge ORDER BY id")).fetchall()
890
+ ids_texts = [(r[0], r[1]) for r in rows]
891
+ total = len(ids_texts)
892
+ updated = 0
893
+ for i in range(0, total, batch_size):
894
+ batch = ids_texts[i:i+batch_size]
895
+ texts = [t for _, t in batch]
896
+ try:
897
+ embs = await run_blocking_with_timeout(lambda: embed_model.encode(texts, convert_to_tensor=True), timeout=MODEL_TIMEOUT)
898
+ except Exception:
899
+ embs = None
900
+ if embs is None:
901
+ continue
902
+ for j, (kid, _) in enumerate(batch):
903
+ emb_bytes = embs[j].cpu().numpy().tobytes()
904
+ with engine_knowledge.begin() as conn:
905
+ conn.execute(sql_text("UPDATE knowledge SET embedding = :e, updated_at = CURRENT_TIMESTAMP WHERE id = :id"), {"e": emb_bytes, "id": kid})
906
+ updated += 1
907
+ return {"status": "✅ Re-embed complete", "total_rows": total, "updated": updated}
908
+ except Exception as e:
909
+ logger.exception("reembed failed")
910
+ return JSONResponse(status_code=500, content={"error": "reembed failed", "details": str(e)})
911
+
912
+ @app.get("/model-status")
913
+ async def model_status():
914
+ return {
915
+ "embed_loaded": embed_model is not None,
916
+ "ollama_cli": ollama_cli_available(),
917
+ "ollama_http": ollama_http_available(),
918
+ "moderator": moderator is not None,
919
+ "language_module": LANGUAGE_MODULE_AVAILABLE,
920
+ "tts_available": TTS_AVAILABLE
921
+ }
922
+
923
+ @app.get("/health")
924
+ async def health():
925
+ try:
926
+ with engine_knowledge.connect() as c:
927
+ k = c.execute(sql_text("SELECT COUNT(*) FROM knowledge")).scalar() or 0
928
+ except Exception:
929
+ k = -1
930
+ try:
931
+ with engine_user.connect() as c:
932
+ u = c.execute(sql_text("SELECT COUNT(*) FROM user_memory")).scalar() or 0
933
+ except Exception:
934
+ u = -1
935
+ return {"ok": True, "knowledge_count": int(k), "user_memory_count": int(u), "uptime_s": round(time.time() - app_start_time, 2), "heartbeat": last_heartbeat}
936
+
937
  @app.post("/chat")
938
  async def chat(request: Request, data: dict = Body(...)):
939
  t0 = time.time()
940
+ # Accept "message" or "text"
941
  if isinstance(data, dict):
942
  raw_msg = str(data.get("message", "") or data.get("text", "") or "").strip()
943
  else:
 
955
  detected_lang = detect_language_safe(raw_msg)
956
  reply_lang = detected_lang if detected_lang and detected_lang != "und" else "en"
957
 
958
+ # Translate incoming to English for retrieval if needed
959
  en_msg = raw_msg
960
  if detected_lang not in ("en", "eng", "", "und"):
961
  try:
 
963
  except Exception:
964
  en_msg = raw_msg
965
 
966
+ # Determine topic: Ollama first, then embedding, then keyword
967
  topic = "general"
968
  try:
969
  if not topic_hint:
 
986
  except Exception:
987
  topic = topic_hint or "general"
988
 
989
+ # Moderation
990
  flags = {}
991
  try:
992
  if moderator is not None:
 
999
  except Exception:
1000
  pass
1001
 
1002
+ # Load topic-scoped knowledge
 
 
 
1003
  try:
1004
  with engine_knowledge.begin() as conn:
1005
+ rows = conn.execute(sql_text("SELECT id, text, reply, language, embedding FROM knowledge WHERE topic = :topic ORDER BY created_at DESC"), {"topic": topic}).fetchall()
 
 
1006
  except Exception as e:
1007
  record_request(time.time() - t0)
1008
  return JSONResponse(status_code=500, content={"error": "failed to read knowledge", "details": str(e)})
1009
 
1010
  knowledge_rows = [{"id": r[0], "text": r[1] or "", "reply": r[2] or "", "lang": r[3] or "und", "embedding": r[4]} for r in rows]
1011
 
1012
+ # Retrieval (embedding-first)
1013
  matches: List[str] = []
1014
  confidence = 0.0
1015
  try:
1016
  if embed_model is not None and knowledge_rows:
1017
  texts = [kr["text"] for kr in knowledge_rows]
1018
+ try:
1019
+ embs = await run_blocking_with_timeout(lambda: embed_model.encode(texts, convert_to_tensor=True), timeout=MODEL_TIMEOUT)
1020
+ q_emb = await run_blocking_with_timeout(lambda: embed_model.encode([en_msg], convert_to_tensor=True)[0], timeout=MODEL_TIMEOUT)
1021
+ import torch as _torch
1022
+ scores = _torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), embs)
1023
+ cand = []
1024
+ for i in range(scores.shape[0]):
1025
+ s = float(scores[i])
1026
+ kr = knowledge_rows[i]
1027
+ candidate_text = (kr["reply"] or kr["text"]).strip()
1028
+ if is_boilerplate_candidate(candidate_text):
1029
+ continue
1030
+ if s >= 0.30:
1031
+ cand.append({"text": candidate_text, "lang": kr["lang"], "score": s})
1032
+ cand = sorted(cand, key=lambda x: -x["score"])
1033
+ matches = [c["text"] for c in cand]
1034
+ confidence = cand[0]["score"] if cand else 0.0
1035
+ except asyncio.TimeoutError:
1036
+ logger.warning("[retrieval] embedding encode timed out")
1037
+ matches = []
1038
+ except Exception as e:
1039
+ logger.warning(f"[retrieval] embedding error: {e}")
1040
+ matches = []
1041
  else:
1042
  cand = []
1043
  for kr in knowledge_rows:
 
1052
  logger.warning(f"[retrieval] error: {e}")
1053
  matches = []
1054
 
1055
+ # Compose reply strictly from topic matches
1056
  if matches and confidence >= 0.6:
1057
  reply_en = matches[0]
1058
  elif matches:
 
1065
  except Exception:
1066
  pass
1067
  reply_final = base
 
1068
  try:
1069
  if not flags.get('toxic', False):
1070
  with engine_user.begin() as conn:
 
1073
  "VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
1074
  ), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
1075
  "mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
 
1076
  conn.execute(sql_text(
1077
+ "DELETE FROM user_memory WHERE id NOT IN (SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
 
1078
  ), {"uid": user_id})
1079
  except Exception as e:
1080
  logger.debug(f"user_memory store error: {e}")
1081
  record_request(time.time() - t0)
1082
+ return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": "", "confidence": round(confidence,2), "flags": flags}
1083
 
1084
+ # Postprocess reply_en
1085
  reply_en = dedupe_sentences(reply_en)
1086
 
1087
+ # Translate to user's language if needed
1088
  reply_final = reply_en
1089
  lang_code = (reply_lang or "und").split("-")[0].lower()
1090
  if lang_code not in ("en", "eng", "und", ""):
 
1095
  logger.warning(f"[translation] failed to translate reply_en -> {lang_code}: {exc}")
1096
  reply_final = reply_en
1097
 
1098
+ # Mood & emoji
1099
  emoji = ""
1100
  try:
1101
  mood = detect_mood(raw_msg + " " + reply_final)
 
1110
  except Exception:
1111
  emoji = ""
1112
 
1113
+ # Persist user memory (only in user DB) and prune to last 10
1114
  try:
1115
  if not flags.get('toxic', False):
1116
  with engine_user.begin() as conn:
 
1119
  "VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
1120
  ), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
1121
  "mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
 
1122
  conn.execute(sql_text(
1123
+ "DELETE FROM user_memory WHERE id NOT IN (SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
 
1124
  ), {"uid": user_id})
1125
  except Exception as e:
1126
  logger.debug(f"user_memory persist error: {e}")
 
1131
  if include_steps:
1132
  reply_final = f"{reply_final}\n\n[Debug: topic={topic} confidence={round(confidence,2)}]"
1133
 
1134
+ return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": emoji, "confidence": round(confidence,2), "flags": flags}
1135
 
1136
  @app.post("/response")
1137
  async def response_wrapper(request: Request, data: dict = Body(...)):
1138
  return await chat(request, data)
1139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1140
  @app.post("/verify-admin")
1141
  async def verify_admin(x_admin_key: str = Header(None, alias="X-Admin-Key")):
1142
  if ADMIN_KEY is None:
 
1196
  html = html.replace("%%STARTUP_TIME%%", str(startup_time_local))
1197
  return HTMLResponse(html)
1198
 
1199
+ # --- Start app ---
1200
+ if __name__ == "__main__":
1201
+ # preload embed and TTS in background
1202
+ if TTS_AVAILABLE:
 
 
 
 
 
 
 
 
 
 
 
1203
  try:
1204
+ threading.Thread(target=lambda: get_tts_model_blocking(), daemon=True).start()
 
1205
  except Exception:
1206
+ pass
1207
+ if SentenceTransformer is not None:
1208
+ try:
1209
+ threading.Thread(target=try_load_embed, daemon=True).start()
1210
+ except Exception:
1211
+ pass
 
 
1212
  app_start_time = time.time()
1213
  import uvicorn
1214
  port = int(os.environ.get("PORT", 7860))