tmj-predictor / src /streamlit_app.py
drkareemkamal's picture
Update src/streamlit_app.py
ce08beb verified
import streamlit as st
import pandas as pd
import numpy as np
import joblib
import json
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime
# Page config
st.set_page_config(
page_title="TMJ Injection Success Predictor",
page_icon="๐Ÿ’‰",
layout="wide"
)
# Load model and materials
@st.cache_resource
def load_artifacts():
"""Load the trained model and materials list"""
try:
# Load model
model = joblib.load('src/best_tmj_success_classifier_without_fe.pkl')
# Load materials list
try:
with open('src/material_list.json', 'r') as f:
materials_data = json.load(f)
materials = materials_data.get('materials', [])
except FileNotFoundError:
# Fallback to default materials
materials = ['Local Anaesthesia', 'Dry Needle', 'Botox',
'Saline', 'Magnesium', 'PRF']
st.warning("Using default materials list. Train the model to generate actual materials from your data.")
# Load metadata if available
metadata = {}
try:
with open('model_metadata.json', 'r') as f:
metadata = json.load(f)
except FileNotFoundError:
pass
return model, materials, metadata
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Initialize
model, materials, metadata = load_artifacts()
# Title and description
st.title("๐Ÿฆท TMJ Injection Success Predictor")
st.markdown("""
This tool predicts the 3-month treatment success probability for TMJ injections based on patient baseline characteristics.
Enter the patient information below to see predictions for different injection materials.
""")
# Display model info if available
if metadata:
with st.expander("โ„น๏ธ Model Information"):
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Model Type", metadata.get('model_type', 'Unknown'))
with col2:
st.metric("Test ROC-AUC", f"{metadata.get('test_roc_auc', 0):.3f}")
with col3:
st.metric("Training Date", metadata.get('training_date', 'Unknown')[:10])
st.write(f"**Success Definition:** {metadata.get('success_definition', 'Unknown')}")
if metadata.get('simplified_version', False):
st.info("This model uses the simplified feature set without text analysis.")
st.divider()
# Create form
with st.form("patient_form"):
st.subheader("Patient Information")
# Required fields
col1, col2 = st.columns(2)
with col1:
st.markdown("**Required Fields**")
sex = st.selectbox("Sex", options=['Male', 'Female'], help="Patient's biological sex")
age = st.number_input("Age", min_value=10, max_value=100, value=45, help="Patient age in years")
pain_m0 = st.slider("Baseline Pain (M0)", min_value=0, max_value=10, value=7,
help="Pain score at baseline (0-10 scale)")
with col2:
st.markdown("** **") # Empty space to align with "Required Fields"
mmo_m0 = st.slider("Baseline MMO (M0)", min_value=0, max_value=80, value=35,
help="Maximum mouth opening at baseline (mm)")
ohip_14_m0 = st.slider("Baseline OHIP-14 (M0)", min_value=0, max_value=56, value=28,
help="Oral Health Impact Profile score at baseline (0-56)")
st.divider()
# Optional fields
st.markdown("**Optional Fields**")
col3, col4 = st.columns(2)
with col3:
location = st.text_input("Location", placeholder="e.g., Right TMJ",
help="Injection location")
muscle_injected = st.text_input("Muscle Injected", placeholder="e.g., Masseter",
help="Specific muscle targeted")
adjunctive_treatment = st.text_input("Adjunctive Treatment", placeholder="e.g., Physical therapy",
help="Additional treatments")
with col4:
previous_injection = st.selectbox("Previous Injection", options=['No', 'Yes'],
help="Has the patient had previous TMJ injections?")
if previous_injection == 'Yes':
material_in_previous_injection = st.selectbox("Previous Material",
options=[''] + materials,
help="Material used in previous injection")
else:
material_in_previous_injection = ''
st.divider()
# Material selection for primary prediction
st.markdown("**Primary Prediction**")
selected_material = st.selectbox("Select Material for Prediction",
options=materials,
help="Choose the material you're considering for this patient")
# Compare all materials option
compare_all = st.checkbox("Compare all available materials", value=True,
help="Show predictions for all materials to help with decision making")
# Submit button
submitted = st.form_submit_button("๐Ÿ”ฎ Predict Success", use_container_width=True, type="primary")
# Process form submission
if submitted:
# Create input dataframe
input_data = pd.DataFrame({
'sex': [sex],
'age': [age],
'pain_m0': [pain_m0],
'mmo_m0': [mmo_m0],
'ohip_14_m0': [ohip_14_m0],
'location': [location if location else np.nan],
'muscle_injected': [muscle_injected if muscle_injected else np.nan],
'adjunctive_treatment': [adjunctive_treatment if adjunctive_treatment else np.nan],
'previous_injection': [1 if previous_injection == 'Yes' else 0],
'material_in_previous_injection': [material_in_previous_injection if material_in_previous_injection else np.nan],
'material_injected': [selected_material]
})
# Make prediction for selected material
try:
prediction_proba = model.predict_proba(input_data)[0, 1]
# Display primary prediction
st.divider()
st.subheader("Prediction Results")
# Create a visual indicator
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
# Success probability gauge
fig = go.Figure(go.Indicator(
mode = "gauge+number+delta",
value = prediction_proba * 100,
domain = {'x': [0, 1], 'y': [0, 1]},
title = {'text': f"Success Probability with {selected_material}"},
number = {'suffix': "%", 'font': {'size': 40}},
gauge = {
'axis': {'range': [None, 100]},
'bar': {'color': "darkblue"},
'steps': [
{'range': [0, 30], 'color': "lightgray"},
{'range': [30, 70], 'color': "gray"},
{'range': [70, 100], 'color': "lightgreen"}
],
'threshold': {
'line': {'color': "red", 'width': 4},
'thickness': 0.75,
'value': 50
}
}
))
fig.update_layout(height=400)
st.plotly_chart(fig, use_container_width=True)
# Interpretation
if prediction_proba >= 0.7:
st.success(f"โœ… High likelihood of success ({prediction_proba:.1%}) with {selected_material}")
elif prediction_proba >= 0.5:
st.warning(f"โš ๏ธ Moderate likelihood of success ({prediction_proba:.1%}) with {selected_material}")
else:
st.error(f"โŒ Low likelihood of success ({prediction_proba:.1%}) with {selected_material}")
# Compare all materials if requested
if compare_all:
st.divider()
st.subheader("๐Ÿ“Š Material Comparison")
# Predict for all materials
material_results = []
for material in materials:
temp_data = input_data.copy()
temp_data['material_injected'] = material
prob = model.predict_proba(temp_data)[0, 1]
material_results.append({
'Material': material,
'Success Probability': prob,
'Success %': f"{prob:.1%}"
})
# Sort by probability
material_df = pd.DataFrame(material_results)
material_df = material_df.sort_values('Success Probability', ascending=False)
# Display results
col1, col2 = st.columns([1, 1])
with col1:
# Table view
st.markdown("**Ranked Materials**")
display_df = material_df[['Material', 'Success %']].reset_index(drop=True)
display_df.index += 1 # Start index at 1
st.dataframe(display_df, use_container_width=True)
# Highlight best option
best_material = material_df.iloc[0]['Material']
best_prob = material_df.iloc[0]['Success Probability']
if best_material != selected_material:
st.info(f"๐Ÿ’ก Consider using **{best_material}** for potentially better outcomes ({best_prob:.1%} vs {prediction_proba:.1%})")
with col2:
# Bar chart
st.markdown("**Visual Comparison**")
fig = px.bar(material_df,
x='Success Probability',
y='Material',
orientation='h',
color='Success Probability',
color_continuous_scale='RdYlGn',
range_color=[0, 1],
text='Success %')
fig.update_traces(textposition='outside')
fig.update_layout(
xaxis_title="Success Probability",
yaxis_title="",
showlegend=False,
xaxis=dict(range=[0, 1.1]),
height=400
)
# Add vertical line at 50%
fig.add_vline(x=0.5, line_dash="dash", line_color="gray",
annotation_text="50% threshold")
st.plotly_chart(fig, use_container_width=True)
# Additional insights
st.divider()
with st.expander("๐Ÿ“‹ Patient Summary"):
st.write("**Baseline Characteristics:**")
summary_cols = st.columns(3)
with summary_cols[0]:
st.write(f"- Age: {age} years")
st.write(f"- Sex: {sex}")
st.write(f"- Previous injection: {previous_injection}")
with summary_cols[1]:
st.write(f"- Pain score: {pain_m0}/10")
st.write(f"- MMO: {mmo_m0} mm")
st.write(f"- OHIP-14: {ohip_14_m0}/56")
with summary_cols[2]:
if location:
st.write(f"- Location: {location}")
if muscle_injected:
st.write(f"- Muscle: {muscle_injected}")
if adjunctive_treatment:
st.write(f"- Adjunctive: {adjunctive_treatment}")
except Exception as e:
st.error(f"Error making prediction: {str(e)}")
st.info("Please ensure the model was trained with all the necessary features.")
# Footer
st.divider()
st.markdown("""
<div style='text-align: center; color: gray;'>
<small>
TMJ Injection Success Predictor |
Model trained on historical patient data |
Predictions are probabilistic and should be used alongside clinical judgment
</small>
</div>
""", unsafe_allow_html=True)
# Sidebar with instructions
with st.sidebar:
st.header("๐Ÿ“– Instructions")
st.markdown("""
1. **Enter patient baseline data** in the form
2. **Select the material** you're considering
3. **Click Predict** to see the success probability
4. **Compare materials** to find the optimal choice
---
### ๐ŸŽฏ Success Definition
Treatment success is typically defined as:
- Pain reduction > 2 points
- MMO increase > 5 mm
- OHIP-14 reduction > 5 points
---
### ๐Ÿ“Š Interpretation Guide
- **70%+**: High success likelihood โœ…
- **50-70%**: Moderate success โš ๏ธ
- **<50%**: Low success likelihood โŒ
---
### โš•๏ธ Clinical Note
These predictions are based on statistical models and should complement, not replace, clinical expertise and patient-specific considerations.
""")