Spaces:
				
			
			
	
			
			
					
		Running
		
			on 
			
			CPU Upgrade
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
			on 
			
			CPU Upgrade
	
		Vaishnav Muraleedharan
		
	commited on
		
		
					Commit 
							
							·
						
						9b39b7c
	
1
								Parent(s):
							
							88c6024
								
chore: format code to maintain consistent style
Browse files- __init__.py +0 -1
 - app.py +31 -11
 - auth.py +53 -47
 - cache.py +19 -11
 - evaluation.py +21 -21
 - gemini.py +16 -21
 - gemini_tts.py +63 -28
 - interview_simulator.py +134 -110
 - medgemma.py +26 -18
 
    	
        __init__.py
    CHANGED
    
    | 
         @@ -11,4 +11,3 @@ 
     | 
|
| 11 | 
         
             
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         
            -
             
     | 
| 
         | 
|
| 11 | 
         
             
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 
         | 
    	
        app.py
    CHANGED
    
    | 
         @@ -12,18 +12,36 @@ 
     | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 17 | 
         
             
            from flask_cors import CORS
         
     | 
| 18 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 19 | 
         
             
            from gemini import gemini_get_text_response
         
     | 
| 20 | 
         
             
            from interview_simulator import stream_interview
         
     | 
| 21 | 
         
            -
            from cache import create_cache_zip
         
     | 
| 22 | 
         
             
            from medgemma import medgemma_get_text_response
         
     | 
| 23 | 
         | 
