update for binary
Browse files- bias_auc.py +6 -6
bias_auc.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import evaluate
|
| 2 |
import datasets
|
| 3 |
from datasets.features import Sequence, Value, ClassLabel
|
|
@@ -30,7 +32,6 @@ Args:
|
|
| 30 |
label list[int]: list containing label index for each item
|
| 31 |
output list[list[float]]: list of model output values for each
|
| 32 |
subgroup list[str] (optional): list of subgroups that appear in target to compute metric over
|
| 33 |
-
|
| 34 |
Returns (for each subgroup in target):
|
| 35 |
'Subgroup' : Subgroup AUC score,
|
| 36 |
'BPSN' : BPSN (Background Positive, Subgroup Negative) AUC,
|
|
@@ -49,7 +50,6 @@ Example:
|
|
| 49 |
... [0.4341845214366913, 0.5658154487609863],
|
| 50 |
... [0.400595098733902, 0.5994048714637756],
|
| 51 |
... [0.3840397894382477, 0.6159601807594299]]
|
| 52 |
-
|
| 53 |
>>> metric = load('Intel/bias_auc')
|
| 54 |
>>> metric.add_batch(target=target,
|
| 55 |
label=label,
|
|
@@ -67,7 +67,7 @@ class BiasAUC(evaluate.Metric):
|
|
| 67 |
features=datasets.Features(
|
| 68 |
{
|
| 69 |
'target': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
| 70 |
-
'label':
|
| 71 |
'output': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
|
| 72 |
}
|
| 73 |
),
|
|
@@ -78,7 +78,7 @@ class BiasAUC(evaluate.Metric):
|
|
| 78 |
"""Returns label and output score from `targets` and `labels`
|
| 79 |
if `subgroup` is in list of targeted groups found in `targets`
|
| 80 |
"""
|
| 81 |
-
target_class = target_class if target_class is not None else
|
| 82 |
for target, label, result in zip(targets, labels, outputs):
|
| 83 |
if subgroup in target:
|
| 84 |
yield label, result[target_class]
|
|
@@ -89,7 +89,7 @@ class BiasAUC(evaluate.Metric):
|
|
| 89 |
label is not the same as `target_class`; or (2) `subgroup` is not in list of
|
| 90 |
targeted groups found in `targets` and label is the same as `target_class`
|
| 91 |
"""
|
| 92 |
-
target_class = target_class if target_class is not None else
|
| 93 |
for target, label, result in zip(targets, labels, outputs):
|
| 94 |
if not target:
|
| 95 |
continue
|
|
@@ -107,7 +107,7 @@ class BiasAUC(evaluate.Metric):
|
|
| 107 |
targeted groups found in `targets` and label is not the same as `target_class`
|
| 108 |
"""
|
| 109 |
# get the index from class
|
| 110 |
-
target_class = target_class if target_class is not None else
|
| 111 |
for target, label, result in zip(targets, labels, outputs):
|
| 112 |
if not target:
|
| 113 |
continue
|
|
|
|
| 1 |
+
%%writefile test_metric/test_metric.py
|
| 2 |
+
|
| 3 |
import evaluate
|
| 4 |
import datasets
|
| 5 |
from datasets.features import Sequence, Value, ClassLabel
|
|
|
|
| 32 |
label list[int]: list containing label index for each item
|
| 33 |
output list[list[float]]: list of model output values for each
|
| 34 |
subgroup list[str] (optional): list of subgroups that appear in target to compute metric over
|
|
|
|
| 35 |
Returns (for each subgroup in target):
|
| 36 |
'Subgroup' : Subgroup AUC score,
|
| 37 |
'BPSN' : BPSN (Background Positive, Subgroup Negative) AUC,
|
|
|
|
| 50 |
... [0.4341845214366913, 0.5658154487609863],
|
| 51 |
... [0.400595098733902, 0.5994048714637756],
|
| 52 |
... [0.3840397894382477, 0.6159601807594299]]
|
|
|
|
| 53 |
>>> metric = load('Intel/bias_auc')
|
| 54 |
>>> metric.add_batch(target=target,
|
| 55 |
label=label,
|
|
|
|
| 67 |
features=datasets.Features(
|
| 68 |
{
|
| 69 |
'target': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
| 70 |
+
'label': Value(dtype='int64', id=None),
|
| 71 |
'output': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
|
| 72 |
}
|
| 73 |
),
|
|
|
|
| 78 |
"""Returns label and output score from `targets` and `labels`
|
| 79 |
if `subgroup` is in list of targeted groups found in `targets`
|
| 80 |
"""
|
| 81 |
+
target_class = target_class if target_class is not None else 0
|
| 82 |
for target, label, result in zip(targets, labels, outputs):
|
| 83 |
if subgroup in target:
|
| 84 |
yield label, result[target_class]
|
|
|
|
| 89 |
label is not the same as `target_class`; or (2) `subgroup` is not in list of
|
| 90 |
targeted groups found in `targets` and label is the same as `target_class`
|
| 91 |
"""
|
| 92 |
+
target_class = target_class if target_class is not None else 1
|
| 93 |
for target, label, result in zip(targets, labels, outputs):
|
| 94 |
if not target:
|
| 95 |
continue
|
|
|
|
| 107 |
targeted groups found in `targets` and label is not the same as `target_class`
|
| 108 |
"""
|
| 109 |
# get the index from class
|
| 110 |
+
target_class = target_class if target_class is not None else 1
|
| 111 |
for target, label, result in zip(targets, labels, outputs):
|
| 112 |
if not target:
|
| 113 |
continue
|