Akis Giannoukos commited on
Commit
17f0761
·
1 Parent(s): d3feaf4

Add UI controls for switching models

Browse files
Files changed (1) hide show
  1. app.py +70 -2
app.py CHANGED
@@ -20,6 +20,7 @@ from transformers import (
20
  )
21
  from gtts import gTTS
22
  import spaces
 
23
 
24
 
25
  # ---------------------------
@@ -30,6 +31,22 @@ DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
30
  CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
31
  MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
32
  USE_TTS_DEFAULT = os.getenv("USE_TTS", "false").strip().lower() == "true"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  # ---------------------------
@@ -61,14 +78,48 @@ def get_textgen_pipeline():
61
  # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
62
  _gen_pipe = pipeline(
63
  task="text-generation",
64
- model=DEFAULT_CHAT_MODEL_ID,
65
- tokenizer=DEFAULT_CHAT_MODEL_ID,
66
  device=_hf_device(),
67
  torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
68
  )
69
  return _gen_pipe
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # ---------------------------
73
  # Utilities
74
  # ---------------------------
@@ -531,6 +582,11 @@ def create_demo():
531
  threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
532
  tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
533
  tts_audio = gr.Audio(label="Clinician voice", interactive=False)
 
 
 
 
 
534
 
535
  with gr.Row():
536
  audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)")
@@ -561,6 +617,18 @@ def create_demo():
561
 
562
  reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
563
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  return demo
565
 
566
  demo = create_demo()
 
20
  )
21
  from gtts import gTTS
22
  import spaces
23
+ import threading
24
 
25
 
26
  # ---------------------------
 
31
  CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
32
  MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
33
  USE_TTS_DEFAULT = os.getenv("USE_TTS", "false").strip().lower() == "true"
34
+ CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json")
35
+
36
+
37
+ def _load_model_id_from_config() -> str:
38
+ try:
39
+ if os.path.exists(CONFIG_PATH):
40
+ with open(CONFIG_PATH, "r") as f:
41
+ data = json.load(f)
42
+ if isinstance(data, dict) and data.get("model_id"):
43
+ return str(data["model_id"])
44
+ except Exception:
45
+ pass
46
+ return DEFAULT_CHAT_MODEL_ID
47
+
48
+
49
+ current_model_id = _load_model_id_from_config()
50
 
51
 
52
  # ---------------------------
 
78
  # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
79
  _gen_pipe = pipeline(
80
  task="text-generation",
81
+ model=current_model_id,
82
+ tokenizer=current_model_id,
83
  device=_hf_device(),
84
  torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
85
  )
86
  return _gen_pipe
87
 
88
 
89
+ def set_current_model_id(new_model_id: str) -> str:
90
+ global current_model_id, _gen_pipe
91
+ new_model_id = (new_model_id or "").strip()
92
+ if not new_model_id:
93
+ return "Model id is empty; keeping current model."
94
+ if new_model_id == current_model_id:
95
+ return f"Model unchanged: `{current_model_id}`"
96
+ current_model_id = new_model_id
97
+ _gen_pipe = None # force reload on next use
98
+ return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)."
99
+
100
+
101
+ def persist_model_id(new_model_id: str) -> None:
102
+ try:
103
+ with open(CONFIG_PATH, "w") as f:
104
+ json.dump({"model_id": new_model_id}, f)
105
+ except Exception:
106
+ pass
107
+
108
+
109
+ def apply_model_and_restart(new_model_id: str) -> str:
110
+ mid = (new_model_id or "").strip()
111
+ if not mid:
112
+ return "Model id is empty; not restarting."
113
+ persist_model_id(mid)
114
+ set_current_model_id(mid)
115
+ # Graceful delayed exit so response can flush
116
+ def _exit_later():
117
+ time.sleep(0.25)
118
+ os._exit(0)
119
+ threading.Thread(target=_exit_later, daemon=True).start()
120
+ return f"Restarting with model `{mid}`..."
121
+
122
+
123
  # ---------------------------
124
  # Utilities
125
  # ---------------------------
 
582
  threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
583
  tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
584
  tts_audio = gr.Audio(label="Clinician voice", interactive=False)
585
+ model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
586
+ with gr.Row():
587
+ apply_model_btn = gr.Button("Apply model (no restart)")
588
+ apply_model_restart_btn = gr.Button("Apply model and restart")
589
+ model_status = gr.Markdown(value=f"Current model: `{current_model_id}`")
590
 
591
  with gr.Row():
592
  audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)")
 
617
 
618
  reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
619
 
620
+ # Model switch handlers
621
+ def _on_apply_model(mid: str):
622
+ msg = set_current_model_id(mid)
623
+ return f"Current model: `{current_model_id}`\n\n{msg}"
624
+
625
+ def _on_apply_model_restart(mid: str):
626
+ msg = apply_model_and_restart(mid)
627
+ return f"{msg}"
628
+
629
+ apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status])
630
+ apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status])
631
+
632
  return demo
633
 
634
  demo = create_demo()