Vaishnav Muraleedharan commited on
Commit
9b39b7c
·
1 Parent(s): 88c6024

chore: format code to maintain consistent style

Browse files
Files changed (9) hide show
  1. __init__.py +0 -1
  2. app.py +31 -11
  3. auth.py +53 -47
  4. cache.py +19 -11
  5. evaluation.py +21 -21
  6. gemini.py +16 -21
  7. gemini_tts.py +63 -28
  8. interview_simulator.py +134 -110
  9. medgemma.py +26 -18
__init__.py CHANGED
@@ -11,4 +11,3 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
app.py CHANGED
@@ -12,18 +12,36 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from evaluation import evaluate_report, evaluation_prompt
16
- from flask import Flask, send_from_directory, request, jsonify, Response, stream_with_context, send_file
 
 
 
 
 
 
 
 
 
 
 
 
17
  from flask_cors import CORS
18
- import os, time, json, re
 
 
19
  from gemini import gemini_get_text_response
20
  from interview_simulator import stream_interview
21
- from cache import create_cache_zip
22
  from medgemma import medgemma_get_text_response
23
 
24
- app = Flask(__name__, static_folder=os.environ.get("FRONTEND_BUILD", "frontend/build"), static_url_path="/")
 
 
 
 
25
  CORS(app, resources={r"/api/*": {"origins": "http://localhost:3000"}})
26
 
 
27
  @app.route("/")
28
  def serve():
29
  """Serves the main index.html file."""
@@ -35,7 +53,7 @@ def stream_conversation():
35
  """Streams the conversation with the interview simulator."""
36
  patient = request.args.get("patient", "Patient")
37
  condition = request.args.get("condition", "unknown condition")
38
-
39
  def generate():
40
  try:
41
  for message in stream_interview(patient, condition):
@@ -43,9 +61,10 @@ def stream_conversation():
43
  except Exception as e:
44
  yield f"data: Error: {str(e)}\n\n"
45
  raise e
46
-
47
  return Response(stream_with_context(generate()), mimetype="text/event-stream")
48
 
 
49
  @app.route("/api/evaluate_report", methods=["POST"])
50
  def evaluate_report_call():
51
  """Evaluates the provided medical report."""
@@ -55,10 +74,10 @@ def evaluate_report_call():
55
  return jsonify({"error": "Report is required"}), 400
56
  condition = data.get("condition", "")
57
  if not condition:
58
- return jsonify({"error": "Condition is required"}), 400
59
-
60
  evaluation_text = evaluate_report(report, condition)
61
-
62
  return jsonify({"evaluation": evaluation_text})
63
 
64
 
@@ -81,6 +100,7 @@ def static_proxy(path):
81
  return send_from_directory(app.static_folder, path)
82
  else:
83
  return send_from_directory(app.static_folder, "index.html")
84
-
 
85
  if __name__ == "__main__":
86
  app.run(host="0.0.0.0", port=7860, threaded=True)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import json
16
+ import os
17
+ import re
18
+ import time
19
+
20
+ from flask import (
21
+ Flask,
22
+ Response,
23
+ jsonify,
24
+ request,
25
+ send_file,
26
+ send_from_directory,
27
+ stream_with_context,
28
+ )
29
  from flask_cors import CORS
30
+
31
+ from cache import create_cache_zip
32
+ from evaluation import evaluate_report, evaluation_prompt
33
  from gemini import gemini_get_text_response
34
  from interview_simulator import stream_interview
 
35
  from medgemma import medgemma_get_text_response
36
 
37
+ app = Flask(
38
+ __name__,
39
+ static_folder=os.environ.get("FRONTEND_BUILD", "frontend/build"),
40
+ static_url_path="/",
41
+ )
42
  CORS(app, resources={r"/api/*": {"origins": "http://localhost:3000"}})
43
 
44
+
45
  @app.route("/")
46
  def serve():
47
  """Serves the main index.html file."""
 
53
  """Streams the conversation with the interview simulator."""
54
  patient = request.args.get("patient", "Patient")
55
  condition = request.args.get("condition", "unknown condition")
56
+
57
  def generate():
58
  try:
59
  for message in stream_interview(patient, condition):
 
61
  except Exception as e:
62
  yield f"data: Error: {str(e)}\n\n"
63
  raise e
64
+
65
  return Response(stream_with_context(generate()), mimetype="text/event-stream")
66
 
67
+
68
  @app.route("/api/evaluate_report", methods=["POST"])
69
  def evaluate_report_call():
70
  """Evaluates the provided medical report."""
 
74
  return jsonify({"error": "Report is required"}), 400
75
  condition = data.get("condition", "")
76
  if not condition:
77
+ return jsonify({"error": "Condition is required"}), 400
78
+
79
  evaluation_text = evaluate_report(report, condition)
80
+
81
  return jsonify({"evaluation": evaluation_text})
82
 
83
 
 
100
  return send_from_directory(app.static_folder, path)
101
  else:
102
  return send_from_directory(app.static_folder, "index.html")
103
+
104
+
105
  if __name__ == "__main__":
106
  app.run(host="0.0.0.0", port=7860, threaded=True)
auth.py CHANGED
@@ -12,66 +12,72 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- import json
16
  import datetime
17
- from google.oauth2 import service_account
 
18
  import google.auth.transport.requests
 
 
19
 
20
  def create_credentials(secret_key_json) -> service_account.Credentials:
21
- """Creates Google Cloud credentials from the provided service account key.
 
 
 
22
 
23
- Returns:
24
- service_account.Credentials: The created credentials object.
 
 
25
 
26
- Raises:
27
- ValueError: If the environment variable is not set or is empty, or if the
28
- JSON format is invalid.
29
- """
 
 
 
 
 
 
 
30
 
31
- if not secret_key_json:
32
- raise ValueError("Userdata variable 'GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY' is not set or is empty.")
33
- try:
34
- service_account_info = json.loads(secret_key_json)
35
- except (SyntaxError, ValueError) as e:
36
- raise ValueError("Invalid service account key JSON format.") from e
37
- return service_account.Credentials.from_service_account_info(
38
- service_account_info,
39
- scopes=['https://www.googleapis.com/auth/cloud-platform']
40
- )
41
 
42
- def refresh_credentials(credentials: service_account.Credentials) -> service_account.Credentials:
43
- """Refreshes the provided Google Cloud credentials if they are about to expire
44
- (within 5 minutes) or if they don't have an expiry time set.
 
 
45
 
46
- Args:
47
- credentials: The credentials object to refresh.
48
 
49
- Returns:
50
- service_account.Credentials: The refreshed credentials object.
51
- """
52
- if credentials.expiry:
53
- expiry_time = credentials.expiry.replace(tzinfo=datetime.timezone.utc)
54
- # Calculate the time remaining until expiration
55
- time_remaining = expiry_time - datetime.datetime.now(datetime.timezone.utc)
56
- # Check if the token is about to expire (e.g., within 5 minutes)
57
- if time_remaining < datetime.timedelta(minutes=5):
 
 
 
 
58
  request = google.auth.transport.requests.Request()
