File size: 67,282 Bytes
dc04619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
#!/usr/bin/env python3

"""

Derm Foundation Neural Network Classifier Training Script - Fixed Version



PURPOSE:

This script trains a multi-output neural network to predict dermatological 

conditions and their associated metadata from pre-computed embeddings. It 

addresses the challenging problem of multi-label medical diagnosis where:

1. Multiple conditions can co-exist (multi-label classification)

2. Each diagnosis has an associated confidence level (regression)

3. Each diagnosis has a weight indicating relative importance (regression)



WHY NEURAL NETWORKS FOR THIS TASK:

Neural networks are the optimal choice for this problem for several reasons:



1. **Non-linear Relationship Learning**: The relationship between image 

   embeddings and skin conditions is highly non-linear. Neural networks excel

   at learning complex, non-linear mappings that simpler models (like logistic

   regression) cannot capture.



2. **Multi-task Learning**: This problem requires predicting three related but

   distinct outputs (conditions, confidence, weights). Neural networks can

   share learned representations across these tasks through shared layers,

   improving generalization and efficiency.



3. **High-dimensional Input**: Embeddings are typically 512-1024 dimensional

   vectors. Neural networks are designed to handle high-dimensional inputs

   effectively through dimensionality reduction in hidden layers.



4. **Multi-label Classification**: Medical diagnosis often involves multiple

   co-existing conditions. Neural networks with sigmoid activation can model

   the independent probability of each condition, unlike single-label methods.



5. **Flexibility**: The architecture can be customized with task-specific

   heads (branches) for different prediction types, allowing specialized

   processing for classification vs regression outputs.



WHY HAMMING LOSS IS VALID:

Hamming loss is an appropriate metric for multi-label classification because:



1. **Accounts for Partial Correctness**: Unlike exact match accuracy, hamming

   loss gives credit for partially correct predictions. Predicting 3 out of 4

   conditions correctly is better than 0 out of 4.



2. **Label-wise Evaluation**: It measures the fraction of incorrectly predicted

   labels, treating each label independently - appropriate when conditions can

   co-occur independently.



3. **Bounded and Interpretable**: Ranges from 0 (perfect) to 1 (completely

   wrong). A hamming loss of 0.1 means 10% of label predictions were incorrect.



4. **Balanced for Sparse Labels**: In medical diagnosis, most samples have few

   positive labels (sparse multi-label). Hamming loss naturally handles this

   imbalance by computing the fraction across all labels.



5. **Clinically Relevant**: In medical applications, both false positives and

   false negatives matter. Hamming loss penalizes both equally, unlike metrics

   that focus on one type of error.



MATHEMATICAL JUSTIFICATION:

For a sample with true labels y and predicted labels ŷ:

    Hamming Loss = (1/n_labels) × Σ(y_i XOR ŷ_i)



This averages the disagreement across all possible labels, making it suitable

for scenarios where:

- The label space is large (many possible conditions)

- Label correlations exist but aren't perfectly predictable

- Clinical accuracy matters at the individual label level



FIXES APPLIED IN THIS VERSION:

- Changed confidence activation from ReLU to softplus (prevents zero outputs)

- Improved confidence scaler fitting (uses only non-zero values)

- Increased confidence loss weight (1.5x for better learning signal)

- Enhanced data validation and preprocessing

- Better handling of sparse confidence/weight matrices



Requirements:

- pandas

- numpy

- tensorflow>=2.13.0

- scikit-learn

- matplotlib

- pickle (standard library)

- os (standard library)

- derm_foundation_embeddings.npz: Pre-computed embeddings from images

- dataset_scin_labels.csv: Ground truth labels with conditions, confidences, weights



OUTPUT:

- Trained neural network model (.keras file)

- Preprocessing components (scalers, label encoder) in .pkl file

- Training history plots showing convergence

- Evaluation metrics on test set

"""

import numpy as np
import pandas as pd
import pickle
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import hamming_loss, mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

"""

Main class implementing the multi-output neural network classifier.



ARCHITECTURE OVERVIEW:

1. **Shared Feature Extraction**: 3 dense layers (512→256→128) with batch

   normalization and dropout. These layers learn a shared representation

   useful for all prediction tasks.



2. **Task-Specific Heads**: Three separate output branches:

   - Condition classification: Sigmoid activation for multi-label prediction

   - Confidence regression: Softplus activation for positive continuous values

   - Weight regression: Sigmoid activation for [0,1] bounded values



WHY MULTI-TASK LEARNING:

- Conditions, confidence, and weights are related but distinct

- Sharing early layers allows the model to learn features useful for all tasks

- Task-specific heads allow specialized processing for each output type

- Improves generalization by preventing overfitting to any single task



TRAINING STRATEGY:

- Multi-task loss: Weighted combination of classification and regression losses

- Early stopping: Prevents overfitting by monitoring validation loss

- Learning rate reduction: Adapts learning rate when progress plateaus

- Batch normalization: Stabilizes training and allows higher learning rates

"""

