import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import time
phi4_model_path = "Intelligent-Internet/II-Medical-8B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
# This is our streaming generator function that yields partial results
@spaces.GPU(duration=60)
def generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
    if not user_message.strip():
        yield history, history
        return
    model = phi4_model
    tokenizer = phi4_tokenizer
    start_tag = "<|im_start|>"
    sep_tag = "<|im_sep|>"
    end_tag = "<|im_end|>"
    system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections: 
 In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step.  
In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional.
Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$). 
Now, analyze the following case:"""
    # Build conversation history in the format the model expects
    prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
    
    # Convert chat history format from the Gradio Chatbot format to prompt format
    for user_msg, bot_msg in history:
        if user_msg:
            prompt += f"{start_tag}user{sep_tag}{user_msg}{end_tag}"
        if bot_msg:
            prompt += f"{start_tag}assistant{sep_tag}{bot_msg}{end_tag}"
    
    # Add the current user message
    prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": int(max_tokens),
        "do_sample": True,
        "temperature": float(temperature),
        "top_k": int(top_k),
        "top_p": float(top_p),
        "repetition_penalty": float(repetition_penalty),
        "streamer": streamer,
    }
    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    # Create a new history with the current user message
    new_history = history.copy() + [[user_message, ""]]
    
    # Collect the generated response
    assistant_response = ""
    for new_token in streamer:
        cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
        assistant_response += cleaned_token
        # Update the last message in history with the current response
        new_history[-1][1] = assistant_response.strip()
        yield new_history, new_history
        # Add a small sleep to control the streaming rate
        time.sleep(0.01)
    
    # Return the final state after streaming is completed
    yield new_history, new_history
# This is our non-streaming wrapper function for buttons that don't support streaming
def process_input(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
    generator = generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history)
    # Get the final result by exhausting the generator
    result = None
    for result in generator:
        pass
    return result
example_messages = {
    "Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.",
    "Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.",
    "Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.",
    "BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese."
}
css = """
.markdown-body .katex { 
    font-size: 1.2em; 
}
.markdown-body .katex-display { 
    margin: 1em 0; 
    overflow-x: auto;
    overflow-y: hidden;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.")
    gr.HTML("""
    
    """)
    chatbot = gr.Chatbot(label="Chat", render_markdown=True, show_copy_button=True)
    history = gr.State([])
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Settings")
            max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
            with gr.Accordion("Advanced Settings", open=False):
                temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
                top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
                top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
                repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
        with gr.Column(scale=4):
            with gr.Row():
                user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3)
                submit_button = gr.Button("Send", variant="primary", scale=1)
                clear_button = gr.Button("Clear", scale=1)
            gr.Markdown("**Try these examples:**")
            with gr.Row():
                example1 = gr.Button("Headache case")
                example2 = gr.Button("Chest pain")
                example3 = gr.Button("Abdominal pain")
                example4 = gr.Button("BMI calculation")
    # Set up the streaming interface
    def on_submit(message, history, max_tokens, temperature, top_k, top_p, repetition_penalty):
        # Return the modified history that includes the new user message
        modified_history = history + [[message, ""]]
        return "", modified_history, modified_history
    def on_stream(history, max_tokens, temperature, top_k, top_p, repetition_penalty):
        if not history:
            return history
        
        # Get the last user message from history
        user_message = history[-1][0]
        
        # Start a fresh history without the last entry
        prev_history = history[:-1]
        
        # Generate streaming responses
        for new_history, _ in generate_streaming_response(
            user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, prev_history
        ):
            yield new_history
    # Connect the submission event
    submit_button.click(
        fn=on_submit,
        inputs=[user_input, history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
        outputs=[user_input, chatbot, history]
    ).then(
        fn=on_stream,
        inputs=[history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
        outputs=chatbot
    )
    # Handle examples
    def set_example(example_text):
        return gr.update(value=example_text)
    clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history])
    example1.click(fn=lambda: set_example(example_messages["Headache case"]), inputs=None, outputs=user_input)
    example2.click(fn=lambda: set_example(example_messages["Chest pain"]), inputs=None, outputs=user_input)
    example3.click(fn=lambda: set_example(example_messages["Abdominal pain"]), inputs=None, outputs=user_input)
    example4.click(fn=lambda: set_example(example_messages["BMI calculation"]), inputs=None, outputs=user_input)
demo.launch(ssr_mode=False)