59
  credentials.refresh(request)
60
- else:
61
- # If no expiry is set, always attempt to refresh (e.g., for certain credential types)
62
- request = google.auth.transport.requests.Request()
63
- credentials.refresh(request)
64
- return credentials
65
 
66
- def get_access_token_refresh_if_needed(credentials: service_account.Credentials) -> str:
67
- """Gets the access token from the credentials, refreshing them if needed.
68
 
69
- Args:
70
- credentials: The credentials object.
71
 
72
- Returns:
73
- str: The access token.
74
- """
75
- credentials = refresh_credentials(credentials)
76
- return credentials.token
77
 
 
 
 
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import datetime
16
+ import json
17
+
18
  import google.auth.transport.requests
19
+ from google.oauth2 import service_account
20
+
21
 
22
  def create_credentials(secret_key_json) -> service_account.Credentials:
23
+ """Creates Google Cloud credentials from the provided service account key.
24
+
25
+ Returns:
26
+ service_account.Credentials: The created credentials object.
27
 
28
+ Raises:
29
+ ValueError: If the environment variable is not set or is empty, or if the
30
+ JSON format is invalid.
31
+ """
32
 
33
+ if not secret_key_json:
34
+ raise ValueError(
35
+ "Userdata variable 'GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY' is not set or is empty."
36
+ )
37
+ try:
38
+ service_account_info = json.loads(secret_key_json)
39
+ except (SyntaxError, ValueError) as e:
40
+ raise ValueError("Invalid service account key JSON format.") from e
41
+ return service_account.Credentials.from_service_account_info(
42
+ service_account_info, scopes=["https://www.googleapis.com/auth/cloud-platform"]
43
+ )
44
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ def refresh_credentials(
47
+ credentials: service_account.Credentials,
48
+ ) -> service_account.Credentials:
49
+ """Refreshes the provided Google Cloud credentials if they are about to expire
50
+ (within 5 minutes) or if they don't have an expiry time set.
51
 
52
+ Args:
53
+ credentials: The credentials object to refresh.
54
 
55
+ Returns:
56
+ service_account.Credentials: The refreshed credentials object.
57
+ """
58
+ if credentials.expiry:
59
+ expiry_time = credentials.expiry.replace(tzinfo=datetime.timezone.utc)
60
+ # Calculate the time remaining until expiration
61
+ time_remaining = expiry_time - datetime.datetime.now(datetime.timezone.utc)
62
+ # Check if the token is about to expire (e.g., within 5 minutes)
63
+ if time_remaining < datetime.timedelta(minutes=5):
64
+ request = google.auth.transport.requests.Request()
65
+ credentials.refresh(request)
66
+ else:
67
+ # If no expiry is set, always attempt to refresh (e.g., for certain credential types)
68
  request = google.auth.transport.requests.Request()
69
  credentials.refresh(request)
70
+ return credentials
 
 
 
 
71
 
 
 
72
 
73
+ def get_access_token_refresh_if_needed(credentials: service_account.Credentials) -> str:
74
+ """Gets the access token from the credentials, refreshing them if needed.
75
 
76
+ Args:
77
+ credentials: The credentials object.
 
 
 
78
 
79
+ Returns:
80
+ str: The access token.
81
+ """
82
+ credentials = refresh_credentials(credentials)
83
+ return credentials.token
cache.py CHANGED
@@ -12,12 +12,13 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from diskcache import Cache
16
  import os
17
  import shutil
18
  import tempfile
19
  import zipfile
20
- import logging
 
21
 
22
  cache = Cache(os.environ.get("CACHE_DIR", "/cache"))
23
  # Print cache statistics after loading
@@ -28,16 +29,17 @@ try:
28
  except Exception as e:
29
  print(f"Could not retrieve cache statistics: {e}")
30
 
 
31
  def create_cache_zip():
32
  temp_dir = tempfile.gettempdir()
33
- base_name = os.path.join(temp_dir, "cache_archive") # A more descriptive name
34
  archive_path = base_name + ".zip"
35
  cache_directory = os.environ.get("CACHE_DIR", "/cache")
36
-
37
  if not os.path.isdir(cache_directory):
38
  logging.error(f"Cache directory not found at {cache_directory}")
39
  return None, f"Cache directory not found on server: {cache_directory}"
40
-
41
  logging.info("Forcing a cache checkpoint for safe backup...")
42
  try:
43
  # Open and immediately close a connection.
@@ -45,15 +47,19 @@ def create_cache_zip():
45
  # into the main .db file, ensuring the on-disk files are consistent.
46
  with Cache(cache_directory) as temp_cache:
47
  temp_cache.close()
48
-
49
  # Clean up temporary files before archiving.
50
- tmp_path = os.path.join(cache_directory, 'tmp')
51
  if os.path.isdir(tmp_path):
52
  logging.info(f"Removing temporary cache directory: {tmp_path}")
53
  shutil.rmtree(tmp_path)
54
 
55
- logging.info(f"Checkpoint complete. Creating zip archive of {cache_directory} to {archive_path}")
56
- with zipfile.ZipFile(archive_path, 'w', zipfile.ZIP_DEFLATED, compresslevel=9) as zipf:
 
 
 
 
57
  for root, _, files in os.walk(cache_directory):
58
  for file in files:
59
  file_path = os.path.join(root, file)
@@ -61,7 +67,9 @@ def create_cache_zip():
61
  zipf.write(file_path, arcname)
62
  logging.info("Zip archive created successfully.")
63
  return archive_path, None
64
-
65
  except Exception as e:
66
- logging.error(f"Error creating zip archive of cache directory: {e}", exc_info=True)
 
 
67
  return None, f"Error creating zip archive: {e}"
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import logging
16
  import os
17
  import shutil
18
  import tempfile
19
  import zipfile
20
+
21
+ from diskcache import Cache
22
 
23
  cache = Cache(os.environ.get("CACHE_DIR", "/cache"))
24
  # Print cache statistics after loading
 
29
  except Exception as e:
30
  print(f"Could not retrieve cache statistics: {e}")
31
 
32
+
33
  def create_cache_zip():
34
  temp_dir = tempfile.gettempdir()
35
+ base_name = os.path.join(temp_dir, "cache_archive") # A more descriptive name
36
  archive_path = base_name + ".zip"
37
  cache_directory = os.environ.get("CACHE_DIR", "/cache")
38
+
39
  if not os.path.isdir(cache_directory):
40
  logging.error(f"Cache directory not found at {cache_directory}")
41
  return None, f"Cache directory not found on server: {cache_directory}"
42
+
43
  logging.info("Forcing a cache checkpoint for safe backup...")
44
  try:
45
  # Open and immediately close a connection.
 
47
  # into the main .db file, ensuring the on-disk files are consistent.
48
  with Cache(cache_directory) as temp_cache:
49
  temp_cache.close()
50
+
51
  # Clean up temporary files before archiving.
52
+ tmp_path = os.path.join(cache_directory, "tmp")
53
  if os.path.isdir(tmp_path):