class DermFoundationNeuralNetwork:
    """

    Initialize the classifier with preprocessing components.



    PREPROCESSING COMPONENTS:

    - mlb (MultiLabelBinarizer): Converts condition names to binary vectors

    Example: ['Eczema', 'Psoriasis'] → [0,1,0,1,0,...,0]

    

    - embedding_scaler (StandardScaler): Normalizes embeddings to mean=0, std=1

    Why: Neural networks train faster with normalized inputs

    

    - confidence_scaler (StandardScaler): Normalizes confidence values

    Why: Brings continuous values to similar scale as other outputs

    

    - weighted_scaler (StandardScaler): Normalizes weight values

    Why: Ensures balanced gradient contributions during training



    DESIGN DECISION:

    Separate scalers for each output type allow independent normalization,

    which is crucial when outputs have different scales and distributions.

    """
    def __init__(self):
        self.model = None
        self.mlb = MultiLabelBinarizer()
        self.embedding_scaler = StandardScaler()
        self.confidence_scaler = StandardScaler()
        self.weighted_scaler = StandardScaler()
        self.history = None
    """

    Load pre-computed Derm Foundation embeddings from NPZ file.



    WHAT ARE EMBEDDINGS:

    Embeddings are dense vector representations of images extracted from a

    pre-trained vision model (Derm Foundation model). They capture high-level

    visual features learned from large-scale dermatology image datasets.



    WHY USE PRE-COMPUTED EMBEDDINGS:

    1. **Efficiency**: Computing embeddings is expensive. Pre-computing them

    allows rapid experimentation with different classifier architectures.

    

    2. **Transfer Learning**: Derm Foundation was trained on massive dermatology

    datasets. Its embeddings encode domain-specific visual patterns.

    

    3. **Separation of Concerns**: Image processing and classification are

    separated, allowing independent optimization of each component.



    FORMAT:

    NPZ file contains a dictionary where:

    - Keys: case_id (string identifiers)

    - Values: embedding vectors (typically 512 or 768 dimensions)

    """
    def load_embeddings(self, npz_file_path):
        """Load embeddings from NPZ file"""
        print(f"Loading embeddings from {npz_file_path}...")
        
        if not os.path.exists(npz_file_path):
            print(f"ERROR: Embeddings file not found: {npz_file_path}")
            return None
        
        embeddings_data = {}
        with open(npz_file_path, 'rb') as f:
            npz_file = np.load(f, allow_pickle=True)
            for key in npz_file.files:
                embeddings_data[key] = npz_file[key]
        
        print(f"Loaded {len(embeddings_data)} embeddings")
        
        # Print info about first embedding for debugging
        first_key = list(embeddings_data.keys())[0]
        first_embedding = embeddings_data[first_key]
        print(f"Embedding shape: {first_embedding.shape}")
        
        return embeddings_data
    
    """

    Load ground truth labels from CSV file.



    REQUIRED COLUMNS:

    1. case_id: Unique identifier matching embedding keys

    2. dermatologist_skin_condition_on_label_name: List of condition names

    3. dermatologist_skin_condition_confidence: Confidence scores (typically 1-5)

    4. weighted_skin_condition_label: Importance weights (0-1 range)



    DATA TYPES:

    - case_id must be string to match embedding keys

    - Lists stored as strings (e.g., "['Eczema', 'Psoriasis']") are evaluated

    - Handles various formats: lists, dicts, single values

    """
    def load_dataset(self, csv_path):
        """Load dataset from CSV file"""
        print(f"Loading dataset from {csv_path}...")
        
        if not os.path.exists(csv_path):
            print(f"ERROR: Dataset file not found: {csv_path}")
            return None
        
        try:
            df = pd.read_csv(csv_path, dtype={'case_id': str})
            df['case_id'] = df['case_id'].astype(str)
            
            print(f"Loaded dataset: {len(df)} samples")
            
            # Verify required columns
            required_columns = [
                'case_id',
                'dermatologist_skin_condition_on_label_name',
                'dermatologist_skin_condition_confidence',
                'weighted_skin_condition_label'
            ]
            
            missing_columns = [col for col in required_columns if col not in df.columns]
            if missing_columns:
                print(f"ERROR: Missing required columns: {missing_columns}")
                return None
            
            return df
            
        except Exception as e:
            print(f"Error loading dataset: {e}")
            return None
    """

    Prepare training data with comprehensive validation and preprocessing.



    COMPLEXITY HANDLING:

    This method handles several challenging data characteristics:



    1. **SPARSE MULTI-LABEL MATRICES**: Most samples have few positive labels

    Solution: Track and report sparsity statistics for validation

    

    2. **VARIABLE-LENGTH LISTS**: Different samples have different numbers of

    conditions, confidences, and weights

    Solution: Parse and align lists carefully, use mean values for mismatches

    

    3. **RARE CONDITIONS**: Some conditions appear in very few samples

    Solution: Filter to top N conditions and minimum sample requirements

    

    4. **ZERO VALUES**: Confidence/weight matrices are mostly zeros (sparse)

    Solution: Track zero vs non-zero ratios, fit scalers only on non-zeros



    FILTERING STRATEGY:

    - min_condition_samples: Removes rare conditions with insufficient data

    - max_conditions: Limits to most frequent conditions to prevent overfitting

    - Both filters ensure model focuses on well-represented, learnable patterns



    WHY FILTER CONDITIONS:

    1. **Statistical Validity**: Need sufficient examples to learn patterns

    2. **Generalization**: Rare conditions lead to overfitting

    3. **Computational Efficiency**: Fewer output nodes = faster training

    4. **Clinical Relevance**: Common conditions are higher priority



    MULTI-LABEL MATRIX STRUCTURE:

    Shape: (n_samples, n_conditions)

    - Rows: Individual patient cases

    - Columns: Binary indicators for each condition (1=present, 0=absent)

    - Multiple 1s per row: Multi-label (multiple conditions co-exist)



    CONFIDENCE/WEIGHT MATRICES:

    Shape: (n_samples, n_conditions)

    - Values at (i,j): Confidence/weight for condition j in sample i

    - Zero when condition j not present in sample i (sparse structure)

    - Non-zero only where corresponding multi-label entry is 1



    DATA VALIDATION:

    Extensive logging of:

    - Sample counts (processed vs skipped)

    - Value ranges (min/max/mean)

    - Sparsity statistics (% non-zero)

    - Top conditions by frequency



    This validation is crucial for:

    - Detecting data quality issues early

    - Understanding model input characteristics

    - Debugging training problems

    """
    def prepare_training_data(self, df, embeddings, min_condition_samples=5, max_conditions=30):
        """Prepare training data with improved confidence and weight handling"""
        print("Preparing training data with enhanced validation...")
        
        X = []  # Embeddings
        condition_labels = []  # For multi-label classification
        individual_confidences = []  # Individual confidence per condition
        individual_weights = []  # Individual weight per condition
        
        skipped_count = 0
        processed_count = 0
        confidence_stats = []  # Track confidence values for validation
        weight_stats = []  # Track weight values for validation
        
        for idx, row in df.iterrows():
            try:
                case_id = row['case_id']
                
                if case_id not in embeddings:
                    skipped_count += 1
                    continue
                
                # Parse the label data
                try:
                    # Parse condition names
                    if isinstance(row['dermatologist_skin_condition_on_label_name'], str):
                        condition_names = eval(row['dermatologist_skin_condition_on_label_name'])
                    else:
                        condition_names = row['dermatologist_skin_condition_on_label_name']
                    
                    # Ensure condition_names is a list
                    if not isinstance(condition_names, list):
                        condition_names = [condition_names] if condition_names else []
                    
                    # Parse confidence scores  
                    if isinstance(row['dermatologist_skin_condition_confidence'], str):
                        confidences = eval(row['dermatologist_skin_condition_confidence'])
                    else:
                        confidences = row['dermatologist_skin_condition_confidence']
                    
                    # Ensure confidences is a list and matches conditions
                    if not isinstance(confidences, list):
                        confidences = [confidences] if confidences is not None else []
                    
                    # Match confidence length to conditions
                    if len(confidences) != len(condition_names):
                        if len(confidences) == 1:
                            confidences = confidences * len(condition_names)
                        else:
                            print(f"Warning: Confidence length mismatch for {case_id}, using mean")
                            mean_conf = np.mean(confidences) if confidences else 3.0
                            confidences = [mean_conf] * len(condition_names)
                    
                    # Parse weighted labels
                    if isinstance(row['weighted_skin_condition_label'], str):
                        weighted_labels = eval(row['weighted_skin_condition_label'])
                    else:
                        weighted_labels = row['weighted_skin_condition_label']
                        
                    # Handle different weight formats
                    if isinstance(weighted_labels, dict):
                        # Convert dict to list matching condition order
                        weights = []
                        for condition in condition_names:
                            weights.append(weighted_labels.get(condition, 0.0))
                    elif isinstance(weighted_labels, list):
                        weights = weighted_labels
                        if len(weights) != len(condition_names):
                            if len(weights) == 1:
                                weights = weights * len(condition_names)
                            else:
                                mean_weight = np.mean(weights) if weights else 0.3
                                weights = [mean_weight] * len(condition_names)
                    else:
                        # Single value
                        weights = [weighted_labels] * len(condition_names) if weighted_labels else [0.3] * len(condition_names)
                        
                except Exception as e:
                    print(f"Error parsing data for {case_id}: {e}")
                    skipped_count += 1
                    continue
                
                # Validate data ranges
                try:
                    confidences = [max(0.0, float(c)) for c in confidences]  # Ensure non-negative
                    weights = [max(0.0, min(1.0, float(w))) for w in weights]  # Clamp to [0,1]
                except:
                    print(f"Error converting values for {case_id}, skipping")
                    skipped_count += 1
                    continue
                
                # Add to training data
                X.append(embeddings[case_id])
                condition_labels.append(condition_names)
                
                # Store individual confidences and weights
                individual_confidences.append({
                    'conditions': condition_names,
                    'confidences': confidences
                })
                
                individual_weights.append({
                    'conditions': condition_names,
                    'weights': weights
                })
                
                # Track statistics
                confidence_stats.extend(confidences)
                weight_stats.extend(weights)
                
                processed_count += 1
                    
            except Exception as e:
                print(f"Error processing row {idx}: {e}")
                skipped_count += 1
                continue
        
        print(f"Training data prepared: {processed_count} samples, {skipped_count} skipped")
        
        if len(X) == 0:
            print("ERROR: No training samples found!")
            return None, None, None, None
        
        # Print data statistics
        print(f"\nData validation:")
        print(f"  Confidence values - min: {min(confidence_stats):.3f}, max: {max(confidence_stats):.3f}, mean: {np.mean(confidence_stats):.3f}")
        print(f"  Weight values - min: {min(weight_stats):.3f}, max: {max(weight_stats):.3f}, mean: {np.mean(weight_stats):.3f}")
        print(f"  Non-zero confidences: {sum(1 for c in confidence_stats if c > 0.001)}/{len(confidence_stats)} ({100*sum(1 for c in confidence_stats if c > 0.001)/len(confidence_stats):.1f}%)")
        print(f"  Non-zero weights: {sum(1 for w in weight_stats if w > 0.001)}/{len(weight_stats)} ({100*sum(1 for w in weight_stats if w > 0.001)/len(weight_stats):.1f}%)")
        
        # Convert to numpy arrays
        X = np.array(X)
        
        # Prepare condition labels - focus on top conditions only
        y_conditions_raw = self.mlb.fit_transform(condition_labels)
        condition_counts = y_conditions_raw.sum(axis=0)
        
        # Get top conditions by frequency
        sorted_indices = np.argsort(condition_counts)[::-1]
        top_condition_indices = sorted_indices[:max_conditions]
        
        # Also ensure minimum samples
        frequent_conditions = condition_counts >= min_condition_samples
        final_indices = np.intersect1d(top_condition_indices, np.where(frequent_conditions)[0])
        
        print(f"Total condition classes: {len(self.mlb.classes_)}")
        print(f"Top {max_conditions} most frequent conditions selected")
        print(f"Conditions with at least {min_condition_samples} examples: {frequent_conditions.sum()}")
        
        # Keep only selected conditions
        selected_classes = self.mlb.classes_[final_indices]
        y_conditions = y_conditions_raw[:, final_indices]
        
        # Update MultiLabelBinarizer
        self.mlb = MultiLabelBinarizer()
        self.mlb.classes_ = selected_classes
        
        print(f"Final condition classes: {len(selected_classes)}")
        print(f"Multi-label matrix shape: {y_conditions.shape}")
        
        # Create individual confidence and weight matrices
        y_confidences = np.zeros((len(X), len(selected_classes)))
        y_weights = np.zeros((len(X), len(selected_classes)))
        
        for i, (conf_data, weight_data) in enumerate(zip(individual_confidences, individual_weights)):
            # Map confidences to selected conditions
            for condition, confidence in zip(conf_data['conditions'], conf_data['confidences']):
                if condition in selected_classes:
                    condition_idx = np.where(selected_classes == condition)[0]
                    if len(condition_idx) > 0:
                        y_confidences[i, condition_idx[0]] = confidence
            
            # Map weights to selected conditions
            for condition, weight in zip(weight_data['conditions'], weight_data['weights']):
                if condition in selected_classes:
                    condition_idx = np.where(selected_classes == condition)[0]
                    if len(condition_idx) > 0:
                        y_weights[i, condition_idx[0]] = weight
        
        # Print matrix statistics
        nonzero_conf = (y_confidences > 0.001).sum()
        nonzero_weight = (y_weights > 0.001).sum()
        total_elements = y_confidences.size
        
        print(f"\nMatrix statistics:")
        print(f"  Confidence matrix - non-zero: {nonzero_conf}/{total_elements} ({100*nonzero_conf/total_elements:.1f}%)")
        print(f"  Weight matrix - non-zero: {nonzero_weight}/{total_elements} ({100*nonzero_weight/total_elements:.1f}%)")
        print(f"  Confidence range: {y_confidences[y_confidences > 0].min():.3f} - {y_confidences[y_confidences > 0].max():.3f}")
        print(f"  Weight range: {y_weights[y_weights > 0].min():.3f} - {y_weights[y_weights > 0].max():.3f}")
        
        # Print top conditions
        condition_counts_filtered = y_conditions.sum(axis=0)
        print("\nTop conditions selected:")
        for i, (condition, count) in enumerate(zip(selected_classes, condition_counts_filtered)):
            print(f"  {i+1:2d}. {condition}: {count} samples")
        
        return X, y_conditions, y_confidences, y_weights
    """

    Build multi-output neural network architecture.



    ARCHITECTURE RATIONALE:



    **SHARED LAYERS (512→256→128)**:

    - Purpose: Learn general features useful for all prediction tasks

    - Size progression: Gradual dimensionality reduction (embeddings→features)

    - Batch Normalization: Stabilizes training, allows higher learning rates

    - Dropout (0.3, 0.3, 0.2): Prevents overfitting, forces robust features



    Why this depth: 

    - 3 layers balances capacity (can learn complex patterns) vs simplicity

    - Too shallow: Can't learn complex patterns

    - Too deep: Overfits, slower training, harder to optimize



    **TASK-SPECIFIC BRANCHES**:

    Each branch has 2 layers (64→output) for specialized processing:



    1. **CONDITION CLASSIFICATION BRANCH**:

    - Activation: Sigmoid (outputs independent probabilities per condition)

    - Why sigmoid: Allows multiple conditions to be predicted simultaneously

    - Loss: Binary cross-entropy (standard for multi-label classification)

    

    2. **CONFIDENCE REGRESSION BRANCH**:

    - Activation: Softplus (ensures positive outputs, smooth gradients)

    - Why softplus not ReLU: ReLU outputs exactly zero for negative inputs,

        causing gradient issues. Softplus outputs small positive values instead.

    - Formula: softplus(x) = log(1 + exp(x))

    - Loss: MSE (Mean Squared Error for continuous values)

    - Loss weight: 1.5x (increased to prioritize confidence learning)

    

    3. **WEIGHT REGRESSION BRANCH**:

    - Activation: Sigmoid (ensures [0,1] bounded output)

    - Why sigmoid: Weights represent proportions/probabilities, must be 0-1

    - Loss: MSE (Mean Squared Error for continuous values)

    - Loss weight: 1.2x (slightly increased priority)



    **LOSS WEIGHTING**:

    Different loss scales require weighting for balanced training:

    - Condition loss: Binary cross-entropy, typically ~0.3-0.7

    - Confidence loss: MSE on scaled values, typically ~0.01-0.1

    - Weight loss: MSE on scaled values, typically ~0.01-0.1



    Weights (1.0, 1.5, 1.2) ensure:

    - All tasks contribute meaningfully to total loss

    - Confidence gets extra emphasis (was underfitting in previous versions)

    - Gradient magnitudes are balanced across tasks



    **WHY ADAM OPTIMIZER**:

    - Adaptive learning rates per parameter (handles different loss scales)

    - Momentum for faster convergence

    - Robust to hyperparameter choices

    - Industry standard for multi-task learning



    **MODEL COMPILATION**:

    The model uses a dictionary output format allowing:

    - Clear separation of different predictions

    - Easy access to specific outputs during inference

    - Flexible loss and metric assignment per output

    """
    def build_model(self, input_dim, num_conditions, learning_rate=0.001):
        """Build neural network with improved confidence and weight outputs"""
        print("Building improved neural network model...")
        
        # Input layer
        inputs = keras.Input(shape=(input_dim,), name='embeddings')
        
        # Shared feature extraction layers
        x = layers.Dense(512, activation='relu', name='dense1')(inputs)  # Increased capacity
        x = layers.BatchNormalization(name='bn1')(x)
        x = layers.Dropout(0.3, name='dropout1')(x)
        
        x = layers.Dense(256, activation='relu', name='dense2')(x)
        x = layers.BatchNormalization(name='bn2')(x)
        x = layers.Dropout(0.3, name='dropout2')(x)
        
        x = layers.Dense(128, activation='relu', name='dense3')(x)
        x = layers.BatchNormalization(name='bn3')(x)
        x = layers.Dropout(0.2, name='dropout3')(x)
        
        # Multi-label condition classification head
        condition_branch = layers.Dense(64, activation='relu', name='condition_dense')(x)
        condition_branch = layers.Dropout(0.2, name='condition_dropout')(condition_branch)
        condition_output = layers.Dense(num_conditions, activation='sigmoid', 
                                      name='conditions')(condition_branch)
        
        # Individual confidence regression head - FIXED ACTIVATION
        confidence_branch = layers.Dense(64, activation='relu', name='confidence_dense1')(x)
        confidence_branch = layers.Dropout(0.2, name='confidence_dropout1')(confidence_branch)
        confidence_branch = layers.Dense(32, activation='relu', name='confidence_dense2')(confidence_branch)
        confidence_branch = layers.Dropout(0.1, name='confidence_dropout2')(confidence_branch)
        # Changed from ReLU to softplus - ensures positive, non-zero outputs
        confidence_output = layers.Dense(num_conditions, activation='softplus', 
                                       name='individual_confidences')(confidence_branch)
        
        # Individual weight regression head  
        weighted_branch = layers.Dense(64, activation='relu', name='weighted_dense1')(x)
        weighted_branch = layers.Dropout(0.2, name='weighted_dropout1')(weighted_branch)
        weighted_branch = layers.Dense(32, activation='relu', name='weighted_dense2')(weighted_branch)
        weighted_branch = layers.Dropout(0.1, name='weighted_dropout2')(weighted_branch)
        # Use sigmoid to ensure 0-1 range
        weighted_output = layers.Dense(num_conditions, activation='sigmoid', 
                                     name='individual_weights')(weighted_branch)
        
        # Create model
        model = keras.Model(
            inputs=inputs,
            outputs={
                'conditions': condition_output,
                'individual_confidences': confidence_output,
                'individual_weights': weighted_output
            }
        )
        
        # Compile model with improved loss weights
        model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
            loss={
                'conditions': 'binary_crossentropy',
                'individual_confidences': 'mse',
                'individual_weights': 'mse'
            },
            loss_weights={
                'conditions': 1.0,
                'individual_confidences': 1.5,  # Increased weight for confidence
                'individual_weights': 1.2      # Increased weight for weights
            },
            metrics={
                'conditions': ['accuracy'],
                'individual_confidences': ['mae'],
                'individual_weights': ['mae']
            }
        )
        
        return model
    
    """

    Main training orchestration method with improved confidence handling.



    TRAINING PIPELINE:

    1. Load data (embeddings + labels)

    2. Prepare training matrices (parse, filter, validate)

    3. Scale features and outputs

    4. Split train/validation sets

    5. Build neural network architecture

    6. Train with callbacks (early stopping, LR reduction, checkpointing)

    7. Evaluate performance

    8. Save trained model



    IMPROVED SCALING STRATEGY (KEY FIX):

    Problem: Previous version scaled all values including zeros

    Solution: Fit scalers only on non-zero values



    Why this matters:

    - Sparse matrices have many structural zeros (condition not present)

    - Including zeros in scaler fitting shifts mean artificially low

    - Model learns to predict near-zero for everything

    - Confidence predictions collapsed to ~0 (major bug)



    New approach:

    ```python

    conf_nonzero = y_confidences[y_confidences > 0.001]

    self.confidence_scaler.fit(conf_nonzero)



    Only non-zero values determine scale

    Model learns actual confidence distribution (1-5 range)

    Predictions are meaningful positive values



    FALLBACK HANDLING:

    If too few non-zero values exist:



    Use sensible dummy values (1-5 for confidence, 0-1 for weights)

    Prevents scaler failure on edge cases

    Ensures training can proceed



    TRAIN/TEST SPLIT:



    80/20 split is standard for medical ML

    Stratification not used (multi-label makes it complex)

    Random state fixed for reproducibility



    CALLBACKS:



    Early Stopping (patience=12):



    Monitors validation loss

    Stops if no improvement for 12 epochs

    Restores best weights (not final weights)

    Why: Prevents overfitting to training set



    ReduceLROnPlateau (factor=0.5, patience=5):



    Monitors confidence loss specifically (was problematic)

    Reduces LR by 50% if loss plateaus

    Allows fine-tuning in late training

    Min LR: 1e-7 prevents excessive reduction



    ModelCheckpoint:



    Saves best model weights during training

    Insurance against training divergence

    Cleaned up after successful training



    TRAINING DURATION:



    60 epochs maximum (increased from 50)

    Early stopping typically triggers around epoch 30-40

    Batch size 32 balances memory vs convergence speed



    HYPERPARAMETERS:



    Learning rate: 0.001 (standard for Adam)

    Batch size: 32 (good for datasets of this size)

    Test split: 0.2 (20% validation, standard practice)



    POST-TRAINING:



    Comprehensive evaluation on test set

    Detailed metrics for all three outputs

    Analysis of confidence prediction quality

    """
    def train(self, npz_file_path="derm_foundation_embeddings.npz", 

              csv_file_path="dataset_scin_labels.csv", 

              test_size=0.2, random_state=42, epochs=50, batch_size=32,

              learning_rate=0.001):
        """Train the neural network with improved confidence handling"""
        
        # Load data
        embeddings = self.load_embeddings(npz_file_path)
        if embeddings is None:
            return False
        
        df = self.load_dataset(csv_file_path)
        if df is None:
            return False
        
        # Prepare training data
        X, y_conditions, y_confidences, y_weights = self.prepare_training_data(df, embeddings)
        if X is None:
            return False
        
        # IMPROVED SCALING - fit only on non-zero values
        print("\nFitting scalers...")
        X_scaled = self.embedding_scaler.fit_transform(X)
        
        # Fit confidence scaler on non-zero values only
        conf_nonzero = y_confidences[y_confidences > 0.001]
        if len(conf_nonzero) > 50:  # Ensure we have enough data
            print(f"Fitting confidence scaler on {len(conf_nonzero)} non-zero values")
            self.confidence_scaler.fit(conf_nonzero.reshape(-1, 1))
        else:
            print("WARNING: Too few non-zero confidence values, using default scaling")
            # Use a reasonable range for confidence values (e.g., 1-5 scale)
            dummy_conf = np.array([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1)
            self.confidence_scaler.fit(dummy_conf)
        
        # Fit weight scaler on non-zero values only
        weight_nonzero = y_weights[y_weights > 0.001]
        if len(weight_nonzero) > 50:
            print(f"Fitting weight scaler on {len(weight_nonzero)} non-zero values")
            self.weighted_scaler.fit(weight_nonzero.reshape(-1, 1))
        else:
            print("WARNING: Too few non-zero weight values, using default scaling")
            # Use a reasonable range for weight values (0-1 scale)
            dummy_weight = np.array([0.1, 0.3, 0.5, 0.7, 0.9]).reshape(-1, 1)
            self.weighted_scaler.fit(dummy_weight)
        
        # Apply scaling to the matrices (preserve structure)
        y_confidences_scaled = np.zeros_like(y_confidences)
        y_weights_scaled = np.zeros_like(y_weights)
        
        # Scale only non-zero values
        for i in range(y_confidences.shape[0]):
            for j in range(y_confidences.shape[1]):
                if y_confidences[i, j] > 0.001:
                    y_confidences_scaled[i, j] = self.confidence_scaler.transform([[y_confidences[i, j]]])[0, 0]
                if y_weights[i, j] > 0.001:
                    y_weights_scaled[i, j] = self.weighted_scaler.transform([[y_weights[i, j]]])[0, 0]
        
        print(f"Scaled confidence range: {y_confidences_scaled[y_confidences_scaled != 0].min():.3f} - {y_confidences_scaled[y_confidences_scaled != 0].max():.3f}")
        print(f"Scaled weight range: {y_weights_scaled[y_weights_scaled != 0].min():.3f} - {y_weights_scaled[y_weights_scaled != 0].max():.3f}")
        
        # Split data
        X_train, X_test, y_cond_train, y_cond_test, y_conf_train, y_conf_test, y_weight_train, y_weight_test = train_test_split(
            X_scaled, y_conditions, y_confidences_scaled, y_weights_scaled,
            test_size=test_size, random_state=random_state
        )
        
        print(f"\nTraining/test split:")
        print(f"  Training samples: {X_train.shape[0]}")
        print(f"  Test samples: {X_test.shape[0]}")
        
        # Build model
        self.model = self.build_model(
            input_dim=X_scaled.shape[1],
            num_conditions=y_conditions.shape[1],
            learning_rate=learning_rate
        )
        
        print(f"\nModel architecture:")
        self.model.summary()
        
        # Prepare training data
        train_data = {
            'conditions': y_cond_train,
            'individual_confidences': y_conf_train,
            'individual_weights': y_weight_train
        }
        
        val_data = {
            'conditions': y_cond_test,
            'individual_confidences': y_conf_test,
            'individual_weights': y_weight_test
        }
        
        # Enhanced callbacks
        early_stopping = keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=12,  # Increased patience
            restore_best_weights=True,
            verbose=1
        )
        
        reduce_lr = keras.callbacks.ReduceLROnPlateau(
            monitor='val_individual_confidences_loss',  # Monitor confidence loss specifically
            factor=0.5,
            patience=5,
            min_lr=1e-7,
            mode='min',  # We want to minimize the loss
            verbose=1
        )
        
        model_checkpoint = keras.callbacks.ModelCheckpoint(
            filepath='best_model_fixed.weights.h5',
            monitor='val_loss',
            save_best_only=True,
            save_weights_only=True,
            verbose=1
        )
        
        print(f"\nStarting training for {epochs} epochs...")
        
        # Train model
        self.history = self.model.fit(
            X_train, train_data,
            validation_data=(X_test, val_data),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=[early_stopping, reduce_lr, model_checkpoint],
            verbose=1
        )
        
        # Evaluate model
        self.evaluate_model(X_test, y_cond_test, y_conf_test, y_weight_test)
        
        return True
    """

    Comprehensive model evaluation with enhanced confidence analysis.

    EVALUATION METRICS:

    1. MULTI-LABEL CLASSIFICATION (Conditions):

    Hamming Loss:



    Definition: Fraction of incorrectly predicted labels

    Range: [0, 1] where 0 is perfect

    Formula: (1/n_labels) × Σ|y_true ⊕ y_pred|

    Example: If 2 out of 30 labels are wrong, hamming loss = 0.067

    Clinical interpretation: Lower is better, <0.1 is excellent



    Exact Match Accuracy:



    Strictest metric: Requires ALL labels perfectly correct

    Range: [0, 1] where 1 is perfect

    Why include: Shows complete prediction correctness

    Medical context: Exact match is ideal but rarely achievable

    (even expert dermatologists disagree on some cases)



    Average Conditions per Sample:



    Describes label distribution complexity

    Higher values → harder multi-label problem

    Typical range: 1-3 conditions per sample



    2. CONFIDENCE REGRESSION:

    Why evaluate only non-zero targets:



    Zeros are structural (condition not present)

    Including zeros conflates two problems:

    a) Predicting which conditions exist (classification task)

    b) Predicting confidence for existing conditions (regression task)

    We want to evaluate (b) separately



    Inverse Transform:



    Converts scaled predictions back to original scale

    Necessary for interpretable metrics

    Example: Scaled 0.3 → Original 3.2 (on 1-5 scale)



    MSE (Mean Squared Error):



    Sensitive to large errors (squared penalty)

    Unit: (confidence units)²

    Lower is better



    MAE (Mean Absolute Error):



    Average absolute difference from ground truth

    Same units as original values

    More robust to outliers than MSE

    Clinical interpretation: If MAE=0.5, average error is ±0.5 points



    RMSE (Root Mean Squared Error):



    Square root of MSE

    Same units as original values (easier to interpret than MSE)

    Emphasizes larger errors more than MAE



    Prediction Range Analysis:



    Verifies predictions are in sensible range

    Example: If ground truth is 1-5, predictions should be similar

    Out-of-range predictions indicate scaling or activation issues



    3. WEIGHT REGRESSION:

    Same metrics as confidence but for weight values (0-1 range)

    DIAGNOSTIC CHECKS:



    "Predictions > 0.1" percentage: Ensures model isn't predicting near-zero

    Range comparison: Truth vs prediction ranges should align

    Non-zero count: Verifies sparse structure is respected



    WHY THIS EVALUATION IS COMPREHENSIVE:



    Multiple metrics cover different aspects (classification + regression)

    Separate evaluation of sparse vs dense regions

    Original scale metrics (clinically interpretable)

    Diagnostic checks for common failure modes

    Both aggregate (MSE) and per-sample (MAE) metrics

    """
    def evaluate_model(self, X_test, y_cond_test, y_conf_test, y_weight_test):
        """Evaluate the trained model with enhanced confidence analysis"""
        print("\n" + "="*70)
        print("MODEL EVALUATION - ENHANCED CONFIDENCE ANALYSIS")
        print("="*70)
        
        # Make predictions
        predictions = self.model.predict(X_test)
        y_cond_pred = predictions['conditions']
        y_conf_pred = predictions['individual_confidences']
        y_weight_pred = predictions['individual_weights']
        
        # Condition classification evaluation
        y_cond_pred_binary = (y_cond_pred > 0.5).astype(int)
        hamming = hamming_loss(y_cond_test, y_cond_pred_binary)
        exact_match = (y_cond_pred_binary == y_cond_test).all(axis=1).mean()
        
        print(f"Multi-label Condition Classification:")
        print(f"  Hamming Loss: {hamming:.4f}")
        print(f"  Exact Match Accuracy: {exact_match:.4f}")
        print(f"  Average conditions per sample: {y_cond_test.sum(axis=1).mean():.2f}")
        
        # ENHANCED confidence evaluation
        print(f"\nConfidence Prediction Analysis:")
        print(f"  Raw prediction range: {y_conf_pred.min():.6f} - {y_conf_pred.max():.6f}")
        print(f"  Non-zero predictions: {(y_conf_pred > 0.001).sum()}/{y_conf_pred.size}")
        
        # Inverse transform and evaluate confidence
        conf_mask = y_conf_test > 0.001
        if conf_mask.sum() > 0:
            y_conf_test_orig = np.zeros_like(y_conf_test)
            y_conf_pred_orig = np.zeros_like(y_conf_pred)
            
            # Inverse transform
            for i in range(y_conf_test.shape[0]):
                for j in range(y_conf_test.shape[1]):
                    if y_conf_test[i, j] > 0.001:
                        y_conf_test_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_test[i, j]]])[0, 0]
                    if y_conf_pred[i, j] > 0.001:
                        y_conf_pred_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_pred[i, j]]])[0, 0]
            
            # Calculate metrics only on positions where ground truth is non-zero
            conf_test_nonzero = y_conf_test_orig[conf_mask]
            conf_pred_nonzero = y_conf_pred_orig[conf_mask]
            
            conf_mse = mean_squared_error(conf_test_nonzero, conf_pred_nonzero)
            conf_mae = mean_absolute_error(conf_test_nonzero, conf_pred_nonzero)
            
            print(f"  Individual Confidence Regression (on {conf_mask.sum()} non-zero targets):")
            print(f"    MSE: {conf_mse:.4f}")
            print(f"    MAE: {conf_mae:.4f}")
            print(f"    RMSE: {np.sqrt(conf_mse):.4f}")
            print(f"    Prediction range (orig scale): {conf_pred_nonzero.min():.3f} - {conf_pred_nonzero.max():.3f}")
            print(f"    Ground truth range (orig scale): {conf_test_nonzero.min():.3f} - {conf_test_nonzero.max():.3f}")
            
            # Check if predictions are reasonable
            reasonable_predictions = (conf_pred_nonzero > 0.1).sum()
            print(f"    Predictions > 0.1: {reasonable_predictions}/{len(conf_pred_nonzero)} ({100*reasonable_predictions/len(conf_pred_nonzero):.1f}%)")
        
        # Individual weight evaluation
        weight_mask = y_weight_test > 0.001
        if weight_mask.sum() > 0:
            y_weight_test_orig = np.zeros_like(y_weight_test)
            y_weight_pred_orig = np.zeros_like(y_weight_pred)
            
            for i in range(y_weight_test.shape[0]):
                for j in range(y_weight_test.shape[1]):
                    if y_weight_test[i, j] > 0.001:
                        y_weight_test_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_test[i, j]]])[0, 0]
                    if y_weight_pred[i, j] > 0.001:
                        y_weight_pred_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_pred[i, j]]])[0, 0]
            
            weight_test_nonzero = y_weight_test_orig[weight_mask]
            weight_pred_nonzero = y_weight_pred_orig[weight_mask]
            
            weight_mse = mean_squared_error(weight_test_nonzero, weight_pred_nonzero)
            weight_mae = mean_absolute_error(weight_test_nonzero, weight_pred_nonzero)
            
            print(f"\nIndividual Weight Regression (on {weight_mask.sum()} non-zero targets):")
            print(f"  MSE: {weight_mse:.4f}")
            print(f"  MAE: {weight_mae:.4f}")
            print(f"  RMSE: {np.sqrt(weight_mse):.4f}")
            print(f"  Prediction range (orig scale): {weight_pred_nonzero.min():.3f} - {weight_pred_nonzero.max():.3f}")
            print(f"  Ground truth range (orig scale): {weight_test_nonzero.min():.3f} - {weight_test_nonzero.max():.3f}")
    """

    Make predictions on new embeddings with comprehensive output formatting.

    PREDICTION PIPELINE:



    Scale input embedding (using training-fitted scaler)

    Forward pass through neural network

    Process raw outputs:



    Condition probabilities: Sigmoid outputs [0,1]

    Confidence values: Softplus outputs [0,∞)

    Weight values: Sigmoid outputs [0,1]



    Inverse transform regression outputs to original scale

    Apply threshold to select predicted conditions

    Return structured dictionary with multiple views of predictions



    THRESHOLD STRATEGY:



    Condition threshold: 0.3 (lower than typical 0.5)

    Why lower: Medical diagnosis prefers sensitivity (catch more conditions)

    False positives less harmful than false negatives in screening

    Can be adjusted based on clinical requirements



    OUTPUT STRUCTURE:

    Primary Predictions (conditions above threshold):



    dermatologist_skin_condition_on_label_name: List of predicted conditions

    dermatologist_skin_condition_confidence: Confidence per predicted condition

    weighted_skin_condition_label: Weight dict for predicted conditions



    Comprehensive View (all conditions):



    all_condition_probabilities: Probability for every possible condition

    all_individual_confidences: Confidence for every possible condition

    all_individual_weights: Weight for every possible condition



    Debugging Information:



    raw_confidence_outputs: Pre-transform neural network outputs

    raw_weight_outputs: Pre-transform neural network outputs

    condition_threshold: Threshold used for filtering



    Why provide multiple views:



    Primary predictions: For direct clinical use

    Comprehensive view: For ranking, uncertainty quantification

    Debug info: For model validation and troubleshooting



    MINIMUM VALUE CLAMPING:

    pythonconfidence_orig = max(0.1, confidence_orig)

    weight_orig = max(0.01, weight_orig)



    Ensures predictions are never exactly zero

    Confidence ≥0.1: Even lowest predictions are meaningful

    Weight ≥0.01: Prevents division-by-zero in downstream processing



    SOFTPLUS ADVANTAGE:

    With softplus activation, even very negative inputs produce small positive

    outputs, so confidence predictions naturally avoid zero. The max(0.1, x)

    provides additional safety margin.

    RETURN FORMAT:

    Dictionary structure allows:



    Easy access to specific prediction types

    Clear semantic meaning (key names describe contents)

    Extensible (can add new keys without breaking existing code)

    JSON-serializable for API deployment

    """
    def predict(self, embedding):
        """Make predictions on a single embedding with individual outputs"""
        if self.model is None:
            print("ERROR: Model not trained. Call train() first.")
            return None
        
        if len(embedding.shape) == 1:
            embedding = embedding.reshape(1, -1)
        
        # Scale embedding
        embedding_scaled = self.embedding_scaler.transform(embedding)
        
        # Make predictions
        predictions = self.model.predict(embedding_scaled, verbose=0)
        
        # Process condition predictions
        condition_probs = predictions['conditions'][0]
        individual_confidences = predictions['individual_confidences'][0]
        individual_weights = predictions['individual_weights'][0]
        
        # Get predicted conditions (above threshold)
        condition_threshold = 0.3  # Lower threshold
        predicted_condition_indices = np.where(condition_probs > condition_threshold)[0]
        
        # Build results
        predicted_conditions = []
        predicted_confidences = []
        predicted_weights_dict = {}
        
        for idx in predicted_condition_indices:
            condition_name = self.mlb.classes_[idx]
            condition_prob = float(condition_probs[idx])
            
            # Inverse transform individual outputs with better handling
            confidence_raw = individual_confidences[idx]
            weight_raw = individual_weights[idx]
            
            # Always inverse transform, even small values (softplus ensures non-zero)
            confidence_orig = self.confidence_scaler.inverse_transform([[confidence_raw]])[0, 0]
            weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0]
            
            predicted_conditions.append(condition_name)
            predicted_confidences.append(max(0.1, confidence_orig))  # Minimum confidence of 0.1
            predicted_weights_dict[condition_name] = max(0.01, weight_orig)  # Minimum weight of 0.01
        
        # Also provide all condition probabilities for reference
        all_condition_probs = {}
        all_confidences = {}
        all_weights = {}
        
        for i, class_name in enumerate(self.mlb.classes_):
            all_condition_probs[class_name] = float(condition_probs[i])
            
            # Always inverse transform all outputs
            conf_raw = individual_confidences[i]
            weight_raw = individual_weights[i]
            
            conf_orig = self.confidence_scaler.inverse_transform([[conf_raw]])[0, 0]
            weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0]
            
            all_confidences[class_name] = max(0.0, conf_orig)
            all_weights[class_name] = max(0.0, weight_orig)
        
        return {
            # Main predicted results (above threshold)
            'dermatologist_skin_condition_on_label_name': predicted_conditions,
            'dermatologist_skin_condition_confidence': predicted_confidences,
            'weighted_skin_condition_label': predicted_weights_dict,
            
            # Additional information for analysis
            'all_condition_probabilities': all_condition_probs,
            'all_individual_confidences': all_confidences,
            'all_individual_weights': all_weights,
            'condition_threshold': condition_threshold,
            
            # Debug information
            'raw_confidence_outputs': individual_confidences.tolist(),
            'raw_weight_outputs': individual_weights.tolist()
        }
    
    def plot_training_history(self):
        if self.history is None:
            print("No training history available")
            return
        
        # Set matplotlib to use non-interactive backend
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        
        # Loss
        axes[0, 0].plot(self.history.history['loss'], label='Training Loss')
        axes[0, 0].plot(self.history.history['val_loss'], label='Validation Loss')
        axes[0, 0].set_title('Model Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        
        # Condition accuracy
        axes[0, 1].plot(self.history.history['conditions_accuracy'], label='Training Accuracy')
        axes[0, 1].plot(self.history.history['val_conditions_accuracy'], label='Validation Accuracy')
        axes[0, 1].set_title('Condition Classification Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        
        # Individual Confidence MAE
        axes[0, 2].plot(self.history.history['individual_confidences_mae'], label='Training MAE')
        axes[0, 2].plot(self.history.history['val_individual_confidences_mae'], label='Validation MAE')
        axes[0, 2].set_title('Individual Confidence MAE')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('MAE')
        axes[0, 2].legend()
        
        # Individual Weight MAE
        axes[1, 0].plot(self.history.history['individual_weights_mae'], label='Training MAE')
        axes[1, 0].plot(self.history.history['val_individual_weights_mae'], label='Validation MAE')
        axes[1, 0].set_title('Individual Weight MAE')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('MAE')
        axes[1, 0].legend()
        
        # Individual confidence loss
        axes[1, 1].plot(self.history.history['individual_confidences_loss'], label='Training Loss')
        axes[1, 1].plot(self.history.history['val_individual_confidences_loss'], label='Validation Loss')
        axes[1, 1].set_title('Individual Confidence Loss')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].legend()
        
        # Individual weight loss
        axes[1, 2].plot(self.history.history['individual_weights_loss'], label='Training Loss')
        axes[1, 2].plot(self.history.history['val_individual_weights_loss'], label='Validation Loss')
        axes[1, 2].set_title('Individual Weight Loss')
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Loss')
        axes[1, 2].legend()
        
        plt.tight_layout()
        plt.savefig('training_history_fixed.png', dpi=300, bbox_inches='tight')
        print("Training history plot saved as: training_history_fixed.png")
        plt.close()
    """

    Persist trained model and preprocessing components to disk.



    SAVED COMPONENTS:



    1. **Keras Model (.keras file)**:

    - Neural network architecture

    - Trained weights for all layers

    - Optimizer state (for resuming training)

    - Compilation settings (loss functions, metrics)



    2. **Preprocessing Data (.pkl file)**:

    - MultiLabelBinarizer: Maps condition names ↔ indices

    - embedding_scaler: Normalizes input embeddings

    - confidence_scaler: Normalizes confidence values

    - weighted_scaler: Normalizes weight values

    - Path to .keras file (for loading)



    WHY SEPARATE FILES:

    - Keras models save to modern .keras format

    - Scikit-learn components need pickle serialization

    - Separation allows independent updates of each component



    LOADING REQUIREMENT:

    Both files are needed for inference:

    - .keras: Neural network for making predictions

    - .pkl: Preprocessors for transforming inputs/outputs



    FILE ORGANIZATION:

    easi_severity_model_derm_foundation_individual_fixed.pkl  (main file)

    easi_severity_model_derm_foundation_individual_fixed.keras (model)

    User loads .pkl file, which contains path to .keras file



    CLEANUP:

    Removes temporary checkpoint file (best_model_fixed.weights.h5)

    created during training to avoid confusion with final model.



    ERROR HANDLING:

    Checks if model exists before saving, provides clear error messages

    and file paths for debugging.

    """

    def save_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"):
        """Save the trained model"""
        if self.model is None:
            print("ERROR: No trained model to save.")
            return False
        
        # Get current directory
        current_dir = os.getcwd()
        
        # Save Keras model with proper extension
        model_filename = os.path.splitext(filepath)[0]
        keras_model_path = os.path.join(current_dir, f"{model_filename}.keras")
        
        print(f"Saving Keras model to: {keras_model_path}")
        self.model.save(keras_model_path)
        
        # Save preprocessing components
        pkl_filepath = os.path.join(current_dir, filepath)
        model_data = {
            'mlb': self.mlb,
            'embedding_scaler': self.embedding_scaler,
            'confidence_scaler': self.confidence_scaler,
            'weighted_scaler': self.weighted_scaler,
            'keras_model_path': keras_model_path
        }
        
        print(f"Saving preprocessing data to: {pkl_filepath}")
        with open(pkl_filepath, 'wb') as f:
            pickle.dump(model_data, f)
        
        print(f"Model saved successfully!")
        print(f"  - Main file: {pkl_filepath}")
        print(f"  - Keras model: {keras_model_path}")
        
        # Clean up temporary checkpoint file
        checkpoint_file = os.path.join(current_dir, 'best_model_fixed.weights.h5')
        if os.path.exists(checkpoint_file):
            os.remove(checkpoint_file)
            print(f"  - Cleaned up temporary checkpoint file")
        
        return True
    
    def load_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"):
        """Load trained model"""
        if not os.path.exists(filepath):
            print(f"ERROR: Model file not found: {filepath}")
            return False
        
        try:
            with open(filepath, 'rb') as f:
                model_data = pickle.load(f)
            
            # Load preprocessing components
            self.mlb = model_data['mlb']
            self.embedding_scaler = model_data['embedding_scaler']
            self.confidence_scaler = model_data['confidence_scaler']
            self.weighted_scaler = model_data['weighted_scaler']
            
            # Load Keras model
            keras_model_path = model_data['keras_model_path']
            if os.path.exists(keras_model_path):
                self.model = keras.models.load_model(keras_model_path)
                print(f"Model loaded from {filepath}")
                print(f"Available condition classes: {len(self.mlb.classes_)}")
                return True
            else:
                print(f"ERROR: Keras model not found at {keras_model_path}")
                return False
                
        except Exception as e:
            print(f"Error loading model: {e}")
            return False

"""

WORKFLOW:



1. Print configuration and fixes applied (user visibility)

2. Initialize classifier

3. Validate input files exist

4. Train model with improved confidence handling

5. Plot training history

6. Test model predictions (validate fix effectiveness)

7. Save trained model



MODEL TESTING (NEW):

After training completes, runs a sample prediction to verify:



Model produces non-zero confidence values (fix validation)

Predictions are in expected ranges

Output structure is correct



This immediate validation catches issues before deployment.

WHY TEST WITH SAMPLE:



Confirms confidence scaling fix worked

Provides immediate feedback on model quality

Demonstrates expected output format

Catches activation function issues (like ReLU→0 bug)



SUCCESS CRITERIA:

✅ Non-zero confidences in reasonable range (e.g., 1-5)

✅ Multiple conditions predicted with varying probabilities

✅ Weights sum to reasonable values

⚠️ Warning if confidence outputs still mostly zero

"""

def main():
    """Main training function with enhanced confidence handling"""
    print("Derm Foundation Neural Network Classifier Training - FIXED VERSION")
    print("="*70)
    print("FIXES APPLIED:")
    print("- Changed confidence activation from ReLU to softplus")
    print("- Improved confidence scaler fitting (non-zero values only)")
    print("- Increased confidence loss weight (1.5x)")
    print("- Enhanced data validation and preprocessing")
    print("- Better handling of sparse confidence/weight matrices")
    print("="*70)
    print("Training neural network to predict:")
    print("1. Skin conditions (multi-label classification)")
    print("2. Individual confidence scores per condition (regression)")
    print("3. Individual weight scores per condition (regression)")
    print("="*70)
    
    # Initialize classifier
    classifier = DermFoundationNeuralNetwork()
    
    # File paths
    npz_file = "derm_foundation_embeddings.npz"
    csv_file = "dataset_scin_labels.csv"
    model_output = "easi_severity_model_derm_foundation_individual_fixed.pkl"
    
    # Check if files exist
    missing_files = []
    if not os.path.exists(npz_file):
        missing_files.append(npz_file)
    if not os.path.exists(csv_file):
        missing_files.append(csv_file)
    
    if missing_files:
        print(f"ERROR: Missing required files:")
        for file in missing_files:
            print(f"  - {file}")
        return
    
    try:
        # Train the model
        success = classifier.train(
            npz_file_path=npz_file,
            csv_file_path=csv_file,
            epochs=60,  # Increased epochs
            batch_size=32,
            learning_rate=0.001
        )
        
        if not success:
            print("Training failed!")
            return
        
        # Plot training history
        try:
            classifier.plot_training_history()
        except Exception as e:
            print(f"Could not plot training history: {e}")
        
        # Test the model with a sample prediction to verify confidence outputs
        print("\n" + "="*70)
        print("TESTING MODEL OUTPUTS")
        print("="*70)
        
        # Get a sample embedding for testing
        try:
            embeddings = classifier.load_embeddings(npz_file)
            if embeddings:
                sample_key = list(embeddings.keys())[0]
                sample_embedding = embeddings[sample_key]
                
                print(f"Testing with sample embedding: {sample_key}")
                test_result = classifier.predict(sample_embedding)
                
                if test_result:
                    print("✅ Model prediction successful!")
                    print(f"Predicted conditions: {len(test_result['dermatologist_skin_condition_on_label_name'])}")
                    
                    # Check confidence outputs
                    all_confidences = list(test_result['all_individual_confidences'].values())
                    nonzero_conf = sum(1 for c in all_confidences if c > 0.01)
                    
                    print(f"Confidence range: {min(all_confidences):.4f} - {max(all_confidences):.4f}")
                    print(f"Non-zero confidences: {nonzero_conf}/{len(all_confidences)}")
                    
                    if nonzero_conf > 0:
                        print("✅ CONFIDENCE ISSUE APPEARS TO BE FIXED!")
                    else:
                        print("⚠️ Confidence outputs still mostly zero - may need further investigation")
                    
                    # Show top predictions
                    if test_result['dermatologist_skin_condition_on_label_name']:
                        print(f"\nSample predictions:")
                        for i, condition in enumerate(test_result['dermatologist_skin_condition_on_label_name'][:3]):
                            prob = test_result['all_condition_probabilities'][condition]
                            conf = test_result['dermatologist_skin_condition_confidence'][i]
                            weight = test_result['weighted_skin_condition_label'][condition]
                            print(f"  {condition}: prob={prob:.3f}, conf={conf:.3f}, weight={weight:.3f}")
                else:
                    print("❌ Model prediction failed")
        except Exception as e:
            print(f"Could not test model: {e}")
        
        # Save the model
        classifier.save_model(model_output)
        
        print(f"\n{'='*70}")
        print("TRAINING COMPLETE!")
        print(f"{'='*70}")
        print(f"Model saved as: {model_output}")
        print(f"Training history plot saved as: training_history_fixed.png")
        print(f"\nTo use the trained model:")
        print(f"```python")
        print(f"classifier = DermFoundationNeuralNetwork()")
        print(f"classifier.load_model('{model_output}')")
        print(f"result = classifier.predict(embedding)")
        print(f"print(result['dermatologist_skin_condition_on_label_name'])")
        print(f"print(result['dermatologist_skin_condition_confidence'])")
        print(f"print(result['weighted_skin_condition_label'])")
        print(f"```")
        
        # Example prediction output format
        print(f"\nExpected prediction output format:")
        print(f"{{")
        print(f"  'dermatologist_skin_condition_on_label_name': ['Eczema', 'Irritant Contact Dermatitis'],")
        print(f"  'dermatologist_skin_condition_confidence': [4.2, 3.1],")
        print(f"  'weighted_skin_condition_label': {{'Eczema': 0.65, 'Irritant Contact Dermatitis': 0.35}}")
        print(f"}}")
        
    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()