| import gradio as gr | |
| import os | |
| import tensorflow as tf | |
| import numpy as np | |
| import requests | |
| import time | |
| from langchain_groq import ChatGroq | |
| from langchain.agents import initialize_agent | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnableSequence | |
| from langchain.tools import StructuredTool | |
| IMG_HEIGHT = 256 | |
| IMG_WIDTH = 256 | |
| model_path = "unet_model.h5" | |
| if not os.path.exists(model_path): | |
| hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5" | |
| print(f"Downloading model from {hf_url}...") | |
| with requests.get(hf_url, stream=True) as r: | |
| r.raise_for_status() | |
| with open(model_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print("Loading model...") | |
| model = tf.keras.models.load_model(model_path, compile=False) | |
| def classify_image_and_stats(image_input): | |
| img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH]) | |
| img_norm = img / 255.0 | |
| img_batch = np.expand_dims(img_norm, axis=0) | |
| prediction = model.predict(img_batch)[0] | |
| mask = (prediction > 0.5).astype(np.uint8) | |
| if mask.ndim == 3 and mask.shape[-1] == 1: | |
| mask = np.squeeze(mask, axis=-1) | |
| tumor_area = np.sum(mask) | |
| total_area = IMG_HEIGHT * IMG_WIDTH | |
| tumor_ratio = tumor_area / total_area | |
| tumor_label = "Tumor Detected" if tumor_ratio > 0.00385 else "No Tumor Detected" | |
| overlay = np.array(img) | |
| red_mask = np.zeros_like(overlay) | |
| red_mask[..., 0] = mask * 255 | |
| overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8) | |
| stats = { | |
| "tumor_area": int(tumor_area), | |
| "total_area": total_area, | |
| "tumor_ratio": tumor_ratio, | |
| "tumor_label": tumor_label | |
| } | |
| return overlay_img, stats | |
| def rishigpt_handler(image_input, groq_api_key): | |
| os.environ["GROQ_API_KEY"] = groq_api_key | |
| overlay_img, stats = classify_image_and_stats(image_input) | |
| def segment_brain_tool(input_text: str) -> str: | |
| return ( | |
| f"Tumor label: {stats['tumor_label']}. " | |
| f"Tumor area: {stats['tumor_area']}. " | |
| f"Ratio: {stats['tumor_ratio']:.4f}." | |
| ) | |
| tool = StructuredTool.from_function( | |
| segment_brain_tool, | |
| name="segment_brain", | |
| description="Provide tumor segmentation stats for the MRI image." | |
| ) | |
| llm = ChatGroq( | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| temperature=0.4 | |
| ) | |
| agent = initialize_agent( | |
| tools=[tool], | |
| llm=llm, | |
| agent="zero-shot-react-description", | |
| verbose=False | |
| ) | |
| user_query = "Give me the segmentation details" | |
| classification = agent.run(user_query) | |
| prompt = PromptTemplate( | |
| input_variables=["result"], | |
| template=( | |
| "You are a compassionate AI radiologist. " | |
| "Read this tumor analysis result: {result}. " | |
| "Summarize the situation for the patient in natural paragraphs, calm, clear tone, with next steps." | |
| ) | |
| ) | |
| chain = prompt | llm | |
| final_text = chain.invoke({"result": classification}).content.strip() | |
| displayed_text = "" | |
| for char in final_text: | |
| displayed_text += char | |
| time.sleep(0.015) | |
| yield overlay_img, displayed_text | |
| inputs = [ | |
| gr.Image(type="numpy", label="Upload Brain MRI Slice"), | |
| gr.Textbox(type="password", label="Groq API Key") | |
| ] | |
| outputs = [ | |
| gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"), | |
| gr.Textbox(label="Doctor's Explanation") | |
| ] | |
| if __name__ == "__main__": | |
| gr.Interface( | |
| fn=rishigpt_handler, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="RishiGPT Medical Brain Segmentation", | |
| description="UNet++ Brain Tumor Segmentation with live mask overlay, detailed stats, and human-like typing explanation." | |
| ).launch() | |