54
  logging.info(f"Removing temporary cache directory: {tmp_path}")
55
  shutil.rmtree(tmp_path)
56
 
57
+ logging.info(
58
+ f"Checkpoint complete. Creating zip archive of {cache_directory} to {archive_path}"
59
+ )
60
+ with zipfile.ZipFile(
61
+ archive_path, "w", zipfile.ZIP_DEFLATED, compresslevel=9
62
+ ) as zipf:
63
  for root, _, files in os.walk(cache_directory):
64
  for file in files:
65
  file_path = os.path.join(root, file)
 
67
  zipf.write(file_path, arcname)
68
  logging.info("Zip archive created successfully.")
69
  return archive_path, None
70
+
71
  except Exception as e:
72
+ logging.error(
73
+ f"Error creating zip archive of cache directory: {e}", exc_info=True
74
+ )
75
  return None, f"Error creating zip archive: {e}"
evaluation.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
 
15
  import re
 
16
  from medgemma import medgemma_get_text_response
17
 
18
 
@@ -40,30 +41,29 @@ REPORT TEMPLATE START
40
  REPORT TEMPLATE END
41
  """
42
 
 
43
  def evaluate_report(report, condition):
44
  """Evaluate the pre-visit report based on the condition using MedGemma LLM."""
45
- evaluation_text = medgemma_get_text_response([
46
- {
47
- "role": "system",
48
- "content": [
49
- {
50
- "type": "text",
51
- "text": f"{evaluation_prompt(condition)}"
52
- }
53
- ]
54
- },
55
- {
56
- "role": "user",
57
- "content": [
58
- {
59
- "type": "text",
60
- "text": f"Here is the report text:\n{report}"
61
- }
62
- ]
63
- },
64
- ])
65
 
66
  # Remove any LLM "thinking" blocks (special tokens sometimes present in output)
67
- evaluation_text = re.sub(r'<unused94>.*?<unused95>', '', evaluation_text, flags=re.DOTALL)
 
 
68
 
69
  return evaluation_text
 
13
  # limitations under the License.
14
 
15
  import re
16
+
17
  from medgemma import medgemma_get_text_response
18
 
19
 
 
41
  REPORT TEMPLATE END
42
  """
43
 
44
+
45
  def evaluate_report(report, condition):
46
  """Evaluate the pre-visit report based on the condition using MedGemma LLM."""
47
+ evaluation_text = medgemma_get_text_response(
48
+ [
49
+ {
50
+ "role": "system",
51
+ "content": [
52
+ {"type": "text", "text": f"{evaluation_prompt(condition)}"}
53
+ ],
54
+ },
55
+ {
56
+ "role": "user",
57
+ "content": [
58
+ {"type": "text", "text": f"Here is the report text:\n{report}"}
59
+ ],
60
+ },
61
+ ]
62
+ )
 
 
 
 
63
 
64
  # Remove any LLM "thinking" blocks (special tokens sometimes present in output)
65
+ evaluation_text = re.sub(
66
+ r"<unused94>.*?<unused95>", "", evaluation_text, flags=re.DOTALL
67
+ )
68
 
69
  return evaluation_text
gemini.py CHANGED
@@ -13,47 +13,42 @@
13
  # limitations under the License.
14
 
15
  import os
 
16
  import requests
 
17
  from cache import cache # new import replacing duplicate cache initialization
18
 
19
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
20
 
 
21
  # Decorate the function to cache its results indefinitely.
22
  @cache.memoize()
23
- def gemini_get_text_response(prompt: str,
24
- stop_sequences: list = None,
25
- temperature: float = 0.1,
26
- max_output_tokens: int = 4000,
27
- top_p: float = 0.8,
28
- top_k: int = 10):
 
 
29
  """
30
  Makes a text generation request to the Gemini API.
31
  """
32
 
33
  api_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={GEMINI_API_KEY}"
34
- headers = {
35
- 'Content-Type': 'application/json'
36
- }
37
 
38
  data = {
39
- "contents": [
40
- {
41
- "parts": [
42
- {
43
- "text": prompt
44
- }
45
- ]
46
- }
47
- ],
48
  "generationConfig": {
49
  "stopSequences": stop_sequences or ["Title"],
50
  "temperature": temperature,
51
  "maxOutputTokens": max_output_tokens,
52
  "topP": top_p,
53
- "topK": top_k
54
- }
55
  }
56
 
57
  response = requests.post(api_url, headers=headers, json=data)
58
  response.raise_for_status() # Raise an exception for bad status codes
59
- return response.json()["candidates"][0]["content"]["parts"][0]["text"]
 
13
  # limitations under the License.
14
 
15
  import os
16
+
17
  import requests
18
+
19
  from cache import cache # new import replacing duplicate cache initialization
20
 
21
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
22
 
23
+
24
  # Decorate the function to cache its results indefinitely.
25
  @cache.memoize()
26
+ def gemini_get_text_response(
27
+ prompt: str,
28
+ stop_sequences: list = None,
29
+ temperature: float = 0.1,
30
+ max_output_tokens: int = 4000,
31
+ top_p: float = 0.8,
32
+ top_k: int = 10,
33
+ ):
34
  """
35
  Makes a text generation request to the Gemini API.
36
  """
37
 
38
  api_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={GEMINI_API_KEY}"
39
+ headers = {"Content-Type": "application/json"}
 
 
40
 
41
  data = {
42
+ "contents": [{"parts": [{"text": prompt}]}],
 
 
 
 
 
 
 
 
43
  "generationConfig": {
44
  "stopSequences": stop_sequences or ["Title"],
45
  "temperature": temperature,
46
  "maxOutputTokens": max_output_tokens,
47
  "topP": top_p,
48
+ "topK": top_k,
49
+ },
50
  }
51
 
52
  response = requests.post(api_url, headers=headers, json=data)
53
  response.raise_for_status() # Raise an exception for bad status codes
54
+ return response.json()["candidates"][0]["content"]["parts"][0]["text"]
gemini_tts.py CHANGED
@@ -12,16 +12,18 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- import google.generativeai as genai
 
16
  import os
17
- import struct
18
  import re
19
- import logging
20
- from cache import cache
 
21
 
22
  # Add these imports for MP3 conversion
23
  from pydub import AudioSegment
24
- import io
 
25
 
26
  # --- Constants ---
27
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
@@ -30,12 +32,16 @@ TTS_MODEL = "gemini-2.5-flash-preview-tts"
30
  DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
31
 
32
  # --- Configuration ---
33
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
34
 
35
  genai.configure(api_key=GEMINI_API_KEY)
36
 
 
37
  class TTSGenerationError(Exception):
38
  """Custom exception for TTS generation failures."""
 
39
  pass
40
 
41
 
@@ -46,7 +52,7 @@ def parse_audio_mime_type(mime_type: str) -> dict[str, int | None]:
46
  e.g., "audio/L16;rate=24000" -> {"bits_per_sample": 16, "rate": 24000}
