Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import json | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from datetime import datetime | |
| # Page config | |
| st.set_page_config( | |
| page_title="TMJ Injection Success Predictor", | |
| page_icon="๐", | |
| layout="wide" | |
| ) | |
| # Load model and materials | |
| def load_artifacts(): | |
| """Load the trained model and materials list""" | |
| try: | |
| # Load model | |
| model = joblib.load('src/best_tmj_success_classifier_without_fe.pkl') | |
| # Load materials list | |
| try: | |
| with open('src/material_list.json', 'r') as f: | |
| materials_data = json.load(f) | |
| materials = materials_data.get('materials', []) | |
| except FileNotFoundError: | |
| # Fallback to default materials | |
| materials = ['Local Anaesthesia', 'Dry Needle', 'Botox', | |
| 'Saline', 'Magnesium', 'PRF'] | |
| st.warning("Using default materials list. Train the model to generate actual materials from your data.") | |
| # Load metadata if available | |
| metadata = {} | |
| try: | |
| with open('model_metadata.json', 'r') as f: | |
| metadata = json.load(f) | |
| except FileNotFoundError: | |
| pass | |
| return model, materials, metadata | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| st.stop() | |
| # Initialize | |
| model, materials, metadata = load_artifacts() | |
| # Title and description | |
| st.title("๐ฆท TMJ Injection Success Predictor") | |
| st.markdown(""" | |
| This tool predicts the 3-month treatment success probability for TMJ injections based on patient baseline characteristics. | |
| Enter the patient information below to see predictions for different injection materials. | |
| """) | |
| # Display model info if available | |
| if metadata: | |
| with st.expander("โน๏ธ Model Information"): | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Model Type", metadata.get('model_type', 'Unknown')) | |
| with col2: | |
| st.metric("Test ROC-AUC", f"{metadata.get('test_roc_auc', 0):.3f}") | |
| with col3: | |
| st.metric("Training Date", metadata.get('training_date', 'Unknown')[:10]) | |
| st.write(f"**Success Definition:** {metadata.get('success_definition', 'Unknown')}") | |
| if metadata.get('simplified_version', False): | |
| st.info("This model uses the simplified feature set without text analysis.") | |
| st.divider() | |
| # Create form | |
| with st.form("patient_form"): | |
| st.subheader("Patient Information") | |
| # Required fields | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("**Required Fields**") | |
| sex = st.selectbox("Sex", options=['Male', 'Female'], help="Patient's biological sex") | |
| age = st.number_input("Age", min_value=10, max_value=100, value=45, help="Patient age in years") | |
| pain_m0 = st.slider("Baseline Pain (M0)", min_value=0, max_value=10, value=7, | |
| help="Pain score at baseline (0-10 scale)") | |
| with col2: | |
| st.markdown("** **") # Empty space to align with "Required Fields" | |
| mmo_m0 = st.slider("Baseline MMO (M0)", min_value=0, max_value=80, value=35, | |
| help="Maximum mouth opening at baseline (mm)") | |
| ohip_14_m0 = st.slider("Baseline OHIP-14 (M0)", min_value=0, max_value=56, value=28, | |
| help="Oral Health Impact Profile score at baseline (0-56)") | |
| st.divider() | |
| # Optional fields | |
| st.markdown("**Optional Fields**") | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| location = st.text_input("Location", placeholder="e.g., Right TMJ", | |
| help="Injection location") | |
| muscle_injected = st.text_input("Muscle Injected", placeholder="e.g., Masseter", | |
| help="Specific muscle targeted") | |
| adjunctive_treatment = st.text_input("Adjunctive Treatment", placeholder="e.g., Physical therapy", | |
| help="Additional treatments") | |
| with col4: | |
| previous_injection = st.selectbox("Previous Injection", options=['No', 'Yes'], | |
| help="Has the patient had previous TMJ injections?") | |
| if previous_injection == 'Yes': | |
| material_in_previous_injection = st.selectbox("Previous Material", | |
| options=[''] + materials, | |
| help="Material used in previous injection") | |
| else: | |
| material_in_previous_injection = '' | |
| st.divider() | |
| # Material selection for primary prediction | |
| st.markdown("**Primary Prediction**") | |
| selected_material = st.selectbox("Select Material for Prediction", | |
| options=materials, | |
| help="Choose the material you're considering for this patient") | |
| # Compare all materials option | |
| compare_all = st.checkbox("Compare all available materials", value=True, | |
| help="Show predictions for all materials to help with decision making") | |
| # Submit button | |
| submitted = st.form_submit_button("๐ฎ Predict Success", use_container_width=True, type="primary") | |
| # Process form submission | |
| if submitted: | |
| # Create input dataframe | |
| input_data = pd.DataFrame({ | |
| 'sex': [sex], | |
| 'age': [age], | |
| 'pain_m0': [pain_m0], | |
| 'mmo_m0': [mmo_m0], | |
| 'ohip_14_m0': [ohip_14_m0], | |
| 'location': [location if location else np.nan], | |
| 'muscle_injected': [muscle_injected if muscle_injected else np.nan], | |
| 'adjunctive_treatment': [adjunctive_treatment if adjunctive_treatment else np.nan], | |
| 'previous_injection': [1 if previous_injection == 'Yes' else 0], | |
| 'material_in_previous_injection': [material_in_previous_injection if material_in_previous_injection else np.nan], | |
| 'material_injected': [selected_material] | |
| }) | |
| # Make prediction for selected material | |
| try: | |
| prediction_proba = model.predict_proba(input_data)[0, 1] | |
| # Display primary prediction | |
| st.divider() | |
| st.subheader("Prediction Results") | |
| # Create a visual indicator | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| # Success probability gauge | |
| fig = go.Figure(go.Indicator( | |
| mode = "gauge+number+delta", | |
| value = prediction_proba * 100, | |
| domain = {'x': [0, 1], 'y': [0, 1]}, | |
| title = {'text': f"Success Probability with {selected_material}"}, | |
| number = {'suffix': "%", 'font': {'size': 40}}, | |
| gauge = { | |
| 'axis': {'range': [None, 100]}, | |
| 'bar': {'color': "darkblue"}, | |
| 'steps': [ | |
| {'range': [0, 30], 'color': "lightgray"}, | |
| {'range': [30, 70], 'color': "gray"}, | |
| {'range': [70, 100], 'color': "lightgreen"} | |
| ], | |
| 'threshold': { | |
| 'line': {'color': "red", 'width': 4}, | |
| 'thickness': 0.75, | |
| 'value': 50 | |
| } | |
| } | |
| )) | |
| fig.update_layout(height=400) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Interpretation | |
| if prediction_proba >= 0.7: | |
| st.success(f"โ High likelihood of success ({prediction_proba:.1%}) with {selected_material}") | |
| elif prediction_proba >= 0.5: | |
| st.warning(f"โ ๏ธ Moderate likelihood of success ({prediction_proba:.1%}) with {selected_material}") | |
| else: | |
| st.error(f"โ Low likelihood of success ({prediction_proba:.1%}) with {selected_material}") | |
| # Compare all materials if requested | |
| if compare_all: | |
| st.divider() | |
| st.subheader("๐ Material Comparison") | |
| # Predict for all materials | |
| material_results = [] | |
| for material in materials: | |
| temp_data = input_data.copy() | |
| temp_data['material_injected'] = material | |
| prob = model.predict_proba(temp_data)[0, 1] | |
| material_results.append({ | |
| 'Material': material, | |
| 'Success Probability': prob, | |
| 'Success %': f"{prob:.1%}" | |
| }) | |
| # Sort by probability | |
| material_df = pd.DataFrame(material_results) | |
| material_df = material_df.sort_values('Success Probability', ascending=False) | |
| # Display results | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| # Table view | |
| st.markdown("**Ranked Materials**") | |
| display_df = material_df[['Material', 'Success %']].reset_index(drop=True) | |
| display_df.index += 1 # Start index at 1 | |
| st.dataframe(display_df, use_container_width=True) | |
| # Highlight best option | |
| best_material = material_df.iloc[0]['Material'] | |
| best_prob = material_df.iloc[0]['Success Probability'] | |
| if best_material != selected_material: | |
| st.info(f"๐ก Consider using **{best_material}** for potentially better outcomes ({best_prob:.1%} vs {prediction_proba:.1%})") | |
| with col2: | |
| # Bar chart | |
| st.markdown("**Visual Comparison**") | |
| fig = px.bar(material_df, | |
| x='Success Probability', | |
| y='Material', | |
| orientation='h', | |
| color='Success Probability', | |
| color_continuous_scale='RdYlGn', | |
| range_color=[0, 1], | |
| text='Success %') | |
| fig.update_traces(textposition='outside') | |
| fig.update_layout( | |
| xaxis_title="Success Probability", | |
| yaxis_title="", | |
| showlegend=False, | |
| xaxis=dict(range=[0, 1.1]), | |
| height=400 | |
| ) | |
| # Add vertical line at 50% | |
| fig.add_vline(x=0.5, line_dash="dash", line_color="gray", | |
| annotation_text="50% threshold") | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Additional insights | |
| st.divider() | |
| with st.expander("๐ Patient Summary"): | |
| st.write("**Baseline Characteristics:**") | |
| summary_cols = st.columns(3) | |
| with summary_cols[0]: | |
| st.write(f"- Age: {age} years") | |
| st.write(f"- Sex: {sex}") | |
| st.write(f"- Previous injection: {previous_injection}") | |
| with summary_cols[1]: | |
| st.write(f"- Pain score: {pain_m0}/10") | |
| st.write(f"- MMO: {mmo_m0} mm") | |
| st.write(f"- OHIP-14: {ohip_14_m0}/56") | |
| with summary_cols[2]: | |
| if location: | |
| st.write(f"- Location: {location}") | |
| if muscle_injected: | |
| st.write(f"- Muscle: {muscle_injected}") | |
| if adjunctive_treatment: | |
| st.write(f"- Adjunctive: {adjunctive_treatment}") | |
| except Exception as e: | |
| st.error(f"Error making prediction: {str(e)}") | |
| st.info("Please ensure the model was trained with all the necessary features.") | |
| # Footer | |
| st.divider() | |
| st.markdown(""" | |
| <div style='text-align: center; color: gray;'> | |
| <small> | |
| TMJ Injection Success Predictor | | |
| Model trained on historical patient data | | |
| Predictions are probabilistic and should be used alongside clinical judgment | |
| </small> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Sidebar with instructions | |
| with st.sidebar: | |
| st.header("๐ Instructions") | |
| st.markdown(""" | |
| 1. **Enter patient baseline data** in the form | |
| 2. **Select the material** you're considering | |
| 3. **Click Predict** to see the success probability | |
| 4. **Compare materials** to find the optimal choice | |
| --- | |
| ### ๐ฏ Success Definition | |
| Treatment success is typically defined as: | |
| - Pain reduction > 2 points | |
| - MMO increase > 5 mm | |
| - OHIP-14 reduction > 5 points | |
| --- | |
| ### ๐ Interpretation Guide | |
| - **70%+**: High success likelihood โ | |
| - **50-70%**: Moderate success โ ๏ธ | |
| - **<50%**: Low success likelihood โ | |
| --- | |
| ### โ๏ธ Clinical Note | |
| These predictions are based on statistical models and should complement, not replace, clinical expertise and patient-specific considerations. | |
| """) |