Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import lightgbm as lgb | |
| import xgboost as xgb | |
| import gradio as gr | |
| import joblib | |
| import os | |
| from obesity_rp import config as cfg | |
| # Global variables to store loaded models, their columns, and the label encoder | |
| loaded_models = {} | |
| loaded_model_columns_map = {} | |
| label_encoder = None | |
| def load_model_artifacts(model_name): | |
| """ | |
| Loads the trained model, feature columns, and the label encoder. | |
| """ | |
| model_file = os.path.join(cfg.MODEL_DIR, f"obesity_{model_name}_model.joblib") | |
| columns_file = os.path.join(cfg.MODEL_DIR, f"{model_name}_model_columns.joblib") | |
| encoder_file = os.path.join(cfg.MODEL_DIR, "label_encoder.joblib") | |
| if not all(os.path.exists(f) for f in [model_file, columns_file, encoder_file]): | |
| raise FileNotFoundError( | |
| f"Model artifacts for '{model_name}' not found. Please ensure all required files exist." | |
| ) | |
| loaded_model = joblib.load(model_file) | |
| loaded_model_columns = joblib.load(columns_file) | |
| le = joblib.load(encoder_file) | |
| print( | |
| f"{model_name} Model, feature columns, and label encoder loaded for prediction." | |
| ) | |
| return loaded_model, loaded_model_columns, le | |
| def predict_obesity_risk( | |
| model_choice, | |
| Gender, | |
| Age, | |
| Height, | |
| Weight, | |
| family_history_with_overweight, | |
| FAVC, | |
| FCVC, | |
| NCP, | |
| CAEC, | |
| SMOKE, | |
| CH2O, | |
| SCC, | |
| FAF, | |
| TUE, | |
| CALC, | |
| MTRANS, | |
| ): | |
| """ | |
| Predicts obesity risk based on input features and chosen model. | |
| """ | |
| global label_encoder | |
| if model_choice not in loaded_models: | |
| try: | |
| model, columns, le = load_model_artifacts(model_choice) | |
| loaded_models[model_choice] = model | |
| loaded_model_columns_map[model_choice] = columns | |
| if label_encoder is None: | |
| label_encoder = le | |
| except FileNotFoundError as e: | |
| return f"Error: {e}. Model '{model_choice}' not found. Please train the model first." | |
| else: | |
| model = loaded_models[model_choice] | |
| columns = loaded_model_columns_map[model_choice] | |
| le = label_encoder | |
| # Create a dictionary to hold the input data | |
| input_data_dict = { | |
| "Age": Age, | |
| "Height": Height, | |
| "Weight": Weight, | |
| "FCVC": FCVC, | |
| "NCP": NCP, | |
| "CH2O": CH2O, | |
| "FAF": FAF, | |
| "TUE": TUE, | |
| } | |
| input_df = pd.DataFrame(0, index=[0], columns=columns) | |
| for col, value in input_data_dict.items(): | |
| if col in input_df.columns: | |
| input_df.loc[0, col] = value | |
| # Handle one-hot encoded categorical features | |
| categorical_inputs = { | |
| "Gender": Gender, | |
| "family_history_with_overweight": family_history_with_overweight, | |
| "FAVC": FAVC, | |
| "CAEC": CAEC, | |
| "SMOKE": SMOKE, | |
| "SCC": SCC, | |
| "CALC": CALC, | |
| "MTRANS": MTRANS, | |
| } | |
| for col_prefix, value in categorical_inputs.items(): | |
| column_name = f"{col_prefix}_{value}" | |
| if column_name in input_df.columns: | |
| input_df.loc[0, column_name] = 1 | |
| input_df = input_df[columns] | |
| prediction_proba = model.predict_proba(input_df)[0] | |
| prediction_encoded = model.predict(input_df)[0] | |
| prediction_label = le.inverse_transform([prediction_encoded])[0] | |
| results = f"Using {model_choice} Model:\nPrediction: {prediction_label}\n\n--- Prediction Probabilities ---\n" | |
| for i, class_name in enumerate(le.classes_): | |
| prob = prediction_proba[i] * 100 | |
| results += f"{class_name}: {prob:.2f}%\n" | |
| return results | |
| def launch_gradio_app(share=False): | |
| """ | |
| Launches the Gradio web application for obesity risk prediction. | |
| """ | |
| print("\n--- Starting Gradio App ---") | |
| # Define Gradio input components | |
| model_choice_input = gr.Dropdown( | |
| choices=cfg.MODEL_CHOICES, label="Select Model", value=cfg.RANDOM_FOREST | |
| ) | |
| gender_input = gr.Dropdown(choices=["Female", "Male"], label="Gender") | |
| age_input = gr.Slider(minimum=1, maximum=100, step=1, label="Age") | |
| height_input = gr.Slider(minimum=1.0, maximum=2.2, step=0.01, label="Height (m)") | |
| weight_input = gr.Slider(minimum=30.0, maximum=200.0, step=0.1, label="Weight (kg)") | |
| family_history_input = gr.Radio( | |
| choices=["yes", "no"], label="Family History with Overweight" | |
| ) | |
| favc_input = gr.Radio( | |
| choices=["yes", "no"], label="Frequent consumption of high caloric food (FAVC)" | |
| ) | |
| fcvc_input = gr.Slider( | |
| minimum=1, | |
| maximum=3, | |
| step=1, | |
| label="Frequency of consumption of vegetables (FCVC)", | |
| ) | |
| ncp_input = gr.Slider( | |
| minimum=1, maximum=4, step=1, label="Number of main meals (NCP)" | |
| ) | |
| caec_input = gr.Dropdown( | |
| choices=["no", "Sometimes", "Frequently", "Always"], | |
| label="Consumption of food between meals (CAEC)", | |
| ) | |
| smoke_input = gr.Radio(choices=["yes", "no"], label="SMOKE") | |
| ch2o_input = gr.Slider( | |
| minimum=1, maximum=3, step=1, label="Consumption of water daily (CH2O)" | |
| ) | |
| scc_input = gr.Radio( | |
| choices=["yes", "no"], label="Calories consumption monitoring (SCC)" | |
| ) | |
| faf_input = gr.Slider( | |
| minimum=0, maximum=3, step=1, label="Physical activity frequency (FAF)" | |
| ) | |
| tue_input = gr.Slider( | |
| minimum=0, maximum=2, step=1, label="Time using technology devices (TUE)" | |
| ) | |
| calc_input = gr.Dropdown( | |
| choices=["no", "Sometimes", "Frequently", "Always"], | |
| label="Consumption of alcohol (CALC)", | |
| ) | |
| mtrans_input = gr.Dropdown( | |
| choices=["Automobile", "Motorbike", "Bike", "Public_Transportation", "Walking"], | |
| label="Transportation used (MTRANS)", | |
| ) | |
| output_text = gr.Textbox(label="Obesity Risk Prediction Result", lines=10) | |
| iface = gr.Interface( | |
| fn=predict_obesity_risk, | |
| inputs=[ | |
| model_choice_input, | |
| gender_input, | |
| age_input, | |
| height_input, | |
| weight_input, | |
| family_history_input, | |
| favc_input, | |
| fcvc_input, | |
| ncp_input, | |
| caec_input, | |
| smoke_input, | |
| ch2o_input, | |
| scc_input, | |
| faf_input, | |
| tue_input, | |
| calc_input, | |
| mtrans_input, | |
| ], | |
| outputs=output_text, | |
| title="Obesity Risk Prediction (Multi-Model)", | |
| description="Select a machine learning model and enter patient details to predict the obesity risk category.", | |
| examples=[ | |
| [ | |
| cfg.RANDOM_FOREST, | |
| "Male", | |
| 25, | |
| 1.8, | |
| 85, | |
| "yes", | |
| "yes", | |
| 2, | |
| 3, | |
| "Sometimes", | |
| "no", | |
| 2, | |
| "no", | |
| 1, | |
| 1, | |
| "Frequently", | |
| "Public_Transportation", | |
| ], | |
| [ | |
| cfg.LIGHTGBM, | |
| "Female", | |
| 30, | |
| 1.65, | |
| 70, | |
| "yes", | |
| "yes", | |
| 3, | |
| 3, | |
| "Frequently", | |
| "no", | |
| 3, | |
| "yes", | |
| 2, | |
| 0, | |
| "Sometimes", | |
| "Automobile", | |
| ], | |
| [ | |
| cfg.XGBOOST, | |
| "Female", | |
| 21, | |
| 1.52, | |
| 56, | |
| "yes", | |
| "no", | |
| 3, | |
| 3, | |
| "Sometimes", | |
| "yes", | |
| 3, | |
| "yes", | |
| 3, | |
| 0, | |
| "Sometimes", | |
| "Public_Transportation", | |
| ], | |
| ], | |
| ) | |
| iface.launch(share=share) | |
| print("--- Gradio App Launched ---") | |
| if __name__ == "__main__": | |
| launch_gradio_app(share=False) | |