47
  """
48
  bits_per_sample = 16 # Default
49
- rate = 24000 # Default
50
 
51
  parts = mime_type.split(";")
52
  for param in parts:
@@ -56,15 +62,16 @@ def parse_audio_mime_type(mime_type: str) -> dict[str, int | None]:
56
  rate_str = param.split("=", 1)[1]
57
  rate = int(rate_str)
58
  except (ValueError, IndexError):
59
- pass # Keep default if parsing fails
60
- elif re.match(r"audio/l\d+", param): # Matches audio/L<digits>
61
- try:
62
- bits_str = param.split("l",1)[1]
63
  bits_per_sample = int(bits_str)
64
- except (ValueError, IndexError):
65
- pass # Keep default
66
  return {"bits_per_sample": bits_per_sample, "rate": rate}
67
 
 
68
  def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
69
  """
70
  Generates a WAV file header for the given raw audio data and parameters.
@@ -82,13 +89,26 @@ def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
82
 
83
  header = struct.pack(
84
  "<4sI4s4sIHHIIHH4sI",
85
- b"RIFF", chunk_size, b"WAVE", b"fmt ",
86
- 16, 1, num_channels, sample_rate, byte_rate, block_align,
87
- bits_per_sample, b"data", data_size
 
 
 
 
 
 
 
 
 
 
88
  )
89
  return header + audio_data
 
 
90
  # --- End of helper functions ---
91
 
 
92
  def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[bytes, str]:
93
  """
94
  Synthesizes English text using the Gemini API via the google-genai library.
@@ -109,11 +129,9 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte
109
  "response_modalities": ["AUDIO"],
110
  "speech_config": {
111
  "voice_config": {
112
- "prebuilt_voice_config": {
113
- "voice_name": gemini_voice_name
114
- }
115
  }
116
- }
117
  }
118
 
119
  response = model.generate_content(
@@ -137,8 +155,11 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte
137
  # --- Audio processing ---
138
  if final_mime_type:
139
  final_mime_type_lower = final_mime_type.lower()
140
- needs_wav_conversion = any(p in final_mime_type_lower for p in ("audio/l16", "audio/l24", "audio/l8")) or \
141
- not final_mime_type_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus"))
 
 
 
142
 
143
  if needs_wav_conversion:
144
  processed_audio_data = convert_to_wav(audio_data_bytes, final_mime_type)
@@ -147,7 +168,10 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte
147
  processed_audio_data = audio_data_bytes
148
  processed_audio_mime = final_mime_type
149
  else:
150
- logging.warning("MIME type not determined. Assuming raw audio and attempting WAV conversion (defaulting to %s).", DEFAULT_RAW_AUDIO_MIME)
 
 
 
151
  processed_audio_data = convert_to_wav(audio_data_bytes, DEFAULT_RAW_AUDIO_MIME)
152
  processed_audio_mime = "audio/wav"
153
 
@@ -155,7 +179,9 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte
155
  if processed_audio_data:
156
  try:
157
  # Load audio into AudioSegment
158
- audio_segment = AudioSegment.from_file(io.BytesIO(processed_audio_data), format="wav")
 
 
159
  mp3_buffer = io.BytesIO()
160
  audio_segment.export(mp3_buffer, format="mp3")
161
  mp3_bytes = mp3_buffer.getvalue()
@@ -169,11 +195,15 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte
169
  logging.error(error_message)
170
  raise TTSGenerationError(error_message)
171
 
 
172
  # Always create the memoized function first, so we can access its .key() method
173
  _memoized_tts_func = cache.memoize()(_synthesize_gemini_tts_impl)
174
 
175
  if GENERATE_SPEECH:
176
- def synthesize_gemini_tts_with_error_handling(*args, **kwargs) -> tuple[bytes | None, str | None]:
 
 
 
177
  """
178
  A wrapper for the memoized TTS function that catches errors and returns (None, None).
179
  This makes the audio generation more resilient to individual failures.
@@ -183,7 +213,10 @@ if GENERATE_SPEECH:
183
  return _memoized_tts_func(*args, **kwargs)
184
  except TTSGenerationError as e:
185
  # If generation fails, log the error and return None, None.
186
- logging.error("Handled TTS Generation Error: %s. Continuing without audio for this segment.", e)
 
 
 
187
  return None, None
188
 
189
  synthesize_gemini_tts = synthesize_gemini_tts_with_error_handling
@@ -206,7 +239,9 @@ else:
206
  return result # Cache hit
207
 
208
  # Cache miss
209
- logging.info("GENERATE_SPEECH is false and no cached result found for key: %s", key)
 
 
210
  return None, None
211
 
212
- synthesize_gemini_tts = read_only_synthesize_gemini_tts
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import io
16
+ import logging
17
  import os
 
18
  import re
19
+ import struct
20
+
21
+ import google.generativeai as genai
22
 
23
  # Add these imports for MP3 conversion
24
  from pydub import AudioSegment
25
+
26
+ from cache import cache
27
 
28
  # --- Constants ---
29
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
 
32
  DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
33
 
34
  # --- Configuration ---
35
+ logging.basicConfig(
36
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
37
+ )
38
 
39
  genai.configure(api_key=GEMINI_API_KEY)
40
 
41
+
42
  class TTSGenerationError(Exception):
43
  """Custom exception for TTS generation failures."""
44
+
45
  pass
46
 
47
 
 
52
  e.g., "audio/L16;rate=24000" -> {"bits_per_sample": 16, "rate": 24000}
53
  """
54
  bits_per_sample = 16 # Default
55
+ rate = 24000 # Default
56
 
57
  parts = mime_type.split(";")
58
  for param in parts:
 
62
  rate_str = param.split("=", 1)[1]
63
  rate = int(rate_str)
64
  except (ValueError, IndexError):
65
+ pass # Keep default if parsing fails
66
+ elif re.match(r"audio/l\d+", param): # Matches audio/L<digits>
67
+ try:
68
+ bits_str = param.split("l", 1)[1]
69
  bits_per_sample = int(bits_str)
70
+ except (ValueError, IndexError):
71
+ pass # Keep default
72
  return {"bits_per_sample": bits_per_sample, "rate": rate}
73
 
74
+
75
  def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
76
  """
77
  Generates a WAV file header for the given raw audio data and parameters.
 
89
 
90
  header = struct.pack(
91
  "<4sI4s4sIHHIIHH4sI",
92
+ b"RIFF",
93
+ chunk_size,
94
+ b"WAVE",
95
+ b"fmt ",
96
+ 16,
97
+ 1,
98
+ num_channels,
99
+ sample_rate,
100
+ byte_rate,
101
+ block_align,
102
+ bits_per_sample,
103
+ b"data",
104
+ data_size,
105
  )
106
  return header + audio_data
107
+
108
+
109
  # --- End of helper functions ---
110
 
111
+
112
  def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[bytes, str]:
