Spaces:
Runtime error
Runtime error
added the local bayes library, removed bayes from req.txt
Browse files- bayes/__init__.py +0 -0
- bayes/__pycache__/__init__.cpython-39.pyc +0 -0
- bayes/__pycache__/data_routines.cpython-39.pyc +0 -0
- bayes/__pycache__/explanations.cpython-39.pyc +0 -0
- bayes/__pycache__/models.cpython-39.pyc +0 -0
- bayes/__pycache__/regression.cpython-39.pyc +0 -0
- bayes/data_routines.py +218 -0
- bayes/explanations.py +701 -0
- bayes/models.py +163 -0
- bayes/regression.py +148 -0
- requirements.txt +0 -1
bayes/__init__.py
ADDED
|
File without changes
|
bayes/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
bayes/__pycache__/data_routines.cpython-39.pyc
ADDED
|
Binary file (6.18 kB). View file
|
|
|
bayes/__pycache__/explanations.cpython-39.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
bayes/__pycache__/models.cpython-39.pyc
ADDED
|
Binary file (5.28 kB). View file
|
|
|
bayes/__pycache__/regression.cpython-39.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
bayes/data_routines.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Routines for processing data."""
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from skimage.segmentation import slic, mark_boundaries
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torchvision import datasets, transforms
|
| 10 |
+
|
| 11 |
+
# The number of segments to use for the images
|
| 12 |
+
NSEGMENTS = 20
|
| 13 |
+
PARAMS = {
|
| 14 |
+
'protected_class': 1,
|
| 15 |
+
'unprotected_class': 0,
|
| 16 |
+
'positive_outcome': 1,
|
| 17 |
+
'negative_outcome': 0
|
| 18 |
+
}
|
| 19 |
+
IMAGENET_LABELS = {
|
| 20 |
+
'french_bulldog': 245,
|
| 21 |
+
'scuba_diver': 983,
|
| 22 |
+
'corn': 987,
|
| 23 |
+
'broccoli': 927
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def get_and_preprocess_compas_data():
|
| 27 |
+
"""Handle processing of COMPAS according to: https://github.com/propublica/compas-analysis
|
| 28 |
+
|
| 29 |
+
Parameters
|
| 30 |
+
----------
|
| 31 |
+
params : Params
|
| 32 |
+
Returns
|
| 33 |
+
----------
|
| 34 |
+
Pandas data frame X of processed data, np.ndarray y, and list of column names
|
| 35 |
+
"""
|
| 36 |
+
PROTECTED_CLASS = PARAMS['protected_class']
|
| 37 |
+
UNPROTECTED_CLASS = PARAMS['unprotected_class']
|
| 38 |
+
POSITIVE_OUTCOME = PARAMS['positive_outcome']
|
| 39 |
+
NEGATIVE_OUTCOME = PARAMS['negative_outcome']
|
| 40 |
+
|
| 41 |
+
compas_df = pd.read_csv("../data/compas-scores-two-years.csv", index_col=0)
|
| 42 |
+
compas_df = compas_df.loc[(compas_df['days_b_screening_arrest'] <= 30) &
|
| 43 |
+
(compas_df['days_b_screening_arrest'] >= -30) &
|
| 44 |
+
(compas_df['is_recid'] != -1) &
|
| 45 |
+
(compas_df['c_charge_degree'] != "O") &
|
| 46 |
+
(compas_df['score_text'] != "NA")]
|
| 47 |
+
|
| 48 |
+
compas_df['length_of_stay'] = (pd.to_datetime(compas_df['c_jail_out']) - pd.to_datetime(compas_df['c_jail_in'])).dt.days
|
| 49 |
+
X = compas_df[['age', 'two_year_recid','c_charge_degree', 'race', 'sex', 'priors_count', 'length_of_stay']]
|
| 50 |
+
|
| 51 |
+
# if person has high score give them the _negative_ model outcome
|
| 52 |
+
y = np.array([NEGATIVE_OUTCOME if score == 'High' else POSITIVE_OUTCOME for score in compas_df['score_text']])
|
| 53 |
+
sens = X.pop('race')
|
| 54 |
+
|
| 55 |
+
# assign African-American as the protected class
|
| 56 |
+
X = pd.get_dummies(X)
|
| 57 |
+
sensitive_attr = np.array(pd.get_dummies(sens).pop('African-American'))
|
| 58 |
+
X['race'] = sensitive_attr
|
| 59 |
+
|
| 60 |
+
# make sure everything is lining up
|
| 61 |
+
assert all((sens == 'African-American') == (X['race'] == PROTECTED_CLASS))
|
| 62 |
+
cols = [col for col in X]
|
| 63 |
+
|
| 64 |
+
categorical_features = [1, 4, 5, 6, 7, 8]
|
| 65 |
+
|
| 66 |
+
output = {
|
| 67 |
+
"X": X.values,
|
| 68 |
+
"y": y,
|
| 69 |
+
"column_names": cols,
|
| 70 |
+
"cat_indices": categorical_features
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
return output
|
| 74 |
+
|
| 75 |
+
def get_and_preprocess_german():
|
| 76 |
+
""""Handle processing of German. We use a preprocessed version of German from Ustun et. al.
|
| 77 |
+
https://arxiv.org/abs/1809.06514. Thanks Berk!
|
| 78 |
+
Parameters:
|
| 79 |
+
----------
|
| 80 |
+
params : Params
|
| 81 |
+
Returns:
|
| 82 |
+
----------
|
| 83 |
+
Pandas data frame X of processed data, np.ndarray y, and list of column names
|
| 84 |
+
"""
|
| 85 |
+
PROTECTED_CLASS = PARAMS['protected_class']
|
| 86 |
+
UNPROTECTED_CLASS = PARAMS['unprotected_class']
|
| 87 |
+
POSITIVE_OUTCOME = PARAMS['positive_outcome']
|
| 88 |
+
NEGATIVE_OUTCOME = PARAMS['negative_outcome']
|
| 89 |
+
|
| 90 |
+
X = pd.read_csv("../data/german_processed.csv")
|
| 91 |
+
y = X["GoodCustomer"]
|
| 92 |
+
|
| 93 |
+
X = X.drop(["GoodCustomer", "PurposeOfLoan"], axis=1)
|
| 94 |
+
X['Gender'] = [1 if v == "Male" else 0 for v in X['Gender'].values]
|
| 95 |
+
|
| 96 |
+
y = np.array([POSITIVE_OUTCOME if p == 1 else NEGATIVE_OUTCOME for p in y.values])
|
| 97 |
+
categorical_features = [0, 1, 2] + list(range(9, X.shape[1]))
|
| 98 |
+
|
| 99 |
+
output = {
|
| 100 |
+
"X": X.values,
|
| 101 |
+
"y": y,
|
| 102 |
+
"column_names": [c for c in X],
|
| 103 |
+
"cat_indices": categorical_features,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
return output
|
| 107 |
+
|
| 108 |
+
def get_PIL_transf():
|
| 109 |
+
"""Gets the PIL image transformation."""
|
| 110 |
+
transf = transforms.Compose([
|
| 111 |
+
transforms.Resize((256, 256)),
|
| 112 |
+
transforms.CenterCrop(224)
|
| 113 |
+
])
|
| 114 |
+
return transf
|
| 115 |
+
|
| 116 |
+
def load_image(path):
|
| 117 |
+
"""Loads an image by path."""
|
| 118 |
+
with open(os.path.abspath(path), 'rb') as f:
|
| 119 |
+
with Image.open(f) as img:
|
| 120 |
+
return img.convert('RGB')
|
| 121 |
+
|
| 122 |
+
def get_imagenet(name, get_label=True):
|
| 123 |
+
"""Gets the imagenet data.
|
| 124 |
+
|
| 125 |
+
Arguments:
|
| 126 |
+
name: The name of the imagenet dataset
|
| 127 |
+
"""
|
| 128 |
+
images_paths = []
|
| 129 |
+
|
| 130 |
+
# Store all the paths of the images
|
| 131 |
+
data_dir = os.path.join("../data", name)
|
| 132 |
+
for (dirpath, dirnames, filenames) in os.walk(data_dir):
|
| 133 |
+
for fn in filenames:
|
| 134 |
+
if fn != ".DS_Store":
|
| 135 |
+
images_paths.append(os.path.join(dirpath, fn))
|
| 136 |
+
|
| 137 |
+
# Load & do transforms for the images
|
| 138 |
+
pill_transf = get_PIL_transf()
|
| 139 |
+
images, segs = [], []
|
| 140 |
+
for img_path in images_paths:
|
| 141 |
+
img = load_image(img_path)
|
| 142 |
+
PIL_transformed_image = np.array(pill_transf(img))
|
| 143 |
+
segments = slic(PIL_transformed_image, n_segments=NSEGMENTS, compactness=100, sigma=1)
|
| 144 |
+
|
| 145 |
+
images.append(PIL_transformed_image)
|
| 146 |
+
segs.append(segments)
|
| 147 |
+
|
| 148 |
+
images = np.array(images)
|
| 149 |
+
|
| 150 |
+
if get_label:
|
| 151 |
+
assert name in IMAGENET_LABELS, "Get label set to True but name not in known imagenet labels"
|
| 152 |
+
y = np.ones(images.shape[0]) * IMAGENET_LABELS[name]
|
| 153 |
+
else:
|
| 154 |
+
y = np.ones(images.shape[0]) * -1
|
| 155 |
+
|
| 156 |
+
segs = np.array(segs)
|
| 157 |
+
|
| 158 |
+
output = {
|
| 159 |
+
"X": images,
|
| 160 |
+
"y": y,
|
| 161 |
+
"segments": segs
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
return output
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_mnist(num):
|
| 168 |
+
"""Gets the MNIST data for a certain digit.
|
| 169 |
+
|
| 170 |
+
Arguments:
|
| 171 |
+
num: The mnist digit to get
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
# Get the mnist data
|
| 175 |
+
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist',
|
| 176 |
+
train=False,
|
| 177 |
+
download=True,
|
| 178 |
+
transform=transforms.Compose([transforms.ToTensor(),
|
| 179 |
+
transforms.Normalize((0.1307,), (0.3081,))
|
| 180 |
+
])),
|
| 181 |
+
batch_size=1,
|
| 182 |
+
shuffle=False)
|
| 183 |
+
|
| 184 |
+
all_test_mnist_of_label_num, all_test_segments_of_label_num = [], []
|
| 185 |
+
|
| 186 |
+
# Get all instances of label num
|
| 187 |
+
for data, y in test_loader:
|
| 188 |
+
if y[0] == num:
|
| 189 |
+
# Apply segmentation
|
| 190 |
+
sample = np.squeeze(data.numpy().astype('double'),axis=0)
|
| 191 |
+
segments = slic(sample.reshape(28,28,1), n_segments=NSEGMENTS, compactness=1, sigma=0.1).reshape(1,28,28)
|
| 192 |
+
all_test_mnist_of_label_num.append(sample)
|
| 193 |
+
all_test_segments_of_label_num.append(segments)
|
| 194 |
+
|
| 195 |
+
all_test_mnist_of_label_num = np.array(all_test_mnist_of_label_num)
|
| 196 |
+
all_test_segments_of_label_num = np.array(all_test_segments_of_label_num)
|
| 197 |
+
|
| 198 |
+
output = {
|
| 199 |
+
"X": all_test_mnist_of_label_num,
|
| 200 |
+
"y": np.ones(all_test_mnist_of_label_num.shape[0]) * num,
|
| 201 |
+
"segments": all_test_segments_of_label_num
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
return output
|
| 205 |
+
|
| 206 |
+
def get_dataset_by_name(name, get_label=True):
|
| 207 |
+
if name == "compas":
|
| 208 |
+
d = get_and_preprocess_compas_data()
|
| 209 |
+
elif name == "german":
|
| 210 |
+
d = get_and_preprocess_german()
|
| 211 |
+
elif "mnist" in name:
|
| 212 |
+
d = get_mnist(int(name[-1]))
|
| 213 |
+
elif "imagenet" in name:
|
| 214 |
+
d = get_imagenet(name[9:], get_label=get_label)
|
| 215 |
+
else:
|
| 216 |
+
raise NameError("Unkown dataset %s", name)
|
| 217 |
+
d['name'] = name
|
| 218 |
+
return d
|
bayes/explanations.py
ADDED
|
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bayesian Local Explanations.
|
| 2 |
+
|
| 3 |
+
This code implements bayesian local explanations. The code supports the LIME & SHAP
|
| 4 |
+
kernels. Along with the LIME & SHAP feature importances, bayesian local explanations
|
| 5 |
+
also support uncertainty expression over the feature importances.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from functools import reduce
|
| 11 |
+
from multiprocessing import Pool
|
| 12 |
+
import numpy as np
|
| 13 |
+
import operator as op
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
import sklearn
|
| 17 |
+
import sklearn.preprocessing
|
| 18 |
+
from sklearn.linear_model import Ridge, Lasso
|
| 19 |
+
from lime import lime_image, lime_tabular
|
| 20 |
+
|
| 21 |
+
from bayes.regression import BayesianLinearRegression
|
| 22 |
+
|
| 23 |
+
LDATA, LINVERSE, LSCALED, LDISTANCES, LY = list(range(5))
|
| 24 |
+
SDATA, SINVERSE, SY = list(range(3))
|
| 25 |
+
|
| 26 |
+
class BayesLocalExplanations:
|
| 27 |
+
"""Bayesian Local Explanations.
|
| 28 |
+
|
| 29 |
+
This class implements the bayesian local explanations.
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self,
|
| 32 |
+
training_data,
|
| 33 |
+
data="image",
|
| 34 |
+
kernel="lime",
|
| 35 |
+
credible_interval=95,
|
| 36 |
+
mode="classification",
|
| 37 |
+
categorical_features=[],
|
| 38 |
+
discretize_continuous=True,
|
| 39 |
+
save_logs=False,
|
| 40 |
+
log_file_name="bayes.log",
|
| 41 |
+
width=0.75,
|
| 42 |
+
verbose=False):
|
| 43 |
+
"""Initialize the local explanations.
|
| 44 |
+
|
| 45 |
+
Arguments:
|
| 46 |
+
training_data: The
|
| 47 |
+
data: The type of data, either "image" or "tabular"
|
| 48 |
+
kernel: The kernel to use, either "lime" or "shap"
|
| 49 |
+
credible_interval: The % credible interval to use for the feature importance
|
| 50 |
+
uncertainty.
|
| 51 |
+
mode: Whether to run with classification or regression.
|
| 52 |
+
categorical_features: The indices of the categorical features, if in regression mode.
|
| 53 |
+
save_logs: Whether to save logs from the run.
|
| 54 |
+
log_file_name: The name of log file.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
assert kernel in ["lime", "shap"], f"Kernel must be one of lime or shap, not {kernel}"
|
| 58 |
+
assert data in ["image", "tabular"], f"Data must be one of image or tabular, not {data}"
|
| 59 |
+
assert mode in ["classification"], "Others modes like regression are not implemented"
|
| 60 |
+
|
| 61 |
+
if save_logs:
|
| 62 |
+
logging.basicConfig(filename=log_file_name,
|
| 63 |
+
filemode='a',
|
| 64 |
+
level=logging.INFO)
|
| 65 |
+
|
| 66 |
+
logging.info("==============================================")
|
| 67 |
+
logging.info("Initializing Bayes%s %s explanations", kernel, data)
|
| 68 |
+
logging.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
|
| 69 |
+
|
| 70 |
+
self.cred_int = credible_interval
|
| 71 |
+
self.data = data
|
| 72 |
+
self.kernel = kernel
|
| 73 |
+
self.mode = mode
|
| 74 |
+
self.categorical_features = categorical_features
|
| 75 |
+
self.discretize_continuous = discretize_continuous
|
| 76 |
+
self.verbose = verbose
|
| 77 |
+
self.width = width * np.sqrt(training_data.shape[1])
|
| 78 |
+
|
| 79 |
+
logging.info("Setting mode to %s", mode)
|
| 80 |
+
logging.info("Credible interval set to %s", self.cred_int)
|
| 81 |
+
|
| 82 |
+
if kernel == "shap" and data == "tabular":
|
| 83 |
+
logging.info("Setting discretize_continuous to True, due to shapley sampling")
|
| 84 |
+
discretize_continuous = True
|
| 85 |
+
|
| 86 |
+
self.training_data = training_data
|
| 87 |
+
self._run_init(training_data)
|
| 88 |
+
|
| 89 |
+
def _run_init(self, training_data):
|
| 90 |
+
if self.kernel == "lime":
|
| 91 |
+
lime_tab_exp = lime_tabular.LimeTabularExplainer(training_data,
|
| 92 |
+
mode=self.mode,
|
| 93 |
+
categorical_features=self.categorical_features,
|
| 94 |
+
discretize_continuous=self.discretize_continuous)
|
| 95 |
+
self.lime_info = lime_tab_exp
|
| 96 |
+
elif self.kernel == "shap":
|
| 97 |
+
# Discretization forcibly set to true for shap sampling on initialization
|
| 98 |
+
shap_tab_exp = lime_tabular.LimeTabularExplainer(training_data,
|
| 99 |
+
mode=self.mode,
|
| 100 |
+
categorical_features=self.categorical_features,
|
| 101 |
+
discretize_continuous=self.discretize_continuous)
|
| 102 |
+
self.shap_info = shap_tab_exp
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
|
| 106 |
+
def _log_args(self, args):
|
| 107 |
+
"""Logs arguments to function."""
|
| 108 |
+
logging.info(args)
|
| 109 |
+
|
| 110 |
+
def _shap_tabular_perturb_n_samples(self,
|
| 111 |
+
data,
|
| 112 |
+
n_samples,
|
| 113 |
+
max_coefs=None):
|
| 114 |
+
"""Generates n shap perturbations"""
|
| 115 |
+
if max_coefs is None:
|
| 116 |
+
max_coefs = np.arange(data.shape[0])
|
| 117 |
+
pre_rdata, pre_inverse = self.shap_info._LimeTabularExplainer__data_inverse(data_row=data,
|
| 118 |
+
num_samples=n_samples)
|
| 119 |
+
rdata = pre_rdata[:, max_coefs]
|
| 120 |
+
inverse = np.tile(data, (n_samples, 1))
|
| 121 |
+
inverse[:, max_coefs] = pre_inverse[:, max_coefs]
|
| 122 |
+
return rdata, inverse
|
| 123 |
+
|
| 124 |
+
def _lime_tabular_perturb_n_samples(self,
|
| 125 |
+
data,
|
| 126 |
+
n_samples):
|
| 127 |
+
"""Generates n_perturbations for LIME."""
|
| 128 |
+
rdata, inverse = self.lime_info._LimeTabularExplainer__data_inverse(data_row=data,
|
| 129 |
+
num_samples=n_samples)
|
| 130 |
+
scaled_data = (rdata - self.lime_info.scaler.mean_) / self.lime_info.scaler.scale_
|
| 131 |
+
distances = sklearn.metrics.pairwise_distances(
|
| 132 |
+
scaled_data,
|
| 133 |
+
scaled_data[0].reshape(1, -1),
|
| 134 |
+
metric='euclidean'
|
| 135 |
+
).ravel()
|
| 136 |
+
return rdata, inverse, scaled_data, distances
|
| 137 |
+
|
| 138 |
+
def _stack_tabular_return(self, existing_return, perturb_return):
|
| 139 |
+
"""Stacks data from new tabular return to existing return."""
|
| 140 |
+
if len(existing_return) == 0:
|
| 141 |
+
return perturb_return
|
| 142 |
+
new_return = []
|
| 143 |
+
for i, item in enumerate(existing_return):
|
| 144 |
+
new_return.append(np.concatenate((item, perturb_return[i]), axis=0))
|
| 145 |
+
return new_return
|
| 146 |
+
|
| 147 |
+
def _select_indices_from_data(self, perturb_return, indices, predictions):
|
| 148 |
+
"""Gets each element from the perturb return according to indices, then appends the predictions."""
|
| 149 |
+
# Previoulsy had this set to range(4)
|
| 150 |
+
temp = [perturb_return[i][indices] for i in range(len(perturb_return))]
|
| 151 |
+
temp.append(predictions)
|
| 152 |
+
return temp
|
| 153 |
+
|
| 154 |
+
def shap_tabular_focus_sample(self,
|
| 155 |
+
data,
|
| 156 |
+
classifier_f,
|
| 157 |
+
label,
|
| 158 |
+
n_samples,
|
| 159 |
+
focus_sample_batch_size,
|
| 160 |
+
focus_sample_initial_points,
|
| 161 |
+
to_consider=10_000,
|
| 162 |
+
tempurature=1e-2,
|
| 163 |
+
enumerate_initial=True):
|
| 164 |
+
"""Focus sample n_samples perturbations for lime tabular."""
|
| 165 |
+
assert focus_sample_initial_points > 0, "Initial focusing sample points cannot be <= 0"
|
| 166 |
+
current_n_perturbations = 0
|
| 167 |
+
|
| 168 |
+
# Get 1's coalitions, if requested
|
| 169 |
+
if enumerate_initial:
|
| 170 |
+
enumerate_init_p = self._enumerate_initial_shap(data)
|
| 171 |
+
current_n_perturbations += enumerate_init_p[0].shape[0]
|
| 172 |
+
else:
|
| 173 |
+
enumerate_init_p = None
|
| 174 |
+
|
| 175 |
+
if self.verbose:
|
| 176 |
+
pbar = tqdm(total=n_samples)
|
| 177 |
+
pbar.update(current_n_perturbations)
|
| 178 |
+
|
| 179 |
+
# Get initial points
|
| 180 |
+
if current_n_perturbations < focus_sample_initial_points:
|
| 181 |
+
initial_perturbations = self._shap_tabular_perturb_n_samples(data, focus_sample_initial_points - current_n_perturbations)
|
| 182 |
+
|
| 183 |
+
if enumerate_init_p is not None:
|
| 184 |
+
current_perturbations = self._stack_tabular_return(enumerate_init_p, initial_perturbations)
|
| 185 |
+
else:
|
| 186 |
+
current_perturbations = initial_perturbations
|
| 187 |
+
|
| 188 |
+
current_n_perturbations += initial_perturbations[0].shape[0]
|
| 189 |
+
else:
|
| 190 |
+
current_perturbations = enumerate_init_p
|
| 191 |
+
|
| 192 |
+
current_perturbations = list(current_perturbations)
|
| 193 |
+
|
| 194 |
+
# Store initial predictions
|
| 195 |
+
current_perturbations.append(classifier_f(current_perturbations[SINVERSE])[:, label])
|
| 196 |
+
if self.verbose:
|
| 197 |
+
pbar.update(initial_perturbations[0].shape[0])
|
| 198 |
+
|
| 199 |
+
while current_n_perturbations < n_samples:
|
| 200 |
+
current_batch_size = min(focus_sample_batch_size, n_samples - current_n_perturbations)
|
| 201 |
+
|
| 202 |
+
# Init current BLR
|
| 203 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
| 204 |
+
weights = self._get_shap_weights(current_perturbations[SDATA], current_perturbations[SDATA].shape[1])
|
| 205 |
+
blr.fit(current_perturbations[SDATA], current_perturbations[-1], weights, compute_creds=False)
|
| 206 |
+
|
| 207 |
+
candidate_perturbations = self._shap_tabular_perturb_n_samples(data, to_consider)
|
| 208 |
+
_, var = blr.predict(candidate_perturbations[SINVERSE])
|
| 209 |
+
|
| 210 |
+
# Get sampling weighting
|
| 211 |
+
var /= tempurature
|
| 212 |
+
exp_var = np.exp(var)
|
| 213 |
+
all_exp = np.sum(exp_var)
|
| 214 |
+
tempurature_scaled_weights = exp_var / all_exp
|
| 215 |
+
|
| 216 |
+
# Get sampled indices
|
| 217 |
+
least_confident_sample = np.random.choice(len(var), size=current_batch_size, p=tempurature_scaled_weights, replace=True)
|
| 218 |
+
|
| 219 |
+
# Get predictions
|
| 220 |
+
cy = classifier_f(candidate_perturbations[SINVERSE][least_confident_sample])[:, label]
|
| 221 |
+
|
| 222 |
+
new_perturbations = self._select_indices_from_data(candidate_perturbations, least_confident_sample, cy)
|
| 223 |
+
current_perturbations = self._stack_tabular_return(current_perturbations, new_perturbations)
|
| 224 |
+
current_n_perturbations += new_perturbations[0].shape[0]
|
| 225 |
+
|
| 226 |
+
if self.verbose:
|
| 227 |
+
pbar.update(new_perturbations[0].shape[0])
|
| 228 |
+
|
| 229 |
+
return current_perturbations
|
| 230 |
+
|
| 231 |
+
def lime_tabular_focus_sample(self,
|
| 232 |
+
data,
|
| 233 |
+
classifier_f,
|
| 234 |
+
label,
|
| 235 |
+
n_samples,
|
| 236 |
+
focus_sample_batch_size,
|
| 237 |
+
focus_sample_initial_points,
|
| 238 |
+
to_consider=10_000,
|
| 239 |
+
tempurature=5e-4,
|
| 240 |
+
existing_data=[]):
|
| 241 |
+
"""Focus sample n_samples perturbations for lime tabular."""
|
| 242 |
+
current_n_perturbations = 0
|
| 243 |
+
|
| 244 |
+
# Get initial focus sampling batch
|
| 245 |
+
if len(existing_data) < focus_sample_initial_points:
|
| 246 |
+
# If there's existing data, make sure we only sample up to existing_data points
|
| 247 |
+
initial_perturbations = self._lime_tabular_perturb_n_samples(data, focus_sample_initial_points - len(existing_data))
|
| 248 |
+
current_perturbations = self._stack_tabular_return(existing_data, initial_perturbations)
|
| 249 |
+
else:
|
| 250 |
+
current_perturbations = existing_data
|
| 251 |
+
|
| 252 |
+
if self.verbose:
|
| 253 |
+
pbar = tqdm(total=n_samples)
|
| 254 |
+
|
| 255 |
+
current_perturbations = list(current_perturbations)
|
| 256 |
+
current_n_perturbations += initial_perturbations[0].shape[0]
|
| 257 |
+
|
| 258 |
+
# Store predictions on initial data
|
| 259 |
+
current_perturbations.append(classifier_f(current_perturbations[LINVERSE])[:, label])
|
| 260 |
+
if self.verbose:
|
| 261 |
+
pbar.update(initial_perturbations[0].shape[0])
|
| 262 |
+
|
| 263 |
+
# Sample up to n_samples
|
| 264 |
+
while current_n_perturbations < n_samples:
|
| 265 |
+
|
| 266 |
+
# If batch size would exceed n_samples, only sample enough to reach n_samples
|
| 267 |
+
current_batch_size = min(focus_sample_batch_size, n_samples - current_n_perturbations)
|
| 268 |
+
|
| 269 |
+
# Init current BLR
|
| 270 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
| 271 |
+
# Get weights on current distances
|
| 272 |
+
weights = self._lime_kernel(current_perturbations[LDISTANCES], self.width)
|
| 273 |
+
# Fit blr on current perturbations & data
|
| 274 |
+
blr.fit(current_perturbations[LDATA], current_perturbations[LY], weights)
|
| 275 |
+
|
| 276 |
+
# Get set of perturbations to consider labeling
|
| 277 |
+
candidate_perturbations = self._lime_tabular_perturb_n_samples(data, to_consider)
|
| 278 |
+
_, var = blr.predict(candidate_perturbations[LDATA])
|
| 279 |
+
|
| 280 |
+
# Reweight
|
| 281 |
+
var /= tempurature
|
| 282 |
+
exp_var = np.exp(var)
|
| 283 |
+
all_exp = np.sum(exp_var)
|
| 284 |
+
tempurature_scaled_weights = exp_var / all_exp
|
| 285 |
+
|
| 286 |
+
# Get sampled indices
|
| 287 |
+
least_confident_sample = np.random.choice(len(var), size=current_batch_size, p=tempurature_scaled_weights, replace=False)
|
| 288 |
+
|
| 289 |
+
# Get predictions
|
| 290 |
+
cy = classifier_f(candidate_perturbations[LINVERSE][least_confident_sample])[:, label]
|
| 291 |
+
|
| 292 |
+
new_perturbations = self._select_indices_from_data(candidate_perturbations, least_confident_sample, cy)
|
| 293 |
+
current_perturbations = self._stack_tabular_return(current_perturbations, new_perturbations)
|
| 294 |
+
current_n_perturbations += new_perturbations[0].shape[0]
|
| 295 |
+
|
| 296 |
+
if self.verbose:
|
| 297 |
+
pbar.update(new_perturbations[0].shape[0])
|
| 298 |
+
|
| 299 |
+
return current_perturbations
|
| 300 |
+
|
| 301 |
+
def _lime_kernel(self, d, kernel_width):
|
| 302 |
+
return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
|
| 303 |
+
|
| 304 |
+
def _explain_bayes_lime(self,
|
| 305 |
+
data,
|
| 306 |
+
classifier_f,
|
| 307 |
+
label,
|
| 308 |
+
focus_sample,
|
| 309 |
+
cred_width,
|
| 310 |
+
n_samples,
|
| 311 |
+
max_n_samples,
|
| 312 |
+
focus_sample_batch_size,
|
| 313 |
+
focus_sample_initial_points,
|
| 314 |
+
ptg_initial_points,
|
| 315 |
+
to_consider):
|
| 316 |
+
"""Computes the bayeslime tabular explanations."""
|
| 317 |
+
|
| 318 |
+
# Case where only n_samples is specified and not focused sampling
|
| 319 |
+
if n_samples is not None and not focus_sample:
|
| 320 |
+
logging.info("Generating bayeslime explanation with %s samples", n_samples)
|
| 321 |
+
|
| 322 |
+
# Generate perturbations
|
| 323 |
+
rdata, inverse, scaled_data, distances = self._lime_tabular_perturb_n_samples(data, n_samples)
|
| 324 |
+
weights = self._lime_kernel(distances, self.width)
|
| 325 |
+
y = classifier_f(inverse)[:, label]
|
| 326 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
| 327 |
+
blr.fit(rdata, y, weights)
|
| 328 |
+
# Focus sampling
|
| 329 |
+
elif focus_sample:
|
| 330 |
+
logging.info("Starting focused sampling")
|
| 331 |
+
if n_samples:
|
| 332 |
+
logging.info("n_samples preset, running focused sampling up to %s samples", n_samples)
|
| 333 |
+
logging.info("using batch size %s with %s initial points", focus_sample_batch_size, focus_sample_initial_points)
|
| 334 |
+
focused_sampling_output = self.lime_tabular_focus_sample(data,
|
| 335 |
+
classifier_f,
|
| 336 |
+
label,
|
| 337 |
+
n_samples,
|
| 338 |
+
focus_sample_batch_size,
|
| 339 |
+
focus_sample_initial_points,
|
| 340 |
+
to_consider=to_consider,
|
| 341 |
+
existing_data=[])
|
| 342 |
+
rdata = focused_sampling_output[LDATA]
|
| 343 |
+
distances = focused_sampling_output[LDISTANCES]
|
| 344 |
+
y = focused_sampling_output[LY]
|
| 345 |
+
|
| 346 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
| 347 |
+
weights = self._lime_kernel(distances, self.width)
|
| 348 |
+
blr.fit(rdata, y, weights)
|
| 349 |
+
else:
|
| 350 |
+
# Use ptg to get the number of samples, then focus sample
|
| 351 |
+
# Note, this isn't used in the paper, this case currently isn't implemented
|
| 352 |
+
raise NotImplementedError
|
| 353 |
+
|
| 354 |
+
else:
|
| 355 |
+
# PTG Step 1, get initial
|
| 356 |
+
rdata, inverse, scaled_data, distances = self._lime_tabular_perturb_n_samples(data, ptg_initial_points)
|
| 357 |
+
weights = self._lime_kernel(distances, self.width)
|
| 358 |
+
y = classifier_f(inverse)[:, label]
|
| 359 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
| 360 |
+
blr.fit(rdata, y, weights)
|
| 361 |
+
|
| 362 |
+
# PTG Step 2, get additional points needed
|
| 363 |
+
n_needed = int(np.ceil(blr.get_ptg(cred_width)))
|
| 364 |
+
if self.verbose:
|
| 365 |
+
tqdm.write(f"Additional Number of perturbations needed is {n_needed}")
|
| 366 |
+
ptg_rdata, ptg_inverse, ptg_scaled_data, ptg_distances = self._lime_tabular_perturb_n_samples(data, n_needed - ptg_initial_points)
|
| 367 |
+
ptg_weights = self._lime_kernel(ptg_distances, self.width)
|
| 368 |
+
|
| 369 |
+
rdata = np.concatenate((rdata, ptg_rdata), axis=0)
|
| 370 |
+
inverse = np.concatenate((inverse, ptg_inverse), axis=0)
|
| 371 |
+
scaled_data = np.concatenate((scaled_data, ptg_scaled_data), axis=0)
|
| 372 |
+
distances = np.concatenate((distances, ptg_distances), axis=0)
|
| 373 |
+
|
| 374 |
+
# Run final model
|
| 375 |
+
ptgy = classifier_f(ptg_inverse)[:, label]
|
| 376 |
+
y = np.concatenate((y, ptgy), axis=0)
|
| 377 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
| 378 |
+
blr.fit(rdata, y, self._lime_kernel(distances, self.width))
|
| 379 |
+
|
| 380 |
+
# Format output for returning
|
| 381 |
+
output = {
|
| 382 |
+
"data": rdata,
|
| 383 |
+
"y": y,
|
| 384 |
+
"distances": distances,
|
| 385 |
+
"blr": blr,
|
| 386 |
+
"coef": blr.coef_,
|
| 387 |
+
"max_coefs": None # Included for consistency purposes w/ bayesshap
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
return output
|
| 391 |
+
|
| 392 |
+
def _get_shap_weights(self, data, M):
|
| 393 |
+
"""Gets shap weights. This assumes data is binary."""
|
| 394 |
+
nonzero = np.count_nonzero(data, axis=1)
|
| 395 |
+
weights = []
|
| 396 |
+
for nz in nonzero:
|
| 397 |
+
denom = (nCk(M, nz) * nz * (M - nz))
|
| 398 |
+
# Stabilize kernel
|
| 399 |
+
if denom == 0:
|
| 400 |
+
weight = 1.0
|
| 401 |
+
else:
|
| 402 |
+
weight = ((M - 1) / denom)
|
| 403 |
+
weights.append(weight)
|
| 404 |
+
return weights
|
| 405 |
+
|
| 406 |
+
def _enumerate_initial_shap(self, data, max_coefs=None):
|
| 407 |
+
"""Enumerate 1's for stability."""
|
| 408 |
+
if max_coefs is None:
|
| 409 |
+
data = np.eye(data.shape[0])
|
| 410 |
+
inverse = self.shap_info.discretizer.undiscretize(data)
|
| 411 |
+
return data, inverse
|
| 412 |
+
else:
|
| 413 |
+
data = np.zeros((max_coefs.shape[0], data.shape[0]))
|
| 414 |
+
for i in range(max_coefs.shape[0]):
|
| 415 |
+
data[i, max_coefs[i]] = 1
|
| 416 |
+
inverse = self.shap_info.discretizer.undiscretize(data)
|
| 417 |
+
return data[:, max_coefs], inverse
|
| 418 |
+
|
| 419 |
+
def _explain_bayes_shap(self,
|
| 420 |
+
data,
|
| 421 |
+
classifier_f,
|
| 422 |
+
label,
|
| 423 |
+
focus_sample,
|
| 424 |
+
cred_width,
|
| 425 |
+
n_samples,
|
| 426 |
+
max_n_samples,
|
| 427 |
+
focus_sample_batch_size,
|
| 428 |
+
focus_sample_initial_points,
|
| 429 |
+
ptg_initial_points,
|
| 430 |
+
to_consider,
|
| 431 |
+
feature_select_num_points=1_000,
|
| 432 |
+
n_features=10,
|
| 433 |
+
l2=True,
|
| 434 |
+
enumerate_initial=True,
|
| 435 |
+
feature_selection=True,
|
| 436 |
+
max_coefs=None):
|
| 437 |
+
"""Computes the bayesshap tabular explanations."""
|
| 438 |
+
if feature_selection and max_coefs is None:
|
| 439 |
+
n_features = min(n_features, data.shape[0])
|
| 440 |
+
_, feature_select_inverse = self._shap_tabular_perturb_n_samples(data, feature_select_num_points)
|
| 441 |
+
lr = Ridge().fit(feature_select_inverse, classifier_f(feature_select_inverse)[:, label])
|
| 442 |
+
max_coefs = np.argsort(np.abs(lr.coef_))[-1 * n_features:]
|
| 443 |
+
elif feature_selection and max_coefs is not None:
|
| 444 |
+
pass
|
| 445 |
+
else:
|
| 446 |
+
max_coefs = None
|
| 447 |
+
|
| 448 |
+
# Case without focused sampling
|
| 449 |
+
if n_samples is not None and not focus_sample:
|
| 450 |
+
logging.info("Generating bayesshap explanation with %s samples", n_samples)
|
| 451 |
+
|
| 452 |
+
# Enumerate single coalitions, if requested
|
| 453 |
+
if enumerate_initial:
|
| 454 |
+
data_init, inverse_init = self._enumerate_initial_shap(data, max_coefs)
|
| 455 |
+
n_more = n_samples - inverse_init.shape[0]
|
| 456 |
+
else:
|
| 457 |
+
n_more = n_samples
|
| 458 |
+
|
| 459 |
+
rdata, inverse = self._shap_tabular_perturb_n_samples(data, n_more, max_coefs)
|
| 460 |
+
|
| 461 |
+
if enumerate_initial:
|
| 462 |
+
rdata = np.concatenate((data_init, rdata), axis=0)
|
| 463 |
+
inverse = np.concatenate((inverse_init, inverse), axis=0)
|
| 464 |
+
|
| 465 |
+
y = classifier_f(inverse)[:, label]
|
| 466 |
+
weights = self._get_shap_weights(rdata, M=rdata.shape[1])
|
| 467 |
+
|
| 468 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
| 469 |
+
blr.fit(rdata, y, weights)
|
| 470 |
+
elif focus_sample:
|
| 471 |
+
if feature_selection:
|
| 472 |
+
raise NotImplementedError
|
| 473 |
+
|
| 474 |
+
logging.info("Starting focused sampling")
|
| 475 |
+
if n_samples:
|
| 476 |
+
logging.info("n_samples preset, running focused sampling up to %s samples", n_samples)
|
| 477 |
+
logging.info("using batch size %s with %s initial points", focus_sample_batch_size, focus_sample_initial_points)
|
| 478 |
+
focused_sampling_output = self.shap_tabular_focus_sample(data,
|
| 479 |
+
classifier_f,
|
| 480 |
+
label,
|
| 481 |
+
n_samples,
|
| 482 |
+
focus_sample_batch_size,
|
| 483 |
+
focus_sample_initial_points,
|
| 484 |
+
to_consider=to_consider,
|
| 485 |
+
enumerate_initial=enumerate_initial)
|
| 486 |
+
rdata = focused_sampling_output[SDATA]
|
| 487 |
+
y = focused_sampling_output[SY]
|
| 488 |
+
weights = self._get_shap_weights(rdata, rdata.shape[1])
|
| 489 |
+
blr = BayesianLinearRegression(percent=self.cred_int, l2=l2)
|
| 490 |
+
blr.fit(rdata, y, weights)
|
| 491 |
+
else:
|
| 492 |
+
# Use ptg to get the number of samples, then focus sample
|
| 493 |
+
# Note, this case isn't used in the paper and currently isn't implemented
|
| 494 |
+
raise NotImplementedError
|
| 495 |
+
else:
|
| 496 |
+
# Use PTG to get initial samples
|
| 497 |
+
|
| 498 |
+
# Enumerate intial points if requested
|
| 499 |
+
if enumerate_initial:
|
| 500 |
+
data_init, inverse_init = self._enumerate_initial_shap(data, max_coefs)
|
| 501 |
+
n_more = ptg_initial_points - inverse_init.shape[0]
|
| 502 |
+
else:
|
| 503 |
+
n_more = ptg_initial_points
|
| 504 |
+
|
| 505 |
+
# Perturb using initial samples
|
| 506 |
+
rdata, inverse = self._shap_tabular_perturb_n_samples(data, n_more, max_coefs)
|
| 507 |
+
if enumerate_initial:
|
| 508 |
+
rdata = np.concatenate((data_init, rdata), axis=0)
|
| 509 |
+
inverse = np.concatenate((inverse_init, inverse), axis=0)
|
| 510 |
+
|
| 511 |
+
# Get labels
|
| 512 |
+
y = classifier_f(inverse)[:, label]
|
| 513 |
+
|
| 514 |
+
# Fit BLR
|
| 515 |
+
weights = self._get_shap_weights(rdata, M=rdata.shape[1])
|
| 516 |
+
blr = BayesianLinearRegression(percent=self.cred_int, l2=l2)
|
| 517 |
+
blr.fit(rdata, y, weights)
|
| 518 |
+
|
| 519 |
+
# Compute PTG number needed
|
| 520 |
+
n_needed = int(np.ceil(blr.get_ptg(cred_width)))
|
| 521 |
+
ptg_rdata, ptg_inverse = self._shap_tabular_perturb_n_samples(data,
|
| 522 |
+
n_needed - ptg_initial_points,
|
| 523 |
+
max_coefs)
|
| 524 |
+
|
| 525 |
+
if self.verbose:
|
| 526 |
+
tqdm.write(f"{n_needed} more samples needed")
|
| 527 |
+
|
| 528 |
+
rdata = np.concatenate((rdata, ptg_rdata), axis=0)
|
| 529 |
+
inverse = np.concatenate((inverse, ptg_inverse), axis=0)
|
| 530 |
+
ptgy = classifier_f(ptg_inverse)[:, label]
|
| 531 |
+
weights = self._get_shap_weights(rdata, M=rdata.shape[1])
|
| 532 |
+
|
| 533 |
+
# Run final model
|
| 534 |
+
ptgy = classifier_f(ptg_inverse)[:, label]
|
| 535 |
+
y = np.concatenate((y, ptgy), axis=0)
|
| 536 |
+
blr = BayesianLinearRegression(percent=self.cred_int, l2=l2)
|
| 537 |
+
blr.fit(rdata, y, weights)
|
| 538 |
+
|
| 539 |
+
# Format output for returning
|
| 540 |
+
output = {
|
| 541 |
+
"data": rdata,
|
| 542 |
+
"y": y,
|
| 543 |
+
"distances": weights,
|
| 544 |
+
"blr": blr,
|
| 545 |
+
"coef": blr.coef_,
|
| 546 |
+
"max_coefs": max_coefs
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
return output
|
| 550 |
+
|
| 551 |
+
def explain(self,
|
| 552 |
+
data,
|
| 553 |
+
classifier_f,
|
| 554 |
+
label,
|
| 555 |
+
cred_width=1e-2,
|
| 556 |
+
focus_sample=True,
|
| 557 |
+
n_samples=None,
|
| 558 |
+
max_n_samples=10_000,
|
| 559 |
+
focus_sample_batch_size=2_500,
|
| 560 |
+
focus_sample_initial_points=100,
|
| 561 |
+
ptg_initial_points=200,
|
| 562 |
+
to_consider=10_000,
|
| 563 |
+
feature_selection=True,
|
| 564 |
+
n_features=15,
|
| 565 |
+
tag=None,
|
| 566 |
+
only_coef=False,
|
| 567 |
+
only_blr=False,
|
| 568 |
+
enumerate_initial=True,
|
| 569 |
+
max_coefs=None,
|
| 570 |
+
l2=True):
|
| 571 |
+
"""Explain an instance.
|
| 572 |
+
|
| 573 |
+
As opposed to other model agnostic explanations, the bayes explanations
|
| 574 |
+
accept a credible interval width instead of a number of perturbations
|
| 575 |
+
value.
|
| 576 |
+
|
| 577 |
+
If the credible interval is set to 95% (as is the default), the bayesian
|
| 578 |
+
explanations will generate feature importances that are +/- width/2
|
| 579 |
+
95% of the time.
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
Arguments:
|
| 583 |
+
data: The data instance to explain
|
| 584 |
+
classifier_f: The classification function. This function should return
|
| 585 |
+
probabilities for each label, where if there are M labels
|
| 586 |
+
and N instances, the output is of shape (N, M).
|
| 587 |
+
label: The label index to explain.
|
| 588 |
+
cred_width: The width of the credible interval of the resulting explanation. Note,
|
| 589 |
+
this serves as a upper bound in the implementation, the final credible
|
| 590 |
+
intervals may be tighter, because PTG is a bit approximate. Also, be
|
| 591 |
+
aware that for kernelshap, if we can compute the kernelshap values exactly
|
| 592 |
+
by enumerating all the coalitions.
|
| 593 |
+
focus_sample: Whether to use uncertainty sampling.
|
| 594 |
+
n_samples: If specified, n_samples with override the width setting feature
|
| 595 |
+
and compute the explanation with n_samples.
|
| 596 |
+
max_n_samples: The maximum number of samples to use. If the width is set to
|
| 597 |
+
a very small value and many samples are required, this serves
|
| 598 |
+
as a point to stop sampling.
|
| 599 |
+
focus_sample_batch_size: The batch size of focus sampling.
|
| 600 |
+
focus_sample_initial_points: The number of perturbations to collect before starting
|
| 601 |
+
focused sampling.
|
| 602 |
+
ptg_initial_points: The number perturbations to collect before computing the ptg estimate.
|
| 603 |
+
to_consider: The number of perturbations to consider in focused sampling.
|
| 604 |
+
feature_selection: Whether to do feature selection using Ridge regression. Note, currently
|
| 605 |
+
only implemented for BayesSHAP.
|
| 606 |
+
n_features: The number of features to use in feature selection.
|
| 607 |
+
tag: A tag to add the explanation.
|
| 608 |
+
only_coef: Only return the explanation means.
|
| 609 |
+
only_blr: Only return the bayesian regression object.
|
| 610 |
+
enumerate_initial: Whether to enumerate a set of initial shap coalitions.
|
| 611 |
+
l2: Whether to fit with l2 regression. Turning off the l2 regression can be useful for the shapley value estimation.
|
| 612 |
+
Returns:
|
| 613 |
+
explanation: The resulting feature importances, credible intervals, and bayes regression
|
| 614 |
+
object.
|
| 615 |
+
"""
|
| 616 |
+
assert isinstance(data, np.ndarray), "Data must be numpy array. Note, this means that classifier_f \
|
| 617 |
+
must accept numpy arrays."
|
| 618 |
+
self._log_args(locals())
|
| 619 |
+
|
| 620 |
+
if self.kernel == "lime" and self.data in ["tabular", "image"]:
|
| 621 |
+
output = self._explain_bayes_lime(data,
|
| 622 |
+
classifier_f,
|
| 623 |
+
label,
|
| 624 |
+
focus_sample,
|
| 625 |
+
cred_width,
|
| 626 |
+
n_samples,
|
| 627 |
+
max_n_samples,
|
| 628 |
+
focus_sample_batch_size,
|
| 629 |
+
focus_sample_initial_points,
|
| 630 |
+
ptg_initial_points,
|
| 631 |
+
to_consider)
|
| 632 |
+
elif self.kernel == "shap" and self.data in ["tabular", "image"]:
|
| 633 |
+
output = self._explain_bayes_shap(data,
|
| 634 |
+
classifier_f,
|
| 635 |
+
label,
|
| 636 |
+
focus_sample,
|
| 637 |
+
cred_width,
|
| 638 |
+
n_samples,
|
| 639 |
+
max_n_samples,
|
| 640 |
+
focus_sample_batch_size,
|
| 641 |
+
focus_sample_initial_points,
|
| 642 |
+
ptg_initial_points,
|
| 643 |
+
to_consider,
|
| 644 |
+
feature_selection=feature_selection,
|
| 645 |
+
n_features=n_features,
|
| 646 |
+
enumerate_initial=enumerate_initial,
|
| 647 |
+
max_coefs=max_coefs,
|
| 648 |
+
l2=l2)
|
| 649 |
+
else:
|
| 650 |
+
pass
|
| 651 |
+
|
| 652 |
+
output['tag'] = tag
|
| 653 |
+
|
| 654 |
+
if only_coef:
|
| 655 |
+
return output['coef']
|
| 656 |
+
|
| 657 |
+
if only_blr:
|
| 658 |
+
return output['blr']
|
| 659 |
+
|
| 660 |
+
return output
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def nCk(n, r):
|
| 664 |
+
"""n choose r
|
| 665 |
+
|
| 666 |
+
From: https://stackoverflow.com/questions/4941753/is-there-a-math-ncr-function-in-python"""
|
| 667 |
+
r = min(r, n-r)
|
| 668 |
+
numer = reduce(op.mul, range(n, n-r, -1), 1)
|
| 669 |
+
denom = reduce(op.mul, range(1, r+1), 1)
|
| 670 |
+
return numer / denom
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def do_exp(args):
|
| 674 |
+
"""Supporting function for the explanations."""
|
| 675 |
+
i, data, init_kwargs, exp_kwargs, labels, max_coefs, pass_args = args
|
| 676 |
+
def do(data_i, label):
|
| 677 |
+
|
| 678 |
+
if pass_args is not None and pass_args.balance_background_dataset:
|
| 679 |
+
init_kwargs['training_data'] = np.concatenate((data_i[None, :], np.zeros((1, data_i.shape[0]))), axis=0)
|
| 680 |
+
|
| 681 |
+
exp = BayesLocalExplanations(**init_kwargs)
|
| 682 |
+
exp_kwargs['tag'] = i
|
| 683 |
+
exp_kwargs['label'] = label
|
| 684 |
+
if max_coefs is not None:
|
| 685 |
+
exp_kwargs['max_coefs'] = max_coefs[i]
|
| 686 |
+
e = deepcopy(exp.explain(data_i, **exp_kwargs))
|
| 687 |
+
return e
|
| 688 |
+
if labels is not None:
|
| 689 |
+
return do(data[i], labels[i])
|
| 690 |
+
else:
|
| 691 |
+
return do(data[i], exp_kwargs['label'])
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def explain_many(all_data, init_kwargs, exp_kwargs, pool_size=1, verbose=False, labels=None, max_coefs=None, args=None):
|
| 695 |
+
"""Parallel explanations."""
|
| 696 |
+
with Pool(pool_size) as p:
|
| 697 |
+
if verbose:
|
| 698 |
+
results = list(tqdm(p.imap(do_exp, [(i, all_data, init_kwargs, exp_kwargs, labels, max_coefs, args) for i in range(all_data.shape[0])])))
|
| 699 |
+
else:
|
| 700 |
+
results = p.map(do_exp, [(i, all_data, init_kwargs, exp_kwargs, labels, max_coefs, args) for i in range(all_data.shape[0])])
|
| 701 |
+
return results
|
bayes/models.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Routines that implement processing data & getting models.
|
| 2 |
+
|
| 3 |
+
This file includes various routines for processing & acquiring models, for
|
| 4 |
+
later use in the code. The table data preprocessing is straightforward. We
|
| 5 |
+
first applying scaling to the data and fit a random forest classifier.
|
| 6 |
+
|
| 7 |
+
The processing of the image data is a bit more complex. To simplify the construction
|
| 8 |
+
of the explanations, the explanations don't accept images. Instead, for image explanations,
|
| 9 |
+
it is necessary to define a function that accept a array of 0's and 1's corresponding to
|
| 10 |
+
segments for a particular image being either excluded or included respectively. The explanation
|
| 11 |
+
is performed on this array.
|
| 12 |
+
"""
|
| 13 |
+
import numpy as np
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
|
| 16 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 17 |
+
from sklearn.preprocessing import StandardScaler
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torchvision import models, transforms
|
| 22 |
+
|
| 23 |
+
from data.mnist.mnist_model import Net
|
| 24 |
+
|
| 25 |
+
def get_xtrain(segs):
|
| 26 |
+
"""A function to get the mock training data to use in the image explanations.
|
| 27 |
+
|
| 28 |
+
This function returns a dataset containing a single instance of ones and
|
| 29 |
+
another of zeros to represent the training data for the explanation. The idea
|
| 30 |
+
is that the explanation will use this data to compute the perturbations, which
|
| 31 |
+
will then be fed into the wrapped model.
|
| 32 |
+
|
| 33 |
+
Arguments:
|
| 34 |
+
segs: The current segments array
|
| 35 |
+
"""
|
| 36 |
+
n_segs = len(np.unique(segs))
|
| 37 |
+
xtrain = np.concatenate((np.ones((1, n_segs)), np.zeros((1, n_segs))), axis=0)
|
| 38 |
+
return xtrain
|
| 39 |
+
|
| 40 |
+
def process_imagenet_get_model(data):
|
| 41 |
+
"""Gets wrapped imagenet model."""
|
| 42 |
+
|
| 43 |
+
# Get the vgg16 model, used in the experiments
|
| 44 |
+
model = models.vgg16(pretrained=True)
|
| 45 |
+
model.eval()
|
| 46 |
+
# model.cuda()
|
| 47 |
+
|
| 48 |
+
xtest = data['X']
|
| 49 |
+
ytest = data['y'].astype(int)
|
| 50 |
+
xtest_segs = data['segments']
|
| 51 |
+
|
| 52 |
+
softmax = torch.nn.Softmax(dim=1)
|
| 53 |
+
|
| 54 |
+
# Transforms
|
| 55 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 56 |
+
std=[0.229, 0.224, 0.225])
|
| 57 |
+
transf = transforms.Compose([
|
| 58 |
+
transforms.ToTensor(),
|
| 59 |
+
normalize
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
t_xtest = transf(xtest[0])[None, :]#.cuda()
|
| 63 |
+
|
| 64 |
+
# Define the wrapped model
|
| 65 |
+
def get_wrapped_model(instance, segments, background=0, batch_size=64):
|
| 66 |
+
def wrapped_model(data):
|
| 67 |
+
perturbed_images = []
|
| 68 |
+
for d in data:
|
| 69 |
+
perturbed_image = deepcopy(instance)
|
| 70 |
+
for i, is_on in enumerate(d):
|
| 71 |
+
if is_on == 0:
|
| 72 |
+
perturbed_image[segments==i, 0] = background
|
| 73 |
+
perturbed_image[segments==i, 1] = background
|
| 74 |
+
perturbed_image[segments==i, 2] = background
|
| 75 |
+
perturbed_images.append(transf(perturbed_image)[None, :])
|
| 76 |
+
perturbed_images = torch.from_numpy(np.concatenate(perturbed_images, axis=0)).float()
|
| 77 |
+
predictions = []
|
| 78 |
+
for q in range(0, perturbed_images.shape[0], batch_size):
|
| 79 |
+
predictions.append(softmax(model(perturbed_images[q:q+batch_size])).cpu().detach().numpy())
|
| 80 |
+
predictions = np.concatenate(predictions, axis=0)
|
| 81 |
+
return predictions
|
| 82 |
+
return wrapped_model
|
| 83 |
+
|
| 84 |
+
output = {
|
| 85 |
+
"model": get_wrapped_model,
|
| 86 |
+
"xtest": xtest,
|
| 87 |
+
"ytest": ytest,
|
| 88 |
+
"xtest_segs": xtest_segs,
|
| 89 |
+
"label": data['y'][0]
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
return output
|
| 93 |
+
|
| 94 |
+
def process_mnist_get_model(data):
|
| 95 |
+
"""Gets wrapped mnist model."""
|
| 96 |
+
xtest = data['X']
|
| 97 |
+
ytest = data['y'].astype(int)
|
| 98 |
+
xtest_segs = data['segments']
|
| 99 |
+
|
| 100 |
+
model = Net()
|
| 101 |
+
model.load_state_dict(torch.load("../data/mnist/mnist_cnn.pt"))
|
| 102 |
+
model.eval()
|
| 103 |
+
model.cuda()
|
| 104 |
+
|
| 105 |
+
softmax = torch.nn.Softmax(dim=1)
|
| 106 |
+
def get_wrapped_model(instance, segments, background=-0.4242, batch_size=100):
|
| 107 |
+
def wrapped_model(data):
|
| 108 |
+
perturbed_images = []
|
| 109 |
+
data = torch.from_numpy(data).float().cuda()
|
| 110 |
+
for d in data:
|
| 111 |
+
perturbed_image = deepcopy(instance)
|
| 112 |
+
for i, is_on in enumerate(d):
|
| 113 |
+
if is_on == 0:
|
| 114 |
+
a = segments==i
|
| 115 |
+
perturbed_image[0, segments[0]==i] = background
|
| 116 |
+
perturbed_images.append(perturbed_image[:, None])
|
| 117 |
+
perturbed_images = torch.from_numpy(np.concatenate(perturbed_images, axis=0)).float().cuda()
|
| 118 |
+
|
| 119 |
+
# Batch predictions if necessary
|
| 120 |
+
if perturbed_images.shape[0] > batch_size:
|
| 121 |
+
predictions = []
|
| 122 |
+
for q in range(0, perturbed_images.shape[0], batch_size):
|
| 123 |
+
predictions.append(softmax(model(perturbed_images[q:q+batch_size])).cpu().detach().numpy())
|
| 124 |
+
predictions = np.concatenate(predictions, axis=0)
|
| 125 |
+
else:
|
| 126 |
+
predictions = softmax(model(perturbed_images)).cpu().detach().numpy()
|
| 127 |
+
return np.array(predictions)
|
| 128 |
+
return wrapped_model
|
| 129 |
+
|
| 130 |
+
output = {
|
| 131 |
+
"model": get_wrapped_model,
|
| 132 |
+
"xtest": xtest,
|
| 133 |
+
"ytest": ytest,
|
| 134 |
+
"xtest_segs": xtest_segs,
|
| 135 |
+
"label": data['y'][0],
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
return output
|
| 139 |
+
|
| 140 |
+
def process_tabular_data_get_model(data):
|
| 141 |
+
"""Processes tabular data + trains random forest classifier."""
|
| 142 |
+
X = data['X']
|
| 143 |
+
y = data['y']
|
| 144 |
+
|
| 145 |
+
xtrain,xtest,ytrain,ytest = train_test_split(X,y,test_size=0.2)
|
| 146 |
+
ss = StandardScaler().fit(xtrain)
|
| 147 |
+
xtrain = ss.transform(xtrain)
|
| 148 |
+
xtest = ss.transform(xtest)
|
| 149 |
+
rf = RandomForestClassifier(n_estimators=100).fit(xtrain,ytrain)
|
| 150 |
+
|
| 151 |
+
output = {
|
| 152 |
+
"model": rf,
|
| 153 |
+
"xtrain": xtrain,
|
| 154 |
+
"xtest": xtest,
|
| 155 |
+
"ytrain": ytrain,
|
| 156 |
+
"ytest": ytest,
|
| 157 |
+
"label": 1,
|
| 158 |
+
"model_score": rf.score(xtest, ytest)
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
print(f"Model Score: {output['model_score']}")
|
| 162 |
+
|
| 163 |
+
return output
|
bayes/regression.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bayesian regression.
|
| 2 |
+
|
| 3 |
+
A class the implements the Bayesian Regression.
|
| 4 |
+
"""
|
| 5 |
+
import operator as op
|
| 6 |
+
from functools import reduce
|
| 7 |
+
import copy
|
| 8 |
+
import collections
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from scipy.stats import invgamma
|
| 12 |
+
from scipy.stats import multivariate_normal
|
| 13 |
+
|
| 14 |
+
class BayesianLinearRegression:
|
| 15 |
+
def __init__(self, percent=95, l2=True, prior=None):
|
| 16 |
+
if prior is not None:
|
| 17 |
+
raise NameError("Currently only support uninformative prior, set to None plz.")
|
| 18 |
+
|
| 19 |
+
self.percent = percent
|
| 20 |
+
self.l2 = l2
|
| 21 |
+
|
| 22 |
+
def fit(self, xtrain, ytrain, sample_weight, compute_creds=True):
|
| 23 |
+
"""
|
| 24 |
+
Fit the bayesian linear regression.
|
| 25 |
+
|
| 26 |
+
Arguments:
|
| 27 |
+
xtrain: the training data
|
| 28 |
+
ytrain: the training labels
|
| 29 |
+
sample_weight: the weights for fitting the regression
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
# store weights
|
| 33 |
+
weights = sample_weight
|
| 34 |
+
|
| 35 |
+
# add intercept
|
| 36 |
+
xtrain = np.concatenate((np.ones(xtrain.shape[0])[:,None], xtrain), axis=1)
|
| 37 |
+
diag_pi_z = np.zeros((len(weights), len(weights)))
|
| 38 |
+
np.fill_diagonal(diag_pi_z, weights)
|
| 39 |
+
|
| 40 |
+
if self.l2:
|
| 41 |
+
V_Phi = np.linalg.inv(xtrain.transpose().dot(diag_pi_z).dot(xtrain) \
|
| 42 |
+
+ np.eye(xtrain.shape[1]))
|
| 43 |
+
else:
|
| 44 |
+
V_Phi = np.linalg.inv(xtrain.transpose().dot(diag_pi_z).dot(xtrain))
|
| 45 |
+
|
| 46 |
+
Phi_hat = V_Phi.dot(xtrain.transpose()).dot(diag_pi_z).dot(ytrain)
|
| 47 |
+
|
| 48 |
+
N = xtrain.shape[0]
|
| 49 |
+
Y_m_Phi_hat = ytrain - xtrain.dot(Phi_hat)
|
| 50 |
+
|
| 51 |
+
s_2 = (1.0 / N) * (Y_m_Phi_hat.dot(diag_pi_z).dot(Y_m_Phi_hat) \
|
| 52 |
+
+ Phi_hat.transpose().dot(Phi_hat))
|
| 53 |
+
|
| 54 |
+
self.score = s_2
|
| 55 |
+
|
| 56 |
+
self.s_2 = s_2
|
| 57 |
+
self.N = N
|
| 58 |
+
self.V_Phi = V_Phi
|
| 59 |
+
self.Phi_hat = Phi_hat
|
| 60 |
+
self.coef_ = Phi_hat[1:]
|
| 61 |
+
self.intercept_ = Phi_hat[0]
|
| 62 |
+
self.weights = weights
|
| 63 |
+
|
| 64 |
+
if compute_creds:
|
| 65 |
+
self.creds = self.get_creds(percent=self.percent)
|
| 66 |
+
else:
|
| 67 |
+
self.creds = "NA"
|
| 68 |
+
|
| 69 |
+
self.crit_params = {
|
| 70 |
+
"s_2": self.s_2,
|
| 71 |
+
"N": self.N,
|
| 72 |
+
"V_Phi": self.V_Phi,
|
| 73 |
+
"Phi_hat": self.Phi_hat,
|
| 74 |
+
"creds": self.creds
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return self
|
| 78 |
+
|
| 79 |
+
def predict(self, data):
|
| 80 |
+
"""
|
| 81 |
+
The predictive distribution.
|
| 82 |
+
|
| 83 |
+
Arguments:
|
| 84 |
+
data: The data to predict
|
| 85 |
+
"""
|
| 86 |
+
q_1 = np.eye(data.shape[0])
|
| 87 |
+
data_ones = np.concatenate((np.ones(data.shape[0])[:,None], data), axis=1)
|
| 88 |
+
|
| 89 |
+
# Get response
|
| 90 |
+
response = np.matmul(data, self.coef_)
|
| 91 |
+
response += self.intercept_
|
| 92 |
+
|
| 93 |
+
# Compute var
|
| 94 |
+
temp = np.matmul(data_ones, self.V_Phi)
|
| 95 |
+
mat = np.matmul(temp, data_ones.transpose())
|
| 96 |
+
var = self.s_2 * (q_1 + mat)
|
| 97 |
+
diag = np.diagonal(var)
|
| 98 |
+
|
| 99 |
+
return response, np.sqrt(diag)
|
| 100 |
+
|
| 101 |
+
def get_ptg(self, desired_width):
|
| 102 |
+
"""
|
| 103 |
+
Compute the ptg perturbations.
|
| 104 |
+
"""
|
| 105 |
+
cert = (desired_width / 1.96) ** 2
|
| 106 |
+
S = self.coef_.shape[0] * self.s_2
|
| 107 |
+
T = np.mean(self.weights)
|
| 108 |
+
return 4 * S / (self.coef_.shape[0] * T * cert)
|
| 109 |
+
|
| 110 |
+
def get_creds(self, percent=95, n_samples=10_000, get_intercept=False):
|
| 111 |
+
"""
|
| 112 |
+
Get the credible intervals.
|
| 113 |
+
|
| 114 |
+
Arguments:
|
| 115 |
+
percent: the percent cutoff for the credible interval, i.e., 95 is 95% credible interval
|
| 116 |
+
n_samples: the number of samples to compute the credible interval
|
| 117 |
+
get_intercept: whether to include the intercept in the credible interval
|
| 118 |
+
"""
|
| 119 |
+
samples = self.draw_posterior_samples(n_samples, get_intercept=get_intercept)
|
| 120 |
+
creds = np.percentile(np.abs(samples - (self.Phi_hat if get_intercept else self.coef_)),
|
| 121 |
+
percent,
|
| 122 |
+
axis=0)
|
| 123 |
+
return creds
|
| 124 |
+
|
| 125 |
+
def draw_posterior_samples(self, num_samples, get_intercept=False):
|
| 126 |
+
"""
|
| 127 |
+
Sample from the posterior.
|
| 128 |
+
|
| 129 |
+
Arguments:
|
| 130 |
+
num_samples: number of samples to draw from the posterior
|
| 131 |
+
get_intercept: whether to include the intercept
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
sigma_2 = invgamma.rvs(self.N / 2, scale=(self.N * self.s_2) / 2, size=num_samples)
|
| 135 |
+
|
| 136 |
+
phi_samples = []
|
| 137 |
+
for sig in sigma_2:
|
| 138 |
+
sample = multivariate_normal.rvs(mean=self.Phi_hat,
|
| 139 |
+
cov=self.V_Phi * sig,
|
| 140 |
+
size=1)
|
| 141 |
+
phi_samples.append(sample)
|
| 142 |
+
|
| 143 |
+
phi_samples = np.vstack(phi_samples)
|
| 144 |
+
|
| 145 |
+
if get_intercept:
|
| 146 |
+
return phi_samples
|
| 147 |
+
else:
|
| 148 |
+
return phi_samples[:, 1:]
|
requirements.txt
CHANGED
|
@@ -6,7 +6,6 @@ astor
|
|
| 6 |
astunparse
|
| 7 |
attrs
|
| 8 |
backcall
|
| 9 |
-
bayes
|
| 10 |
beautifulsoup4
|
| 11 |
BHClustering
|
| 12 |
bleach
|
|
|
|
| 6 |
astunparse
|
| 7 |
attrs
|
| 8 |
backcall
|
|
|
|
| 9 |
beautifulsoup4
|
| 10 |
BHClustering
|
| 11 |
bleach
|