| 24 | 
         
            -
            app = Flask( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 25 | 
         
             
            CORS(app, resources={r"/api/*": {"origins": "http://localhost:3000"}})
         
     | 
| 26 | 
         | 
| 
         | 
|
| 27 | 
         
             
            @app.route("/")
         
     | 
| 28 | 
         
             
            def serve():
         
     | 
| 29 | 
         
             
                """Serves the main index.html file."""
         
     | 
| 
         @@ -35,7 +53,7 @@ def stream_conversation(): 
     | 
|
| 35 | 
         
             
                """Streams the conversation with the interview simulator."""
         
     | 
| 36 | 
         
             
                patient = request.args.get("patient", "Patient")
         
     | 
| 37 | 
         
             
                condition = request.args.get("condition", "unknown condition")
         
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
             
                def generate():
         
     | 
| 40 | 
         
             
                    try:
         
     | 
| 41 | 
         
             
                        for message in stream_interview(patient, condition):
         
     | 
| 
         @@ -43,9 +61,10 @@ def stream_conversation(): 
     | 
|
| 43 | 
         
             
                    except Exception as e:
         
     | 
| 44 | 
         
             
                        yield f"data: Error: {str(e)}\n\n"
         
     | 
| 45 | 
         
             
                        raise e
         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
             
                return Response(stream_with_context(generate()), mimetype="text/event-stream")
         
     | 
| 48 | 
         | 
| 
         | 
|
| 49 | 
         
             
            @app.route("/api/evaluate_report", methods=["POST"])
         
     | 
| 50 | 
         
             
            def evaluate_report_call():
         
     | 
| 51 | 
         
             
                """Evaluates the provided medical report."""
         
     | 
| 
         @@ -55,10 +74,10 @@ def evaluate_report_call(): 
     | 
|
| 55 | 
         
             
                    return jsonify({"error": "Report is required"}), 400
         
     | 
| 56 | 
         
             
                condition = data.get("condition", "")
         
     | 
| 57 | 
         
             
                if not condition:
         
     | 
| 58 | 
         
            -
                    return jsonify({"error": "Condition is required"}), 400 
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
             
                evaluation_text = evaluate_report(report, condition)
         
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
             
                return jsonify({"evaluation": evaluation_text})
         
     | 
| 63 | 
         | 
| 64 | 
         | 
| 
         @@ -81,6 +100,7 @@ def static_proxy(path): 
     | 
|
| 81 | 
         
             
                    return send_from_directory(app.static_folder, path)
         
     | 
| 82 | 
         
             
                else:
         
     | 
| 83 | 
         
             
                    return send_from_directory(app.static_folder, "index.html")
         
     | 
| 84 | 
         
            -
             
     | 
| 
         | 
|
| 85 | 
         
             
            if __name__ == "__main__":
         
     | 
| 86 | 
         
             
                app.run(host="0.0.0.0", port=7860, threaded=True)
         
     | 
| 
         | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            +
            import json
         
     | 
| 16 | 
         
            +
            import os
         
     | 
| 17 | 
         
            +
            import re
         
     | 
| 18 | 
         
            +
            import time
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from flask import (
         
     | 
| 21 | 
         
            +
                Flask,
         
     | 
| 22 | 
         
            +
                Response,
         
     | 
| 23 | 
         
            +
                jsonify,
         
     | 
| 24 | 
         
            +
                request,
         
     | 
| 25 | 
         
            +
                send_file,
         
     | 
| 26 | 
         
            +
                send_from_directory,
         
     | 
| 27 | 
         
            +
                stream_with_context,
         
     | 
| 28 | 
         
            +
            )
         
     | 
| 29 | 
         
             
            from flask_cors import CORS
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            from cache import create_cache_zip
         
     | 
| 32 | 
         
            +
            from evaluation import evaluate_report, evaluation_prompt
         
     | 
| 33 | 
         
             
            from gemini import gemini_get_text_response
         
     | 
| 34 | 
         
             
            from interview_simulator import stream_interview
         
     | 
| 
         | 
|
| 35 | 
         
             
            from medgemma import medgemma_get_text_response
         
     | 
| 36 | 
         | 
| 37 | 
         
            +
            app = Flask(
         
     | 
| 38 | 
         
            +
                __name__,
         
     | 
| 39 | 
         
            +
                static_folder=os.environ.get("FRONTEND_BUILD", "frontend/build"),
         
     | 
| 40 | 
         
            +
                static_url_path="/",
         
     | 
| 41 | 
         
            +
            )
         
     | 
| 42 | 
         
             
            CORS(app, resources={r"/api/*": {"origins": "http://localhost:3000"}})
         
     | 
| 43 | 
         | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
             
            @app.route("/")
         
     | 
| 46 | 
         
             
            def serve():
         
     | 
| 47 | 
         
             
                """Serves the main index.html file."""
         
     | 
| 
         | 
|
| 53 | 
         
             
                """Streams the conversation with the interview simulator."""
         
     | 
| 54 | 
         
             
                patient = request.args.get("patient", "Patient")
         
     | 
| 55 | 
         
             
                condition = request.args.get("condition", "unknown condition")
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
             
                def generate():
         
     | 
| 58 | 
         
             
                    try:
         
     | 
| 59 | 
         
             
                        for message in stream_interview(patient, condition):
         
     | 
| 
         | 
|
| 61 | 
         
             
                    except Exception as e:
         
     | 
| 62 | 
         
             
                        yield f"data: Error: {str(e)}\n\n"
         
     | 
| 63 | 
         
             
                        raise e
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
             
                return Response(stream_with_context(generate()), mimetype="text/event-stream")
         
     | 
| 66 | 
         | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
             
            @app.route("/api/evaluate_report", methods=["POST"])
         
     | 
| 69 | 
         
             
            def evaluate_report_call():
         
     | 
| 70 | 
         
             
                """Evaluates the provided medical report."""
         
     | 
| 
         | 
|
| 74 | 
         
             
                    return jsonify({"error": "Report is required"}), 400
         
     | 
| 75 | 
         
             
                condition = data.get("condition", "")
         
     | 
| 76 | 
         
             
                if not condition:
         
     | 
| 77 | 
         
            +
                    return jsonify({"error": "Condition is required"}), 400
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
             
                evaluation_text = evaluate_report(report, condition)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
             
                return jsonify({"evaluation": evaluation_text})
         
     | 
| 82 | 
         | 
| 83 | 
         | 
| 
         | 
|
| 100 | 
         
             
                    return send_from_directory(app.static_folder, path)
         
     | 
| 101 | 
         
             
                else:
         
     | 
| 102 | 
         
             
                    return send_from_directory(app.static_folder, "index.html")
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
             
            if __name__ == "__main__":
         
     | 
| 106 | 
         
             
                app.run(host="0.0.0.0", port=7860, threaded=True)
         
     | 
    	
        auth.py
    CHANGED
    
    | 
         @@ -12,66 +12,72 @@ 
     | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            -
            import json
         
     | 
| 16 | 
         
             
            import datetime
         
     | 
| 17 | 
         
            -
             
     | 
| 
         | 
|
| 18 | 
         
             
            import google.auth.transport.requests
         
     | 
| 
         | 
|
| 
         | 
|
| 19 | 
         | 
| 20 | 
         
             
            def create_credentials(secret_key_json) -> service_account.Credentials:
         
     | 
| 21 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 22 | 
         | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 25 | 
         | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 30 | 
         | 
| 31 | 
         
            -
              if not secret_key_json:
         
     | 
| 32 | 
         
            -
                raise ValueError("Userdata variable 'GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY' is not set or is empty.")
         
     | 
| 33 | 
         
            -
              try:
         
     | 
| 34 | 
         
            -
                service_account_info = json.loads(secret_key_json)
         
     | 
| 35 | 
         
            -
              except (SyntaxError, ValueError) as e:
         
     | 
| 36 | 
         
            -
                raise ValueError("Invalid service account key JSON format.") from e
         
     | 
| 37 | 
         
            -
              return service_account.Credentials.from_service_account_info(
         
     | 
| 38 | 
         
            -
                service_account_info,
         
     | 
| 39 | 
         
            -
                scopes=['https://www.googleapis.com/auth/cloud-platform']
         
     | 
| 40 | 
         
            -
              )
         
     | 
| 41 | 
         | 
| 42 | 
         
            -
            def refresh_credentials( 
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 45 | 
         | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 58 | 
         
             
                    request = google.auth.transport.requests.Request()
         
     | 
| 59 | 
         
             
                    credentials.refresh(request)
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
                # If no expiry is set, always attempt to refresh (e.g., for certain credential types)
         
     | 
| 62 | 
         
            -
                request = google.auth.transport.requests.Request()
         
     | 
| 63 | 
         
            -
                credentials.refresh(request)
         
     | 
| 64 | 
         
            -
              return credentials
         
     | 
| 65 | 
         | 
| 66 | 
         
            -
            def get_access_token_refresh_if_needed(credentials: service_account.Credentials) -> str:
         
     | 
| 67 | 
         
            -
              """Gets the access token from the credentials, refreshing them if needed.
         
     | 
| 68 | 
         | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
             
     | 
| 74 | 
         
            -
              """
         
     | 
| 75 | 
         
            -
              credentials = refresh_credentials(credentials)
         
     | 
| 76 | 
         
            -
              return credentials.token
         
     | 
| 77 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 
         | 
|
| 15 | 
         
             
            import datetime
         
     | 
| 16 | 
         
            +
            import json
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
             
            import google.auth.transport.requests
         
     | 
| 19 | 
         
            +
            from google.oauth2 import service_account
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         | 
| 22 | 
         
             
            def create_credentials(secret_key_json) -> service_account.Credentials:
         
     | 
| 23 | 
         
            +
                """Creates Google Cloud credentials from the provided service account key.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                Returns:
         
     | 
| 26 | 
         
            +
                    service_account.Credentials: The created credentials object.
         
     | 
| 27 | 
         | 
| 28 | 
         
            +
                Raises:
         
     | 
| 29 | 
         
            +
                    ValueError: If the environment variable is not set or is empty, or if the
         
     | 
| 30 | 
         
            +
                        JSON format is invalid.
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         | 
| 33 | 
         
            +
                if not secret_key_json:
         
     | 
| 34 | 
         
            +
                    raise ValueError(
         
     | 
| 35 | 
         
            +
                        "Userdata variable 'GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY' is not set or is empty."
         
     | 
| 36 | 
         
            +
                    )
         
     | 
| 37 | 
         
            +
                try:
         
     | 
| 38 | 
         
            +
                    service_account_info = json.loads(secret_key_json)
         
     | 
| 39 | 
         
            +
                except (SyntaxError, ValueError) as e:
         
     | 
| 40 | 
         
            +
                    raise ValueError("Invalid service account key JSON format.") from e
         
     | 
| 41 | 
         
            +
                return service_account.Credentials.from_service_account_info(
         
     | 
| 42 | 
         
            +
                    service_account_info, scopes=["https://www.googleapis.com/auth/cloud-platform"]
         
     | 
| 43 | 
         
            +
                )
         
     | 
| 44 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 45 | 
         | 
| 46 | 
         
            +
            def refresh_credentials(
         
     | 
| 47 | 
         
            +
                credentials: service_account.Credentials,
         
     | 
| 48 | 
         
            +
            ) -> service_account.Credentials:
         
     | 
| 49 | 
         
            +
                """Refreshes the provided Google Cloud credentials if they are about to expire
         
     | 
| 50 | 
         
            +
                  (within 5 minutes) or if they don't have an expiry time set.
         
     | 
| 51 | 
         | 
| 52 | 
         
            +
                Args:
         
     | 
| 53 | 
         
            +
                    credentials: The credentials object to refresh.
         
     | 
| 54 | 
         | 
| 55 | 
         
            +
                Returns:
         
     | 
| 56 | 
         
            +
                    service_account.Credentials: The refreshed credentials object.
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
                if credentials.expiry:
         
     | 
| 59 | 
         
            +
                    expiry_time = credentials.expiry.replace(tzinfo=datetime.timezone.utc)
         
     | 
| 60 | 
         
            +
                    # Calculate the time remaining until expiration
         
     | 
| 61 | 
         
            +
                    time_remaining = expiry_time - datetime.datetime.now(datetime.timezone.utc)
         
     | 
| 62 | 
         
            +
                    # Check if the token is about to expire (e.g., within 5 minutes)
         
     | 
| 63 | 
         
            +
                    if time_remaining < datetime.timedelta(minutes=5):
         
     | 
| 64 | 
         
            +
                        request = google.auth.transport.requests.Request()
         
     | 
| 65 | 
         
            +
                        credentials.refresh(request)
         
     | 
| 66 | 
         
            +
                else:
         
     | 
| 67 | 
         
            +
                    # If no expiry is set, always attempt to refresh (e.g., for certain credential types)
         
     | 
| 68 | 
         
             
                    request = google.auth.transport.requests.Request()
         
     | 
| 69 | 
         
             
                    credentials.refresh(request)
         
     | 
| 70 | 
         
            +
                return credentials
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 71 | 
         | 
| 
         | 
|
| 
         | 
|
| 72 | 
         | 
| 73 | 
         
            +
            def get_access_token_refresh_if_needed(credentials: service_account.Credentials) -> str:
         
     | 
| 74 | 
         
            +
                """Gets the access token from the credentials, refreshing them if needed.
         
     | 
| 75 | 
         | 
| 76 | 
         
            +
                Args:
         
     | 
| 77 | 
         
            +
                    credentials: The credentials object.
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 78 | 
         | 
| 79 | 
         
            +
                Returns:
         
     | 
| 80 | 
         
            +
                    str: The access token.
         
     | 
| 81 | 
         
            +
                """
         
     | 
| 82 | 
         
            +
                credentials = refresh_credentials(credentials)
         
     | 
| 83 | 
         
            +
                return credentials.token
         
     | 
    	
        cache.py
    CHANGED
    
    | 
         @@ -12,12 +12,13 @@ 
     | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
             
            import os
         
     | 
| 17 | 
         
             
            import shutil
         
     | 
| 18 | 
         
             
            import tempfile
         
     | 
| 19 | 
         
             
            import zipfile
         
     | 
| 20 | 
         
            -
             
     | 
| 
         | 
|
| 21 | 
         | 
| 22 | 
         
             
            cache = Cache(os.environ.get("CACHE_DIR", "/cache"))
         
     | 
| 23 | 
         
             
            # Print cache statistics after loading
         
     | 
| 
         @@ -28,16 +29,17 @@ try: 
     | 
|
| 28 | 
         
             
            except Exception as e:
         
     | 
| 29 | 
         
             
                print(f"Could not retrieve cache statistics: {e}")
         
     | 
| 30 | 
         | 
| 
         | 
|
| 31 | 
         
             
            def create_cache_zip():
         
     | 
| 32 | 
         
             
                temp_dir = tempfile.gettempdir()
         
     | 
| 33 | 
         
            -
                base_name = os.path.join(temp_dir, "cache_archive") 
     | 
| 34 | 
         
             
                archive_path = base_name + ".zip"
         
     | 
| 35 | 
         
             
                cache_directory = os.environ.get("CACHE_DIR", "/cache")
         
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
             
                if not os.path.isdir(cache_directory):
         
     | 
| 38 | 
         
             
                    logging.error(f"Cache directory not found at {cache_directory}")
         
     | 
| 39 | 
         
             
                    return None, f"Cache directory not found on server: {cache_directory}"
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
             
                logging.info("Forcing a cache checkpoint for safe backup...")
         
     | 
| 42 | 
         
             
                try:
         
     | 
| 43 | 
         
             
                    # Open and immediately close a connection.
         
     | 
| 
         @@ -45,15 +47,19 @@ def create_cache_zip(): 
     | 
|
| 45 | 
         
             
                    # into the main .db file, ensuring the on-disk files are consistent.
         
     | 
| 46 | 
         
             
                    with Cache(cache_directory) as temp_cache:
         
     | 
| 47 | 
         
             
                        temp_cache.close()
         
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
             
                    # Clean up temporary files before archiving.
         
     | 
| 50 | 
         
            -
                    tmp_path = os.path.join(cache_directory,  
     | 
| 51 | 
         
             
                    if os.path.isdir(tmp_path):
         
     | 
| 52 | 
         
             
                        logging.info(f"Removing temporary cache directory: {tmp_path}")
         
     | 
| 53 | 
         
             
                        shutil.rmtree(tmp_path)
         
     | 
| 54 | 
         | 
| 55 | 
         
            -
                    logging.info( 
     | 
| 56 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 57 | 
         
             
                        for root, _, files in os.walk(cache_directory):
         
     | 
| 58 | 
         
             
                            for file in files:
         
     | 
| 59 | 
         
             
                                file_path = os.path.join(root, file)
         
     | 
| 
         @@ -61,7 +67,9 @@ def create_cache_zip(): 
     | 
|
| 61 | 
         
             
                                zipf.write(file_path, arcname)
         
     | 
| 62 | 
         
             
                    logging.info("Zip archive created successfully.")
         
     | 
| 63 | 
         
             
                    return archive_path, None
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
             
                except Exception as e:
         
     | 
| 66 | 
         
            -
                    logging.error( 
     | 
| 
         | 
|
| 
         | 
|
| 67 | 
         
             
                    return None, f"Error creating zip archive: {e}"
         
     | 
| 
         | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            +
            import logging
         
     | 
| 16 | 
         
             
            import os
         
     | 
| 17 | 
         
             
            import shutil
         
     | 
| 18 | 
         
             
            import tempfile
         
     | 
| 19 | 
         
             
            import zipfile
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from diskcache import Cache
         
     | 
| 22 | 
         | 
| 23 | 
         
             
            cache = Cache(os.environ.get("CACHE_DIR", "/cache"))
         
     | 
| 24 | 
         
             
            # Print cache statistics after loading
         
     | 
| 
         | 
|
| 29 | 
         
             
            except Exception as e:
         
     | 
| 30 | 
         
             
                print(f"Could not retrieve cache statistics: {e}")
         
     | 
| 31 | 
         | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
             
            def create_cache_zip():
         
     | 
| 34 | 
         
             
                temp_dir = tempfile.gettempdir()
         
     | 
| 35 | 
         
            +
                base_name = os.path.join(temp_dir, "cache_archive")  # A more descriptive name
         
     | 
| 36 | 
         
             
                archive_path = base_name + ".zip"
         
     | 
| 37 | 
         
             
                cache_directory = os.environ.get("CACHE_DIR", "/cache")
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
             
                if not os.path.isdir(cache_directory):
         
     | 
| 40 | 
         
             
                    logging.error(f"Cache directory not found at {cache_directory}")
         
     | 
| 41 | 
         
             
                    return None, f"Cache directory not found on server: {cache_directory}"
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
             
                logging.info("Forcing a cache checkpoint for safe backup...")
         
     | 
| 44 | 
         
             
                try:
         
     | 
| 45 | 
         
             
                    # Open and immediately close a connection.
         
     | 
| 
         | 
|
| 47 | 
         
             
                    # into the main .db file, ensuring the on-disk files are consistent.
         
     | 
| 48 | 
         
             
                    with Cache(cache_directory) as temp_cache:
         
     | 
| 49 | 
         
             
                        temp_cache.close()
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
             
                    # Clean up temporary files before archiving.
         
     | 
| 52 | 
         
            +
                    tmp_path = os.path.join(cache_directory, "tmp")
         
     | 
| 53 | 
         
             
                    if os.path.isdir(tmp_path):
         
     | 
| 54 | 
         
             
                        logging.info(f"Removing temporary cache directory: {tmp_path}")
         
     | 
| 55 | 
         
             
                        shutil.rmtree(tmp_path)
         
     | 
| 56 | 
         | 
| 57 | 
         
            +
                    logging.info(
         
     | 
| 58 | 
         
            +
                        f"Checkpoint complete. Creating zip archive of {cache_directory} to {archive_path}"
         
     | 
| 59 | 
         
            +
                    )
         
     | 
| 60 | 
         
            +
                    with zipfile.ZipFile(
         
     | 
| 61 | 
         
            +
                        archive_path, "w", zipfile.ZIP_DEFLATED, compresslevel=9
         
     | 
| 62 | 
         
            +
                    ) as zipf:
         
     | 
| 63 | 
         
             
                        for root, _, files in os.walk(cache_directory):
         
     | 
| 64 | 
         
             
                            for file in files:
         
     | 
| 65 | 
         
             
                                file_path = os.path.join(root, file)
         
     | 
| 
         | 
|
| 67 | 
         
             
                                zipf.write(file_path, arcname)
         
     | 
| 68 | 
         
             
                    logging.info("Zip archive created successfully.")
         
     | 
| 69 | 
         
             
                    return archive_path, None
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
             
                except Exception as e:
         
     | 
| 72 | 
         
            +
                    logging.error(
         
     | 
| 73 | 
         
            +
                        f"Error creating zip archive of cache directory: {e}", exc_info=True
         
     | 
| 74 | 
         
            +
                    )
         
     | 
| 75 | 
         
             
                    return None, f"Error creating zip archive: {e}"
         
     | 
    	
        evaluation.py
    CHANGED
    
    | 
         @@ -13,6 +13,7 @@ 
     | 
|
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            import re
         
     | 
| 
         | 
|
| 16 | 
         
             
            from medgemma import medgemma_get_text_response
         
     | 
| 17 | 
         | 
| 18 | 
         | 
| 
         @@ -40,30 +41,29 @@ REPORT TEMPLATE START 
     | 
|
| 40 | 
         
             
            REPORT TEMPLATE END
         
     | 
| 41 | 
         
             
            """
         
     | 
| 42 | 
         | 
| 
         | 
|
| 43 | 
         
             
            def evaluate_report(report, condition):
         
     | 
| 44 | 
         
             
                """Evaluate the pre-visit report based on the condition using MedGemma LLM."""
         
     | 
| 45 | 
         
            -
                evaluation_text = medgemma_get_text_response( 
     | 
| 46 | 
         
            -
                     
     | 
| 47 | 
         
            -
                         
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
                             
     | 
| 50 | 
         
            -
                                "type": "text",
         
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
                         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
                            }
         
     | 
| 62 | 
         
            -
                        ]
         
     | 
| 63 | 
         
            -
                    },        
         
     | 
| 64 | 
         
            -
                ])
         
     | 
| 65 | 
         | 
| 66 | 
         
             
                # Remove any LLM "thinking" blocks (special tokens sometimes present in output)
         
     | 
| 67 | 
         
            -
                evaluation_text = re.sub( 
     | 
| 
         | 
|
| 
         | 
|
| 68 | 
         | 
| 69 | 
         
             
                return evaluation_text
         
     | 
| 
         | 
|
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            import re
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
             
            from medgemma import medgemma_get_text_response
         
     | 
| 18 | 
         | 
| 19 | 
         | 
| 
         | 
|
| 41 | 
         
             
            REPORT TEMPLATE END
         
     | 
| 42 | 
         
             
            """
         
     | 
| 43 | 
         | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
             
            def evaluate_report(report, condition):
         
     | 
| 46 | 
         
             
                """Evaluate the pre-visit report based on the condition using MedGemma LLM."""
         
     | 
| 47 | 
         
            +
                evaluation_text = medgemma_get_text_response(
         
     | 
| 48 | 
         
            +
                    [
         
     | 
| 49 | 
         
            +
                        {
         
     | 
| 50 | 
         
            +
                            "role": "system",
         
     | 
| 51 | 
         
            +
                            "content": [
         
     | 
| 52 | 
         
            +
                                {"type": "text", "text": f"{evaluation_prompt(condition)}"}
         
     | 
| 53 | 
         
            +
                            ],
         
     | 
| 54 | 
         
            +
                        },
         
     | 
| 55 | 
         
            +
                        {
         
     | 
| 56 | 
         
            +
                            "role": "user",
         
     | 
| 57 | 
         
            +
                            "content": [
         
     | 
| 58 | 
         
            +
                                {"type": "text", "text": f"Here is the report text:\n{report}"}
         
     | 
| 59 | 
         
            +
                            ],
         
     | 
| 60 | 
         
            +
                        },
         
     | 
| 61 | 
         
            +
                    ]
         
     | 
| 62 | 
         
            +
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 63 | 
         | 
| 64 | 
         
             
                # Remove any LLM "thinking" blocks (special tokens sometimes present in output)
         
     | 
| 65 | 
         
            +
                evaluation_text = re.sub(
         
     | 
| 66 | 
         
            +
                    r"<unused94>.*?<unused95>", "", evaluation_text, flags=re.DOTALL
         
     | 
| 67 | 
         
            +
                )
         
     | 
| 68 | 
         | 
| 69 | 
         
             
                return evaluation_text
         
     | 
    	
        gemini.py
    CHANGED
    
    | 
         @@ -13,47 +13,42 @@ 
     | 
|
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            import os
         
     | 
| 
         | 
|
| 16 | 
         
             
            import requests
         
     | 
| 
         | 
|
| 17 | 
         
             
            from cache import cache  # new import replacing duplicate cache initialization
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
         
     | 
| 20 | 
         | 
| 
         | 
|
| 21 | 
         
             
            # Decorate the function to cache its results indefinitely.
         
     | 
| 22 | 
         
             
            @cache.memoize()
         
     | 
| 23 | 
         
            -
            def gemini_get_text_response( 
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 29 | 
         
             
                """
         
     | 
| 30 | 
         
             
                Makes a text generation request to the Gemini API.
         
     | 
| 31 | 
         
             
                """
         
     | 
| 32 | 
         | 
| 33 | 
         
             
                api_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={GEMINI_API_KEY}"
         
     | 
| 34 | 
         
            -
                headers = {
         
     | 
| 35 | 
         
            -
                    'Content-Type': 'application/json'
         
     | 
| 36 | 
         
            -
                }
         
     | 
| 37 | 
         | 
| 38 | 
         
             
                data = {
         
     | 
| 39 | 
         
            -
                    "contents": [
         
     | 
| 40 | 
         
            -
                        {
         
     | 
| 41 | 
         
            -
                            "parts": [
         
     | 
| 42 | 
         
            -
                                {
         
     | 
| 43 | 
         
            -
                                    "text": prompt
         
     | 
| 44 | 
         
            -
                                }
         
     | 
| 45 | 
         
            -
                            ]
         
     | 
| 46 | 
         
            -
                        }
         
     | 
| 47 | 
         
            -
                    ],
         
     | 
| 48 | 
         
             
                    "generationConfig": {
         
     | 
| 49 | 
         
             
                        "stopSequences": stop_sequences or ["Title"],
         
     | 
| 50 | 
         
             
                        "temperature": temperature,
         
     | 
| 51 | 
         
             
                        "maxOutputTokens": max_output_tokens,
         
     | 
| 52 | 
         
             
                        "topP": top_p,
         
     | 
| 53 | 
         
            -
                        "topK": top_k
         
     | 
| 54 | 
         
            -
                    }
         
     | 
| 55 | 
         
             
                }
         
     | 
| 56 | 
         | 
| 57 | 
         
             
                response = requests.post(api_url, headers=headers, json=data)
         
     | 
| 58 | 
         
             
                response.raise_for_status()  # Raise an exception for bad status codes
         
     | 
| 59 | 
         
            -
                return response.json()["candidates"][0]["content"]["parts"][0]["text"]
         
     | 
| 
         | 
|
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            import os
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
             
            import requests
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
             
            from cache import cache  # new import replacing duplicate cache initialization
         
     | 
| 20 | 
         | 
| 21 | 
         
             
            GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
         
     | 
| 22 | 
         | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
             
            # Decorate the function to cache its results indefinitely.
         
     | 
| 25 | 
         
             
            @cache.memoize()
         
     | 
| 26 | 
         
            +
            def gemini_get_text_response(
         
     | 
| 27 | 
         
            +
                prompt: str,
         
     | 
| 28 | 
         
            +
                stop_sequences: list = None,
         
     | 
| 29 | 
         
            +
                temperature: float = 0.1,
         
     | 
| 30 | 
         
            +
                max_output_tokens: int = 4000,
         
     | 
| 31 | 
         
            +
                top_p: float = 0.8,
         
     | 
| 32 | 
         
            +
                top_k: int = 10,
         
     | 
| 33 | 
         
            +
            ):
         
     | 
| 34 | 
         
             
                """
         
     | 
| 35 | 
         
             
                Makes a text generation request to the Gemini API.
         
     | 
| 36 | 
         
             
                """
         
     | 
| 37 | 
         | 
| 38 | 
         
             
                api_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={GEMINI_API_KEY}"
         
     | 
| 39 | 
         
            +
                headers = {"Content-Type": "application/json"}
         
     | 
| 
         | 
|
| 
         | 
|
| 40 | 
         | 
| 41 | 
         
             
                data = {
         
     | 
| 42 | 
         
            +
                    "contents": [{"parts": [{"text": prompt}]}],
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 43 | 
         
             
                    "generationConfig": {
         
     | 
| 44 | 
         
             
                        "stopSequences": stop_sequences or ["Title"],
         
     | 
| 45 | 
         
             
                        "temperature": temperature,
         
     | 
| 46 | 
         
             
                        "maxOutputTokens": max_output_tokens,
         
     | 
| 47 | 
         
             
                        "topP": top_p,
         
     | 
| 48 | 
         
            +
                        "topK": top_k,
         
     | 
| 49 | 
         
            +
                    },
         
     | 
| 50 | 
         
             
                }
         
     | 
| 51 | 
         | 
| 52 | 
         
             
                response = requests.post(api_url, headers=headers, json=data)
         
     | 
| 53 | 
         
             
                response.raise_for_status()  # Raise an exception for bad status codes
         
     | 
| 54 | 
         
            +
                return response.json()["candidates"][0]["content"]["parts"][0]["text"]
         
     | 
    	
        gemini_tts.py
    CHANGED
    
    | 
         @@ -12,16 +12,18 @@ 
     | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            -
            import  
     | 
| 
         | 
|
| 16 | 
         
             
            import os
         
     | 
| 17 | 
         
            -
            import struct
         
     | 
| 18 | 
         
             
            import re
         
     | 
| 19 | 
         
            -
            import  
     | 
| 20 | 
         
            -
             
     | 
| 
         | 
|
| 21 | 
         | 
| 22 | 
         
             
            # Add these imports for MP3 conversion
         
     | 
| 23 | 
         
             
            from pydub import AudioSegment
         
     | 
| 24 | 
         
            -
             
     | 
| 
         | 
|
| 25 | 
         | 
| 26 | 
         
             
            # --- Constants ---
         
     | 
| 27 | 
         
             
            GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
         
     | 
| 
         @@ -30,12 +32,16 @@ TTS_MODEL = "gemini-2.5-flash-preview-tts" 
     | 
|
| 30 | 
         
             
            DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            # --- Configuration ---
         
     | 
| 33 | 
         
            -
            logging.basicConfig( 
     | 
| 
         | 
|
| 
         | 
|
| 34 | 
         | 
| 35 | 
         
             
            genai.configure(api_key=GEMINI_API_KEY)
         
     | 
| 36 | 
         | 
| 
         | 
|
| 37 | 
         
             
            class TTSGenerationError(Exception):
         
     | 
| 38 | 
         
             
                """Custom exception for TTS generation failures."""
         
     | 
| 
         | 
|
| 39 | 
         
             
                pass
         
     | 
| 40 | 
         | 
| 41 | 
         | 
| 
         @@ -46,7 +52,7 @@ def parse_audio_mime_type(mime_type: str) -> dict[str, int | None]: 
     | 
|
| 46 | 
         
             
                e.g., "audio/L16;rate=24000" -> {"bits_per_sample": 16, "rate": 24000}
         
     | 
| 47 | 
         
             
                """
         
     | 
| 48 | 
         
             
                bits_per_sample = 16  # Default
         
     | 
| 49 | 
         
            -
                rate = 24000 
     | 
| 50 | 
         | 
| 51 | 
         
             
                parts = mime_type.split(";")
         
     | 
| 52 | 
         
             
                for param in parts:
         
     | 
| 
         @@ -56,15 +62,16 @@ def parse_audio_mime_type(mime_type: str) -> dict[str, int | None]: 
     | 
|
| 56 | 
         
             
                            rate_str = param.split("=", 1)[1]
         
     | 
| 57 | 
         
             
                            rate = int(rate_str)
         
     | 
| 58 | 
         
             
                        except (ValueError, IndexError):
         
     | 
| 59 | 
         
            -
                            pass 
     | 
| 60 | 
         
            -
                    elif re.match(r"audio/l\d+", param): 
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
                            bits_str = param.split("l",1)[1]
         
     | 
| 63 | 
         
             
                            bits_per_sample = int(bits_str)
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
                            pass 
     | 
| 66 | 
         
             
                return {"bits_per_sample": bits_per_sample, "rate": rate}
         
     | 
| 67 | 
         | 
| 
         | 
|
| 68 | 
         
             
            def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
         
     | 
| 69 | 
         
             
                """
         
     | 
| 70 | 
         
             
                Generates a WAV file header for the given raw audio data and parameters.
         
     | 
| 
         @@ -82,13 +89,26 @@ def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes: 
     | 
|
| 82 | 
         | 
| 83 | 
         
             
                header = struct.pack(
         
     | 
| 84 | 
         
             
                    "<4sI4s4sIHHIIHH4sI",
         
     | 
| 85 | 
         
            -
                    b"RIFF", 
     | 
| 86 | 
         
            -
                     
     | 
| 87 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 88 | 
         
             
                )
         
     | 
| 89 | 
         
             
                return header + audio_data
         
     | 
| 
         | 
|
| 
         | 
|
| 90 | 
         
             
            # --- End of helper functions ---
         
     | 
| 91 | 
         | 
| 
         | 
|
| 92 | 
         
             
            def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[bytes, str]:
         
     | 
| 93 | 
         
             
                """
         
     | 
| 94 | 
         
             
                Synthesizes English text using the Gemini API via the google-genai library.
         
     | 
| 
         @@ -109,11 +129,9 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte 
     | 
|
| 109 | 
         
             
                        "response_modalities": ["AUDIO"],
         
     | 
| 110 | 
         
             
                        "speech_config": {
         
     | 
| 111 | 
         
             
                            "voice_config": {
         
     | 
| 112 | 
         
            -
                                "prebuilt_voice_config": {
         
     | 
| 113 | 
         
            -
                                    "voice_name": gemini_voice_name
         
     | 
| 114 | 
         
            -
                                }
         
     | 
| 115 | 
         
             
                            }
         
     | 
| 116 | 
         
            -
                        }
         
     | 
| 117 | 
         
             
                    }
         
     | 
| 118 | 
         | 
| 119 | 
         
             
                    response = model.generate_content(
         
     | 
| 
         @@ -137,8 +155,11 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte 
     | 
|
| 137 | 
         
             
                # --- Audio processing ---
         
     | 
| 138 | 
         
             
                if final_mime_type:
         
     | 
| 139 | 
         
             
                    final_mime_type_lower = final_mime_type.lower()
         
     | 
| 140 | 
         
            -
                    needs_wav_conversion = any( 
     | 
| 141 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 142 | 
         | 
| 143 | 
         
             
                    if needs_wav_conversion:
         
     | 
| 144 | 
         
             
                        processed_audio_data = convert_to_wav(audio_data_bytes, final_mime_type)
         
     | 
| 
         @@ -147,7 +168,10 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte 
     | 
|
| 147 | 
         
             
                        processed_audio_data = audio_data_bytes
         
     | 
| 148 | 
         
             
                        processed_audio_mime = final_mime_type
         
     | 
| 149 | 
         
             
                else:
         
     | 
| 150 | 
         
            -
                    logging.warning( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 151 | 
         
             
                    processed_audio_data = convert_to_wav(audio_data_bytes, DEFAULT_RAW_AUDIO_MIME)
         
     | 
| 152 | 
         
             
                    processed_audio_mime = "audio/wav"
         
     | 
| 153 | 
         | 
| 
         @@ -155,7 +179,9 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte 
     | 
|
| 155 | 
         
             
                if processed_audio_data:
         
     | 
| 156 | 
         
             
                    try:
         
     | 
| 157 | 
         
             
                        # Load audio into AudioSegment
         
     | 
| 158 | 
         
            -
                        audio_segment = AudioSegment.from_file( 
     | 
| 
         | 
|
| 
         | 
|
| 159 | 
         
             
                        mp3_buffer = io.BytesIO()
         
     | 
| 160 | 
         
             
                        audio_segment.export(mp3_buffer, format="mp3")
         
     | 
| 161 | 
         
             
                        mp3_bytes = mp3_buffer.getvalue()
         
     | 
| 
         @@ -169,11 +195,15 @@ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[byte 
     | 
|
| 169 | 
         
             
                    logging.error(error_message)
         
     | 
| 170 | 
         
             
                    raise TTSGenerationError(error_message)
         
     | 
| 171 | 
         | 
| 
         | 
|
| 172 | 
         
             
            # Always create the memoized function first, so we can access its .key() method
         
     | 
| 173 | 
         
             
            _memoized_tts_func = cache.memoize()(_synthesize_gemini_tts_impl)
         
     | 
| 174 | 
         | 
| 175 | 
         
             
            if GENERATE_SPEECH:
         
     | 
| 176 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 177 | 
         
             
                    """
         
     | 
| 178 | 
         
             
                    A wrapper for the memoized TTS function that catches errors and returns (None, None).
         
     | 
| 179 | 
         
             
                    This makes the audio generation more resilient to individual failures.
         
     | 
| 
         @@ -183,7 +213,10 @@ if GENERATE_SPEECH: 
     | 
|
| 183 | 
         
             
                        return _memoized_tts_func(*args, **kwargs)
         
     | 
| 184 | 
         
             
                    except TTSGenerationError as e:
         
     | 
| 185 | 
         
             
                        # If generation fails, log the error and return None, None.
         
     | 
| 186 | 
         
            -
                        logging.error( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 187 | 
         
             
                        return None, None
         
     | 
| 188 | 
         | 
| 189 | 
         
             
                synthesize_gemini_tts = synthesize_gemini_tts_with_error_handling
         
     | 
| 
         @@ -206,7 +239,9 @@ else: 
     | 
|
| 206 | 
         
             
                        return result  # Cache hit
         
     | 
| 207 | 
         | 
| 208 | 
         
             
                    # Cache miss
         
     | 
| 209 | 
         
            -
                    logging.info( 
     | 
| 
         | 
|
| 
         | 
|
| 210 | 
         
             
                    return None, None
         
     | 
| 211 | 
         | 
| 212 | 
         
            -
                synthesize_gemini_tts = read_only_synthesize_gemini_tts
         
     | 
| 
         | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            +
            import io
         
     | 
| 16 | 
         
            +
            import logging
         
     | 
| 17 | 
         
             
            import os
         
     | 
| 
         | 
|
| 18 | 
         
             
            import re
         
     | 
| 19 | 
         
            +
            import struct
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import google.generativeai as genai
         
     | 
| 22 | 
         | 
| 23 | 
         
             
            # Add these imports for MP3 conversion
         
     | 
| 24 | 
         
             
            from pydub import AudioSegment
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            from cache import cache
         
     | 
| 27 | 
         | 
| 28 | 
         
             
            # --- Constants ---
         
     | 
| 29 | 
         
             
            GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
         
     | 
| 
         | 
|
| 32 | 
         
             
            DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
         
     | 
| 33 | 
         | 
| 34 | 
         
             
            # --- Configuration ---
         
     | 
| 35 | 
         
            +
            logging.basicConfig(
         
     | 
| 36 | 
         
            +
                level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         | 
| 39 | 
         
             
            genai.configure(api_key=GEMINI_API_KEY)
         
     | 
| 40 | 
         | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
             
            class TTSGenerationError(Exception):
         
     | 
| 43 | 
         
             
                """Custom exception for TTS generation failures."""
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
             
                pass
         
     | 
| 46 | 
         | 
| 47 | 
         | 
| 
         | 
|
| 52 | 
         
             
                e.g., "audio/L16;rate=24000" -> {"bits_per_sample": 16, "rate": 24000}
         
     | 
| 53 | 
         
             
                """
         
     | 
| 54 | 
         
             
                bits_per_sample = 16  # Default
         
     | 
| 55 | 
         
            +
                rate = 24000  # Default
         
     | 
| 56 | 
         | 
| 57 | 
         
             
                parts = mime_type.split(";")
         
     | 
| 58 | 
         
             
                for param in parts:
         
     | 
| 
         | 
|
| 62 | 
         
             
                            rate_str = param.split("=", 1)[1]
         
     | 
| 63 | 
         
             
                            rate = int(rate_str)
         
     | 
| 64 | 
         
             
                        except (ValueError, IndexError):
         
     | 
| 65 | 
         
            +
                            pass  # Keep default if parsing fails
         
     | 
| 66 | 
         
            +
                    elif re.match(r"audio/l\d+", param):  # Matches audio/L<digits>
         
     | 
| 67 | 
         
            +
                        try:
         
     | 
| 68 | 
         
            +
                            bits_str = param.split("l", 1)[1]
         
     | 
| 69 | 
         
             
                            bits_per_sample = int(bits_str)
         
     | 
| 70 | 
         
            +
                        except (ValueError, IndexError):
         
     | 
| 71 | 
         
            +
                            pass  # Keep default
         
     | 
| 72 | 
         
             
                return {"bits_per_sample": bits_per_sample, "rate": rate}
         
     | 
| 73 | 
         | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
             
            def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
         
     | 
| 76 | 
         
             
                """
         
     | 
| 77 | 
         
             
                Generates a WAV file header for the given raw audio data and parameters.
         
     | 
| 
         | 
|
| 89 | 
         | 
| 90 | 
         
             
                header = struct.pack(
         
     | 
| 91 | 
         
             
                    "<4sI4s4sIHHIIHH4sI",
         
     | 
| 92 | 
         
            +
                    b"RIFF",
         
     | 
| 93 | 
         
            +
                    chunk_size,
         
     | 
| 94 | 
         
            +
                    b"WAVE",
         
     | 
| 95 | 
         
            +
                    b"fmt ",
         
     | 
| 96 | 
         
            +
                    16,
         
     | 
| 97 | 
         
            +
                    1,
         
     | 
| 98 | 
         
            +
                    num_channels,
         
     | 
| 99 | 
         
            +
                    sample_rate,
         
     | 
| 100 | 
         
            +
                    byte_rate,
         
     | 
| 101 | 
         
            +
                    block_align,
         
     | 
| 102 | 
         
            +
                    bits_per_sample,
         
     | 
| 103 | 
         
            +
                    b"data",
         
     | 
| 104 | 
         
            +
                    data_size,
         
     | 
| 105 | 
         
             
                )
         
     | 
| 106 | 
         
             
                return header + audio_data
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
             
            # --- End of helper functions ---
         
     | 
| 110 | 
         | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
             
            def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[bytes, str]:
         
     | 
| 113 | 
         
             
                """
         
     | 
| 114 | 
         
             
                Synthesizes English text using the Gemini API via the google-genai library.
         
     | 
| 
         | 
|
| 129 | 
         
             
                        "response_modalities": ["AUDIO"],
         
     | 
| 130 | 
         
             
                        "speech_config": {
         
     | 
| 131 | 
         
             
                            "voice_config": {
         
     | 
| 132 | 
         
            +
                                "prebuilt_voice_config": {"voice_name": gemini_voice_name}
         
     | 
| 
         | 
|
| 
         | 
|
| 133 | 
         
             
                            }
         
     | 
| 134 | 
         
            +
                        },
         
     | 
| 135 | 
         
             
                    }
         
     | 
| 136 | 
         | 
| 137 | 
         
             
                    response = model.generate_content(
         
     | 
| 
         | 
|
| 155 | 
         
             
                # --- Audio processing ---
         
     | 
| 156 | 
         
             
                if final_mime_type:
         
     | 
| 157 | 
         
             
                    final_mime_type_lower = final_mime_type.lower()
         
     | 
| 158 | 
         
            +
                    needs_wav_conversion = any(
         
     | 
| 159 | 
         
            +
                        p in final_mime_type_lower for p in ("audio/l16", "audio/l24", "audio/l8")
         
     | 
| 160 | 
         
            +
                    ) or not final_mime_type_lower.startswith(
         
     | 
| 161 | 
         
            +
                        ("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus")
         
     | 
| 162 | 
         
            +
                    )
         
     | 
| 163 | 
         | 
| 164 | 
         
             
                    if needs_wav_conversion:
         
     | 
| 165 | 
         
             
                        processed_audio_data = convert_to_wav(audio_data_bytes, final_mime_type)
         
     | 
| 
         | 
|
| 168 | 
         
             
                        processed_audio_data = audio_data_bytes
         
     | 
| 169 | 
         
             
                        processed_audio_mime = final_mime_type
         
     | 
| 170 | 
         
             
                else:
         
     | 
| 171 | 
         
            +
                    logging.warning(
         
     | 
| 172 | 
         
            +
                        "MIME type not determined. Assuming raw audio and attempting WAV conversion (defaulting to %s).",
         
     | 
| 173 | 
         
            +
                        DEFAULT_RAW_AUDIO_MIME,
         
     | 
| 174 | 
         
            +
                    )
         
     | 
| 175 | 
         
             
                    processed_audio_data = convert_to_wav(audio_data_bytes, DEFAULT_RAW_AUDIO_MIME)
         
     | 
| 176 | 
         
             
                    processed_audio_mime = "audio/wav"
         
     | 
| 177 | 
         | 
| 
         | 
|
| 179 | 
         
             
                if processed_audio_data:
         
     | 
| 180 | 
         
             
                    try:
         
     | 
| 181 | 
         
             
                        # Load audio into AudioSegment
         
     | 
| 182 | 
         
            +
                        audio_segment = AudioSegment.from_file(
         
     | 
| 183 | 
         
            +
                            io.BytesIO(processed_audio_data), format="wav"
         
     | 
| 184 | 
         
            +
                        )
         
     | 
| 185 | 
         
             
                        mp3_buffer = io.BytesIO()
         
     | 
| 186 | 
         
             
                        audio_segment.export(mp3_buffer, format="mp3")
         
     | 
| 187 | 
         
             
                        mp3_bytes = mp3_buffer.getvalue()
         
     | 
| 
         | 
|
| 195 | 
         
             
                    logging.error(error_message)
         
     | 
| 196 | 
         
             
                    raise TTSGenerationError(error_message)
         
     | 
| 197 | 
         | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
             
            # Always create the memoized function first, so we can access its .key() method
         
     | 
| 200 | 
         
             
            _memoized_tts_func = cache.memoize()(_synthesize_gemini_tts_impl)
         
     | 
| 201 | 
         | 
| 202 | 
         
             
            if GENERATE_SPEECH:
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                def synthesize_gemini_tts_with_error_handling(
         
     | 
| 205 | 
         
            +
                    *args, **kwargs
         
     | 
| 206 | 
         
            +
                ) -> tuple[bytes | None, str | None]:
         
     | 
| 207 | 
         
             
                    """
         
     | 
| 208 | 
         
             
                    A wrapper for the memoized TTS function that catches errors and returns (None, None).
         
     | 
| 209 | 
         
             
                    This makes the audio generation more resilient to individual failures.
         
     | 
| 
         | 
|
| 213 | 
         
             
                        return _memoized_tts_func(*args, **kwargs)
         
     | 
| 214 | 
         
             
                    except TTSGenerationError as e:
         
     | 
| 215 | 
         
             
                        # If generation fails, log the error and return None, None.
         
     | 
| 216 | 
         
            +
                        logging.error(
         
     | 
| 217 | 
         
            +
                            "Handled TTS Generation Error: %s. Continuing without audio for this segment.",
         
     | 
| 218 | 
         
            +
                            e,
         
     | 
| 219 | 
         
            +
                        )
         
     | 
| 220 | 
         
             
                        return None, None
         
     | 
| 221 | 
         | 
| 222 | 
         
             
                synthesize_gemini_tts = synthesize_gemini_tts_with_error_handling
         
     | 
| 
         | 
|
| 239 | 
         
             
                        return result  # Cache hit
         
     | 
| 240 | 
         | 
| 241 | 
         
             
                    # Cache miss
         
     | 
| 242 | 
         
            +
                    logging.info(
         
     | 
| 243 | 
         
            +
                        "GENERATE_SPEECH is false and no cached result found for key: %s", key
         
     | 
| 244 | 
         
            +
                    )
         
     | 
| 245 | 
         
             
                    return None, None
         
     | 
| 246 | 
         | 
| 247 | 
         
            +
                synthesize_gemini_tts = read_only_synthesize_gemini_tts
         
     | 
    	
        interview_simulator.py
    CHANGED
    
    | 
         @@ -12,70 +12,89 @@ 
     | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 
         | 
|
| 15 | 
         
             
            import json
         
     | 
| 16 | 
         
            -
            import re
         
     | 
| 17 | 
         
             
            import os
         
     | 
| 18 | 
         
            -
            import  
     | 
| 19 | 
         | 
| 20 | 
         
             
            from gemini import gemini_get_text_response
         
     | 
| 21 | 
         
            -
            from medgemma import medgemma_get_text_response
         
     | 
| 22 | 
         
             
            from gemini_tts import synthesize_gemini_tts
         
     | 
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
             
            INTERVIEWER_VOICE = "Aoede"
         
     | 
| 25 | 
         | 
| 
         | 
|
| 26 | 
         
             
            def read_symptoms_json():
         
     | 
| 27 | 
         
             
                # Load the list of symptoms for each condition from a JSON file
         
     | 
| 28 | 
         
            -
                with open("symptoms.json",  
     | 
| 29 | 
         
             
                    return json.load(f)
         
     | 
| 30 | 
         | 
| 
         | 
|
| 31 | 
         
             
            def read_patient_and_conditions_json():
         
     | 
| 32 | 
         
             
                # Load all patient and condition data from the frontend assets
         
     | 
| 33 | 
         
            -
                with open( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 34 | 
         
             
                    return json.load(f)
         
     | 
| 35 | 
         | 
| 
         | 
|
| 36 | 
         
             
            def get_patient(patient_name):
         
     | 
| 37 | 
         
             
                """Helper function to locate a patient record by name. Raises StopIteration if not found."""
         
     | 
| 38 | 
         
             
                return next(p for p in PATIENTS if p["name"] == patient_name)
         
     | 
| 39 | 
         | 
| 
         | 
|
| 40 | 
         
             
            def read_fhir_json(patient):
         
     | 
| 41 | 
         
             
                # Load the FHIR (EHR) JSON file for a given patient
         
     | 
| 42 | 
         
            -
                with open( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 43 | 
         
             
                    return json.load(f)
         
     | 
| 44 | 
         | 
| 
         | 
|
| 45 | 
         
             
            def get_ehr_summary_per_patient(patient_name):
         
     | 
| 46 | 
         
             
                # Returns a concise EHR summary for the patient, using LLM if not already cached
         
     | 
| 47 | 
         
             
                patient = get_patient(patient_name)
         
     | 
| 48 | 
         
             
                if patient.get("ehr_summary"):
         
     | 
| 49 | 
         
             
                    return patient["ehr_summary"]
         
     | 
| 50 | 
         
             
                # Use MedGemma to summarize the EHR for the patient
         
     | 
| 51 | 
         
            -
                ehr_summary = medgemma_get_text_response( 
     | 
| 52 | 
         
            -
                     
     | 
| 53 | 
         
            -
                         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
                             
     | 
| 56 | 
         
            -
                                 
     | 
| 57 | 
         
            -
             
     | 
| 
         | 
|
| 58 | 
         
             
                                Provide a concise summary of the patient's medical history, including any existing conditions, medications, and relevant past treatments.
         
     | 
| 59 | 
         
            -
                                Do not include personal opinions or assumptions, only factual information."""
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
             
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
             
     | 
| 67 | 
         
            -
             
     | 
| 68 | 
         
            -
             
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
                    }
         
     | 
| 72 | 
         
            -
                ])
         
     | 
| 73 | 
         
             
                patient["ehr_summary"] = ehr_summary
         
     | 
| 74 | 
         
             
                return ehr_summary
         
     | 
| 75 | 
         | 
| 
         | 
|
| 76 | 
         
             
            PATIENTS = read_patient_and_conditions_json()["patients"]
         
     | 
| 77 | 
         
             
            SYMPTOMS = read_symptoms_json()
         
     | 
| 78 | 
         
            -
             
     | 
| 
         | 
|
| 79 | 
         
             
            def patient_roleplay_instructions(patient_name, condition_name, previous_answers):
         
     | 
| 80 | 
         
             
                """
         
     | 
| 81 | 
         
             
                Generates structured instructions for the LLM to roleplay as a patient, including persona, scenario, and symptom logic.
         
     | 
| 
         @@ -120,6 +139,7 @@ def patient_roleplay_instructions(patient_name, condition_name, previous_answers 
     | 
|
| 120 | 
         
             
                    ---
         
     | 
| 121 | 
         
             
                """
         
     | 
| 122 | 
         | 
| 
         | 
|
| 123 | 
         
             
            def interviewer_roleplay_instructions(patient_name):
         
     | 
| 124 | 
         
             
                # Returns detailed instructions for the LLM to roleplay as the interviewer/clinical assistant
         
     | 
| 125 | 
         
             
                return f"""
         
     | 
| 
         @@ -153,6 +173,7 @@ def interviewer_roleplay_instructions(patient_name): 
     | 
|
| 153 | 
         
             
                    3.  **End Interview:** You MUST continue the interview until you have asked 20 questions OR the patient is unable to provide more information. When the interview is complete, you MUST conclude by printing this exact phrase: "Thank you for answering my questions. I have everything needed to prepare a report for your visit. End interview."
         
     | 
| 154 | 
         
             
                """
         
     | 
| 155 | 
         | 
| 
         | 
|
| 156 | 
         
             
            def report_writer_instructions(patient_name: str) -> str:
         
     | 
| 157 | 
         
             
                """
         
     | 
| 158 | 
         
             
                Generates the system prompt with clear instructions, role, and constraints for the LLM.
         
     | 
| 
         @@ -202,7 +223,10 @@ The final output MUST be ONLY the full, updated Markdown medical report. 
     | 
|
| 202 | 
         
             
            DO NOT include any introductory phrases, explanations, or any text other than the report itself.
         
     | 
| 203 | 
         
             
            </output_format>"""
         
     | 
| 204 | 
         | 
| 205 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 206 | 
         
             
                """
         
     | 
| 207 | 
         
             
                Constructs the full prompt, sends it to the LLM, and processes the response.
         
     | 
| 208 | 
         
             
                This function handles both the initial creation and subsequent updates of a report.
         
     | 
| 
         @@ -212,7 +236,7 @@ def write_report(patient_name: str, interview_text: str, existing_report: str = 
     | 
|
| 212 | 
         | 
| 213 | 
         
             
                # If no existing report is provided, load a default template from a string.
         
     | 
| 214 | 
         
             
                if not existing_report:
         
     | 
| 215 | 
         
            -
                    with open("report_template.txt",  
     | 
| 216 | 
         
             
                        existing_report = f.read()
         
     | 
| 217 | 
         | 
| 218 | 
         
             
                # Construct the user prompt with the specific task and data
         
     | 
| 
         @@ -237,149 +261,149 @@ Now, generate the complete and updated medical report based on all system and us 
     | 
|
| 237 | 
         | 
| 238 | 
         
             
                # Assemble the full message payload for the LLM API
         
     | 
| 239 | 
         
             
                messages = [
         
     | 
| 240 | 
         
            -
                    {
         
     | 
| 241 | 
         
            -
             
     | 
| 242 | 
         
            -
                        "content": [{"type": "text", "text": instructions}]
         
     | 
| 243 | 
         
            -
                    },
         
     | 
| 244 | 
         
            -
                    {
         
     | 
| 245 | 
         
            -
                        "role": "user",
         
     | 
| 246 | 
         
            -
                        "content": [{"type": "text", "text": user_prompt}]
         
     | 
| 247 | 
         
            -
                    }
         
     | 
| 248 | 
         
             
                ]
         
     | 
| 249 | 
         | 
| 250 | 
         
             
                report = medgemma_get_text_response(messages)
         
     | 
| 251 | 
         
            -
                cleaned_report = re.sub(r 
     | 
| 252 | 
         
             
                cleaned_report = cleaned_report.strip()
         
     | 
| 253 | 
         | 
| 254 | 
         
             
                # The LLM sometimes wraps the markdown report in a markdown code block.
         
     | 
| 255 | 
         
             
                # This regex checks if the entire string is a code block and extracts the content.
         
     | 
| 256 | 
         
            -
                match = re.match( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 257 | 
         
             
                if match:
         
     | 
| 258 | 
         
             
                    cleaned_report = match.group(1)
         
     | 
| 259 | 
         | 
| 260 | 
         
             
                return cleaned_report.strip()
         
     | 
| 261 | 
         | 
| 262 | 
         | 
| 263 | 
         
            -
             
     | 
| 264 | 
         
             
            def stream_interview(patient_name, condition_name):
         
     | 
| 265 | 
         
            -
                print( 
     | 
| 
         | 
|
| 
         | 
|
| 266 | 
         
             
                # Prepare roleplay instructions and initial dialog (using existing helper functions)
         
     | 
| 267 | 
         
             
                interviewer_instructions = interviewer_roleplay_instructions(patient_name)
         
     | 
| 268 | 
         
            -
             
     | 
| 269 | 
         
             
                # Determine voices for TTS
         
     | 
| 270 | 
         
             
                patient = get_patient(patient_name)
         
     | 
| 271 | 
         
             
                patient_voice = patient["voice"]
         
     | 
| 272 | 
         
            -
             
     | 
| 273 | 
         
             
                dialog = [
         
     | 
| 274 | 
         
             
                    {
         
     | 
| 275 | 
         
             
                        "role": "system",
         
     | 
| 276 | 
         
            -
                        "content": [
         
     | 
| 277 | 
         
            -
                            {
         
     | 
| 278 | 
         
            -
                                "type": "text",
         
     | 
| 279 | 
         
            -
                                "text": interviewer_instructions
         
     | 
| 280 | 
         
            -
                            }
         
     | 
| 281 | 
         
            -
                        ]
         
     | 
| 282 | 
         
             
                    },
         
     | 
| 283 | 
         
            -
                    {
         
     | 
| 284 | 
         
            -
                        "role": "user",
         
     | 
| 285 | 
         
            -
                        "content": [
         
     | 
| 286 | 
         
            -
                            {
         
     | 
| 287 | 
         
            -
                                "type": "text",
         
     | 
| 288 | 
         
            -
                                "text": "start interview"
         
     | 
| 289 | 
         
            -
                            }
         
     | 
| 290 | 
         
            -
                        ]
         
     | 
| 291 | 
         
            -
                    }
         
     | 
| 292 | 
         
             
                ]
         
     | 
| 293 | 
         
            -
             
     | 
| 294 | 
         
             
                write_report_text = ""
         
     | 
| 295 | 
         
             
                full_interview_q_a = ""
         
     | 
| 296 | 
         
             
                number_of_questions_limit = 30
         
     | 
| 297 | 
         
             
                for i in range(number_of_questions_limit):
         
     | 
| 298 | 
         
             
                    # Get the next interviewer question from MedGemma
         
     | 
| 299 | 
         
             
                    interviewer_question_text = medgemma_get_text_response(
         
     | 
| 300 | 
         
            -
                        messages=dialog,
         
     | 
| 301 | 
         
            -
                        temperature=0.1,
         
     | 
| 302 | 
         
            -
                        max_tokens=2048,
         
     | 
| 303 | 
         
            -
                        stream=False
         
     | 
| 304 | 
         
             
                    )
         
     | 
| 305 | 
         
             
                    # Process optional "thinking" text (if present in the LLM output)
         
     | 
| 306 | 
         
            -
                    thinking_search = re.search( 
     | 
| 
         | 
|
| 
         | 
|
| 307 | 
         
             
                    if thinking_search:
         
     | 
| 308 | 
         
             
                        thinking_text = thinking_search.group(1)
         
     | 
| 309 | 
         
            -
                        interviewer_question_text = interviewer_question_text.replace( 
     | 
| 
         | 
|
| 
         | 
|
| 310 | 
         
             
                        if i == 0:
         
     | 
| 311 | 
         
             
                            # Only yield the "thinking" summary for the first question
         
     | 
| 312 | 
         
             
                            thinking_text = gemini_get_text_response(
         
     | 
| 313 | 
         
             
                                f"""Provide a summary of up to 100 words containing only the reasoning and planning from this text,
         
     | 
| 314 | 
         
            -
                                do not include instructions, use first person: {thinking_text}""" 
     | 
| 315 | 
         
            -
                             
     | 
| 316 | 
         
            -
             
     | 
| 317 | 
         
            -
                                "text": thinking_text
         
     | 
| 318 | 
         
            -
                             
     | 
| 319 | 
         | 
| 320 | 
         
             
                    # Clean up the text for TTS and display
         
     | 
| 321 | 
         
            -
                    clean_interviewer_text = interviewer_question_text.replace( 
     | 
| 
         | 
|
| 
         | 
|
| 322 | 
         | 
| 323 | 
         
             
                    # Generate audio for the interviewer's question using Gemini TTS
         
     | 
| 324 | 
         
            -
                    audio_data, mime_type = synthesize_gemini_tts( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 325 | 
         
             
                    audio_b64 = None
         
     | 
| 326 | 
         
             
                    if audio_data and mime_type:
         
     | 
| 327 | 
         
             
                        audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
         
     | 
| 328 | 
         | 
| 329 | 
         
             
                    # Yield interviewer message (text and audio)
         
     | 
| 330 | 
         
            -
                    yield json.dumps( 
     | 
| 331 | 
         
            -
                         
     | 
| 332 | 
         
            -
             
     | 
| 333 | 
         
            -
             
     | 
| 334 | 
         
            -
             
     | 
| 335 | 
         
            -
             
     | 
| 336 | 
         
            -
             
     | 
| 337 | 
         
            -
             
     | 
| 338 | 
         
            -
             
     | 
| 339 | 
         
            -
                            " 
     | 
| 340 | 
         
            -
             
     | 
| 341 | 
         
            -
             
     | 
| 
         | 
|
| 342 | 
         
             
                    if "End interview" in interviewer_question_text:
         
     | 
| 343 | 
         
             
                        # End the interview loop if the LLM signals completion
         
     | 
| 344 | 
         
             
                        break
         
     | 
| 345 | 
         | 
| 346 | 
         
             
                    # Get the patient's response from Gemini (roleplay LLM)
         
     | 
| 347 | 
         
            -
                    patient_response_text = gemini_get_text_response( 
     | 
| 
         | 
|
| 348 | 
         
             
                    {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}\n\n
         
     | 
| 349 | 
         
            -
                    Question: {interviewer_question_text}""" 
     | 
| 
         | 
|
| 350 | 
         | 
| 351 | 
         
             
                    # Generate audio for the patient's response
         
     | 
| 352 | 
         
            -
                    audio_data, mime_type = synthesize_gemini_tts( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 353 | 
         
             
                    audio_b64 = None
         
     | 
| 354 | 
         
             
                    if audio_data and mime_type:
         
     | 
| 355 | 
         
             
                        audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
         
     | 
| 356 | 
         | 
| 357 | 
         
             
                    # Yield patient message (text and audio)
         
     | 
| 358 | 
         
            -
                    yield json.dumps( 
     | 
| 359 | 
         
            -
                        "speaker": "patient",
         
     | 
| 360 | 
         
            -
             
     | 
| 361 | 
         
            -
             
     | 
| 362 | 
         
            -
             
     | 
| 363 | 
         
            -
             
     | 
| 364 | 
         
            -
             
     | 
| 365 | 
         
            -
                         
     | 
| 366 | 
         
            -
             
     | 
| 367 | 
         
            -
                            "text": patient_response_text
         
     | 
| 368 | 
         
            -
                        }]
         
     | 
| 369 | 
         
            -
                    })
         
     | 
| 370 | 
         
             
                    # Track the full Q&A for context in future LLM calls
         
     | 
| 371 | 
         
            -
                    most_recent_q_a =  
     | 
| 372 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 373 | 
         
             
                    # Update the report after each Q&A
         
     | 
| 374 | 
         
            -
                    write_report_text = write_report( 
     | 
| 
         | 
|
| 
         | 
|
| 375 | 
         
             
                    full_interview_q_a += most_recent_q_a
         
     | 
| 376 | 
         
            -
                    yield json.dumps({
         
     | 
| 377 | 
         
            -
                        "speaker": "report",
         
     | 
| 378 | 
         
            -
                        "text": write_report_text
         
     | 
| 379 | 
         
            -
                    })
         
     | 
| 380 | 
         | 
| 381 | 
         
            -
                print( 
     | 
| 
         | 
|
| 382 | 
         
             
                      Patient profile used:
         
     | 
| 383 | 
         
            -
                      {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}""" 
     | 
| 
         | 
|
| 384 | 
         
             
                # Add this at the end to signal end of stream
         
     | 
| 385 | 
         
            -
                yield json.dumps({"event": "end"})
         
     | 
| 
         | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            +
            import base64
         
     | 
| 16 | 
         
             
            import json
         
     | 
| 
         | 
|
| 17 | 
         
             
            import os
         
     | 
| 18 | 
         
            +
            import re
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            from gemini import gemini_get_text_response
         
     | 
| 
         | 
|
| 21 | 
         
             
            from gemini_tts import synthesize_gemini_tts
         
     | 
| 22 | 
         
            +
            from medgemma import medgemma_get_text_response
         
     | 
| 23 | 
         | 
| 24 | 
         
             
            INTERVIEWER_VOICE = "Aoede"
         
     | 
| 25 | 
         | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
             
            def read_symptoms_json():
         
     | 
| 28 | 
         
             
                # Load the list of symptoms for each condition from a JSON file
         
     | 
| 29 | 
         
            +
                with open("symptoms.json", "r") as f:
         
     | 
| 30 | 
         
             
                    return json.load(f)
         
     | 
| 31 | 
         | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
             
            def read_patient_and_conditions_json():
         
     | 
| 34 | 
         
             
                # Load all patient and condition data from the frontend assets
         
     | 
| 35 | 
         
            +
                with open(
         
     | 
| 36 | 
         
            +
                    os.path.join(
         
     | 
| 37 | 
         
            +
                        os.environ.get("FRONTEND_BUILD", "frontend/build"),
         
     | 
| 38 | 
         
            +
                        "assets",
         
     | 
| 39 | 
         
            +
                        "patients_and_conditions.json",
         
     | 
| 40 | 
         
            +
                    ),
         
     | 
| 41 | 
         
            +
                    "r",
         
     | 
| 42 | 
         
            +
                ) as f:
         
     | 
| 43 | 
         
             
                    return json.load(f)
         
     | 
| 44 | 
         | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
             
            def get_patient(patient_name):
         
     | 
| 47 | 
         
             
                """Helper function to locate a patient record by name. Raises StopIteration if not found."""
         
     | 
| 48 | 
         
             
                return next(p for p in PATIENTS if p["name"] == patient_name)
         
     | 
| 49 | 
         | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
             
            def read_fhir_json(patient):
         
     | 
| 52 | 
         
             
                # Load the FHIR (EHR) JSON file for a given patient
         
     | 
| 53 | 
         
            +
                with open(
         
     | 
| 54 | 
         
            +
                    os.path.join(
         
     | 
| 55 | 
         
            +
                        os.environ.get("FRONTEND_BUILD", "frontend/build"),
         
     | 
| 56 | 
         
            +
                        patient["fhirFile"].lstrip("/"),
         
     | 
| 57 | 
         
            +
                    ),
         
     | 
| 58 | 
         
            +
                    "r",
         
     | 
| 59 | 
         
            +
                ) as f:
         
     | 
| 60 | 
         
             
                    return json.load(f)
         
     | 
| 61 | 
         | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
             
            def get_ehr_summary_per_patient(patient_name):
         
     | 
| 64 | 
         
             
                # Returns a concise EHR summary for the patient, using LLM if not already cached
         
     | 
| 65 | 
         
             
                patient = get_patient(patient_name)
         
     | 
| 66 | 
         
             
                if patient.get("ehr_summary"):
         
     | 
| 67 | 
         
             
                    return patient["ehr_summary"]
         
     | 
| 68 | 
         
             
                # Use MedGemma to summarize the EHR for the patient
         
     | 
| 69 | 
         
            +
                ehr_summary = medgemma_get_text_response(
         
     | 
| 70 | 
         
            +
                    [
         
     | 
| 71 | 
         
            +
                        {
         
     | 
| 72 | 
         
            +
                            "role": "system",
         
     | 
| 73 | 
         
            +
                            "content": [
         
     | 
| 74 | 
         
            +
                                {
         
     | 
| 75 | 
         
            +
                                    "type": "text",
         
     | 
| 76 | 
         
            +
                                    "text": f"""You are a medical assistant summarizing the EHR (FHIR) records for the patient {patient_name}.
         
     | 
| 77 | 
         
             
                                Provide a concise summary of the patient's medical history, including any existing conditions, medications, and relevant past treatments.
         
     | 
| 78 | 
         
            +
                                Do not include personal opinions or assumptions, only factual information.""",
         
     | 
| 79 | 
         
            +
                                }
         
     | 
| 80 | 
         
            +
                            ],
         
     | 
| 81 | 
         
            +
                        },
         
     | 
| 82 | 
         
            +
                        {
         
     | 
| 83 | 
         
            +
                            "role": "user",
         
     | 
| 84 | 
         
            +
                            "content": [
         
     | 
| 85 | 
         
            +
                                {"type": "text", "text": json.dumps(read_fhir_json(patient))}
         
     | 
| 86 | 
         
            +
                            ],
         
     | 
| 87 | 
         
            +
                        },
         
     | 
| 88 | 
         
            +
                    ]
         
     | 
| 89 | 
         
            +
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 90 | 
         
             
                patient["ehr_summary"] = ehr_summary
         
     | 
| 91 | 
         
             
                return ehr_summary
         
     | 
| 92 | 
         | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
             
            PATIENTS = read_patient_and_conditions_json()["patients"]
         
     | 
| 95 | 
         
             
            SYMPTOMS = read_symptoms_json()
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
             
            def patient_roleplay_instructions(patient_name, condition_name, previous_answers):
         
     | 
| 99 | 
         
             
                """
         
     | 
| 100 | 
         
             
                Generates structured instructions for the LLM to roleplay as a patient, including persona, scenario, and symptom logic.
         
     | 
| 
         | 
|
| 139 | 
         
             
                    ---
         
     | 
| 140 | 
         
             
                """
         
     | 
| 141 | 
         | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
             
            def interviewer_roleplay_instructions(patient_name):
         
     | 
| 144 | 
         
             
                # Returns detailed instructions for the LLM to roleplay as the interviewer/clinical assistant
         
     | 
| 145 | 
         
             
                return f"""
         
     | 
| 
         | 
|
| 173 | 
         
             
                    3.  **End Interview:** You MUST continue the interview until you have asked 20 questions OR the patient is unable to provide more information. When the interview is complete, you MUST conclude by printing this exact phrase: "Thank you for answering my questions. I have everything needed to prepare a report for your visit. End interview."
         
     | 
| 174 | 
         
             
                """
         
     | 
| 175 | 
         | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
             
            def report_writer_instructions(patient_name: str) -> str:
         
     | 
| 178 | 
         
             
                """
         
     | 
| 179 | 
         
             
                Generates the system prompt with clear instructions, role, and constraints for the LLM.
         
     | 
| 
         | 
|
| 223 | 
         
             
            DO NOT include any introductory phrases, explanations, or any text other than the report itself.
         
     | 
| 224 | 
         
             
            </output_format>"""
         
     | 
| 225 | 
         | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            def write_report(
         
     | 
| 228 | 
         
            +
                patient_name: str, interview_text: str, existing_report: str = None
         
     | 
| 229 | 
         
            +
            ) -> str:
         
     | 
| 230 | 
         
             
                """
         
     | 
| 231 | 
         
             
                Constructs the full prompt, sends it to the LLM, and processes the response.
         
     | 
| 232 | 
         
             
                This function handles both the initial creation and subsequent updates of a report.
         
     | 
| 
         | 
|
| 236 | 
         | 
| 237 | 
         
             
                # If no existing report is provided, load a default template from a string.
         
     | 
| 238 | 
         
             
                if not existing_report:
         
     | 
| 239 | 
         
            +
                    with open("report_template.txt", "r") as f:
         
     | 
| 240 | 
         
             
                        existing_report = f.read()
         
     | 
| 241 | 
         | 
| 242 | 
         
             
                # Construct the user prompt with the specific task and data
         
     | 
| 
         | 
|
| 261 | 
         | 
| 262 | 
         
             
                # Assemble the full message payload for the LLM API
         
     | 
| 263 | 
         
             
                messages = [
         
     | 
| 264 | 
         
            +
                    {"role": "system", "content": [{"type": "text", "text": instructions}]},
         
     | 
| 265 | 
         
            +
                    {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 266 | 
         
             
                ]
         
     | 
| 267 | 
         | 
| 268 | 
         
             
                report = medgemma_get_text_response(messages)
         
     | 
| 269 | 
         
            +
                cleaned_report = re.sub(r"<unused94>.*?</unused95>", "", report, flags=re.DOTALL)
         
     | 
| 270 | 
         
             
                cleaned_report = cleaned_report.strip()
         
     | 
| 271 | 
         | 
| 272 | 
         
             
                # The LLM sometimes wraps the markdown report in a markdown code block.
         
     | 
| 273 | 
         
             
                # This regex checks if the entire string is a code block and extracts the content.
         
     | 
| 274 | 
         
            +
                match = re.match(
         
     | 
| 275 | 
         
            +
                    r"^\s*```(?:markdown)?\s*(.*?)\s*```\s*$",
         
     | 
| 276 | 
         
            +
                    cleaned_report,
         
     | 
| 277 | 
         
            +
                    re.DOTALL | re.IGNORECASE,
         
     | 
| 278 | 
         
            +
                )
         
     | 
| 279 | 
         
             
                if match:
         
     | 
| 280 | 
         
             
                    cleaned_report = match.group(1)
         
     | 
| 281 | 
         | 
| 282 | 
         
             
                return cleaned_report.strip()
         
     | 
| 283 | 
         | 
| 284 | 
         | 
| 
         | 
|
| 285 | 
         
             
            def stream_interview(patient_name, condition_name):
         
     | 
| 286 | 
         
            +
                print(
         
     | 
| 287 | 
         
            +
                    f"Starting interview simulation for patient: {patient_name}, condition: {condition_name}"
         
     | 
| 288 | 
         
            +
                )
         
     | 
| 289 | 
         
             
                # Prepare roleplay instructions and initial dialog (using existing helper functions)
         
     | 
| 290 | 
         
             
                interviewer_instructions = interviewer_roleplay_instructions(patient_name)
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
             
                # Determine voices for TTS
         
     | 
| 293 | 
         
             
                patient = get_patient(patient_name)
         
     | 
| 294 | 
         
             
                patient_voice = patient["voice"]
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
             
                dialog = [
         
     | 
| 297 | 
         
             
                    {
         
     | 
| 298 | 
         
             
                        "role": "system",
         
     | 
| 299 | 
         
            +
                        "content": [{"type": "text", "text": interviewer_instructions}],
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 300 | 
         
             
                    },
         
     | 
| 301 | 
         
            +
                    {"role": "user", "content": [{"type": "text", "text": "start interview"}]},
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 302 | 
         
             
                ]
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
             
                write_report_text = ""
         
     | 
| 305 | 
         
             
                full_interview_q_a = ""
         
     | 
| 306 | 
         
             
                number_of_questions_limit = 30
         
     | 
| 307 | 
         
             
                for i in range(number_of_questions_limit):
         
     | 
| 308 | 
         
             
                    # Get the next interviewer question from MedGemma
         
     | 
| 309 | 
         
             
                    interviewer_question_text = medgemma_get_text_response(
         
     | 
| 310 | 
         
            +
                        messages=dialog, temperature=0.1, max_tokens=2048, stream=False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 311 | 
         
             
                    )
         
     | 
| 312 | 
         
             
                    # Process optional "thinking" text (if present in the LLM output)
         
     | 
| 313 | 
         
            +
                    thinking_search = re.search(
         
     | 
| 314 | 
         
            +
                        "<unused94>(.+?)<unused95>", interviewer_question_text, re.DOTALL
         
     | 
| 315 | 
         
            +
                    )
         
     | 
| 316 | 
         
             
                    if thinking_search:
         
     | 
| 317 | 
         
             
                        thinking_text = thinking_search.group(1)
         
     | 
| 318 | 
         
            +
                        interviewer_question_text = interviewer_question_text.replace(
         
     | 
| 319 | 
         
            +
                            f"<unused94>{thinking_text}<unused95>", ""
         
     | 
| 320 | 
         
            +
                        )
         
     | 
| 321 | 
         
             
                        if i == 0:
         
     | 
| 322 | 
         
             
                            # Only yield the "thinking" summary for the first question
         
     | 
| 323 | 
         
             
                            thinking_text = gemini_get_text_response(
         
     | 
| 324 | 
         
             
                                f"""Provide a summary of up to 100 words containing only the reasoning and planning from this text,
         
     | 
| 325 | 
         
            +
                                do not include instructions, use first person: {thinking_text}"""
         
     | 
| 326 | 
         
            +
                            )
         
     | 
| 327 | 
         
            +
                            yield json.dumps(
         
     | 
| 328 | 
         
            +
                                {"speaker": "interviewer thinking", "text": thinking_text}
         
     | 
| 329 | 
         
            +
                            )
         
     | 
| 330 | 
         | 
| 331 | 
         
             
                    # Clean up the text for TTS and display
         
     | 
| 332 | 
         
            +
                    clean_interviewer_text = interviewer_question_text.replace(
         
     | 
| 333 | 
         
            +
                        "End interview.", ""
         
     | 
| 334 | 
         
            +
                    ).strip()
         
     | 
| 335 | 
         | 
| 336 | 
         
             
                    # Generate audio for the interviewer's question using Gemini TTS
         
     | 
| 337 | 
         
            +
                    audio_data, mime_type = synthesize_gemini_tts(
         
     | 
| 338 | 
         
            +
                        f"Speak in a slightly upbeat and brisk manner, as a friendly clinician: {clean_interviewer_text}",
         
     | 
| 339 | 
         
            +
                        INTERVIEWER_VOICE,
         
     | 
| 340 | 
         
            +
                    )
         
     | 
| 341 | 
         
             
                    audio_b64 = None
         
     | 
| 342 | 
         
             
                    if audio_data and mime_type:
         
     | 
| 343 | 
         
             
                        audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
         
     | 
| 344 | 
         | 
| 345 | 
         
             
                    # Yield interviewer message (text and audio)
         
     | 
| 346 | 
         
            +
                    yield json.dumps(
         
     | 
| 347 | 
         
            +
                        {
         
     | 
| 348 | 
         
            +
                            "speaker": "interviewer",
         
     | 
| 349 | 
         
            +
                            "text": clean_interviewer_text,
         
     | 
| 350 | 
         
            +
                            "audio": audio_b64,
         
     | 
| 351 | 
         
            +
                        }
         
     | 
| 352 | 
         
            +
                    )
         
     | 
| 353 | 
         
            +
                    dialog.append(
         
     | 
| 354 | 
         
            +
                        {
         
     | 
| 355 | 
         
            +
                            "role": "assistant",
         
     | 
| 356 | 
         
            +
                            "content": [{"type": "text", "text": interviewer_question_text}],
         
     | 
| 357 | 
         
            +
                        }
         
     | 
| 358 | 
         
            +
                    )
         
     | 
| 359 | 
         
             
                    if "End interview" in interviewer_question_text:
         
     | 
| 360 | 
         
             
                        # End the interview loop if the LLM signals completion
         
     | 
| 361 | 
         
             
                        break
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    # Get the patient's response from Gemini (roleplay LLM)
         
     | 
| 364 | 
         
            +
                    patient_response_text = gemini_get_text_response(
         
     | 
| 365 | 
         
            +
                        f"""
         
     | 
| 366 | 
         
             
                    {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}\n\n
         
     | 
| 367 | 
         
            +
                    Question: {interviewer_question_text}"""
         
     | 
| 368 | 
         
            +
                    )
         
     | 
| 369 | 
         | 
| 370 | 
         
             
                    # Generate audio for the patient's response
         
     | 
| 371 | 
         
            +
                    audio_data, mime_type = synthesize_gemini_tts(
         
     | 
| 372 | 
         
            +
                        f"Say this in faster speed, using a sick tone: {patient_response_text}",
         
     | 
| 373 | 
         
            +
                        patient_voice,
         
     | 
| 374 | 
         
            +
                    )
         
     | 
| 375 | 
         
             
                    audio_b64 = None
         
     | 
| 376 | 
         
             
                    if audio_data and mime_type:
         
     | 
| 377 | 
         
             
                        audio_b64 = f"data:{mime_type};base64,{base64.b64encode(audio_data).decode('utf-8')}"
         
     | 
| 378 | 
         | 
| 379 | 
         
             
                    # Yield patient message (text and audio)
         
     | 
| 380 | 
         
            +
                    yield json.dumps(
         
     | 
| 381 | 
         
            +
                        {"speaker": "patient", "text": patient_response_text, "audio": audio_b64}
         
     | 
| 382 | 
         
            +
                    )
         
     | 
| 383 | 
         
            +
                    dialog.append(
         
     | 
| 384 | 
         
            +
                        {
         
     | 
| 385 | 
         
            +
                            "role": "user",
         
     | 
| 386 | 
         
            +
                            "content": [{"type": "text", "text": patient_response_text}],
         
     | 
| 387 | 
         
            +
                        }
         
     | 
| 388 | 
         
            +
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 389 | 
         
             
                    # Track the full Q&A for context in future LLM calls
         
     | 
| 390 | 
         
            +
                    most_recent_q_a = (
         
     | 
| 391 | 
         
            +
                        f"Q: {interviewer_question_text}\nA: {patient_response_text}\n"
         
     | 
| 392 | 
         
            +
                    )
         
     | 
| 393 | 
         
            +
                    full_interview_q_a_with_new_q_a = (
         
     | 
| 394 | 
         
            +
                        "PREVIOUS Q&A:\n" + full_interview_q_a + "\nNEW Q&A:\n" + most_recent_q_a
         
     | 
| 395 | 
         
            +
                    )
         
     | 
| 396 | 
         
             
                    # Update the report after each Q&A
         
     | 
| 397 | 
         
            +
                    write_report_text = write_report(
         
     | 
| 398 | 
         
            +
                        patient_name, full_interview_q_a_with_new_q_a, write_report_text
         
     | 
| 399 | 
         
            +
                    )
         
     | 
| 400 | 
         
             
                    full_interview_q_a += most_recent_q_a
         
     | 
| 401 | 
         
            +
                    yield json.dumps({"speaker": "report", "text": write_report_text})
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 402 | 
         | 
| 403 | 
         
            +
                print(
         
     | 
| 404 | 
         
            +
                    f"""Interview simulation completed for patient: {patient_name}, condition: {condition_name}.
         
     | 
| 405 | 
         
             
                      Patient profile used:
         
     | 
| 406 | 
         
            +
                      {patient_roleplay_instructions(patient_name, condition_name, full_interview_q_a)}"""
         
     | 
| 407 | 
         
            +
                )
         
     | 
| 408 | 
         
             
                # Add this at the end to signal end of stream
         
     | 
| 409 | 
         
            +
                yield json.dumps({"event": "end"})
         
     | 
    	
        medgemma.py
    CHANGED
    
    | 
         @@ -12,18 +12,21 @@ 
     | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 
         | 
|
| 
         | 
|
| 15 | 
         
             
            # MedGemma endpoint
         
     | 
| 16 | 
         
             
            import requests
         
     | 
| 
         | 
|
| 17 | 
         
             
            from auth import create_credentials, get_access_token_refresh_if_needed
         
     | 
| 18 | 
         
            -
            import os
         
     | 
| 19 | 
         
             
            from cache import cache
         
     | 
| 20 | 
         | 
| 21 | 
         
            -
            _endpoint_url = os.environ.get( 
     | 
| 22 | 
         | 
| 23 | 
         
             
            # Create credentials
         
     | 
| 24 | 
         
            -
            secret_key_json = os.environ.get( 
     | 
| 25 | 
         
             
            medgemma_credentials = create_credentials(secret_key_json)
         
     | 
| 26 | 
         | 
| 
         | 
|
| 27 | 
         
             
            # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations.endpoints.chat/completions
         
     | 
| 28 | 
         
             
            @cache.memoize()
         
     | 
| 29 | 
         
             
            def medgemma_get_text_response(
         
     | 
| 
         @@ -36,7 +39,7 @@ def medgemma_get_text_response( 
     | 
|
| 36 | 
         
             
                stop: list[str] | str | None = None,
         
     | 
| 37 | 
         
             
                frequency_penalty: float | None = None,
         
     | 
| 38 | 
         
             
                presence_penalty: float | None = None,
         
     | 
| 39 | 
         
            -
                model: str="tgi"
         
     | 
| 40 | 
         
             
            ):
         
     | 
| 41 | 
         
             
                """
         
     | 
| 42 | 
         
             
                Makes a chat completion request to the configured LLM API (OpenAI-compatible).
         
     | 
| 
         @@ -47,26 +50,31 @@ def medgemma_get_text_response( 
     | 
|
| 47 | 
         
             
                }
         
     | 
| 48 | 
         | 
| 49 | 
         
             
                # Based on the openai format
         
     | 
| 50 | 
         
            -
                payload = {
         
     | 
| 51 | 
         
            -
                            "messages": messages,
         
     | 
| 52 | 
         
            -
                            "max_tokens": max_tokens
         
     | 
| 53 | 
         
            -
                          }
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
                if temperature is not None: payload["temperature"] = temperature
         
     | 
| 57 | 
         
            -
                if top_p is not None: payload["top_p"] = top_p
         
     | 
| 58 | 
         
            -
                if seed is not None: payload["seed"] = seed
         
     | 
| 59 | 
         
            -
                if stop is not None: payload["stop"] = stop
         
     | 
| 60 | 
         
            -
                if frequency_penalty is not None: payload["frequency_penalty"] = frequency_penalty
         
     | 
| 61 | 
         
            -
                if presence_penalty is not None: payload["presence_penalty"] = presence_penalty
         
     | 
| 62 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 63 | 
         | 
| 64 | 
         
            -
                response = requests.post( 
     | 
| 
         | 
|
| 
         | 
|
| 65 | 
         
             
                try:
         
     | 
| 66 | 
         
             
                    response.raise_for_status()
         
     | 
| 67 | 
         
             
                    return response.json()["choices"][0]["message"]["content"]
         
     | 
| 68 | 
         
             
                except requests.exceptions.JSONDecodeError:
         
     | 
| 69 | 
         
             
                    # Log the problematic response for easier debugging in the future.
         
     | 
| 70 | 
         
            -
                    print( 
     | 
| 
         | 
|
| 
         | 
|
| 71 | 
         
             
                    # Re-raise the exception so the caller knows something went wrong.
         
     | 
| 72 | 
         
             
                    raise
         
     | 
| 
         | 
|
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         | 
| 15 | 
         
            +
            import os
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
             
            # MedGemma endpoint
         
     | 
| 18 | 
         
             
            import requests
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
             
            from auth import create_credentials, get_access_token_refresh_if_needed
         
     | 
| 
         | 
|
| 21 | 
         
             
            from cache import cache
         
     | 
| 22 | 
         | 
| 23 | 
         
            +
            _endpoint_url = os.environ.get("GCP_MEDGEMMA_ENDPOINT")
         
     | 
| 24 | 
         | 
| 25 | 
         
             
            # Create credentials
         
     | 
| 26 | 
         
            +
            secret_key_json = os.environ.get("GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY")
         
     | 
| 27 | 
         
             
            medgemma_credentials = create_credentials(secret_key_json)
         
     | 
| 28 | 
         | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
             
            # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations.endpoints.chat/completions
         
     | 
| 31 | 
         
             
            @cache.memoize()
         
     | 
| 32 | 
         
             
            def medgemma_get_text_response(
         
     | 
| 
         | 
|
| 39 | 
         
             
                stop: list[str] | str | None = None,
         
     | 
| 40 | 
         
             
                frequency_penalty: float | None = None,
         
     | 
| 41 | 
         
             
                presence_penalty: float | None = None,
         
     | 
| 42 | 
         
            +
                model: str = "tgi",
         
     | 
| 43 | 
         
             
            ):
         
     | 
| 44 | 
         
             
                """
         
     | 
| 45 | 
         
             
                Makes a chat completion request to the configured LLM API (OpenAI-compatible).
         
     | 
| 
         | 
|
| 50 | 
         
             
                }
         
     | 
| 51 | 
         | 
| 52 | 
         
             
                # Based on the openai format
         
     | 
| 53 | 
         
            +
                payload = {"messages": messages, "max_tokens": max_tokens}
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 54 | 
         | 
| 55 | 
         
            +
                if temperature is not None:
         
     | 
| 56 | 
         
            +
                    payload["temperature"] = temperature
         
     | 
| 57 | 
         
            +
                if top_p is not None:
         
     | 
| 58 | 
         
            +
                    payload["top_p"] = top_p
         
     | 
| 59 | 
         
            +
                if seed is not None:
         
     | 
| 60 | 
         
            +
                    payload["seed"] = seed
         
     | 
| 61 | 
         
            +
                if stop is not None:
         
     | 
| 62 | 
         
            +
                    payload["stop"] = stop
         
     | 
| 63 | 
         
            +
                if frequency_penalty is not None:
         
     | 
| 64 | 
         
            +
                    payload["frequency_penalty"] = frequency_penalty
         
     | 
| 65 | 
         
            +
                if presence_penalty is not None:
         
     | 
| 66 | 
         
            +
                    payload["presence_penalty"] = presence_penalty
         
     | 
| 67 | 
         | 
| 68 | 
         
            +
                response = requests.post(
         
     | 
| 69 | 
         
            +
                    _endpoint_url, headers=headers, json=payload, stream=stream, timeout=60
         
     | 
| 70 | 
         
            +
                )
         
     | 
| 71 | 
         
             
                try:
         
     | 
| 72 | 
         
             
                    response.raise_for_status()
         
     | 
| 73 | 
         
             
                    return response.json()["choices"][0]["message"]["content"]
         
     | 
| 74 | 
         
             
                except requests.exceptions.JSONDecodeError:
         
     | 
| 75 | 
         
             
                    # Log the problematic response for easier debugging in the future.
         
     | 
| 76 | 
         
            +
                    print(
         
     | 
| 77 | 
         
            +
                        f"Error: Failed to decode JSON from MedGemma. Status: {response.status_code}, Response: {response.text}"
         
     | 
| 78 | 
         
            +
                    )
         
     | 
| 79 | 
         
             
                    # Re-raise the exception so the caller knows something went wrong.
         
     | 
| 80 | 
         
             
                    raise
         
     |