Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from openai import OpenAI | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import logging | |
| import torch | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| from typing import Iterator | |
| import os | |
| import spaces | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # App version | |
| APP_VERSION = "1.0.0" | |
| logger.info(f"Starting Radiology Teaching App v{APP_VERSION}") | |
| # Model configuration | |
| MODEL_NAME = "openai/whisper-large-v3-turbo" | |
| BATCH_SIZE = 8 | |
| FILE_LIMIT_MB = 5000 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| # Initialize the LLM | |
| if torch.cuda.is_available(): | |
| llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored" | |
| llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained(llm_model_id) | |
| tokenizer.use_default_system_prompt = False | |
| # Initialize the transcription pipeline | |
| pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL_NAME, | |
| chunk_length_s=30, | |
| device=device, | |
| ) | |
| try: | |
| # Load only 10 rows from the dataset | |
| logger.info("Loading MIMIC-CXR dataset...") | |
| dataset = load_dataset("itsanmolgupta/mimic-cxr-dataset", split="train").select(range(10)) | |
| df = pd.DataFrame(dataset) | |
| logger.info(f"Successfully loaded {len(df)} cases") | |
| except Exception as e: | |
| logger.error(f"Error loading dataset: {str(e)}") | |
| raise | |
| def encode_image_to_base64(image_bytes): | |
| return base64.b64encode(image_bytes).decode('utf-8') | |
| def analyze_report(user_findings, ground_truth_findings, ground_truth_impression, api_key): | |
| if not api_key: | |
| return "Please provide a DeepSeek API key to analyze the report." | |
| try: | |
| client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com") | |
| logger.info("Analyzing report with DeepSeek...") | |
| prompt = f"""You are an expert radiologist reviewing a trainee's chest X-ray report. | |
| Trainee's Findings: | |
| {user_findings} | |
| Ground Truth Findings: | |
| {ground_truth_findings} | |
| Ground Truth Impression: | |
| {ground_truth_impression} | |
| Please provide: | |
| 1. Number of important findings missed by the trainee (list them) | |
| 2. Quality assessment of the trainee's report (structure, completeness, accuracy) | |
| 3. Constructive feedback for improvement | |
| Format your response in clear sections.""" | |
| response = client.chat.completions.create( | |
| model="deepseek-chat", | |
| messages=[ | |
| {"role": "system", "content": "You are an expert radiologist providing constructive feedback."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| stream=False | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Error in report analysis: {str(e)}") | |
| return f"Error analyzing report: {str(e)}" | |
| def transcribe(inputs, task="transcribe"): | |
| """Transcribe audio using Whisper""" | |
| if inputs is None: | |
| raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
| try: | |
| logger.info("Transcribing audio...") | |
| text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"] | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error in transcription: {str(e)}") | |
| raise gr.Error(f"Transcription failed: {str(e)}") | |
| def analyze_with_llama( | |
| transcribed_text: str, | |
| ground_truth_findings: str, | |
| ground_truth_impression: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| """Analyze transcribed report against ground truth using Llama""" | |
| task_prompt = f"""You are an expert radiologist. Compare the following transcribed radiology report with the ground truth and provide very concise feedback. | |
| Transcribed Report: | |
| {transcribed_text} | |
| Ground Truth Findings: | |
| {ground_truth_findings} | |
| Ground Truth Impression: | |
| {ground_truth_impression} | |
| Please analyze: | |
| 1. Accuracy of findings. Only comment on how the user's transcribed report compares to the ground truth. | |
| 2. Completeness of user report compared to ground truth. | |
| 3. Structure and clarity of user report compared to ground truth. | |
| 4. Areas for improvement for user report compared to ground truth. | |
| Provide concise analysis in a clear, structured format.""" | |
| conversation = [ | |
| {"role": "system", "content": "You are an expert radiologist providing detailed feedback."}, | |
| {"role": "user", "content": task_prompt} | |
| ] | |
| input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| input_ids = input_ids.to(llm.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=llm.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| def load_random_case(hide_ground_truth): | |
| try: | |
| # Randomly select a case from our dataset | |
| random_case = df.sample(n=1).iloc[0] | |
| logger.info("Loading random case...") | |
| # Get the image, findings, and impression | |
| image = random_case['image'] | |
| # Store full findings and impression regardless of hide_ground_truth | |
| findings = random_case['findings'] | |
| impression = random_case['impression'] | |
| # Only hide display if hide_ground_truth is True | |
| display_findings = "" if hide_ground_truth else findings | |
| display_impression = "" if hide_ground_truth else impression | |
| # Return both display values and actual values | |
| return image, display_findings, display_impression, findings, impression | |
| except Exception as e: | |
| logger.error(f"Error loading random case: {str(e)}") | |
| return None, "Error loading case", "Error loading case", "", "" | |
| def process_case(image, user_findings, hide_ground_truth, api_key, current_findings="", current_impression="", actual_findings="", actual_impression=""): | |
| # Use actual findings/impression for analysis if they exist, otherwise fall back to current values | |
| findings_for_analysis = actual_findings if actual_findings else current_findings | |
| impression_for_analysis = actual_impression if actual_impression else current_impression | |
| analysis = analyze_report(user_findings, findings_for_analysis, impression_for_analysis, api_key) | |
| # Return display values based on hide_ground_truth | |
| display_findings = "" if hide_ground_truth else findings_for_analysis | |
| display_impression = "" if hide_ground_truth else impression_for_analysis | |
| return display_findings, display_impression, analysis | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# Radiology Report Training System v{APP_VERSION}") | |
| gr.Markdown("### Practice your chest X-ray reading and reporting skills") | |
| # Add state variables to store actual findings and impression | |
| actual_findings_state = gr.State("") | |
| actual_impression_state = gr.State("") | |
| with gr.Tab("DeepSeek Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_display = gr.Image(label="Chest X-ray Image", type="pil") | |
| api_key_input = gr.Textbox(label="DeepSeek API Key", type="password") | |
| hide_truth = gr.Checkbox(label="Hide Ground Truth", value=False) | |
| load_btn = gr.Button("Load Random Case") | |
| with gr.Column(): | |
| user_findings_input = gr.Textbox(label="Your Findings", lines=10, placeholder="Type or dictate your findings here...") | |
| ground_truth_findings = gr.Textbox(label="Ground Truth Findings", lines=5, interactive=False) | |
| ground_truth_impression = gr.Textbox(label="Ground Truth Impression", lines=5, interactive=False) | |
| analysis_output = gr.Textbox(label="Analysis and Feedback", lines=10, interactive=False) | |
| submit_btn = gr.Button("Submit Report") | |
| with gr.Tab("Local Inference"): | |
| gr.Markdown("### Use Local Models for Transcription and Analysis") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Transcription Interface | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio") | |
| task_input = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe") | |
| transcribe_button = gr.Button("Transcribe Audio") | |
| transcription_output = gr.Textbox(label="Transcription Output", lines=5) | |
| # Load case for comparison | |
| load_case_btn = gr.Button("Load Random Case for Comparison") | |
| local_image_display = gr.Image(label="Chest X-ray Image", type="pil") | |
| local_ground_truth_findings = gr.Textbox(label="Ground Truth Findings", lines=5, interactive=False) | |
| local_ground_truth_impression = gr.Textbox(label="Ground Truth Impression", lines=5, interactive=False) | |
| with gr.Column(): | |
| # Editable transcription and analysis interface | |
| edited_transcription = gr.Textbox(label="Edit Transcription", lines=10) | |
| temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1) | |
| top_p_input = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05) | |
| top_k_input = gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1) | |
| max_tokens_input = gr.Slider(label="Max Tokens", minimum=256, maximum=2048, value=1024, step=128) | |
| repetition_penalty_input = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05) | |
| analyze_btn = gr.Button("Analyze with Llama") | |
| llama_analysis_output = gr.Textbox( | |
| label="Llama Analysis Output", | |
| lines=8, | |
| max_lines=8, | |
| show_copy_button=True, | |
| interactive=False, | |
| autoscroll=False | |
| ) | |
| # Event handlers for Local Inference tab | |
| transcribe_button.click( | |
| fn=transcribe, | |
| inputs=[audio_input, task_input], | |
| outputs=transcription_output | |
| ) | |
| # Copy transcription to editable box | |
| transcription_output.change( | |
| fn=lambda x: x, | |
| inputs=[transcription_output], | |
| outputs=[edited_transcription] | |
| ) | |
| # Load case for local analysis | |
| load_case_btn.click( | |
| fn=load_random_case, | |
| inputs=[gr.Checkbox(value=False, visible=False)], # Hidden checkbox for hide_ground_truth | |
| outputs=[ | |
| local_image_display, | |
| local_ground_truth_findings, | |
| local_ground_truth_impression, | |
| gr.State(), # Hidden state | |
| gr.State() # Hidden state | |
| ] | |
| ) | |
| # Analyze with Llama | |
| analyze_btn.click( | |
| fn=analyze_with_llama, | |
| inputs=[ | |
| edited_transcription, | |
| local_ground_truth_findings, | |
| local_ground_truth_impression, | |
| max_tokens_input, | |
| temperature_input, | |
| top_p_input, | |
| top_k_input, | |
| repetition_penalty_input | |
| ], | |
| outputs=llama_analysis_output | |
| ) | |
| # Event handlers for DeepSeek Analysis tab | |
| load_btn.click( | |
| fn=load_random_case, | |
| inputs=[hide_truth], | |
| outputs=[ | |
| image_display, | |
| ground_truth_findings, | |
| ground_truth_impression, | |
| actual_findings_state, | |
| actual_impression_state | |
| ] | |
| ) | |
| submit_btn.click( | |
| fn=process_case, | |
| inputs=[ | |
| image_display, | |
| user_findings_input, | |
| hide_truth, | |
| api_key_input, | |
| ground_truth_findings, | |
| ground_truth_impression, | |
| actual_findings_state, | |
| actual_impression_state | |
| ], | |
| outputs=[ | |
| ground_truth_findings, | |
| ground_truth_impression, | |
| analysis_output | |
| ] | |
| ) | |
| hide_truth.change( | |
| fn=lambda x, f, i: ("" if x else f, "" if x else i, ""), | |
| inputs=[hide_truth, actual_findings_state, actual_impression_state], | |
| outputs=[ground_truth_findings, ground_truth_impression, analysis_output] | |
| ) | |
| logger.info("Starting Gradio interface...") | |
| demo.queue().launch() |