Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Derm Foundation Neural Network Classifier Training Script - Fixed Version | |
| PURPOSE: | |
| This script trains a multi-output neural network to predict dermatological | |
| conditions and their associated metadata from pre-computed embeddings. It | |
| addresses the challenging problem of multi-label medical diagnosis where: | |
| 1. Multiple conditions can co-exist (multi-label classification) | |
| 2. Each diagnosis has an associated confidence level (regression) | |
| 3. Each diagnosis has a weight indicating relative importance (regression) | |
| WHY NEURAL NETWORKS FOR THIS TASK: | |
| Neural networks are the optimal choice for this problem for several reasons: | |
| 1. **Non-linear Relationship Learning**: The relationship between image | |
| embeddings and skin conditions is highly non-linear. Neural networks excel | |
| at learning complex, non-linear mappings that simpler models (like logistic | |
| regression) cannot capture. | |
| 2. **Multi-task Learning**: This problem requires predicting three related but | |
| distinct outputs (conditions, confidence, weights). Neural networks can | |
| share learned representations across these tasks through shared layers, | |
| improving generalization and efficiency. | |
| 3. **High-dimensional Input**: Embeddings are typically 512-1024 dimensional | |
| vectors. Neural networks are designed to handle high-dimensional inputs | |
| effectively through dimensionality reduction in hidden layers. | |
| 4. **Multi-label Classification**: Medical diagnosis often involves multiple | |
| co-existing conditions. Neural networks with sigmoid activation can model | |
| the independent probability of each condition, unlike single-label methods. | |
| 5. **Flexibility**: The architecture can be customized with task-specific | |
| heads (branches) for different prediction types, allowing specialized | |
| processing for classification vs regression outputs. | |
| WHY HAMMING LOSS IS VALID: | |
| Hamming loss is an appropriate metric for multi-label classification because: | |
| 1. **Accounts for Partial Correctness**: Unlike exact match accuracy, hamming | |
| loss gives credit for partially correct predictions. Predicting 3 out of 4 | |
| conditions correctly is better than 0 out of 4. | |
| 2. **Label-wise Evaluation**: It measures the fraction of incorrectly predicted | |
| labels, treating each label independently - appropriate when conditions can | |
| co-occur independently. | |
| 3. **Bounded and Interpretable**: Ranges from 0 (perfect) to 1 (completely | |
| wrong). A hamming loss of 0.1 means 10% of label predictions were incorrect. | |
| 4. **Balanced for Sparse Labels**: In medical diagnosis, most samples have few | |
| positive labels (sparse multi-label). Hamming loss naturally handles this | |
| imbalance by computing the fraction across all labels. | |
| 5. **Clinically Relevant**: In medical applications, both false positives and | |
| false negatives matter. Hamming loss penalizes both equally, unlike metrics | |
| that focus on one type of error. | |
| MATHEMATICAL JUSTIFICATION: | |
| For a sample with true labels y and predicted labels ŷ: | |
| Hamming Loss = (1/n_labels) × Σ(y_i XOR ŷ_i) | |
| This averages the disagreement across all possible labels, making it suitable | |
| for scenarios where: | |
| - The label space is large (many possible conditions) | |
| - Label correlations exist but aren't perfectly predictable | |
| - Clinical accuracy matters at the individual label level | |
| FIXES APPLIED IN THIS VERSION: | |
| - Changed confidence activation from ReLU to softplus (prevents zero outputs) | |
| - Improved confidence scaler fitting (uses only non-zero values) | |
| - Increased confidence loss weight (1.5x for better learning signal) | |
| - Enhanced data validation and preprocessing | |
| - Better handling of sparse confidence/weight matrices | |
| Requirements: | |
| - pandas | |
| - numpy | |
| - tensorflow>=2.13.0 | |
| - scikit-learn | |
| - matplotlib | |
| - pickle (standard library) | |
| - os (standard library) | |
| - derm_foundation_embeddings.npz: Pre-computed embeddings from images | |
| - dataset_scin_labels.csv: Ground truth labels with conditions, confidences, weights | |
| OUTPUT: | |
| - Trained neural network model (.keras file) | |
| - Preprocessing components (scalers, label encoder) in .pkl file | |
| - Training history plots showing convergence | |
| - Evaluation metrics on test set | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| import pickle | |
| import os | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import hamming_loss, mean_squared_error, mean_absolute_error | |
| import matplotlib.pyplot as plt | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| """ | |
| Main class implementing the multi-output neural network classifier. | |
| ARCHITECTURE OVERVIEW: | |
| 1. **Shared Feature Extraction**: 3 dense layers (512→256→128) with batch | |
| normalization and dropout. These layers learn a shared representation | |
| useful for all prediction tasks. | |
| 2. **Task-Specific Heads**: Three separate output branches: | |
| - Condition classification: Sigmoid activation for multi-label prediction | |
| - Confidence regression: Softplus activation for positive continuous values | |
| - Weight regression: Sigmoid activation for [0,1] bounded values | |
| WHY MULTI-TASK LEARNING: | |
| - Conditions, confidence, and weights are related but distinct | |
| - Sharing early layers allows the model to learn features useful for all tasks | |
| - Task-specific heads allow specialized processing for each output type | |
| - Improves generalization by preventing overfitting to any single task | |
| TRAINING STRATEGY: | |
| - Multi-task loss: Weighted combination of classification and regression losses | |
| - Early stopping: Prevents overfitting by monitoring validation loss | |
| - Learning rate reduction: Adapts learning rate when progress plateaus | |
| - Batch normalization: Stabilizes training and allows higher learning rates | |
| """ | |
| class DermFoundationNeuralNetwork: | |
| """ | |
| Initialize the classifier with preprocessing components. | |
| PREPROCESSING COMPONENTS: | |
| - mlb (MultiLabelBinarizer): Converts condition names to binary vectors | |
| Example: ['Eczema', 'Psoriasis'] → [0,1,0,1,0,...,0] | |
| - embedding_scaler (StandardScaler): Normalizes embeddings to mean=0, std=1 | |
| Why: Neural networks train faster with normalized inputs | |
| - confidence_scaler (StandardScaler): Normalizes confidence values | |
| Why: Brings continuous values to similar scale as other outputs | |
| - weighted_scaler (StandardScaler): Normalizes weight values | |
| Why: Ensures balanced gradient contributions during training | |
| DESIGN DECISION: | |
| Separate scalers for each output type allow independent normalization, | |
| which is crucial when outputs have different scales and distributions. | |
| """ | |
| def __init__(self): | |
| self.model = None | |
| self.mlb = MultiLabelBinarizer() | |
| self.embedding_scaler = StandardScaler() | |
| self.confidence_scaler = StandardScaler() | |
| self.weighted_scaler = StandardScaler() | |
| self.history = None | |
| """ | |
| Load pre-computed Derm Foundation embeddings from NPZ file. | |
| WHAT ARE EMBEDDINGS: | |
| Embeddings are dense vector representations of images extracted from a | |
| pre-trained vision model (Derm Foundation model). They capture high-level | |
| visual features learned from large-scale dermatology image datasets. | |
| WHY USE PRE-COMPUTED EMBEDDINGS: | |
| 1. **Efficiency**: Computing embeddings is expensive. Pre-computing them | |
| allows rapid experimentation with different classifier architectures. | |
| 2. **Transfer Learning**: Derm Foundation was trained on massive dermatology | |
| datasets. Its embeddings encode domain-specific visual patterns. | |
| 3. **Separation of Concerns**: Image processing and classification are | |
| separated, allowing independent optimization of each component. | |
| FORMAT: | |
| NPZ file contains a dictionary where: | |
| - Keys: case_id (string identifiers) | |
| - Values: embedding vectors (typically 512 or 768 dimensions) | |
| """ | |
| def load_embeddings(self, npz_file_path): | |
| """Load embeddings from NPZ file""" | |
| print(f"Loading embeddings from {npz_file_path}...") | |
| if not os.path.exists(npz_file_path): | |
| print(f"ERROR: Embeddings file not found: {npz_file_path}") | |
| return None | |
| embeddings_data = {} | |
| with open(npz_file_path, 'rb') as f: | |
| npz_file = np.load(f, allow_pickle=True) | |
| for key in npz_file.files: | |
| embeddings_data[key] = npz_file[key] | |
| print(f"Loaded {len(embeddings_data)} embeddings") | |
| # Print info about first embedding for debugging | |
| first_key = list(embeddings_data.keys())[0] | |
| first_embedding = embeddings_data[first_key] | |
| print(f"Embedding shape: {first_embedding.shape}") | |
| return embeddings_data | |
| """ | |
| Load ground truth labels from CSV file. | |
| REQUIRED COLUMNS: | |
| 1. case_id: Unique identifier matching embedding keys | |
| 2. dermatologist_skin_condition_on_label_name: List of condition names | |
| 3. dermatologist_skin_condition_confidence: Confidence scores (typically 1-5) | |
| 4. weighted_skin_condition_label: Importance weights (0-1 range) | |
| DATA TYPES: | |
| - case_id must be string to match embedding keys | |
| - Lists stored as strings (e.g., "['Eczema', 'Psoriasis']") are evaluated | |
| - Handles various formats: lists, dicts, single values | |
| """ | |
| def load_dataset(self, csv_path): | |
| """Load dataset from CSV file""" | |
| print(f"Loading dataset from {csv_path}...") | |
| if not os.path.exists(csv_path): | |
| print(f"ERROR: Dataset file not found: {csv_path}") | |
| return None | |
| try: | |
| df = pd.read_csv(csv_path, dtype={'case_id': str}) | |
| df['case_id'] = df['case_id'].astype(str) | |
| print(f"Loaded dataset: {len(df)} samples") | |
| # Verify required columns | |
| required_columns = [ | |
| 'case_id', | |
| 'dermatologist_skin_condition_on_label_name', | |
| 'dermatologist_skin_condition_confidence', | |
| 'weighted_skin_condition_label' | |
| ] | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| print(f"ERROR: Missing required columns: {missing_columns}") | |
| return None | |
| return df | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| return None | |
| """ | |
| Prepare training data with comprehensive validation and preprocessing. | |
| COMPLEXITY HANDLING: | |
| This method handles several challenging data characteristics: | |
| 1. **SPARSE MULTI-LABEL MATRICES**: Most samples have few positive labels | |
| Solution: Track and report sparsity statistics for validation | |
| 2. **VARIABLE-LENGTH LISTS**: Different samples have different numbers of | |
| conditions, confidences, and weights | |
| Solution: Parse and align lists carefully, use mean values for mismatches | |
| 3. **RARE CONDITIONS**: Some conditions appear in very few samples | |
| Solution: Filter to top N conditions and minimum sample requirements | |
| 4. **ZERO VALUES**: Confidence/weight matrices are mostly zeros (sparse) | |
| Solution: Track zero vs non-zero ratios, fit scalers only on non-zeros | |
| FILTERING STRATEGY: | |
| - min_condition_samples: Removes rare conditions with insufficient data | |
| - max_conditions: Limits to most frequent conditions to prevent overfitting | |
| - Both filters ensure model focuses on well-represented, learnable patterns | |
| WHY FILTER CONDITIONS: | |
| 1. **Statistical Validity**: Need sufficient examples to learn patterns | |
| 2. **Generalization**: Rare conditions lead to overfitting | |
| 3. **Computational Efficiency**: Fewer output nodes = faster training | |
| 4. **Clinical Relevance**: Common conditions are higher priority | |
| MULTI-LABEL MATRIX STRUCTURE: | |
| Shape: (n_samples, n_conditions) | |
| - Rows: Individual patient cases | |
| - Columns: Binary indicators for each condition (1=present, 0=absent) | |
| - Multiple 1s per row: Multi-label (multiple conditions co-exist) | |
| CONFIDENCE/WEIGHT MATRICES: | |
| Shape: (n_samples, n_conditions) | |
| - Values at (i,j): Confidence/weight for condition j in sample i | |
| - Zero when condition j not present in sample i (sparse structure) | |
| - Non-zero only where corresponding multi-label entry is 1 | |
| DATA VALIDATION: | |
| Extensive logging of: | |
| - Sample counts (processed vs skipped) | |
| - Value ranges (min/max/mean) | |
| - Sparsity statistics (% non-zero) | |
| - Top conditions by frequency | |
| This validation is crucial for: | |
| - Detecting data quality issues early | |
| - Understanding model input characteristics | |
| - Debugging training problems | |
| """ | |
| def prepare_training_data(self, df, embeddings, min_condition_samples=5, max_conditions=30): | |
| """Prepare training data with improved confidence and weight handling""" | |
| print("Preparing training data with enhanced validation...") | |
| X = [] # Embeddings | |
| condition_labels = [] # For multi-label classification | |
| individual_confidences = [] # Individual confidence per condition | |
| individual_weights = [] # Individual weight per condition | |
| skipped_count = 0 | |
| processed_count = 0 | |
| confidence_stats = [] # Track confidence values for validation | |
| weight_stats = [] # Track weight values for validation | |
| for idx, row in df.iterrows(): | |
| try: | |
| case_id = row['case_id'] | |
| if case_id not in embeddings: | |
| skipped_count += 1 | |
| continue | |
| # Parse the label data | |
| try: | |
| # Parse condition names | |
| if isinstance(row['dermatologist_skin_condition_on_label_name'], str): | |
| condition_names = eval(row['dermatologist_skin_condition_on_label_name']) | |
| else: | |
| condition_names = row['dermatologist_skin_condition_on_label_name'] | |
| # Ensure condition_names is a list | |
| if not isinstance(condition_names, list): | |
| condition_names = [condition_names] if condition_names else [] | |
| # Parse confidence scores | |
| if isinstance(row['dermatologist_skin_condition_confidence'], str): | |
| confidences = eval(row['dermatologist_skin_condition_confidence']) | |
| else: | |
| confidences = row['dermatologist_skin_condition_confidence'] | |
| # Ensure confidences is a list and matches conditions | |
| if not isinstance(confidences, list): | |
| confidences = [confidences] if confidences is not None else [] | |
| # Match confidence length to conditions | |
| if len(confidences) != len(condition_names): | |
| if len(confidences) == 1: | |
| confidences = confidences * len(condition_names) | |
| else: | |
| print(f"Warning: Confidence length mismatch for {case_id}, using mean") | |
| mean_conf = np.mean(confidences) if confidences else 3.0 | |
| confidences = [mean_conf] * len(condition_names) | |
| # Parse weighted labels | |
| if isinstance(row['weighted_skin_condition_label'], str): | |
| weighted_labels = eval(row['weighted_skin_condition_label']) | |
| else: | |
| weighted_labels = row['weighted_skin_condition_label'] | |
| # Handle different weight formats | |
| if isinstance(weighted_labels, dict): | |
| # Convert dict to list matching condition order | |
| weights = [] | |
| for condition in condition_names: | |
| weights.append(weighted_labels.get(condition, 0.0)) | |
| elif isinstance(weighted_labels, list): | |
| weights = weighted_labels | |
| if len(weights) != len(condition_names): | |
| if len(weights) == 1: | |
| weights = weights * len(condition_names) | |
| else: | |
| mean_weight = np.mean(weights) if weights else 0.3 | |
| weights = [mean_weight] * len(condition_names) | |
| else: | |
| # Single value | |
| weights = [weighted_labels] * len(condition_names) if weighted_labels else [0.3] * len(condition_names) | |
| except Exception as e: | |
| print(f"Error parsing data for {case_id}: {e}") | |
| skipped_count += 1 | |
| continue | |
| # Validate data ranges | |
| try: | |
| confidences = [max(0.0, float(c)) for c in confidences] # Ensure non-negative | |
| weights = [max(0.0, min(1.0, float(w))) for w in weights] # Clamp to [0,1] | |
| except: | |
| print(f"Error converting values for {case_id}, skipping") | |
| skipped_count += 1 | |
| continue | |
| # Add to training data | |
| X.append(embeddings[case_id]) | |
| condition_labels.append(condition_names) | |
| # Store individual confidences and weights | |
| individual_confidences.append({ | |
| 'conditions': condition_names, | |
| 'confidences': confidences | |
| }) | |
| individual_weights.append({ | |
| 'conditions': condition_names, | |
| 'weights': weights | |
| }) | |
| # Track statistics | |
| confidence_stats.extend(confidences) | |
| weight_stats.extend(weights) | |
| processed_count += 1 | |
| except Exception as e: | |
| print(f"Error processing row {idx}: {e}") | |
| skipped_count += 1 | |
| continue | |
| print(f"Training data prepared: {processed_count} samples, {skipped_count} skipped") | |
| if len(X) == 0: | |
| print("ERROR: No training samples found!") | |
| return None, None, None, None | |
| # Print data statistics | |
| print(f"\nData validation:") | |
| print(f" Confidence values - min: {min(confidence_stats):.3f}, max: {max(confidence_stats):.3f}, mean: {np.mean(confidence_stats):.3f}") | |
| print(f" Weight values - min: {min(weight_stats):.3f}, max: {max(weight_stats):.3f}, mean: {np.mean(weight_stats):.3f}") | |
| print(f" Non-zero confidences: {sum(1 for c in confidence_stats if c > 0.001)}/{len(confidence_stats)} ({100*sum(1 for c in confidence_stats if c > 0.001)/len(confidence_stats):.1f}%)") | |
| print(f" Non-zero weights: {sum(1 for w in weight_stats if w > 0.001)}/{len(weight_stats)} ({100*sum(1 for w in weight_stats if w > 0.001)/len(weight_stats):.1f}%)") | |
| # Convert to numpy arrays | |
| X = np.array(X) | |
| # Prepare condition labels - focus on top conditions only | |
| y_conditions_raw = self.mlb.fit_transform(condition_labels) | |
| condition_counts = y_conditions_raw.sum(axis=0) | |
| # Get top conditions by frequency | |
| sorted_indices = np.argsort(condition_counts)[::-1] | |
| top_condition_indices = sorted_indices[:max_conditions] | |
| # Also ensure minimum samples | |
| frequent_conditions = condition_counts >= min_condition_samples | |
| final_indices = np.intersect1d(top_condition_indices, np.where(frequent_conditions)[0]) | |
| print(f"Total condition classes: {len(self.mlb.classes_)}") | |
| print(f"Top {max_conditions} most frequent conditions selected") | |
| print(f"Conditions with at least {min_condition_samples} examples: {frequent_conditions.sum()}") | |
| # Keep only selected conditions | |
| selected_classes = self.mlb.classes_[final_indices] | |
| y_conditions = y_conditions_raw[:, final_indices] | |
| # Update MultiLabelBinarizer | |
| self.mlb = MultiLabelBinarizer() | |
| self.mlb.classes_ = selected_classes | |
| print(f"Final condition classes: {len(selected_classes)}") | |
| print(f"Multi-label matrix shape: {y_conditions.shape}") | |
| # Create individual confidence and weight matrices | |
| y_confidences = np.zeros((len(X), len(selected_classes))) | |
| y_weights = np.zeros((len(X), len(selected_classes))) | |
| for i, (conf_data, weight_data) in enumerate(zip(individual_confidences, individual_weights)): | |
| # Map confidences to selected conditions | |
| for condition, confidence in zip(conf_data['conditions'], conf_data['confidences']): | |
| if condition in selected_classes: | |
| condition_idx = np.where(selected_classes == condition)[0] | |
| if len(condition_idx) > 0: | |
| y_confidences[i, condition_idx[0]] = confidence | |
| # Map weights to selected conditions | |
| for condition, weight in zip(weight_data['conditions'], weight_data['weights']): | |
| if condition in selected_classes: | |
| condition_idx = np.where(selected_classes == condition)[0] | |
| if len(condition_idx) > 0: | |
| y_weights[i, condition_idx[0]] = weight | |
| # Print matrix statistics | |
| nonzero_conf = (y_confidences > 0.001).sum() | |
| nonzero_weight = (y_weights > 0.001).sum() | |
| total_elements = y_confidences.size | |
| print(f"\nMatrix statistics:") | |
| print(f" Confidence matrix - non-zero: {nonzero_conf}/{total_elements} ({100*nonzero_conf/total_elements:.1f}%)") | |
| print(f" Weight matrix - non-zero: {nonzero_weight}/{total_elements} ({100*nonzero_weight/total_elements:.1f}%)") | |
| print(f" Confidence range: {y_confidences[y_confidences > 0].min():.3f} - {y_confidences[y_confidences > 0].max():.3f}") | |
| print(f" Weight range: {y_weights[y_weights > 0].min():.3f} - {y_weights[y_weights > 0].max():.3f}") | |
| # Print top conditions | |
| condition_counts_filtered = y_conditions.sum(axis=0) | |
| print("\nTop conditions selected:") | |
| for i, (condition, count) in enumerate(zip(selected_classes, condition_counts_filtered)): | |
| print(f" {i+1:2d}. {condition}: {count} samples") | |
| return X, y_conditions, y_confidences, y_weights | |
| """ | |
| Build multi-output neural network architecture. | |
| ARCHITECTURE RATIONALE: | |
| **SHARED LAYERS (512→256→128)**: | |
| - Purpose: Learn general features useful for all prediction tasks | |
| - Size progression: Gradual dimensionality reduction (embeddings→features) | |
| - Batch Normalization: Stabilizes training, allows higher learning rates | |
| - Dropout (0.3, 0.3, 0.2): Prevents overfitting, forces robust features | |
| Why this depth: | |
| - 3 layers balances capacity (can learn complex patterns) vs simplicity | |
| - Too shallow: Can't learn complex patterns | |
| - Too deep: Overfits, slower training, harder to optimize | |
| **TASK-SPECIFIC BRANCHES**: | |
| Each branch has 2 layers (64→output) for specialized processing: | |
| 1. **CONDITION CLASSIFICATION BRANCH**: | |
| - Activation: Sigmoid (outputs independent probabilities per condition) | |
| - Why sigmoid: Allows multiple conditions to be predicted simultaneously | |
| - Loss: Binary cross-entropy (standard for multi-label classification) | |
| 2. **CONFIDENCE REGRESSION BRANCH**: | |
| - Activation: Softplus (ensures positive outputs, smooth gradients) | |
| - Why softplus not ReLU: ReLU outputs exactly zero for negative inputs, | |
| causing gradient issues. Softplus outputs small positive values instead. | |
| - Formula: softplus(x) = log(1 + exp(x)) | |
| - Loss: MSE (Mean Squared Error for continuous values) | |
| - Loss weight: 1.5x (increased to prioritize confidence learning) | |
| 3. **WEIGHT REGRESSION BRANCH**: | |
| - Activation: Sigmoid (ensures [0,1] bounded output) | |
| - Why sigmoid: Weights represent proportions/probabilities, must be 0-1 | |
| - Loss: MSE (Mean Squared Error for continuous values) | |
| - Loss weight: 1.2x (slightly increased priority) | |
| **LOSS WEIGHTING**: | |
| Different loss scales require weighting for balanced training: | |
| - Condition loss: Binary cross-entropy, typically ~0.3-0.7 | |
| - Confidence loss: MSE on scaled values, typically ~0.01-0.1 | |
| - Weight loss: MSE on scaled values, typically ~0.01-0.1 | |
| Weights (1.0, 1.5, 1.2) ensure: | |
| - All tasks contribute meaningfully to total loss | |
| - Confidence gets extra emphasis (was underfitting in previous versions) | |
| - Gradient magnitudes are balanced across tasks | |
| **WHY ADAM OPTIMIZER**: | |
| - Adaptive learning rates per parameter (handles different loss scales) | |
| - Momentum for faster convergence | |
| - Robust to hyperparameter choices | |
| - Industry standard for multi-task learning | |
| **MODEL COMPILATION**: | |
| The model uses a dictionary output format allowing: | |
| - Clear separation of different predictions | |
| - Easy access to specific outputs during inference | |
| - Flexible loss and metric assignment per output | |
| """ | |
| def build_model(self, input_dim, num_conditions, learning_rate=0.001): | |
| """Build neural network with improved confidence and weight outputs""" | |
| print("Building improved neural network model...") | |
| # Input layer | |
| inputs = keras.Input(shape=(input_dim,), name='embeddings') | |
| # Shared feature extraction layers | |
| x = layers.Dense(512, activation='relu', name='dense1')(inputs) # Increased capacity | |
| x = layers.BatchNormalization(name='bn1')(x) | |
| x = layers.Dropout(0.3, name='dropout1')(x) | |
| x = layers.Dense(256, activation='relu', name='dense2')(x) | |
| x = layers.BatchNormalization(name='bn2')(x) | |
| x = layers.Dropout(0.3, name='dropout2')(x) | |
| x = layers.Dense(128, activation='relu', name='dense3')(x) | |
| x = layers.BatchNormalization(name='bn3')(x) | |
| x = layers.Dropout(0.2, name='dropout3')(x) | |
| # Multi-label condition classification head | |
| condition_branch = layers.Dense(64, activation='relu', name='condition_dense')(x) | |
| condition_branch = layers.Dropout(0.2, name='condition_dropout')(condition_branch) | |
| condition_output = layers.Dense(num_conditions, activation='sigmoid', | |
| name='conditions')(condition_branch) | |
| # Individual confidence regression head - FIXED ACTIVATION | |
| confidence_branch = layers.Dense(64, activation='relu', name='confidence_dense1')(x) | |
| confidence_branch = layers.Dropout(0.2, name='confidence_dropout1')(confidence_branch) | |
| confidence_branch = layers.Dense(32, activation='relu', name='confidence_dense2')(confidence_branch) | |
| confidence_branch = layers.Dropout(0.1, name='confidence_dropout2')(confidence_branch) | |
| # Changed from ReLU to softplus - ensures positive, non-zero outputs | |
| confidence_output = layers.Dense(num_conditions, activation='softplus', | |
| name='individual_confidences')(confidence_branch) | |
| # Individual weight regression head | |
| weighted_branch = layers.Dense(64, activation='relu', name='weighted_dense1')(x) | |
| weighted_branch = layers.Dropout(0.2, name='weighted_dropout1')(weighted_branch) | |
| weighted_branch = layers.Dense(32, activation='relu', name='weighted_dense2')(weighted_branch) | |
| weighted_branch = layers.Dropout(0.1, name='weighted_dropout2')(weighted_branch) | |
| # Use sigmoid to ensure 0-1 range | |
| weighted_output = layers.Dense(num_conditions, activation='sigmoid', | |
| name='individual_weights')(weighted_branch) | |
| # Create model | |
| model = keras.Model( | |
| inputs=inputs, | |
| outputs={ | |
| 'conditions': condition_output, | |
| 'individual_confidences': confidence_output, | |
| 'individual_weights': weighted_output | |
| } | |
| ) | |
| # Compile model with improved loss weights | |
| model.compile( | |
| optimizer=keras.optimizers.Adam(learning_rate=learning_rate), | |
| loss={ | |
| 'conditions': 'binary_crossentropy', | |
| 'individual_confidences': 'mse', | |
| 'individual_weights': 'mse' | |
| }, | |
| loss_weights={ | |
| 'conditions': 1.0, | |
| 'individual_confidences': 1.5, # Increased weight for confidence | |
| 'individual_weights': 1.2 # Increased weight for weights | |
| }, | |
| metrics={ | |
| 'conditions': ['accuracy'], | |
| 'individual_confidences': ['mae'], | |
| 'individual_weights': ['mae'] | |
| } | |
| ) | |
| return model | |
| """ | |
| Main training orchestration method with improved confidence handling. | |
| TRAINING PIPELINE: | |
| 1. Load data (embeddings + labels) | |
| 2. Prepare training matrices (parse, filter, validate) | |
| 3. Scale features and outputs | |
| 4. Split train/validation sets | |
| 5. Build neural network architecture | |
| 6. Train with callbacks (early stopping, LR reduction, checkpointing) | |
| 7. Evaluate performance | |
| 8. Save trained model | |
| IMPROVED SCALING STRATEGY (KEY FIX): | |
| Problem: Previous version scaled all values including zeros | |
| Solution: Fit scalers only on non-zero values | |
| Why this matters: | |
| - Sparse matrices have many structural zeros (condition not present) | |
| - Including zeros in scaler fitting shifts mean artificially low | |
| - Model learns to predict near-zero for everything | |
| - Confidence predictions collapsed to ~0 (major bug) | |
| New approach: | |
| ```python | |
| conf_nonzero = y_confidences[y_confidences > 0.001] | |
| self.confidence_scaler.fit(conf_nonzero) | |
| Only non-zero values determine scale | |
| Model learns actual confidence distribution (1-5 range) | |
| Predictions are meaningful positive values | |
| FALLBACK HANDLING: | |
| If too few non-zero values exist: | |
| Use sensible dummy values (1-5 for confidence, 0-1 for weights) | |
| Prevents scaler failure on edge cases | |
| Ensures training can proceed | |
| TRAIN/TEST SPLIT: | |
| 80/20 split is standard for medical ML | |
| Stratification not used (multi-label makes it complex) | |
| Random state fixed for reproducibility | |
| CALLBACKS: | |
| Early Stopping (patience=12): | |
| Monitors validation loss | |
| Stops if no improvement for 12 epochs | |
| Restores best weights (not final weights) | |
| Why: Prevents overfitting to training set | |
| ReduceLROnPlateau (factor=0.5, patience=5): | |
| Monitors confidence loss specifically (was problematic) | |
| Reduces LR by 50% if loss plateaus | |
| Allows fine-tuning in late training | |
| Min LR: 1e-7 prevents excessive reduction | |
| ModelCheckpoint: | |
| Saves best model weights during training | |
| Insurance against training divergence | |
| Cleaned up after successful training | |
| TRAINING DURATION: | |
| 60 epochs maximum (increased from 50) | |
| Early stopping typically triggers around epoch 30-40 | |
| Batch size 32 balances memory vs convergence speed | |
| HYPERPARAMETERS: | |
| Learning rate: 0.001 (standard for Adam) | |
| Batch size: 32 (good for datasets of this size) | |
| Test split: 0.2 (20% validation, standard practice) | |
| POST-TRAINING: | |
| Comprehensive evaluation on test set | |
| Detailed metrics for all three outputs | |
| Analysis of confidence prediction quality | |
| """ | |
| def train(self, npz_file_path="derm_foundation_embeddings.npz", | |
| csv_file_path="dataset_scin_labels.csv", | |
| test_size=0.2, random_state=42, epochs=50, batch_size=32, | |
| learning_rate=0.001): | |
| """Train the neural network with improved confidence handling""" | |
| # Load data | |
| embeddings = self.load_embeddings(npz_file_path) | |
| if embeddings is None: | |
| return False | |
| df = self.load_dataset(csv_file_path) | |
| if df is None: | |
| return False | |
| # Prepare training data | |
| X, y_conditions, y_confidences, y_weights = self.prepare_training_data(df, embeddings) | |
| if X is None: | |
| return False | |
| # IMPROVED SCALING - fit only on non-zero values | |
| print("\nFitting scalers...") | |
| X_scaled = self.embedding_scaler.fit_transform(X) | |
| # Fit confidence scaler on non-zero values only | |
| conf_nonzero = y_confidences[y_confidences > 0.001] | |
| if len(conf_nonzero) > 50: # Ensure we have enough data | |
| print(f"Fitting confidence scaler on {len(conf_nonzero)} non-zero values") | |
| self.confidence_scaler.fit(conf_nonzero.reshape(-1, 1)) | |
| else: | |
| print("WARNING: Too few non-zero confidence values, using default scaling") | |
| # Use a reasonable range for confidence values (e.g., 1-5 scale) | |
| dummy_conf = np.array([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1) | |
| self.confidence_scaler.fit(dummy_conf) | |
| # Fit weight scaler on non-zero values only | |
| weight_nonzero = y_weights[y_weights > 0.001] | |
| if len(weight_nonzero) > 50: | |
| print(f"Fitting weight scaler on {len(weight_nonzero)} non-zero values") | |
| self.weighted_scaler.fit(weight_nonzero.reshape(-1, 1)) | |
| else: | |
| print("WARNING: Too few non-zero weight values, using default scaling") | |
| # Use a reasonable range for weight values (0-1 scale) | |
| dummy_weight = np.array([0.1, 0.3, 0.5, 0.7, 0.9]).reshape(-1, 1) | |
| self.weighted_scaler.fit(dummy_weight) | |
| # Apply scaling to the matrices (preserve structure) | |
| y_confidences_scaled = np.zeros_like(y_confidences) | |
| y_weights_scaled = np.zeros_like(y_weights) | |
| # Scale only non-zero values | |
| for i in range(y_confidences.shape[0]): | |
| for j in range(y_confidences.shape[1]): | |
| if y_confidences[i, j] > 0.001: | |
| y_confidences_scaled[i, j] = self.confidence_scaler.transform([[y_confidences[i, j]]])[0, 0] | |
| if y_weights[i, j] > 0.001: | |
| y_weights_scaled[i, j] = self.weighted_scaler.transform([[y_weights[i, j]]])[0, 0] | |
| print(f"Scaled confidence range: {y_confidences_scaled[y_confidences_scaled != 0].min():.3f} - {y_confidences_scaled[y_confidences_scaled != 0].max():.3f}") | |
| print(f"Scaled weight range: {y_weights_scaled[y_weights_scaled != 0].min():.3f} - {y_weights_scaled[y_weights_scaled != 0].max():.3f}") | |
| # Split data | |
| X_train, X_test, y_cond_train, y_cond_test, y_conf_train, y_conf_test, y_weight_train, y_weight_test = train_test_split( | |
| X_scaled, y_conditions, y_confidences_scaled, y_weights_scaled, | |
| test_size=test_size, random_state=random_state | |
| ) | |
| print(f"\nTraining/test split:") | |
| print(f" Training samples: {X_train.shape[0]}") | |
| print(f" Test samples: {X_test.shape[0]}") | |
| # Build model | |
| self.model = self.build_model( | |
| input_dim=X_scaled.shape[1], | |
| num_conditions=y_conditions.shape[1], | |
| learning_rate=learning_rate | |
| ) | |
| print(f"\nModel architecture:") | |
| self.model.summary() | |
| # Prepare training data | |
| train_data = { | |
| 'conditions': y_cond_train, | |
| 'individual_confidences': y_conf_train, | |
| 'individual_weights': y_weight_train | |
| } | |
| val_data = { | |
| 'conditions': y_cond_test, | |
| 'individual_confidences': y_conf_test, | |
| 'individual_weights': y_weight_test | |
| } | |
| # Enhanced callbacks | |
| early_stopping = keras.callbacks.EarlyStopping( | |
| monitor='val_loss', | |
| patience=12, # Increased patience | |
| restore_best_weights=True, | |
| verbose=1 | |
| ) | |
| reduce_lr = keras.callbacks.ReduceLROnPlateau( | |
| monitor='val_individual_confidences_loss', # Monitor confidence loss specifically | |
| factor=0.5, | |
| patience=5, | |
| min_lr=1e-7, | |
| mode='min', # We want to minimize the loss | |
| verbose=1 | |
| ) | |
| model_checkpoint = keras.callbacks.ModelCheckpoint( | |
| filepath='best_model_fixed.weights.h5', | |
| monitor='val_loss', | |
| save_best_only=True, | |
| save_weights_only=True, | |
| verbose=1 | |
| ) | |
| print(f"\nStarting training for {epochs} epochs...") | |
| # Train model | |
| self.history = self.model.fit( | |
| X_train, train_data, | |
| validation_data=(X_test, val_data), | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| callbacks=[early_stopping, reduce_lr, model_checkpoint], | |
| verbose=1 | |
| ) | |
| # Evaluate model | |
| self.evaluate_model(X_test, y_cond_test, y_conf_test, y_weight_test) | |
| return True | |
| """ | |
| Comprehensive model evaluation with enhanced confidence analysis. | |
| EVALUATION METRICS: | |
| 1. MULTI-LABEL CLASSIFICATION (Conditions): | |
| Hamming Loss: | |
| Definition: Fraction of incorrectly predicted labels | |
| Range: [0, 1] where 0 is perfect | |
| Formula: (1/n_labels) × Σ|y_true ⊕ y_pred| | |
| Example: If 2 out of 30 labels are wrong, hamming loss = 0.067 | |
| Clinical interpretation: Lower is better, <0.1 is excellent | |
| Exact Match Accuracy: | |
| Strictest metric: Requires ALL labels perfectly correct | |
| Range: [0, 1] where 1 is perfect | |
| Why include: Shows complete prediction correctness | |
| Medical context: Exact match is ideal but rarely achievable | |
| (even expert dermatologists disagree on some cases) | |
| Average Conditions per Sample: | |
| Describes label distribution complexity | |
| Higher values → harder multi-label problem | |
| Typical range: 1-3 conditions per sample | |
| 2. CONFIDENCE REGRESSION: | |
| Why evaluate only non-zero targets: | |
| Zeros are structural (condition not present) | |
| Including zeros conflates two problems: | |
| a) Predicting which conditions exist (classification task) | |
| b) Predicting confidence for existing conditions (regression task) | |
| We want to evaluate (b) separately | |
| Inverse Transform: | |
| Converts scaled predictions back to original scale | |
| Necessary for interpretable metrics | |
| Example: Scaled 0.3 → Original 3.2 (on 1-5 scale) | |
| MSE (Mean Squared Error): | |
| Sensitive to large errors (squared penalty) | |
| Unit: (confidence units)² | |
| Lower is better | |
| MAE (Mean Absolute Error): | |
| Average absolute difference from ground truth | |
| Same units as original values | |
| More robust to outliers than MSE | |
| Clinical interpretation: If MAE=0.5, average error is ±0.5 points | |
| RMSE (Root Mean Squared Error): | |
| Square root of MSE | |
| Same units as original values (easier to interpret than MSE) | |
| Emphasizes larger errors more than MAE | |
| Prediction Range Analysis: | |
| Verifies predictions are in sensible range | |
| Example: If ground truth is 1-5, predictions should be similar | |
| Out-of-range predictions indicate scaling or activation issues | |
| 3. WEIGHT REGRESSION: | |
| Same metrics as confidence but for weight values (0-1 range) | |
| DIAGNOSTIC CHECKS: | |
| "Predictions > 0.1" percentage: Ensures model isn't predicting near-zero | |
| Range comparison: Truth vs prediction ranges should align | |
| Non-zero count: Verifies sparse structure is respected | |
| WHY THIS EVALUATION IS COMPREHENSIVE: | |
| Multiple metrics cover different aspects (classification + regression) | |
| Separate evaluation of sparse vs dense regions | |
| Original scale metrics (clinically interpretable) | |
| Diagnostic checks for common failure modes | |
| Both aggregate (MSE) and per-sample (MAE) metrics | |
| """ | |
| def evaluate_model(self, X_test, y_cond_test, y_conf_test, y_weight_test): | |
| """Evaluate the trained model with enhanced confidence analysis""" | |
| print("\n" + "="*70) | |
| print("MODEL EVALUATION - ENHANCED CONFIDENCE ANALYSIS") | |
| print("="*70) | |
| # Make predictions | |
| predictions = self.model.predict(X_test) | |
| y_cond_pred = predictions['conditions'] | |
| y_conf_pred = predictions['individual_confidences'] | |
| y_weight_pred = predictions['individual_weights'] | |
| # Condition classification evaluation | |
| y_cond_pred_binary = (y_cond_pred > 0.5).astype(int) | |
| hamming = hamming_loss(y_cond_test, y_cond_pred_binary) | |
| exact_match = (y_cond_pred_binary == y_cond_test).all(axis=1).mean() | |
| print(f"Multi-label Condition Classification:") | |
| print(f" Hamming Loss: {hamming:.4f}") | |
| print(f" Exact Match Accuracy: {exact_match:.4f}") | |
| print(f" Average conditions per sample: {y_cond_test.sum(axis=1).mean():.2f}") | |
| # ENHANCED confidence evaluation | |
| print(f"\nConfidence Prediction Analysis:") | |
| print(f" Raw prediction range: {y_conf_pred.min():.6f} - {y_conf_pred.max():.6f}") | |
| print(f" Non-zero predictions: {(y_conf_pred > 0.001).sum()}/{y_conf_pred.size}") | |
| # Inverse transform and evaluate confidence | |
| conf_mask = y_conf_test > 0.001 | |
| if conf_mask.sum() > 0: | |
| y_conf_test_orig = np.zeros_like(y_conf_test) | |
| y_conf_pred_orig = np.zeros_like(y_conf_pred) | |
| # Inverse transform | |
| for i in range(y_conf_test.shape[0]): | |
| for j in range(y_conf_test.shape[1]): | |
| if y_conf_test[i, j] > 0.001: | |
| y_conf_test_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_test[i, j]]])[0, 0] | |
| if y_conf_pred[i, j] > 0.001: | |
| y_conf_pred_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_pred[i, j]]])[0, 0] | |
| # Calculate metrics only on positions where ground truth is non-zero | |
| conf_test_nonzero = y_conf_test_orig[conf_mask] | |
| conf_pred_nonzero = y_conf_pred_orig[conf_mask] | |
| conf_mse = mean_squared_error(conf_test_nonzero, conf_pred_nonzero) | |
| conf_mae = mean_absolute_error(conf_test_nonzero, conf_pred_nonzero) | |
| print(f" Individual Confidence Regression (on {conf_mask.sum()} non-zero targets):") | |
| print(f" MSE: {conf_mse:.4f}") | |
| print(f" MAE: {conf_mae:.4f}") | |
| print(f" RMSE: {np.sqrt(conf_mse):.4f}") | |
| print(f" Prediction range (orig scale): {conf_pred_nonzero.min():.3f} - {conf_pred_nonzero.max():.3f}") | |
| print(f" Ground truth range (orig scale): {conf_test_nonzero.min():.3f} - {conf_test_nonzero.max():.3f}") | |
| # Check if predictions are reasonable | |
| reasonable_predictions = (conf_pred_nonzero > 0.1).sum() | |
| print(f" Predictions > 0.1: {reasonable_predictions}/{len(conf_pred_nonzero)} ({100*reasonable_predictions/len(conf_pred_nonzero):.1f}%)") | |
| # Individual weight evaluation | |
| weight_mask = y_weight_test > 0.001 | |
| if weight_mask.sum() > 0: | |
| y_weight_test_orig = np.zeros_like(y_weight_test) | |
| y_weight_pred_orig = np.zeros_like(y_weight_pred) | |
| for i in range(y_weight_test.shape[0]): | |
| for j in range(y_weight_test.shape[1]): | |
| if y_weight_test[i, j] > 0.001: | |
| y_weight_test_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_test[i, j]]])[0, 0] | |
| if y_weight_pred[i, j] > 0.001: | |
| y_weight_pred_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_pred[i, j]]])[0, 0] | |
| weight_test_nonzero = y_weight_test_orig[weight_mask] | |
| weight_pred_nonzero = y_weight_pred_orig[weight_mask] | |
| weight_mse = mean_squared_error(weight_test_nonzero, weight_pred_nonzero) | |
| weight_mae = mean_absolute_error(weight_test_nonzero, weight_pred_nonzero) | |
| print(f"\nIndividual Weight Regression (on {weight_mask.sum()} non-zero targets):") | |
| print(f" MSE: {weight_mse:.4f}") | |
| print(f" MAE: {weight_mae:.4f}") | |
| print(f" RMSE: {np.sqrt(weight_mse):.4f}") | |
| print(f" Prediction range (orig scale): {weight_pred_nonzero.min():.3f} - {weight_pred_nonzero.max():.3f}") | |
| print(f" Ground truth range (orig scale): {weight_test_nonzero.min():.3f} - {weight_test_nonzero.max():.3f}") | |
| """ | |
| Make predictions on new embeddings with comprehensive output formatting. | |
| PREDICTION PIPELINE: | |
| Scale input embedding (using training-fitted scaler) | |
| Forward pass through neural network | |
| Process raw outputs: | |
| Condition probabilities: Sigmoid outputs [0,1] | |
| Confidence values: Softplus outputs [0,∞) | |
| Weight values: Sigmoid outputs [0,1] | |
| Inverse transform regression outputs to original scale | |
| Apply threshold to select predicted conditions | |
| Return structured dictionary with multiple views of predictions | |
| THRESHOLD STRATEGY: | |
| Condition threshold: 0.3 (lower than typical 0.5) | |
| Why lower: Medical diagnosis prefers sensitivity (catch more conditions) | |
| False positives less harmful than false negatives in screening | |
| Can be adjusted based on clinical requirements | |
| OUTPUT STRUCTURE: | |
| Primary Predictions (conditions above threshold): | |
| dermatologist_skin_condition_on_label_name: List of predicted conditions | |
| dermatologist_skin_condition_confidence: Confidence per predicted condition | |
| weighted_skin_condition_label: Weight dict for predicted conditions | |
| Comprehensive View (all conditions): | |
| all_condition_probabilities: Probability for every possible condition | |
| all_individual_confidences: Confidence for every possible condition | |
| all_individual_weights: Weight for every possible condition | |
| Debugging Information: | |
| raw_confidence_outputs: Pre-transform neural network outputs | |
| raw_weight_outputs: Pre-transform neural network outputs | |
| condition_threshold: Threshold used for filtering | |
| Why provide multiple views: | |
| Primary predictions: For direct clinical use | |
| Comprehensive view: For ranking, uncertainty quantification | |
| Debug info: For model validation and troubleshooting | |
| MINIMUM VALUE CLAMPING: | |
| pythonconfidence_orig = max(0.1, confidence_orig) | |
| weight_orig = max(0.01, weight_orig) | |
| Ensures predictions are never exactly zero | |
| Confidence ≥0.1: Even lowest predictions are meaningful | |
| Weight ≥0.01: Prevents division-by-zero in downstream processing | |
| SOFTPLUS ADVANTAGE: | |
| With softplus activation, even very negative inputs produce small positive | |
| outputs, so confidence predictions naturally avoid zero. The max(0.1, x) | |
| provides additional safety margin. | |
| RETURN FORMAT: | |
| Dictionary structure allows: | |
| Easy access to specific prediction types | |
| Clear semantic meaning (key names describe contents) | |
| Extensible (can add new keys without breaking existing code) | |
| JSON-serializable for API deployment | |
| """ | |
| def predict(self, embedding): | |
| """Make predictions on a single embedding with individual outputs""" | |
| if self.model is None: | |
| print("ERROR: Model not trained. Call train() first.") | |
| return None | |
| if len(embedding.shape) == 1: | |
| embedding = embedding.reshape(1, -1) | |
| # Scale embedding | |
| embedding_scaled = self.embedding_scaler.transform(embedding) | |
| # Make predictions | |
| predictions = self.model.predict(embedding_scaled, verbose=0) | |
| # Process condition predictions | |
| condition_probs = predictions['conditions'][0] | |
| individual_confidences = predictions['individual_confidences'][0] | |
| individual_weights = predictions['individual_weights'][0] | |
| # Get predicted conditions (above threshold) | |
| condition_threshold = 0.3 # Lower threshold | |
| predicted_condition_indices = np.where(condition_probs > condition_threshold)[0] | |
| # Build results | |
| predicted_conditions = [] | |
| predicted_confidences = [] | |
| predicted_weights_dict = {} | |
| for idx in predicted_condition_indices: | |
| condition_name = self.mlb.classes_[idx] | |
| condition_prob = float(condition_probs[idx]) | |
| # Inverse transform individual outputs with better handling | |
| confidence_raw = individual_confidences[idx] | |
| weight_raw = individual_weights[idx] | |
| # Always inverse transform, even small values (softplus ensures non-zero) | |
| confidence_orig = self.confidence_scaler.inverse_transform([[confidence_raw]])[0, 0] | |
| weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0] | |
| predicted_conditions.append(condition_name) | |
| predicted_confidences.append(max(0.1, confidence_orig)) # Minimum confidence of 0.1 | |
| predicted_weights_dict[condition_name] = max(0.01, weight_orig) # Minimum weight of 0.01 | |
| # Also provide all condition probabilities for reference | |
| all_condition_probs = {} | |
| all_confidences = {} | |
| all_weights = {} | |
| for i, class_name in enumerate(self.mlb.classes_): | |
| all_condition_probs[class_name] = float(condition_probs[i]) | |
| # Always inverse transform all outputs | |
| conf_raw = individual_confidences[i] | |
| weight_raw = individual_weights[i] | |
| conf_orig = self.confidence_scaler.inverse_transform([[conf_raw]])[0, 0] | |
| weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0] | |
| all_confidences[class_name] = max(0.0, conf_orig) | |
| all_weights[class_name] = max(0.0, weight_orig) | |
| return { | |
| # Main predicted results (above threshold) | |
| 'dermatologist_skin_condition_on_label_name': predicted_conditions, | |
| 'dermatologist_skin_condition_confidence': predicted_confidences, | |
| 'weighted_skin_condition_label': predicted_weights_dict, | |
| # Additional information for analysis | |
| 'all_condition_probabilities': all_condition_probs, | |
| 'all_individual_confidences': all_confidences, | |
| 'all_individual_weights': all_weights, | |
| 'condition_threshold': condition_threshold, | |
| # Debug information | |
| 'raw_confidence_outputs': individual_confidences.tolist(), | |
| 'raw_weight_outputs': individual_weights.tolist() | |
| } | |
| def plot_training_history(self): | |
| if self.history is None: | |
| print("No training history available") | |
| return | |
| # Set matplotlib to use non-interactive backend | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(2, 3, figsize=(18, 10)) | |
| # Loss | |
| axes[0, 0].plot(self.history.history['loss'], label='Training Loss') | |
| axes[0, 0].plot(self.history.history['val_loss'], label='Validation Loss') | |
| axes[0, 0].set_title('Model Loss') | |
| axes[0, 0].set_xlabel('Epoch') | |
| axes[0, 0].set_ylabel('Loss') | |
| axes[0, 0].legend() | |
| # Condition accuracy | |
| axes[0, 1].plot(self.history.history['conditions_accuracy'], label='Training Accuracy') | |
| axes[0, 1].plot(self.history.history['val_conditions_accuracy'], label='Validation Accuracy') | |
| axes[0, 1].set_title('Condition Classification Accuracy') | |
| axes[0, 1].set_xlabel('Epoch') | |
| axes[0, 1].set_ylabel('Accuracy') | |
| axes[0, 1].legend() | |
| # Individual Confidence MAE | |
| axes[0, 2].plot(self.history.history['individual_confidences_mae'], label='Training MAE') | |
| axes[0, 2].plot(self.history.history['val_individual_confidences_mae'], label='Validation MAE') | |
| axes[0, 2].set_title('Individual Confidence MAE') | |
| axes[0, 2].set_xlabel('Epoch') | |
| axes[0, 2].set_ylabel('MAE') | |
| axes[0, 2].legend() | |
| # Individual Weight MAE | |
| axes[1, 0].plot(self.history.history['individual_weights_mae'], label='Training MAE') | |
| axes[1, 0].plot(self.history.history['val_individual_weights_mae'], label='Validation MAE') | |
| axes[1, 0].set_title('Individual Weight MAE') | |
| axes[1, 0].set_xlabel('Epoch') | |
| axes[1, 0].set_ylabel('MAE') | |
| axes[1, 0].legend() | |
| # Individual confidence loss | |
| axes[1, 1].plot(self.history.history['individual_confidences_loss'], label='Training Loss') | |
| axes[1, 1].plot(self.history.history['val_individual_confidences_loss'], label='Validation Loss') | |
| axes[1, 1].set_title('Individual Confidence Loss') | |
| axes[1, 1].set_xlabel('Epoch') | |
| axes[1, 1].set_ylabel('Loss') | |
| axes[1, 1].legend() | |
| # Individual weight loss | |
| axes[1, 2].plot(self.history.history['individual_weights_loss'], label='Training Loss') | |
| axes[1, 2].plot(self.history.history['val_individual_weights_loss'], label='Validation Loss') | |
| axes[1, 2].set_title('Individual Weight Loss') | |
| axes[1, 2].set_xlabel('Epoch') | |
| axes[1, 2].set_ylabel('Loss') | |
| axes[1, 2].legend() | |
| plt.tight_layout() | |
| plt.savefig('training_history_fixed.png', dpi=300, bbox_inches='tight') | |
| print("Training history plot saved as: training_history_fixed.png") | |
| plt.close() | |
| """ | |
| Persist trained model and preprocessing components to disk. | |
| SAVED COMPONENTS: | |
| 1. **Keras Model (.keras file)**: | |
| - Neural network architecture | |
| - Trained weights for all layers | |
| - Optimizer state (for resuming training) | |
| - Compilation settings (loss functions, metrics) | |
| 2. **Preprocessing Data (.pkl file)**: | |
| - MultiLabelBinarizer: Maps condition names ↔ indices | |
| - embedding_scaler: Normalizes input embeddings | |
| - confidence_scaler: Normalizes confidence values | |
| - weighted_scaler: Normalizes weight values | |
| - Path to .keras file (for loading) | |
| WHY SEPARATE FILES: | |
| - Keras models save to modern .keras format | |
| - Scikit-learn components need pickle serialization | |
| - Separation allows independent updates of each component | |
| LOADING REQUIREMENT: | |
| Both files are needed for inference: | |
| - .keras: Neural network for making predictions | |
| - .pkl: Preprocessors for transforming inputs/outputs | |
| FILE ORGANIZATION: | |
| easi_severity_model_derm_foundation_individual_fixed.pkl (main file) | |
| easi_severity_model_derm_foundation_individual_fixed.keras (model) | |
| User loads .pkl file, which contains path to .keras file | |
| CLEANUP: | |
| Removes temporary checkpoint file (best_model_fixed.weights.h5) | |
| created during training to avoid confusion with final model. | |
| ERROR HANDLING: | |
| Checks if model exists before saving, provides clear error messages | |
| and file paths for debugging. | |
| """ | |
| def save_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"): | |
| """Save the trained model""" | |
| if self.model is None: | |
| print("ERROR: No trained model to save.") | |
| return False | |
| # Get current directory | |
| current_dir = os.getcwd() | |
| # Save Keras model with proper extension | |
| model_filename = os.path.splitext(filepath)[0] | |
| keras_model_path = os.path.join(current_dir, f"{model_filename}.keras") | |
| print(f"Saving Keras model to: {keras_model_path}") | |
| self.model.save(keras_model_path) | |
| # Save preprocessing components | |
| pkl_filepath = os.path.join(current_dir, filepath) | |
| model_data = { | |
| 'mlb': self.mlb, | |
| 'embedding_scaler': self.embedding_scaler, | |
| 'confidence_scaler': self.confidence_scaler, | |
| 'weighted_scaler': self.weighted_scaler, | |
| 'keras_model_path': keras_model_path | |
| } | |
| print(f"Saving preprocessing data to: {pkl_filepath}") | |
| with open(pkl_filepath, 'wb') as f: | |
| pickle.dump(model_data, f) | |
| print(f"Model saved successfully!") | |
| print(f" - Main file: {pkl_filepath}") | |
| print(f" - Keras model: {keras_model_path}") | |
| # Clean up temporary checkpoint file | |
| checkpoint_file = os.path.join(current_dir, 'best_model_fixed.weights.h5') | |
| if os.path.exists(checkpoint_file): | |
| os.remove(checkpoint_file) | |
| print(f" - Cleaned up temporary checkpoint file") | |
| return True | |
| def load_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"): | |
| """Load trained model""" | |
| if not os.path.exists(filepath): | |
| print(f"ERROR: Model file not found: {filepath}") | |
| return False | |
| try: | |
| with open(filepath, 'rb') as f: | |
| model_data = pickle.load(f) | |
| # Load preprocessing components | |
| self.mlb = model_data['mlb'] | |
| self.embedding_scaler = model_data['embedding_scaler'] | |
| self.confidence_scaler = model_data['confidence_scaler'] | |
| self.weighted_scaler = model_data['weighted_scaler'] | |
| # Load Keras model | |
| keras_model_path = model_data['keras_model_path'] | |
| if os.path.exists(keras_model_path): | |
| self.model = keras.models.load_model(keras_model_path) | |
| print(f"Model loaded from {filepath}") | |
| print(f"Available condition classes: {len(self.mlb.classes_)}") | |
| return True | |
| else: | |
| print(f"ERROR: Keras model not found at {keras_model_path}") | |
| return False | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| """ | |
| WORKFLOW: | |
| 1. Print configuration and fixes applied (user visibility) | |
| 2. Initialize classifier | |
| 3. Validate input files exist | |
| 4. Train model with improved confidence handling | |
| 5. Plot training history | |
| 6. Test model predictions (validate fix effectiveness) | |
| 7. Save trained model | |
| MODEL TESTING (NEW): | |
| After training completes, runs a sample prediction to verify: | |
| Model produces non-zero confidence values (fix validation) | |
| Predictions are in expected ranges | |
| Output structure is correct | |
| This immediate validation catches issues before deployment. | |
| WHY TEST WITH SAMPLE: | |
| Confirms confidence scaling fix worked | |
| Provides immediate feedback on model quality | |
| Demonstrates expected output format | |
| Catches activation function issues (like ReLU→0 bug) | |
| SUCCESS CRITERIA: | |
| ✅ Non-zero confidences in reasonable range (e.g., 1-5) | |
| ✅ Multiple conditions predicted with varying probabilities | |
| ✅ Weights sum to reasonable values | |
| ⚠️ Warning if confidence outputs still mostly zero | |
| """ | |
| def main(): | |
| """Main training function with enhanced confidence handling""" | |
| print("Derm Foundation Neural Network Classifier Training - FIXED VERSION") | |
| print("="*70) | |
| print("FIXES APPLIED:") | |
| print("- Changed confidence activation from ReLU to softplus") | |
| print("- Improved confidence scaler fitting (non-zero values only)") | |
| print("- Increased confidence loss weight (1.5x)") | |
| print("- Enhanced data validation and preprocessing") | |
| print("- Better handling of sparse confidence/weight matrices") | |
| print("="*70) | |
| print("Training neural network to predict:") | |
| print("1. Skin conditions (multi-label classification)") | |
| print("2. Individual confidence scores per condition (regression)") | |
| print("3. Individual weight scores per condition (regression)") | |
| print("="*70) | |
| # Initialize classifier | |
| classifier = DermFoundationNeuralNetwork() | |
| # File paths | |
| npz_file = "derm_foundation_embeddings.npz" | |
| csv_file = "dataset_scin_labels.csv" | |
| model_output = "easi_severity_model_derm_foundation_individual_fixed.pkl" | |
| # Check if files exist | |
| missing_files = [] | |
| if not os.path.exists(npz_file): | |
| missing_files.append(npz_file) | |
| if not os.path.exists(csv_file): | |
| missing_files.append(csv_file) | |
| if missing_files: | |
| print(f"ERROR: Missing required files:") | |
| for file in missing_files: | |
| print(f" - {file}") | |
| return | |
| try: | |
| # Train the model | |
| success = classifier.train( | |
| npz_file_path=npz_file, | |
| csv_file_path=csv_file, | |
| epochs=60, # Increased epochs | |
| batch_size=32, | |
| learning_rate=0.001 | |
| ) | |
| if not success: | |
| print("Training failed!") | |
| return | |
| # Plot training history | |
| try: | |
| classifier.plot_training_history() | |
| except Exception as e: | |
| print(f"Could not plot training history: {e}") | |
| # Test the model with a sample prediction to verify confidence outputs | |
| print("\n" + "="*70) | |
| print("TESTING MODEL OUTPUTS") | |
| print("="*70) | |
| # Get a sample embedding for testing | |
| try: | |
| embeddings = classifier.load_embeddings(npz_file) | |
| if embeddings: | |
| sample_key = list(embeddings.keys())[0] | |
| sample_embedding = embeddings[sample_key] | |
| print(f"Testing with sample embedding: {sample_key}") | |
| test_result = classifier.predict(sample_embedding) | |
| if test_result: | |
| print("✅ Model prediction successful!") | |
| print(f"Predicted conditions: {len(test_result['dermatologist_skin_condition_on_label_name'])}") | |
| # Check confidence outputs | |
| all_confidences = list(test_result['all_individual_confidences'].values()) | |
| nonzero_conf = sum(1 for c in all_confidences if c > 0.01) | |
| print(f"Confidence range: {min(all_confidences):.4f} - {max(all_confidences):.4f}") | |
| print(f"Non-zero confidences: {nonzero_conf}/{len(all_confidences)}") | |
| if nonzero_conf > 0: | |
| print("✅ CONFIDENCE ISSUE APPEARS TO BE FIXED!") | |
| else: | |
| print("⚠️ Confidence outputs still mostly zero - may need further investigation") | |
| # Show top predictions | |
| if test_result['dermatologist_skin_condition_on_label_name']: | |
| print(f"\nSample predictions:") | |
| for i, condition in enumerate(test_result['dermatologist_skin_condition_on_label_name'][:3]): | |
| prob = test_result['all_condition_probabilities'][condition] | |
| conf = test_result['dermatologist_skin_condition_confidence'][i] | |
| weight = test_result['weighted_skin_condition_label'][condition] | |
| print(f" {condition}: prob={prob:.3f}, conf={conf:.3f}, weight={weight:.3f}") | |
| else: | |
| print("❌ Model prediction failed") | |
| except Exception as e: | |
| print(f"Could not test model: {e}") | |
| # Save the model | |
| classifier.save_model(model_output) | |
| print(f"\n{'='*70}") | |
| print("TRAINING COMPLETE!") | |
| print(f"{'='*70}") | |
| print(f"Model saved as: {model_output}") | |
| print(f"Training history plot saved as: training_history_fixed.png") | |
| print(f"\nTo use the trained model:") | |
| print(f"```python") | |
| print(f"classifier = DermFoundationNeuralNetwork()") | |
| print(f"classifier.load_model('{model_output}')") | |
| print(f"result = classifier.predict(embedding)") | |
| print(f"print(result['dermatologist_skin_condition_on_label_name'])") | |
| print(f"print(result['dermatologist_skin_condition_confidence'])") | |
| print(f"print(result['weighted_skin_condition_label'])") | |
| print(f"```") | |
| # Example prediction output format | |
| print(f"\nExpected prediction output format:") | |
| print(f"{{") | |
| print(f" 'dermatologist_skin_condition_on_label_name': ['Eczema', 'Irritant Contact Dermatitis'],") | |
| print(f" 'dermatologist_skin_condition_confidence': [4.2, 3.1],") | |
| print(f" 'weighted_skin_condition_label': {{'Eczema': 0.65, 'Irritant Contact Dermatitis': 0.35}}") | |
| print(f"}}") | |
| except Exception as e: | |
| print(f"Error during training: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() |