113
  """
114
  Synthesizes English text using the Gemini API via the google-genai library.
 
129
  "response_modalities": ["AUDIO"],
130
  "speech_config": {
131
  "voice_config": {
132
+ "prebuilt_voice_config": {"voice_name": gemini_voice_name}
 
 
133
  }
134
+ },
135
  }
136
 
137
  response = model.generate_content(
 
155
  # --- Audio processing ---
156
  if final_mime_type:
157
  final_mime_type_lower = final_mime_type.lower()
158
+ needs_wav_conversion = any(
159
+ p in final_mime_type_lower for p in ("audio/l16", "audio/l24", "audio/l8")
160
+ ) or not final_mime_type_lower.startswith(
161
+ ("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus")
162
+ )
163
 
164
  if needs_wav_conversion:
165
  processed_audio_data = convert_to_wav(audio_data_bytes, final_mime_type)
 
168
  processed_audio_data = audio_data_bytes
169
  processed_audio_mime = final_mime_type
170
  else:
171
+ logging.warning(
172
+ "MIME type not determined. Assuming raw audio and attempting WAV conversion (defaulting to %s).",
173
+ DEFAULT_RAW_AUDIO_MIME,
174
+ )
175
  processed_audio_data = convert_to_wav(audio_data_bytes, DEFAULT_RAW_AUDIO_MIME)
176
  processed_audio_mime = "audio/wav"
177
 
 
179
  if processed_audio_data:
180
  try:
181
  # Load audio into AudioSegment
182
+ audio_segment = AudioSegment.from_file(
183
+ io.BytesIO(processed_audio_data), format="wav"
184
+ )
185
  mp3_buffer = io.BytesIO()
186
  audio_segment.export(mp3_buffer, format="mp3")
187
  mp3_bytes = mp3_buffer.getvalue()
 
195
  logging.error(error_message)
196
  raise TTSGenerationError(error_message)
197
 
198
+
199
  # Always create the memoized function first, so we can access its .key() method
200
  _memoized_tts_func = cache.memoize()(_synthesize_gemini_tts_impl)
201
 
202
  if GENERATE_SPEECH:
203
+
204
+ def synthesize_gemini_tts_with_error_handling(
205
+ *args, **kwargs
206
+ ) -> tuple[bytes | None, str | None]:
207
  """
208
  A wrapper for the memoized TTS function that catches errors and returns (None, None).
209
  This makes the audio generation more resilient to individual failures.
 
213
  return _memoized_tts_func(*args, **kwargs)
214
  except TTSGenerationError as e:
215
  # If generation fails, log the error and return None, None.
216
+ logging.error(
217
+ "Handled TTS Generation Error: %s. Continuing without audio for this segment.",
218
+ e,
219
+ )
220
  return None, None
221
 
222
  synthesize_gemini_tts = synthesize_gemini_tts_with_error_handling
 
239
  return result # Cache hit
240
 
241
  # Cache miss
242
+ logging.info(
243
+ "GENERATE_SPEECH is false and no cached result found for key: %s", key
244
+ )
245
  return None, None
246
 
247
+ synthesize_gemini_tts = read_only_synthesize_gemini_tts
interview_simulator.py CHANGED
@@ -12,70 +12,89 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import json
16
- import re
17
  import os
18
- import base64
19
 
20
  from gemini import gemini_get_text_response
21
- from medgemma import medgemma_get_text_response
22
  from gemini_tts import synthesize_gemini_tts
 
23
 
24
  INTERVIEWER_VOICE = "Aoede"
25
 
 
26
  def read_symptoms_json():
27
  # Load the list of symptoms for each condition from a JSON file
28
- with open("symptoms.json", 'r') as f:
29
  return json.load(f)
30
 
 
31
  def read_patient_and_conditions_json():
32
  # Load all patient and condition data from the frontend assets
33
- with open(os.path.join(os.environ.get("FRONTEND_BUILD", "frontend/build"), "assets", "patients_and_conditions.json"), 'r') as f:
 
 
 
 
 
 
 
34
  return json.load(f)
35
 
 
36
  def get_patient(patient_name):
37
  """Helper function to locate a patient record by name. Raises StopIteration if not found."""
38
  return next(p for p in PATIENTS if p["name"] == patient_name)
39
 
 
40
  def read_fhir_json(patient):
41
  # Load the FHIR (EHR) JSON file for a given patient
42
- with open(os.path.join(os.environ.get("FRONTEND_BUILD", "frontend/build"), patient["fhirFile"].lstrip("/")), 'r') as f:
 
 
 
 
 
 
43
  return json.load(f)
44
 
 
45
  def get_ehr_summary_per_patient(patient_name):
46
  # Returns a concise EHR summary for the patient, using LLM if not already cached
47
  patient = get_patient(patient_name)
48
  if patient.get("ehr_summary"):
49
  return patient["ehr_summary"]
50
  # Use MedGemma to summarize the EHR for the patient
51
- ehr_summary = medgemma_get_text_response([
52
- {
53
- "role": "system",
54
- "content": [
55
- {
56
- "type": "text",
57
- "text": f"""You are a medical assistant summarizing the EHR (FHIR) records for the patient {patient_name}.
 
58
  Provide a concise summary of the patient's medical history, including any existing conditions, medications, and relevant past treatments.
59
- Do not include personal opinions or assumptions, only factual information."""
60
- }
61
- ]
62
- },
63
- {
64
- "role": "user",
65
- "content": [
66
- {
67
- "type": "text",
68
- "text": json.dumps(read_fhir_json(patient))
69
- }
70
- ]
71
- }
72
- ])
73
  patient["ehr_summary"] = ehr_summary
74
  return ehr_summary
75
 
 
76
  PATIENTS = read_patient_and_conditions_json()["patients"]
77
  SYMPTOMS = read_symptoms_json()
78
-
 
79
  def patient_roleplay_instructions(patient_name, condition_name, previous_answers):
80
  """
81
  Generates structured instructions for the LLM to roleplay as a patient, including persona, scenario, and symptom logic.
@@ -120,6 +139,7 @@ def patient_roleplay_instructions(patient_name, condition_name, previous_answers
120
  ---
121
  """
122
 
 
123
  def interviewer_roleplay_instructions(patient_name):
124
  # Returns detailed instructions for the LLM to roleplay as the interviewer/clinical assistant
125
  return f"""
@@ -153,6 +173,7 @@ def interviewer_roleplay_instructions(patient_name):
153
  3. **End Interview:** You MUST continue the interview until you have asked 20 questions OR the patient is unable to provide more information. When the interview is complete, you MUST conclude by printing this exact phrase: "Thank you for answering my questions. I have everything needed to prepare a report for your visit. End interview."
154
  """
155
 
 
156
  def report_writer_instructions(patient_name: str) -> str:
157
  """
158
  Generates the system prompt with clear instructions, role, and constraints for the LLM.
@@ -202,7 +223,10 @@ The final output MUST be ONLY the full, updated Markdown medical report.
202
  DO NOT include any introductory phrases, explanations, or any text other than the report itself.
203
  </output_format>"""
204
 
205
- def write_report(patient_name: str, interview_text: str, existing_report: str = None) -> str:
 
 
 
