import streamlit as st import pandas as pd import numpy as np import joblib import json import plotly.express as px import plotly.graph_objects as go from datetime import datetime # Page config st.set_page_config( page_title="TMJ Injection Success Predictor", page_icon="đ", layout="wide" ) # Load model and materials @st.cache_resource def load_artifacts(): """Load the trained model and materials list""" try: # Load model model = joblib.load('src/best_tmj_success_classifier_without_fe.pkl') # Load materials list try: with open('src/material_list.json', 'r') as f: materials_data = json.load(f) materials = materials_data.get('materials', []) except FileNotFoundError: # Fallback to default materials materials = ['Local Anaesthesia', 'Dry Needle', 'Botox', 'Saline', 'Magnesium', 'PRF'] st.warning("Using default materials list. Train the model to generate actual materials from your data.") # Load metadata if available metadata = {} try: with open('model_metadata.json', 'r') as f: metadata = json.load(f) except FileNotFoundError: pass return model, materials, metadata except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() # Initialize model, materials, metadata = load_artifacts() # Title and description st.title("đώ TMJ Injection Success Predictor") st.markdown(""" This tool predicts the 3-month treatment success probability for TMJ injections based on patient baseline characteristics. Enter the patient information below to see predictions for different injection materials. """) # Display model info if available if metadata: with st.expander("âšī¸ Model Information"): col1, col2, col3 = st.columns(3) with col1: st.metric("Model Type", metadata.get('model_type', 'Unknown')) with col2: st.metric("Test ROC-AUC", f"{metadata.get('test_roc_auc', 0):.3f}") with col3: st.metric("Training Date", metadata.get('training_date', 'Unknown')[:10]) st.write(f"**Success Definition:** {metadata.get('success_definition', 'Unknown')}") if metadata.get('simplified_version', False): st.info("This model uses the simplified feature set without text analysis.") st.divider() # Create form with st.form("patient_form"): st.subheader("Patient Information") # Required fields col1, col2 = st.columns(2) with col1: st.markdown("**Required Fields**") sex = st.selectbox("Sex", options=['Male', 'Female'], help="Patient's biological sex") age = st.number_input("Age", min_value=10, max_value=100, value=45, help="Patient age in years") pain_m0 = st.slider("Baseline Pain (M0)", min_value=0, max_value=10, value=7, help="Pain score at baseline (0-10 scale)") with col2: st.markdown("** **") # Empty space to align with "Required Fields" mmo_m0 = st.slider("Baseline MMO (M0)", min_value=0, max_value=80, value=35, help="Maximum mouth opening at baseline (mm)") ohip_14_m0 = st.slider("Baseline OHIP-14 (M0)", min_value=0, max_value=56, value=28, help="Oral Health Impact Profile score at baseline (0-56)") st.divider() # Optional fields st.markdown("**Optional Fields**") col3, col4 = st.columns(2) with col3: location = st.text_input("Location", placeholder="e.g., Right TMJ", help="Injection location") muscle_injected = st.text_input("Muscle Injected", placeholder="e.g., Masseter", help="Specific muscle targeted") adjunctive_treatment = st.text_input("Adjunctive Treatment", placeholder="e.g., Physical therapy", help="Additional treatments") with col4: previous_injection = st.selectbox("Previous Injection", options=['No', 'Yes'], help="Has the patient had previous TMJ injections?") if previous_injection == 'Yes': material_in_previous_injection = st.selectbox("Previous Material", options=[''] + materials, help="Material used in previous injection") else: material_in_previous_injection = '' st.divider() # Material selection for primary prediction st.markdown("**Primary Prediction**") selected_material = st.selectbox("Select Material for Prediction", options=materials, help="Choose the material you're considering for this patient") # Compare all materials option compare_all = st.checkbox("Compare all available materials", value=True, help="Show predictions for all materials to help with decision making") # Submit button submitted = st.form_submit_button("đŽ Predict Success", use_container_width=True, type="primary") # Process form submission if submitted: # Create input dataframe input_data = pd.DataFrame({ 'sex': [sex], 'age': [age], 'pain_m0': [pain_m0], 'mmo_m0': [mmo_m0], 'ohip_14_m0': [ohip_14_m0], 'location': [location if location else np.nan], 'muscle_injected': [muscle_injected if muscle_injected else np.nan], 'adjunctive_treatment': [adjunctive_treatment if adjunctive_treatment else np.nan], 'previous_injection': [1 if previous_injection == 'Yes' else 0], 'material_in_previous_injection': [material_in_previous_injection if material_in_previous_injection else np.nan], 'material_injected': [selected_material] }) # Make prediction for selected material try: prediction_proba = model.predict_proba(input_data)[0, 1] # Display primary prediction st.divider() st.subheader("Prediction Results") # Create a visual indicator col1, col2, col3 = st.columns([1, 2, 1]) with col2: # Success probability gauge fig = go.Figure(go.Indicator( mode = "gauge+number+delta", value = prediction_proba * 100, domain = {'x': [0, 1], 'y': [0, 1]}, title = {'text': f"Success Probability with {selected_material}"}, number = {'suffix': "%", 'font': {'size': 40}}, gauge = { 'axis': {'range': [None, 100]}, 'bar': {'color': "darkblue"}, 'steps': [ {'range': [0, 30], 'color': "lightgray"}, {'range': [30, 70], 'color': "gray"}, {'range': [70, 100], 'color': "lightgreen"} ], 'threshold': { 'line': {'color': "red", 'width': 4}, 'thickness': 0.75, 'value': 50 } } )) fig.update_layout(height=400) st.plotly_chart(fig, use_container_width=True) # Interpretation if prediction_proba >= 0.7: st.success(f"â High likelihood of success ({prediction_proba:.1%}) with {selected_material}") elif prediction_proba >= 0.5: st.warning(f"â ī¸ Moderate likelihood of success ({prediction_proba:.1%}) with {selected_material}") else: st.error(f"â Low likelihood of success ({prediction_proba:.1%}) with {selected_material}") # Compare all materials if requested if compare_all: st.divider() st.subheader("đ Material Comparison") # Predict for all materials material_results = [] for material in materials: temp_data = input_data.copy() temp_data['material_injected'] = material prob = model.predict_proba(temp_data)[0, 1] material_results.append({ 'Material': material, 'Success Probability': prob, 'Success %': f"{prob:.1%}" }) # Sort by probability material_df = pd.DataFrame(material_results) material_df = material_df.sort_values('Success Probability', ascending=False) # Display results col1, col2 = st.columns([1, 1]) with col1: # Table view st.markdown("**Ranked Materials**") display_df = material_df[['Material', 'Success %']].reset_index(drop=True) display_df.index += 1 # Start index at 1 st.dataframe(display_df, use_container_width=True) # Highlight best option best_material = material_df.iloc[0]['Material'] best_prob = material_df.iloc[0]['Success Probability'] if best_material != selected_material: st.info(f"đĄ Consider using **{best_material}** for potentially better outcomes ({best_prob:.1%} vs {prediction_proba:.1%})") with col2: # Bar chart st.markdown("**Visual Comparison**") fig = px.bar(material_df, x='Success Probability', y='Material', orientation='h', color='Success Probability', color_continuous_scale='RdYlGn', range_color=[0, 1], text='Success %') fig.update_traces(textposition='outside') fig.update_layout( xaxis_title="Success Probability", yaxis_title="", showlegend=False, xaxis=dict(range=[0, 1.1]), height=400 ) # Add vertical line at 50% fig.add_vline(x=0.5, line_dash="dash", line_color="gray", annotation_text="50% threshold") st.plotly_chart(fig, use_container_width=True) # Additional insights st.divider() with st.expander("đ Patient Summary"): st.write("**Baseline Characteristics:**") summary_cols = st.columns(3) with summary_cols[0]: st.write(f"- Age: {age} years") st.write(f"- Sex: {sex}") st.write(f"- Previous injection: {previous_injection}") with summary_cols[1]: st.write(f"- Pain score: {pain_m0}/10") st.write(f"- MMO: {mmo_m0} mm") st.write(f"- OHIP-14: {ohip_14_m0}/56") with summary_cols[2]: if location: st.write(f"- Location: {location}") if muscle_injected: st.write(f"- Muscle: {muscle_injected}") if adjunctive_treatment: st.write(f"- Adjunctive: {adjunctive_treatment}") except Exception as e: st.error(f"Error making prediction: {str(e)}") st.info("Please ensure the model was trained with all the necessary features.") # Footer st.divider() st.markdown("""