Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import csv | |
| import json | |
| import logging | |
| import random | |
| import re | |
| from dataclasses import replace | |
| from pathlib import Path | |
| from config import BASE_DIR, RANDOMIZE_CHOICES | |
| from models import Case, CaseSummary, AnswerLog, ConversationTurn, QuestionOutcome, ClinicalMCQ | |
| # --- Configuration --- | |
| # Configure basic logging (optional, adjust as needed) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def fetch_report(report_path: Path): | |
| """Report file reading utility function.""" | |
| try: | |
| with open(report_path, 'r') as f: | |
| report = json.load(f) | |
| logger.info(f"Successfully loaded '{report_path}' into memory.") | |
| return report | |
| except FileNotFoundError: | |
| logger.error(f"ERROR: Could not find report file: {report_path}") | |
| return "" | |
| def get_available_reports(reports_csv_path: Path): | |
| """Reads available reports as Cases for this demo.""" | |
| available_reports: dict[str, Case] = {} | |
| if reports_csv_path.is_file(): | |
| try: | |
| with (open(reports_csv_path, mode='r', encoding='utf-8') as csvfile): | |
| reader = csv.DictReader(csvfile) | |
| required_headers = {'case_id', 'case_condition_name', 'report_path', 'download_image_url', 'findings'} | |
| if not required_headers.issubset(reader.fieldnames): | |
| logger.error( | |
| f"CSV file {reports_csv_path} is missing one or more required headers: {required_headers - set(reader.fieldnames)}" | |
| ) | |
| else: | |
| for row in reader: | |
| case_id = row['case_id'] | |
| condition_name = row['case_condition_name'] | |
| report_path_from_csv = row['report_path'] # e.g., static/reports/report1.txt or empty | |
| download_image_url_from_csv = row['download_image_url'] | |
| potential_findings = row['findings'] | |
| # Construct absolute path for report file validation (paths from CSV are relative to BASE_DIR) | |
| abs_report_path_to_check = BASE_DIR / report_path_from_csv | |
| if not abs_report_path_to_check.is_file(): | |
| logger.warning( | |
| f"Image file not found for case '{case_id}' at '{abs_report_path_to_check}'. Skipping this entry.") | |
| continue | |
| if download_image_url_from_csv is None or download_image_url_from_csv == "": | |
| logger.warning( | |
| f"Download image url not found for case '{case_id}'. Skipping this entry.") | |
| continue | |
| ground_truth_labels = fetch_report(report_path_from_csv) | |
| case = Case( | |
| id=case_id, | |
| condition_name=condition_name, | |
| ground_truth_labels=ground_truth_labels, | |
| download_image_url=download_image_url_from_csv, | |
| potential_findings=potential_findings, | |
| ) | |
| available_reports[str(case_id)] = case | |
| logger.info(f"Loaded {len(available_reports)} report/image pairs from CSV.") | |
| except Exception as e: | |
| logger.error(f"Error reading or processing CSV file {reports_csv_path}: {e}", exc_info=True) | |
| else: | |
| logger.warning(f"Manifest CSV file not found at {reports_csv_path}. AVAILABLE_REPORTS will be empty.") | |
| return available_reports | |
| def get_json_from_model_response(response_text: str) -> dict: | |
| """ | |
| Robustly parses a JSON object from a response that may contain it | |
| within a markdown code block. | |
| """ | |
| # This regex now looks for a JSON object starting with { and ending with } | |
| json_match = re.search(r"```json\s*(\{.*?\})\s*```", response_text, re.DOTALL) | |
| if json_match: | |
| json_str = json_match.group(1) | |
| try: | |
| return json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to decode JSON after extraction: {e}") | |
| raise Exception(f"Could not parse JSON from extracted block: {json_str}") | |
| # Fallback if the model misses the markdown block | |
| logger.warning("Could not find a ```json block. Falling back to raw search.") | |
| json_match_fallback = re.search(r"(\{.*\})", response_text, re.DOTALL) | |
| if json_match_fallback: | |
| return json.loads(json_match_fallback.group(1)) | |
| raise Exception(f"Could not find or parse JSON object in the API response: {response_text}") | |
| def get_potential_findings(case: Case) -> str: | |
| """Get potential findings for a case.""" | |
| return case.potential_findings | |
| def build_summary_template(case: Case, rag_cache: dict) -> CaseSummary: | |
| """Builds summary template with static data like potential_findings, guideline_resources and condition.""" | |
| citation_string = "" # Default | |
| rag_data = rag_cache.get(case.id, {}) | |
| citations = rag_data.get("citations", []) | |
| if citations: | |
| citation_string = ', '.join(map(str, citations)) | |
| return CaseSummary( | |
| med_gemma_interpretation="", | |
| potential_findings=get_potential_findings(case), | |
| rationale=[], | |
| guideline_specific_resource=citation_string, | |
| condition=case.condition_name | |
| ) | |
| def populate_rationale(summary_template: CaseSummary, conversation_history: list[ConversationTurn]) -> CaseSummary: | |
| """Populates rationale and interpretation depending on user journey.""" | |
| correct_count = 0 | |
| total_questions = len(conversation_history) | |
| rationale_logs = [] | |
| for turn in conversation_history: | |
| question = turn.clinicalMcq.question | |
| choices = turn.clinicalMcq.choices | |
| model_answer_key = turn.clinicalMcq.answer | |
| user_attempt1_key = turn.userResponse.attempt1 | |
| user_attempt2_key = turn.userResponse.attempt2 | |
| correct_answer_text = choices.get(model_answer_key, f"N/A - Model Answer Key '{model_answer_key}' not found.") | |
| outcomes = [] | |
| if user_attempt1_key != model_answer_key and user_attempt2_key != model_answer_key: | |
| user_attempt_key = user_attempt2_key if user_attempt2_key else user_attempt1_key | |
| incorrect_text = choices[user_attempt_key] | |
| outcomes.append(QuestionOutcome(type="Incorrect", text=incorrect_text)) | |
| else: | |
| correct_count += 1 | |
| outcomes.append(QuestionOutcome(type="Correct", text=correct_answer_text)) | |
| rationale_logs.append(AnswerLog(question=question, outcomes=outcomes)) | |
| accuracy = (correct_count / total_questions) * 100 if total_questions > 0 else 0 | |
| if accuracy == 100: | |
| interpretation = f"Wonderful job! You achieved a perfect score of {accuracy:.0f}%, correctly identifying all key findings on your first attempt." | |
| elif accuracy >= 50: | |
| interpretation = f"Good job. You scored {accuracy:.0f}%, showing a solid understanding of the key findings for this case." | |
| else: | |
| interpretation = f"This was a challenging case, and you scored {accuracy:.0f}%. More preparation is needed. Review the rationale below for details." | |
| return CaseSummary( | |
| med_gemma_interpretation=interpretation, | |
| potential_findings=summary_template.potential_findings, | |
| rationale=rationale_logs, | |
| guideline_specific_resource=summary_template.guideline_specific_resource, | |
| condition=summary_template.condition, | |
| ) | |
| def randomize_mcqs(original_mcqs: list[ClinicalMCQ]) -> list[ClinicalMCQ]: | |
| """ | |
| Takes a list of clinical MCQs and randomizes their answer choices. | |
| If an error occurs while randomizing a question, it returns the original question | |
| in its place and continues. | |
| """ | |
| if not RANDOMIZE_CHOICES: | |
| return original_mcqs | |
| randomized_questions = [] | |
| for q in original_mcqs: | |
| try: | |
| # --- Step 1: Identify the correct answer's text --- | |
| # Before shuffling, we save the actual string of the correct answer. | |
| correct_answer_text = q.choices[q.answer] | |
| # --- Step 2: Shuffle the choice values --- | |
| # We extract the choice texts into a list and shuffle them in place. | |
| choice_texts = list(q.choices.values()) | |
| random.shuffle(choice_texts) | |
| # --- Step 3: Rebuild choices and find the new answer key (Concise version) --- | |
| # Pair the original sorted keys with the newly shuffled texts using zip. | |
| keys = sorted(q.choices.keys()) | |
| new_choices = dict(zip(keys, choice_texts)) | |
| # Efficiently find the new key corresponding to the correct answer's text. | |
| new_answer_key = next(key for key, value in new_choices.items() if value == correct_answer_text) | |
| # --- Step 4: Create an updated, immutable copy of the question --- | |
| # Using `dataclasses.replace` is a clean, Pythonic way to create a new | |
| # instance with updated values, promoting immutability. | |
| randomized_q = replace(q, choices=new_choices, answer=new_answer_key) | |
| randomized_questions.append(randomized_q) | |
| except Exception as e: | |
| # If any error occurs (e.g., KeyError from a bad answer key), | |
| # print a warning and append the original, unmodified question. | |
| logger.warning(f"Warning: Could not randomize question '{q.id}'. Returning original. Error: {e}") | |
| randomized_questions.append(q) | |
| return randomized_questions | |