206
  """
207
  Constructs the full prompt, sends it to the LLM, and processes the response.
208
  This function handles both the initial creation and subsequent updates of a report.
@@ -212,7 +236,7 @@ def write_report(patient_name: str, interview_text: str, existing_report: str =
212
 
213
  # If no existing report is provided, load a default template from a string.
214
  if not existing_report:
215
- with open("report_template.txt", 'r') as f:
216
  existing_report = f.read()
217
 
218
  # Construct the user prompt with the specific task and data
@@ -237,149 +261,149 @@ Now, generate the complete and updated medical report based on all system and us
237
 
238
  # Assemble the full message payload for the LLM API
239
  messages = [
240
- {
241
- "role": "system",
242
- "content": [{"type": "text", "text": instructions}]
243
- },
244
- {
245
- "role": "user",
246
- "content": [{"type": "text", "text": user_prompt}]
247
- }
248
  ]
249
 
250
  report = medgemma_get_text_response(messages)
251
- cleaned_report = re.sub(r'<unused94>.*?</unused95>', '', report, flags=re.DOTALL)
252
  cleaned_report = cleaned_report.strip()
253
 
254
  # The LLM sometimes wraps the markdown report in a markdown code block.
255
  # This regex checks if the entire string is a code block and extracts the content.
256
- match = re.match(r'^\s*```(?:markdown)?\s*(.*?)\s*```\s*$', cleaned_report, re.DOTALL | re.IGNORECASE)
 
 
 
 
257
  if match:
258
  cleaned_report = match.group(1)
259
 
260
  return cleaned_report.strip()
261
 
262
 
263
-
264
  def stream_interview(patient_name, condition_name):
265
- print(f"Starting interview simulation for patient: {patient_name}, condition: {condition_name}")
 
 
266
  # Prepare roleplay instructions and initial dialog (using existing helper functions)
267
  interviewer_instructions = interviewer_roleplay_instructions(patient_name)
268
-
269
  # Determine voices for TTS
270
  patient = get_patient(patient_name)
271
  patient_voice = patient["voice"]
272
-
273
  dialog = [
274
  {
275
  "role": "system",
276
- "content": [
277
- {
278
- "type": "text",
279
- "text": interviewer_instructions
280
- }
281
- ]
282
  },
283
- {
284
- "role": "user",
285
- "content": [
286
- {
287
- "type": "text",
288
- "text": "start interview"
289
- }
290
- ]
291
- }
292
  ]
293
-
294
  write_report_text = ""
295
  full_interview_q_a = ""
296
  number_of_questions_limit = 30
297
  for i in range(number_of_questions_limit):
298
  # Get the next interviewer question from MedGemma
299
  interviewer_question_text = medgemma_get_text_response(
300
- messages=dialog,
301
- temperature=0.1,
302
- max_tokens=2048,
303
- stream=False
304
  )
305
  # Process optional "thinking" text (if present in the LLM output)
306
- thinking_search = re.search('<unused94>(.+?)<unused95>', interviewer_question_text, re.DOTALL)
 
 
307
  if thinking_search:
308
  thinking_text = thinking_search.group(1)
309
- interviewer_question_text = interviewer_question_text.replace(f'<unused94>{thinking_text}<unused95>', "")
 
 
310
  if i == 0:
311
  # Only yield the "thinking" summary for the first question
312
  thinking_text = gemini_get_text_response(
313
  f"""Provide a summary of up to 100 words containing only the reasoning and planning from this text,
314
- do not include instructions, use first person: {thinking_text}""")
315
- yield json.dumps({
316
- "speaker": "interviewer thinking",
317
- "text": thinking_text
318
- })
319
 
320
  # Clean up the text for TTS and display
321
- clean_interviewer_text = interviewer_question_text.replace("End interview.", "").strip()
 
 
322
 
323
  # Generate audio for the interviewer's question using Gemini TTS
324
- audio_data, mime_type = synthesize_gemini_tts(f"Speak in a slightly upbeat and brisk manner, as a friendly clinician: {clean_interviewer_text}", INTERVIEWER_VOICE)
 
 
 
325
  audio_b64 = None
326
  if audio_data and mime_type:
327
  audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
328
 
329
  # Yield interviewer message (text and audio)
330
- yield json.dumps({
331
- "speaker": "interviewer",
332
- "text": clean_interviewer_text,
333
- "audio": audio_b64
334
- })
335
- dialog.append({
336
- "role": "assistant",
337
- "content": [{
338
- "type": "text",
339
- "text": interviewer_question_text
340
- }]
341
- })
 
342
  if "End interview" in interviewer_question_text:
343
  # End the interview loop if the LLM signals completion
344
  break
345
 
346
  # Get the patient's response from Gemini (roleplay LLM)
347
- patient_response_text = gemini_get_text_response(f"""
 
348
  {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}\n\n
349
- Question: {interviewer_question_text}""")
 
350
 
351
  # Generate audio for the patient's response
352
- audio_data, mime_type = synthesize_gemini_tts(f"Say this in faster speed, using a sick tone: {patient_response_text}", patient_voice)
 
 
 
353
  audio_b64 = None
354
  if audio_data and mime_type:
355
  audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
356
 
357
  # Yield patient message (text and audio)
358
- yield json.dumps({
359
- "speaker": "patient",
360
- "text": patient_response_text,
361
- "audio": audio_b64
362
- })
363
- dialog.append({
364
- "role": "user",
365
- "content": [{
366
- "type": "text",
367
- "text": patient_response_text
368
- }]
369
- })
370
  # Track the full Q&A for context in future LLM calls
371
- most_recent_q_a = f"Q: {interviewer_question_text}\nA: {patient_response_text}\n"
372
- full_interview_q_a_with_new_q_a = "PREVIOUS Q&A:\n" + full_interview_q_a + "\nNEW Q&A:\n" + most_recent_q_a
 
 
 
 
373
  # Update the report after each Q&A
374
- write_report_text = write_report(patient_name, full_interview_q_a_with_new_q_a, write_report_text)
 
 
375
  full_interview_q_a += most_recent_q_a
376
- yield json.dumps({
377
- "speaker": "report",
378
- "text": write_report_text
379
- })
380
 
381
- print(f"""Interview simulation completed for patient: {patient_name}, condition: {condition_name}.
 
382
  Patient profile used:
383
- {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}""")
 
384
  # Add this at the end to signal end of stream
385
- yield json.dumps({"event": "end"})
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import base64
16
  import json
 
17
  import os
18
+ import re
19
 
20
  from gemini import gemini_get_text_response
 
21
  from gemini_tts import synthesize_gemini_tts
22
+ from medgemma import medgemma_get_text_response
23
 
24
  INTERVIEWER_VOICE = "Aoede"
25
 
26
+
27
  def read_symptoms_json():
28
  # Load the list of symptoms for each condition from a JSON file
29
+ with open("symptoms.json", "r") as f:
30
  return json.load(f)
31
 
32
+
33
  def read_patient_and_conditions_json():
34
  # Load all patient and condition data from the frontend assets
35
+ with open(
36
+ os.path.join(
37
+ os.environ.get("FRONTEND_BUILD", "frontend/build"),
38
+ "assets",
39
+ "patients_and_conditions.json",
40
+ ),
41
+ "r",
42
+ ) as f:
43
  return json.load(f)
44
 
45
+
46
  def get_patient(patient_name):
47
  """Helper function to locate a patient record by name. Raises StopIteration if not found."""
48
  return next(p for p in PATIENTS if p["name"] == patient_name)
49
 
50
+
51
  def read_fhir_json(patient):
52
  # Load the FHIR (EHR) JSON file for a given patient
53
+ with open(
54
+ os.path.join(
55
+ os.environ.get("FRONTEND_BUILD", "frontend/build"),
56
+ patient["fhirFile"].lstrip("/"),
57
+ ),
58
+ "r",
59
+ ) as f:
60
  return json.load(f)
61
 
62
+
63
  def get_ehr_summary_per_patient(patient_name):
64
  # Returns a concise EHR summary for the patient, using LLM if not already cached
65
  patient = get_patient(patient_name)
66
  if patient.get("ehr_summary"):
67
  return patient["ehr_summary"]
68
  # Use MedGemma to summarize the EHR for the patient
69
+ ehr_summary = medgemma_get_text_response(
70
+ [
71
+ {
72
+ "role": "system",
73
+ "content": [
74
+ {
75
+ "type": "text",
76
+ "text": f"""You are a medical assistant summarizing the EHR (FHIR) records for the patient {patient_name}.
77
  Provide a concise summary of the patient's medical history, including any existing conditions, medications, and relevant past treatments.
78
+ Do not include personal opinions or assumptions, only factual information.""",
79
+ }
80
+ ],
81
+ },
82
+ {
83
+ "role": "user",
84
+ "content": [
85
+ {"type": "text", "text": json.dumps(read_fhir_json(patient))}
86
+ ],
87
+ },
88
+ ]
89
+ )
 
 
90
  patient["ehr_summary"] = ehr_summary
91
  return ehr_summary
92
 
93
+
94
  PATIENTS = read_patient_and_conditions_json()["patients"]
95
  SYMPTOMS = read_symptoms_json()
96
+
97
+
98
  def patient_roleplay_instructions(patient_name, condition_name, previous_answers):
99
  """
100
  Generates structured instructions for the LLM to roleplay as a patient, including persona, scenario, and symptom logic.
 
139
  ---
140
  """
141
 
142
+
143
  def interviewer_roleplay_instructions(patient_name):
144
  # Returns detailed instructions for the LLM to roleplay as the interviewer/clinical assistant
145
  return f"""
 
173
  3. **End Interview:** You MUST continue the interview until you have asked 20 questions OR the patient is unable to provide more information. When the interview is complete, you MUST conclude by printing this exact phrase: "Thank you for answering my questions. I have everything needed to prepare a report for your visit. End interview."
174
  """
175
 
176
+
177
  def report_writer_instructions(patient_name: str) -> str:
178
  """
179
  Generates the system prompt with clear instructions, role, and constraints for the LLM.
 
223
  DO NOT include any introductory phrases, explanations, or any text other than the report itself.
224
  </output_format>"""
225
 
226
+
227
+ def write_report(
228
+ patient_name: str, interview_text: str, existing_report: str = None
229
+ ) -> str:
230
  """
231
  Constructs the full prompt, sends it to the LLM, and processes the response.
232
  This function handles both the initial creation and subsequent updates of a report.
 
236
 
237
  # If no existing report is provided, load a default template from a string.
238
  if not existing_report:
239
+ with open("report_template.txt", "r") as f:
240
  existing_report = f.read()
241
 
242
  # Construct the user prompt with the specific task and data
 
261
 
262
  # Assemble the full message payload for the LLM API
263
  messages = [
264
+ {"role": "system", "content": [{"type": "text", "text": instructions}]},
265
+ {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
 
 
 
 
 
 
266
  ]
267
 
268
  report = medgemma_get_text_response(messages)
269
+ cleaned_report = re.sub(r"<unused94>.*?</unused95>", "", report, flags=re.DOTALL)
270
  cleaned_report = cleaned_report.strip()
271
 
272
  # The LLM sometimes wraps the markdown report in a markdown code block.
273
  # This regex checks if the entire string is a code block and extracts the content.
274
+ match = re.match(
275
+ r"^\s*```(?:markdown)?\s*(.*?)\s*```\s*$",
276
+ cleaned_report,
277
+ re.DOTALL | re.IGNORECASE,
278
+ )
279
  if match:
280
  cleaned_report = match.group(1)
281
 
282
  return cleaned_report.strip()
283
 
284
 
 
285
  def stream_interview(patient_name, condition_name):
286
+ print(
287
+ f"Starting interview simulation for patient: {patient_name}, condition: {condition_name}"
288
+ )
289
  # Prepare roleplay instructions and initial dialog (using existing helper functions)
290
  interviewer_instructions = interviewer_roleplay_instructions(patient_name)
291
+
292
  # Determine voices for TTS
293
  patient = get_patient(patient_name)
294
  patient_voice = patient["voice"]
295
+
296
  dialog = [
297
  {
298
  "role": "system",
299
+ "content": [{"type": "text", "text": interviewer_instructions}],
 
 
 
 
 
300
  },
301
+ {"role": "user", "content": [{"type": "text", "text": "start interview"}]},
 
 
 
 
 
 
 
 
302
  ]
303
+
304
  write_report_text = ""
305
  full_interview_q_a = ""
306
  number_of_questions_limit = 30
307
  for i in range(number_of_questions_limit):
308
  # Get the next interviewer question from MedGemma
309
  interviewer_question_text = medgemma_get_text_response(
310
+ messages=dialog, temperature=0.1, max_tokens=2048, stream=False
 
 
 
311
  )
312
  # Process optional "thinking" text (if present in the LLM output)
313
+ thinking_search = re.search(
314
+ "<unused94>(.+?)<unused95>", interviewer_question_text, re.DOTALL
315
+ )
316
  if thinking_search:
317
  thinking_text = thinking_search.group(1)
318
+ interviewer_question_text = interviewer_question_text.replace(
319
+ f"<unused94>{thinking_text}<unused95>", ""
320
+ )
321
  if i == 0:
322
  # Only yield the "thinking" summary for the first question
323
  thinking_text = gemini_get_text_response(
324
  f"""Provide a summary of up to 100 words containing only the reasoning and planning from this text,
325
+ do not include instructions, use first person: {thinking_text}"""
326
+ )
327
+ yield json.dumps(
328
+ {"speaker": "interviewer thinking", "text": thinking_text}
329
+ )
330
 
331
  # Clean up the text for TTS and display
332
+ clean_interviewer_text = interviewer_question_text.replace(
333
+ "End interview.", ""
334
+ ).strip()
335
 
336
  # Generate audio for the interviewer's question using Gemini TTS
337
+ audio_data, mime_type = synthesize_gemini_tts(
338
+ f"Speak in a slightly upbeat and brisk manner, as a friendly clinician: {clean_interviewer_text}",
339
+ INTERVIEWER_VOICE,
340
+ )
341
  audio_b64 = None
342
  if audio_data and mime_type:
343
  audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
344
 
345
  # Yield interviewer message (text and audio)
346
+ yield json.dumps(
347
+ {
348
+ "speaker": "interviewer",
349
+ "text": clean_interviewer_text,
350
+ "audio": audio_b64,
351
+ }
352
+ )
353
+ dialog.append(
354
+ {
355
+ "role": "assistant",
356
+ "content": [{"type": "text", "text": interviewer_question_text}],
357
+ }
358
+ )
359
  if "End interview" in interviewer_question_text:
360
  # End the interview loop if the LLM signals completion
361
  break
362
 
363
  # Get the patient's response from Gemini (roleplay LLM)
364
+ patient_response_text = gemini_get_text_response(
365
+ f"""
366
  {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}\n\n
367
+ Question: {interviewer_question_text}"""
368
+ )
369
 
370
  # Generate audio for the patient's response
371
+ audio_data, mime_type = synthesize_gemini_tts(
372
+ f"Say this in faster speed, using a sick tone: {patient_response_text}",
373
+ patient_voice,
374
+ )
375
  audio_b64 = None
376
  if audio_data and mime_type:
377
  audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
378
 
379
  # Yield patient message (text and audio)
380
+ yield json.dumps(
381
+ {"speaker": "patient", "text": patient_response_text, "audio": audio_b64}
382
+ )
383
+ dialog.append(
384
+ {
385
+ "role": "user",
386
+ "content": [{"type": "text", "text": patient_response_text}],
387
+ }
388
+ )
 
 
 
389
  # Track the full Q&A for context in future LLM calls
390
+ most_recent_q_a = (
391
+ f"Q: {interviewer_question_text}\nA: {patient_response_text}\n"
392
+ )
393
+ full_interview_q_a_with_new_q_a = (
394
+ "PREVIOUS Q&A:\n" + full_interview_q_a + "\nNEW Q&A:\n" + most_recent_q_a
395
+ )
396
  # Update the report after each Q&A
397
+ write_report_text = write_report(
398
+ patient_name, full_interview_q_a_with_new_q_a, write_report_text
399
+ )
400
  full_interview_q_a += most_recent_q_a
401
+ yield json.dumps({"speaker": "report", "text": write_report_text})
 
 
 
402
 
403
+ print(
404
+ f"""Interview simulation completed for patient: {patient_name}, condition: {condition_name}.
405
  Patient profile used:
406
+ {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}"""
407
+ )
408
  # Add this at the end to signal end of stream
409
+ yield json.dumps({"event": "end"})
medgemma.py CHANGED
@@ -12,18 +12,21 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  # MedGemma endpoint
16
  import requests
 
17
  from auth import create_credentials, get_access_token_refresh_if_needed
18
- import os
19
  from cache import cache
20
 
21
- _endpoint_url = os.environ.get('GCP_MEDGEMMA_ENDPOINT')
22
 
23
  # Create credentials
24
- secret_key_json = os.environ.get('GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY')
25
  medgemma_credentials = create_credentials(secret_key_json)
26
 
 
27
  # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations.endpoints.chat/completions
28
  @cache.memoize()
29
  def medgemma_get_text_response(
@@ -36,7 +39,7 @@ def medgemma_get_text_response(
36
  stop: list[str] | str | None = None,
37
  frequency_penalty: float | None = None,
38
  presence_penalty: float | None = None,
39
- model: str="tgi"
40
  ):
41
  """
42
  Makes a chat completion request to the configured LLM API (OpenAI-compatible).
@@ -47,26 +50,31 @@ def medgemma_get_text_response(
47
  }
48
 
49
  # Based on the openai format
50
- payload = {
51
- "messages": messages,
52
- "max_tokens": max_tokens
53
- }
54
-
55
-
56
- if temperature is not None: payload["temperature"] = temperature
57
- if top_p is not None: payload["top_p"] = top_p
58
- if seed is not None: payload["seed"] = seed
59
- if stop is not None: payload["stop"] = stop
60
- if frequency_penalty is not None: payload["frequency_penalty"] = frequency_penalty
61
- if presence_penalty is not None: payload["presence_penalty"] = presence_penalty
62
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- response = requests.post(_endpoint_url, headers=headers, json=payload, stream=stream, timeout=60)
 
 
65
  try:
66
  response.raise_for_status()
67
  return response.json()["choices"][0]["message"]["content"]
68
  except requests.exceptions.JSONDecodeError:
69
  # Log the problematic response for easier debugging in the future.
70
- print(f"Error: Failed to decode JSON from MedGemma. Status: {response.status_code}, Response: {response.text}")
 
 
71
  # Re-raise the exception so the caller knows something went wrong.
72
  raise
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import os
16
+
17
  # MedGemma endpoint
18
  import requests
19
+
20
  from auth import create_credentials, get_access_token_refresh_if_needed
 
21
  from cache import cache
22
 
23
+ _endpoint_url = os.environ.get("GCP_MEDGEMMA_ENDPOINT")
24
 
25
  # Create credentials
26
+ secret_key_json = os.environ.get("GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY")
27
  medgemma_credentials = create_credentials(secret_key_json)
28
 
29
+
30
  # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations.endpoints.chat/completions
31
  @cache.memoize()
32
  def medgemma_get_text_response(
 
39
  stop: list[str] | str | None = None,
40
  frequency_penalty: float | None = None,
41
  presence_penalty: float | None = None,
42
+ model: str = "tgi",
43
  ):
44
  """
45
  Makes a chat completion request to the configured LLM API (OpenAI-compatible).
 
50
  }
51
 
52
  # Based on the openai format
53
+ payload = {"messages": messages, "max_tokens": max_tokens}
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ if temperature is not None:
56
+ payload["temperature"] = temperature
57
+ if top_p is not None:
58
+ payload["top_p"] = top_p
59
+ if seed is not None:
60
+ payload["seed"] = seed
61
+ if stop is not None:
62
+ payload["stop"] = stop
63
+ if frequency_penalty is not None:
64
+ payload["frequency_penalty"] = frequency_penalty
65
+ if presence_penalty is not None:
66
+ payload["presence_penalty"] = presence_penalty
67
 
68
+ response = requests.post(
69
+ _endpoint_url, headers=headers, json=payload, stream=stream, timeout=60
70
+ )
71
  try:
72
  response.raise_for_status()
73
  return response.json()["choices"][0]["message"]["content"]
74
  except requests.exceptions.JSONDecodeError:
75
  # Log the problematic response for easier debugging in the future.
76
+ print(
77
+ f"Error: Failed to decode JSON from MedGemma. Status: {response.status_code}, Response: {response.text}"
78
+ )
79
  # Re-raise the exception so the caller knows something went wrong.
80
  raise