add app and generation / model code
Browse files
app.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from cli import iterative_refinement
|
| 6 |
+
from viz import grid_of_images_default
|
| 7 |
+
from subprocess
|
| 8 |
+
subprocess.call("download_models.sh", shell=True)
|
| 9 |
+
models = {
|
| 10 |
+
"convae": torch.load("convae.th", map_location="cpu"),
|
| 11 |
+
"deep_convae": torch.load("deep_convae.th", map_location="cpu"),
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
def gen(model, seed, nb_iter, nb_samples, width, height):
|
| 15 |
+
torch.manual_seed(int(seed))
|
| 16 |
+
bs = 64
|
| 17 |
+
model = models[model]
|
| 18 |
+
samples = iterative_refinement(
|
| 19 |
+
model,
|
| 20 |
+
nb_iter=int(nb_iter),
|
| 21 |
+
nb_examples=int(nb_samples),
|
| 22 |
+
w=int(width), h=int(height), c=1,
|
| 23 |
+
batch_size=bs,
|
| 24 |
+
)
|
| 25 |
+
grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1]))
|
| 26 |
+
grid = (grid*255).astype("uint8")
|
| 27 |
+
return Image.fromarray(grid)
|
| 28 |
+
|
| 29 |
+
iface = gr.Interface(
|
| 30 |
+
fn=gen,
|
| 31 |
+
inputs=[gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)],
|
| 32 |
+
outputs="image"
|
| 33 |
+
)
|
| 34 |
+
iface.launch()
|
cli.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import matplotlib as mpl
|
| 3 |
+
mpl.use('Agg')
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
from clize import run
|
| 8 |
+
import numpy as np
|
| 9 |
+
from skimage.io import imsave
|
| 10 |
+
|
| 11 |
+
from viz import grid_of_images_default
|
| 12 |
+
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from model import DenseAE
|
| 17 |
+
from model import ConvAE
|
| 18 |
+
from model import DeepConvAE
|
| 19 |
+
from model import SimpleConvAE
|
| 20 |
+
from model import ZAE
|
| 21 |
+
from model import KAE
|
| 22 |
+
from data import load_dataset
|
| 23 |
+
|
| 24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def plot_dataset(code_2d, categories):
|
| 28 |
+
colors = [
|
| 29 |
+
'r',
|
| 30 |
+
'b',
|
| 31 |
+
'g',
|
| 32 |
+
'crimson',
|
| 33 |
+
'gold',
|
| 34 |
+
'yellow',
|
| 35 |
+
'maroon',
|
| 36 |
+
'm',
|
| 37 |
+
'c',
|
| 38 |
+
'orange'
|
| 39 |
+
]
|
| 40 |
+
for cat in range(0, 10):
|
| 41 |
+
g = (categories == cat)
|
| 42 |
+
plt.scatter(
|
| 43 |
+
code_2d[g, 0],
|
| 44 |
+
code_2d[g, 1],
|
| 45 |
+
marker='+',
|
| 46 |
+
c=colors[cat],
|
| 47 |
+
s=40,
|
| 48 |
+
alpha=0.7,
|
| 49 |
+
label="digit {}".format(cat)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def plot_generated(code_2d, categories):
|
| 54 |
+
g = (categories < 0)
|
| 55 |
+
plt.scatter(
|
| 56 |
+
code_2d[g, 0],
|
| 57 |
+
code_2d[g, 1],
|
| 58 |
+
marker='+',
|
| 59 |
+
c='gray',
|
| 60 |
+
s=30
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def grid_embedding(h):
|
| 65 |
+
from lapjv import lapjv
|
| 66 |
+
from scipy.spatial.distance import cdist
|
| 67 |
+
assert int(np.sqrt(h.shape[0])) ** 2 == h.shape[0], 'Nb of examples must be a square number'
|
| 68 |
+
size = int(np.sqrt(h.shape[0]))
|
| 69 |
+
grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
|
| 70 |
+
cost_matrix = cdist(grid, h, "sqeuclidean").astype('float32')
|
| 71 |
+
cost_matrix = cost_matrix * (100000 / cost_matrix.max())
|
| 72 |
+
_, rows, cols = lapjv(cost_matrix)
|
| 73 |
+
return rows
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def save_weights(m, folder='.'):
|
| 77 |
+
if isinstance(m, nn.Linear):
|
| 78 |
+
w = m.weight.data
|
| 79 |
+
if w.size(1) == 28*28 or w.size(0) == 28*28:
|
| 80 |
+
w0, w1 = w.size(0), w.size(1)
|
| 81 |
+
if w0 == 28*28:
|
| 82 |
+
w = w.transpose(0, 1)
|
| 83 |
+
w = w.contiguous()
|
| 84 |
+
w = w.view(w.size(0), 1, 28, 28)
|
| 85 |
+
gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
|
| 86 |
+
imsave('{}/feat_{}.png'.format(folder, w0), gr)
|
| 87 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 88 |
+
w = m.weight.data
|
| 89 |
+
if w.size(0) in (32, 64, 128, 256, 512) and w.size(1) in (1, 3):
|
| 90 |
+
gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
|
| 91 |
+
imsave('{}/feat.png'.format(folder), gr)
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None):
|
| 95 |
+
if batch_size is None:
|
| 96 |
+
batch_size = nb_examples
|
| 97 |
+
x = torch.rand(nb_iter, nb_examples, c, w, h)
|
| 98 |
+
for i in range(1, nb_iter):
|
| 99 |
+
for j in range(0, nb_examples, batch_size):
|
| 100 |
+
oldv = x[i-1][j:j + batch_size].to(device)
|
| 101 |
+
newv = ae(oldv)
|
| 102 |
+
newv = newv.data.cpu()
|
| 103 |
+
x[i][j:j + batch_size] = newv
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_model(name, w, h, c):
|
| 108 |
+
if name == 'convae':
|
| 109 |
+
ae = ConvAE(
|
| 110 |
+
w=w, h=h, c=c,
|
| 111 |
+
nb_filters=128,
|
| 112 |
+
spatial=True,
|
| 113 |
+
channel=True,
|
| 114 |
+
channel_stride=4,
|
| 115 |
+
)
|
| 116 |
+
elif name == 'zae':
|
| 117 |
+
ae = ZAE(
|
| 118 |
+
w=w, h=h, c=c,
|
| 119 |
+
theta=3,
|
| 120 |
+
nb_hidden=1000,
|
| 121 |
+
)
|
| 122 |
+
elif name == 'kae':
|
| 123 |
+
ae = KAE(
|
| 124 |
+
w=w, h=h, c=c,
|
| 125 |
+
nb_active=1000,
|
| 126 |
+
nb_hidden=1000,
|
| 127 |
+
)
|
| 128 |
+
elif name == 'denseae':
|
| 129 |
+
ae = DenseAE(
|
| 130 |
+
w=w, h=h, c=c,
|
| 131 |
+
encode_hidden=[1000],
|
| 132 |
+
decode_hidden=[],
|
| 133 |
+
ksparse=True,
|
| 134 |
+
nb_active=50,
|
| 135 |
+
)
|
| 136 |
+
elif name == 'simple_convae':
|
| 137 |
+
ae = SimpleConvAE(
|
| 138 |
+
w=w, h=h, c=c,
|
| 139 |
+
nb_filters=128,
|
| 140 |
+
)
|
| 141 |
+
elif name == 'deep_convae':
|
| 142 |
+
ae = DeepConvAE(
|
| 143 |
+
w=w, h=h, c=c,
|
| 144 |
+
nb_filters=128,
|
| 145 |
+
spatial=True,
|
| 146 |
+
channel=True,
|
| 147 |
+
channel_stride=4,
|
| 148 |
+
nb_layers=3,
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError('Unknown model')
|
| 152 |
+
|
| 153 |
+
return ae
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def salt_and_pepper(X, proba=0.5):
|
| 157 |
+
a = (torch.rand(X.size()).to(device) <= (1 - proba)).float()
|
| 158 |
+
b = (torch.rand(X.size()).to(device) <= 0.5).float()
|
| 159 |
+
c = ((a == 0).float() * b)
|
| 160 |
+
return X * a + c
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walkback=False, denoise=False, epochs=100, batch_size=64, log_interval=100):
|
| 164 |
+
gamma = 0.99
|
| 165 |
+
dataset = load_dataset(dataset, split='train')
|
| 166 |
+
x0, _ = dataset[0]
|
| 167 |
+
c, h, w = x0.size()
|
| 168 |
+
dataloader = torch.utils.data.DataLoader(
|
| 169 |
+
dataset,
|
| 170 |
+
batch_size=batch_size,
|
| 171 |
+
shuffle=True,
|
| 172 |
+
num_workers=4
|
| 173 |
+
)
|
| 174 |
+
if resume:
|
| 175 |
+
ae = torch.load('{}/model.th'.format(folder))
|
| 176 |
+
ae = ae.to(device)
|
| 177 |
+
else:
|
| 178 |
+
ae = build_model(model, w=w, h=h, c=c)
|
| 179 |
+
ae = ae.to(device)
|
| 180 |
+
optim = torch.optim.Adadelta(ae.parameters(), lr=0.1, eps=1e-7, rho=0.95, weight_decay=0)
|
| 181 |
+
avg_loss = 0.
|
| 182 |
+
nb_updates = 0
|
| 183 |
+
_save_weights = partial(save_weights, folder=folder)
|
| 184 |
+
|
| 185 |
+
for epoch in range(epochs):
|
| 186 |
+
for X, y in dataloader:
|
| 187 |
+
ae.zero_grad()
|
| 188 |
+
X = X.to(device)
|
| 189 |
+
if hasattr(ae, 'nb_active'):
|
| 190 |
+
ae.nb_active = max(ae.nb_active - 1, 32)
|
| 191 |
+
# walkback + denoise
|
| 192 |
+
if walkback:
|
| 193 |
+
loss = 0.
|
| 194 |
+
x = X.data
|
| 195 |
+
nb = 5
|
| 196 |
+
for _ in range(nb):
|
| 197 |
+
x = salt_and_pepper(x, proba=0.3) # denoise
|
| 198 |
+
x = x.to(device)
|
| 199 |
+
x = ae(x) # reconstruct
|
| 200 |
+
Xr = x
|
| 201 |
+
loss += (((x - X) ** 2).view(X.size(0), -1).sum(1).mean()) / nb
|
| 202 |
+
x = (torch.rand(x.size()).to(device) <= x.data).float() # sample
|
| 203 |
+
# denoise only
|
| 204 |
+
elif denoise:
|
| 205 |
+
Xc = salt_and_pepper(X.data, proba=0.3)
|
| 206 |
+
Xr = ae(Xc)
|
| 207 |
+
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
|
| 208 |
+
# normal training
|
| 209 |
+
else:
|
| 210 |
+
Xr = ae(X)
|
| 211 |
+
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
|
| 212 |
+
loss.backward()
|
| 213 |
+
optim.step()
|
| 214 |
+
avg_loss = avg_loss * gamma + loss.item() * (1 - gamma)
|
| 215 |
+
if nb_updates % log_interval == 0:
|
| 216 |
+
print('Epoch : {:05d} AvgTrainLoss: {:.6f}, Batch Loss : {:.6f}'.format(epoch, avg_loss, loss.item() ))
|
| 217 |
+
gr = grid_of_images_default(np.array(Xr.data.tolist()))
|
| 218 |
+
imsave('{}/rec.png'.format(folder), gr)
|
| 219 |
+
ae.apply(_save_weights)
|
| 220 |
+
torch.save(ae, '{}/model.th'.format(folder))
|
| 221 |
+
nb_updates += 1
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_generate=100, tsne=False):
|
| 225 |
+
if not os.path.exists(folder):
|
| 226 |
+
os.makedirs(folder, exist_ok=True)
|
| 227 |
+
dataset = load_dataset(dataset, split='train')
|
| 228 |
+
x0, _ = dataset[0]
|
| 229 |
+
c, h, w = x0.size()
|
| 230 |
+
nb = nb_generate
|
| 231 |
+
print('Load model...')
|
| 232 |
+
if model_path is None:
|
| 233 |
+
model_path = os.path.join(folder, "model.th")
|
| 234 |
+
ae = torch.load(model_path, map_location="cpu")
|
| 235 |
+
ae = ae.to(device)
|
| 236 |
+
def enc(X):
|
| 237 |
+
batch_size = 64
|
| 238 |
+
h_list = []
|
| 239 |
+
for i in range(0, X.size(0), batch_size):
|
| 240 |
+
x = X[i:i + batch_size]
|
| 241 |
+
x = x.to(device)
|
| 242 |
+
name = ae.__class__.__name__
|
| 243 |
+
if name in ('ConvAE',):
|
| 244 |
+
h = ae.encode(x)
|
| 245 |
+
h, _ = h.max(2)
|
| 246 |
+
h = h.view((h.size(0), -1))
|
| 247 |
+
elif name in ('DenseAE',):
|
| 248 |
+
x = x.view(x.size(0), -1)
|
| 249 |
+
h = x
|
| 250 |
+
#h = ae.encode(x)
|
| 251 |
+
else:
|
| 252 |
+
h = x.view(x.size(0), -1)
|
| 253 |
+
h = h.data.cpu()
|
| 254 |
+
h_list.append(h)
|
| 255 |
+
return torch.cat(h_list, 0)
|
| 256 |
+
|
| 257 |
+
print('iterative refinement...')
|
| 258 |
+
g = iterative_refinement(
|
| 259 |
+
ae,
|
| 260 |
+
nb_iter=nb_iter,
|
| 261 |
+
nb_examples=nb,
|
| 262 |
+
w=w, h=h, c=c,
|
| 263 |
+
batch_size=64
|
| 264 |
+
)
|
| 265 |
+
np.savez('{}/generated.npz'.format(folder), X=g.numpy())
|
| 266 |
+
g_subset = g[:, 0:100]
|
| 267 |
+
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
|
| 268 |
+
imsave('{}/gen_full_iters.png'.format(folder), gr)
|
| 269 |
+
|
| 270 |
+
g = g[-1] # last iter
|
| 271 |
+
print(g.shape)
|
| 272 |
+
gr = grid_of_images_default(g.numpy())
|
| 273 |
+
imsave('{}/gen_full.png'.format(folder), gr)
|
| 274 |
+
|
| 275 |
+
if tsne:
|
| 276 |
+
from sklearn.manifold import TSNE
|
| 277 |
+
dataloader = torch.utils.data.DataLoader(
|
| 278 |
+
dataset,
|
| 279 |
+
batch_size=nb,
|
| 280 |
+
shuffle=True,
|
| 281 |
+
num_workers=1
|
| 282 |
+
)
|
| 283 |
+
print('Load data...')
|
| 284 |
+
X, y = next(iter(dataloader))
|
| 285 |
+
print('Encode data...')
|
| 286 |
+
xh = enc(X)
|
| 287 |
+
print('Encode generated...')
|
| 288 |
+
gh = enc(g)
|
| 289 |
+
X = X.numpy()
|
| 290 |
+
g = g.numpy()
|
| 291 |
+
xh = xh.numpy()
|
| 292 |
+
gh = gh.numpy()
|
| 293 |
+
|
| 294 |
+
a = np.concatenate((X, g), axis=0)
|
| 295 |
+
ah = np.concatenate((xh, gh), axis=0)
|
| 296 |
+
labels = np.array(y.tolist() + [-1] * len(g))
|
| 297 |
+
sne = TSNE()
|
| 298 |
+
print('fit tsne...')
|
| 299 |
+
ah = sne.fit_transform(ah)
|
| 300 |
+
print('grid embedding...')
|
| 301 |
+
|
| 302 |
+
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
|
| 303 |
+
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
|
| 304 |
+
rows = grid_embedding(ahsmall)
|
| 305 |
+
asmall = asmall[rows]
|
| 306 |
+
gr = grid_of_images_default(asmall)
|
| 307 |
+
imsave('{}/sne_grid.png'.format(folder), gr)
|
| 308 |
+
|
| 309 |
+
fig = plt.figure(figsize=(10, 10))
|
| 310 |
+
plot_dataset(ah, labels)
|
| 311 |
+
plot_generated(ah, labels)
|
| 312 |
+
plt.legend(loc='best')
|
| 313 |
+
plt.axis('off')
|
| 314 |
+
plt.savefig('{}/sne.png'.format(folder))
|
| 315 |
+
plt.close(fig)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if __name__ == '__main__':
|
| 320 |
+
run([train, test])
|
convert.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch, h5py
|
| 3 |
+
from model import *
|
| 4 |
+
w, h, c = 28, 28, 1
|
| 5 |
+
model_new = DeepConvAE(
|
| 6 |
+
w=w, h=h, c=c,
|
| 7 |
+
nb_filters=128,
|
| 8 |
+
spatial=True,
|
| 9 |
+
channel=True,
|
| 10 |
+
channel_stride=4,
|
| 11 |
+
# total layers = nb_layers*2, where we have nb_layers for encoder and nb_layers for decoder
|
| 12 |
+
nb_layers=3,
|
| 13 |
+
)
|
| 14 |
+
# model_old = h5py.File("mnist_deepconvae/model.h5")
|
| 15 |
+
model_old = h5py.File("/home/mehdi/work/code/out_of_class/ae/mnist/model.h5")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
print(model_new)
|
| 19 |
+
print(model_old["model_weights"].keys())
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
for name, param in model_new.named_parameters():
|
| 23 |
+
enc_or_decode, layer_id, bias_or_kernel = name.split(".")
|
| 24 |
+
|
| 25 |
+
if enc_or_decode == "encode":
|
| 26 |
+
layer_name = "conv2d"
|
| 27 |
+
else:
|
| 28 |
+
layer_name = "up_conv2d"
|
| 29 |
+
|
| 30 |
+
layer_id = (int(layer_id)//2) + 1
|
| 31 |
+
|
| 32 |
+
full_layer_name = f"{layer_name}_{layer_id}"
|
| 33 |
+
print(full_layer_name)
|
| 34 |
+
|
| 35 |
+
k = "kernel" if bias_or_kernel == "weight" else "bias"
|
| 36 |
+
weights = model_old["model_weights"][full_layer_name][full_layer_name][k][()]
|
| 37 |
+
weights = np.array(weights)
|
| 38 |
+
weights = torch.from_numpy(weights)
|
| 39 |
+
print(name, layer_id, param.shape, weights.shape)
|
| 40 |
+
inds = [4,3,2,1,0]
|
| 41 |
+
if k == "kernel":
|
| 42 |
+
if layer_name == "conv2d":
|
| 43 |
+
weights = weights.permute((3,2,0,1))
|
| 44 |
+
weights = weights[:,:,inds]
|
| 45 |
+
weights = weights[:,:,:, inds]
|
| 46 |
+
print("W", weights.shape)
|
| 47 |
+
elif layer_name == "up_conv2d":
|
| 48 |
+
weights = weights.permute((2,3,0,1))
|
| 49 |
+
print(param.shape, weights.shape)
|
| 50 |
+
param.data.copy_(weights)
|
| 51 |
+
print((param-weights).sum())
|
| 52 |
+
torch.save(model_new, "mnist_deepconvae/model.th")
|
data.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
import torchvision.datasets as dset
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Invert:
|
| 8 |
+
def __call__(self, x):
|
| 9 |
+
return 1 - x
|
| 10 |
+
|
| 11 |
+
class Gray:
|
| 12 |
+
def __call__(self, x):
|
| 13 |
+
return x[0:1]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_dataset(dataset_name, split='full'):
|
| 18 |
+
if dataset_name == 'mnist':
|
| 19 |
+
dataset = dset.MNIST(
|
| 20 |
+
root='data/mnist',
|
| 21 |
+
download=True,
|
| 22 |
+
transform=transforms.Compose([
|
| 23 |
+
transforms.ToTensor(),
|
| 24 |
+
])
|
| 25 |
+
)
|
| 26 |
+
return dataset
|
| 27 |
+
elif dataset_name == 'coco':
|
| 28 |
+
dataset = dset.ImageFolder(root='data/coco',
|
| 29 |
+
transform=transforms.Compose([
|
| 30 |
+
transforms.Scale(64),
|
| 31 |
+
transforms.CenterCrop(64),
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
]))
|
| 34 |
+
return dataset
|
| 35 |
+
elif dataset_name == 'quickdraw':
|
| 36 |
+
X = (np.load('data/quickdraw/teapot.npy'))
|
| 37 |
+
X = X.reshape((X.shape[0], 28, 28))
|
| 38 |
+
X = X / 255.
|
| 39 |
+
X = X.astype(np.float32)
|
| 40 |
+
X = torch.from_numpy(X)
|
| 41 |
+
dataset = TensorDataset(X, X)
|
| 42 |
+
return dataset
|
| 43 |
+
elif dataset_name == 'shoes':
|
| 44 |
+
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images/Shoes',
|
| 45 |
+
transform=transforms.Compose([
|
| 46 |
+
transforms.Scale(64),
|
| 47 |
+
transforms.CenterCrop(64),
|
| 48 |
+
transforms.ToTensor(),
|
| 49 |
+
]))
|
| 50 |
+
return dataset
|
| 51 |
+
elif dataset_name == 'footwear':
|
| 52 |
+
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images',
|
| 53 |
+
transform=transforms.Compose([
|
| 54 |
+
transforms.Scale(64),
|
| 55 |
+
transforms.CenterCrop(64),
|
| 56 |
+
transforms.ToTensor(),
|
| 57 |
+
]))
|
| 58 |
+
return dataset
|
| 59 |
+
elif dataset_name == 'celeba':
|
| 60 |
+
dataset = dset.ImageFolder(root='data/celeba',
|
| 61 |
+
transform=transforms.Compose([
|
| 62 |
+
transforms.Scale(32),
|
| 63 |
+
transforms.CenterCrop(32),
|
| 64 |
+
transforms.ToTensor(),
|
| 65 |
+
]))
|
| 66 |
+
return dataset
|
| 67 |
+
elif dataset_name == 'birds':
|
| 68 |
+
dataset = dset.ImageFolder(root='data/birds/'+split,
|
| 69 |
+
transform=transforms.Compose([
|
| 70 |
+
transforms.Scale(32),
|
| 71 |
+
transforms.CenterCrop(32),
|
| 72 |
+
transforms.ToTensor(),
|
| 73 |
+
]))
|
| 74 |
+
return dataset
|
| 75 |
+
elif dataset_name == 'sketchy':
|
| 76 |
+
dataset = dset.ImageFolder(root='data/sketchy/'+split,
|
| 77 |
+
transform=transforms.Compose([
|
| 78 |
+
transforms.Scale(64),
|
| 79 |
+
transforms.CenterCrop(64),
|
| 80 |
+
transforms.ToTensor(),
|
| 81 |
+
Gray()
|
| 82 |
+
]))
|
| 83 |
+
return dataset
|
| 84 |
+
|
| 85 |
+
elif dataset_name == 'fonts':
|
| 86 |
+
dataset = dset.ImageFolder(root='data/fonts/'+split,
|
| 87 |
+
transform=transforms.Compose([
|
| 88 |
+
transforms.ToTensor(),
|
| 89 |
+
Invert(),
|
| 90 |
+
Gray(),
|
| 91 |
+
]))
|
| 92 |
+
return dataset
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError('Error : unknown dataset')
|
model.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn.init import xavier_uniform
|
| 5 |
+
|
| 6 |
+
class KAE(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, w=32, h=32, c=1, nb_hidden=300, nb_active=16):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.nb_hidden = nb_hidden
|
| 11 |
+
self.nb_active = nb_active
|
| 12 |
+
self.encode = nn.Sequential(
|
| 13 |
+
nn.Linear(w*h*c, nb_hidden, bias=False)
|
| 14 |
+
)
|
| 15 |
+
self.bias = nn.Parameter(torch.zeros(w*h*c))
|
| 16 |
+
self.params = nn.ParameterList([self.bias])
|
| 17 |
+
self.apply(_weights_init)
|
| 18 |
+
|
| 19 |
+
def forward(self, X):
|
| 20 |
+
size = X.size()
|
| 21 |
+
X = X.view(X.size(0), -1)
|
| 22 |
+
h = self.encode(X)
|
| 23 |
+
Xr, _ = self.decode(h)
|
| 24 |
+
Xr = Xr.view(size)
|
| 25 |
+
return Xr
|
| 26 |
+
|
| 27 |
+
def decode(self, h):
|
| 28 |
+
thetas, _ = torch.sort(h, dim=1, descending=True)
|
| 29 |
+
thetas = thetas[:, self.nb_active:self.nb_active+1]
|
| 30 |
+
h = h * (h > thetas).float()
|
| 31 |
+
Xr = torch.matmul(h, self.encode[0].weight) + self.bias
|
| 32 |
+
Xr = nn.Sigmoid()(Xr)
|
| 33 |
+
return Xr, h
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ZAE(nn.Module):
|
| 37 |
+
|
| 38 |
+
def __init__(self, w=32, h=32, c=1, nb_hidden=300, theta=1):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.nb_hidden = nb_hidden
|
| 41 |
+
self.theta = theta
|
| 42 |
+
self.encode = nn.Sequential(
|
| 43 |
+
nn.Linear(w*h*c, nb_hidden, bias=False)
|
| 44 |
+
)
|
| 45 |
+
self.bias = nn.Parameter(torch.zeros(w*h*c))
|
| 46 |
+
self.params = nn.ParameterList([self.bias])
|
| 47 |
+
self.apply(_weights_init)
|
| 48 |
+
|
| 49 |
+
def forward(self, X):
|
| 50 |
+
size = X.size()
|
| 51 |
+
X = X.view(X.size(0), -1)
|
| 52 |
+
h = self.encode(X)
|
| 53 |
+
Xr, _ = self.decode(h)
|
| 54 |
+
Xr = Xr.view(size)
|
| 55 |
+
return Xr
|
| 56 |
+
|
| 57 |
+
def decode(self, h):
|
| 58 |
+
h = h * (h > self.theta).float()
|
| 59 |
+
Xr = torch.matmul(h, self.encode[0].weight) + self.bias
|
| 60 |
+
Xr = nn.Sigmoid()(Xr)
|
| 61 |
+
return Xr, h
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DenseAE(nn.Module):
|
| 66 |
+
|
| 67 |
+
def __init__(self, w=32, h=32, c=1, encode_hidden=(300,), decode_hidden=(300,), ksparse=True, nb_active=10, denoise=None):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.encode_hidden = encode_hidden
|
| 70 |
+
self.decode_hidden = decode_hidden
|
| 71 |
+
self.ksparse = ksparse
|
| 72 |
+
self.nb_active = nb_active
|
| 73 |
+
self.denoise = denoise
|
| 74 |
+
|
| 75 |
+
# encode layers
|
| 76 |
+
layers = []
|
| 77 |
+
hid_prev = w * h * c
|
| 78 |
+
for hid in encode_hidden:
|
| 79 |
+
layers.extend([
|
| 80 |
+
nn.Linear(hid_prev, hid),
|
| 81 |
+
nn.ReLU(True)
|
| 82 |
+
])
|
| 83 |
+
hid_prev = hid
|
| 84 |
+
self.encode = nn.Sequential(*layers)
|
| 85 |
+
|
| 86 |
+
# decode layers
|
| 87 |
+
layers = []
|
| 88 |
+
for hid in decode_hidden:
|
| 89 |
+
layers.extend([
|
| 90 |
+
nn.Linear(hid_prev, hid),
|
| 91 |
+
nn.ReLU(True)
|
| 92 |
+
])
|
| 93 |
+
hid_prev = hid
|
| 94 |
+
layers.extend([
|
| 95 |
+
nn.Linear(hid_prev, w * h * c),
|
| 96 |
+
nn.Sigmoid()
|
| 97 |
+
])
|
| 98 |
+
self.decode = nn.Sequential(*layers)
|
| 99 |
+
|
| 100 |
+
self.apply(_weights_init)
|
| 101 |
+
|
| 102 |
+
def forward(self, X):
|
| 103 |
+
size = X.size()
|
| 104 |
+
if self.denoise is not None:
|
| 105 |
+
X = X * ((torch.rand(X.size()) <= self.denoise).float()).to(X.device)
|
| 106 |
+
X = X.view(X.size(0), -1)
|
| 107 |
+
h = self.encode(X)
|
| 108 |
+
if self.ksparse:
|
| 109 |
+
h = ksparse(h, nb_active=self.nb_active)
|
| 110 |
+
Xr = self.decode(h)
|
| 111 |
+
Xr = Xr.view(size)
|
| 112 |
+
return Xr
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def ksparse(x, nb_active=10):
|
| 117 |
+
mask = torch.ones(x.size())
|
| 118 |
+
for i, xi in enumerate(x.data.tolist()):
|
| 119 |
+
inds = np.argsort(xi)
|
| 120 |
+
inds = inds[::-1]
|
| 121 |
+
inds = inds[nb_active:]
|
| 122 |
+
if len(inds):
|
| 123 |
+
inds = np.array(inds)
|
| 124 |
+
inds = torch.from_numpy(inds).long()
|
| 125 |
+
mask[i][inds] = 0
|
| 126 |
+
return x * (mask).float().to(x.device)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ConvAE(nn.Module):
|
| 130 |
+
|
| 131 |
+
def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.spatial = spatial
|
| 134 |
+
self.channel = channel
|
| 135 |
+
self.channel_stride = channel_stride
|
| 136 |
+
self.encode = nn.Sequential(
|
| 137 |
+
nn.Conv2d(c, nb_filters, 5, 1, 0),
|
| 138 |
+
nn.ReLU(True),
|
| 139 |
+
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
|
| 140 |
+
nn.ReLU(True),
|
| 141 |
+
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
|
| 142 |
+
)
|
| 143 |
+
self.decode = nn.Sequential(
|
| 144 |
+
nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
|
| 145 |
+
nn.Sigmoid()
|
| 146 |
+
)
|
| 147 |
+
self.apply(_weights_init)
|
| 148 |
+
|
| 149 |
+
def forward(self, X):
|
| 150 |
+
size = X.size()
|
| 151 |
+
h = self.encode(X)
|
| 152 |
+
h = self.sparsify(h)
|
| 153 |
+
Xr = self.decode(h)
|
| 154 |
+
return Xr
|
| 155 |
+
|
| 156 |
+
def sparsify(self, h):
|
| 157 |
+
if self.spatial:
|
| 158 |
+
h = spatial_sparsity(h)
|
| 159 |
+
if self.channel:
|
| 160 |
+
h = strided_channel_sparsity(h, stride=self.channel_stride)
|
| 161 |
+
return h
|
| 162 |
+
|
| 163 |
+
class SimpleConvAE(nn.Module):
|
| 164 |
+
|
| 165 |
+
def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.spatial = spatial
|
| 168 |
+
self.channel = channel
|
| 169 |
+
self.channel_stride = channel_stride
|
| 170 |
+
self.encode = nn.Sequential(
|
| 171 |
+
nn.Conv2d(c, nb_filters, 13, 1, 0),
|
| 172 |
+
nn.ReLU(True),
|
| 173 |
+
)
|
| 174 |
+
self.decode = nn.Sequential(
|
| 175 |
+
nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
|
| 176 |
+
nn.Sigmoid()
|
| 177 |
+
)
|
| 178 |
+
self.apply(_weights_init)
|
| 179 |
+
|
| 180 |
+
def forward(self, X):
|
| 181 |
+
size = X.size()
|
| 182 |
+
h = self.encode(X)
|
| 183 |
+
h = self.sparsify(h)
|
| 184 |
+
Xr = self.decode(h)
|
| 185 |
+
return Xr
|
| 186 |
+
|
| 187 |
+
def sparsify(self, h):
|
| 188 |
+
if self.spatial:
|
| 189 |
+
h = spatial_sparsity(h)
|
| 190 |
+
if self.channel:
|
| 191 |
+
h = strided_channel_sparsity(h, stride=self.channel_stride)
|
| 192 |
+
return h
|
| 193 |
+
|
| 194 |
+
class DeepConvAE(nn.Module):
|
| 195 |
+
|
| 196 |
+
def __init__(self, w=32, h=32, c=1, nb_filters=64, nb_layers=3, spatial=True, channel=True, channel_stride=4):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.spatial = spatial
|
| 199 |
+
self.channel = channel
|
| 200 |
+
self.channel_stride = channel_stride
|
| 201 |
+
|
| 202 |
+
layers = [
|
| 203 |
+
nn.Conv2d(c, nb_filters, 5, 1, 0),
|
| 204 |
+
nn.ReLU(True),
|
| 205 |
+
]
|
| 206 |
+
for _ in range(nb_layers - 1):
|
| 207 |
+
layers.extend([
|
| 208 |
+
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
|
| 209 |
+
nn.ReLU(True),
|
| 210 |
+
])
|
| 211 |
+
self.encode = nn.Sequential(*layers)
|
| 212 |
+
layers = []
|
| 213 |
+
for _ in range(nb_layers - 1):
|
| 214 |
+
layers.extend([
|
| 215 |
+
nn.ConvTranspose2d(nb_filters, nb_filters, 5, 1, 0),
|
| 216 |
+
nn.ReLU(True),
|
| 217 |
+
])
|
| 218 |
+
layers.extend([
|
| 219 |
+
nn.ConvTranspose2d(nb_filters, c, 5, 1, 0),
|
| 220 |
+
nn.Sigmoid()
|
| 221 |
+
])
|
| 222 |
+
self.decode = nn.Sequential(*layers)
|
| 223 |
+
self.apply(_weights_init)
|
| 224 |
+
|
| 225 |
+
def forward(self, X):
|
| 226 |
+
size = X.size()
|
| 227 |
+
h = self.encode(X)
|
| 228 |
+
h = self.sparsify(h)
|
| 229 |
+
Xr = self.decode(h)
|
| 230 |
+
return Xr
|
| 231 |
+
|
| 232 |
+
def sparsify(self, h):
|
| 233 |
+
if self.spatial:
|
| 234 |
+
h = spatial_sparsity(h)
|
| 235 |
+
if self.channel:
|
| 236 |
+
h = strided_channel_sparsity(h, stride=self.channel_stride)
|
| 237 |
+
return h
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def spatial_sparsity(x):
|
| 241 |
+
maxes = x.amax(dim=(2,3), keepdims=True)
|
| 242 |
+
return x * equals(x, maxes)
|
| 243 |
+
|
| 244 |
+
def equals(x, y, eps=1e-8):
|
| 245 |
+
return torch.abs(x-y) <= eps
|
| 246 |
+
|
| 247 |
+
def strided_channel_sparsity(x, stride=1):
|
| 248 |
+
B, F = x.shape[0:2]
|
| 249 |
+
h, w = x.shape[2:]
|
| 250 |
+
x_ = x.view(B, F, h // stride, stride, w // stride, stride)
|
| 251 |
+
mask = equals(x_, x_.amax(axis=(1, 3, 5), keepdims=True))
|
| 252 |
+
mask = mask.view(x.shape).float()
|
| 253 |
+
return x * mask
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _weights_init(m):
|
| 257 |
+
if hasattr(m, 'weight'):
|
| 258 |
+
xavier_uniform(m.weight.data)
|
| 259 |
+
if m.bias is not None:
|
| 260 |
+
m.bias.data.fill_(0)
|
test.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from machinedesign.autoencoder.interface import load
|
| 4 |
+
from keras.models import Model
|
| 5 |
+
torch.use_deterministic_algorithms(True)
|
| 6 |
+
model = torch.load("mnist_deepconvae/model.th")
|
| 7 |
+
model_keras = load("/home/mehdi/work/code/out_of_class/ae/mnist")
|
| 8 |
+
print(model_keras.layers[8])
|
| 9 |
+
|
| 10 |
+
m = Model(model_keras.inputs, model_keras.layers[8].output)
|
| 11 |
+
X = torch.rand(1,1,28,28)
|
| 12 |
+
with torch.no_grad():
|
| 13 |
+
# X1 = model.sparsify(model.encode(X))
|
| 14 |
+
X1 = model(X)
|
| 15 |
+
X2 = model_keras.predict(X)
|
| 16 |
+
X2 = torch.from_numpy(X2)
|
| 17 |
+
print(torch.abs(X1-X2).sum())
|
| 18 |
+
# for i in range(128):
|
| 19 |
+
# print(i, torch.abs(X1[0,i]-X2[0,i]).sum())
|
| 20 |
+
# print(X1[0,i, 0, :])
|
| 21 |
+
# print(X2[0,i,0, :])
|
viz.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains common visualization functions
|
| 3 |
+
used to report results of the models.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from functools import partial
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def horiz_merge(left, right):
|
| 11 |
+
"""
|
| 12 |
+
merges two images, left and right horizontally to obtain
|
| 13 |
+
a bigger image containing both.
|
| 14 |
+
|
| 15 |
+
Parameters
|
| 16 |
+
---------
|
| 17 |
+
left: 2D or 3D numpy array
|
| 18 |
+
left image.
|
| 19 |
+
2D for grayscale.
|
| 20 |
+
3D for color.
|
| 21 |
+
right : numpy array array
|
| 22 |
+
right image.
|
| 23 |
+
2D for grayscale
|
| 24 |
+
3D for color.
|
| 25 |
+
|
| 26 |
+
Returns
|
| 27 |
+
-------
|
| 28 |
+
|
| 29 |
+
numpy array (2D or 3D depending on left and right)
|
| 30 |
+
"""
|
| 31 |
+
assert left.shape[0] == right.shape[0]
|
| 32 |
+
assert left.shape[2:] == right.shape[2:]
|
| 33 |
+
shape = (left.shape[0], left.shape[1] + right.shape[1],) + left.shape[2:]
|
| 34 |
+
im_merge = np.zeros(shape)
|
| 35 |
+
im_merge[:, 0:left.shape[1]] = left
|
| 36 |
+
im_merge[:, left.shape[1]:] = right
|
| 37 |
+
return im_merge
|
| 38 |
+
|
| 39 |
+
def vert_merge(top, bottom):
|
| 40 |
+
"""
|
| 41 |
+
merges two images, top and bottom vertically to obtain
|
| 42 |
+
a bigger image containing both.
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
---------
|
| 46 |
+
top: 2D or 3D numpy array
|
| 47 |
+
top image.
|
| 48 |
+
2D for grayscale.
|
| 49 |
+
3D for color.
|
| 50 |
+
bottom : numpy array array
|
| 51 |
+
bottom image.
|
| 52 |
+
2D for grayscale
|
| 53 |
+
3D for color.
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
|
| 58 |
+
numpy array (2D or 3D depending on left and right)
|
| 59 |
+
"""
|
| 60 |
+
im = horiz_merge(top, bottom)
|
| 61 |
+
if len(im.shape) == 2:
|
| 62 |
+
im = im.transpose((1, 0))
|
| 63 |
+
elif len(im.shape) == 3:
|
| 64 |
+
im = im.transpose((1, 0, 2))
|
| 65 |
+
return im
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def grid_of_images(M, border=0, bordercolor=[0.0, 0.0, 0.0], shape=None, normalize=False):
|
| 69 |
+
"""
|
| 70 |
+
Draw a grid of images from M
|
| 71 |
+
The order in the grid which corresponds to the order in M
|
| 72 |
+
is starting from top to bottom then left to right.
|
| 73 |
+
|
| 74 |
+
Parameters
|
| 75 |
+
----------
|
| 76 |
+
|
| 77 |
+
M : numpy array
|
| 78 |
+
if 3D, convert it to 4D, the shape will be interpreted as (nb_images, h, w) and converted to (nb_images, 1, h, w).
|
| 79 |
+
if 4D, consider it as colored or grayscale
|
| 80 |
+
- if the shape is (nb_images, nb_colors, h, w), it is converted to (nb_images, h, w, nb_colors)
|
| 81 |
+
- otherwise, if it already (nb_images, h, w, nb_colors), use it as it is.
|
| 82 |
+
- nb_colors can be 1 (grayscale) or 3 (colors).
|
| 83 |
+
border: int
|
| 84 |
+
thickness of border(default=0)
|
| 85 |
+
shape: tuple (nb_cols, nb_rows)
|
| 86 |
+
shape of the grid
|
| 87 |
+
by default make a square shape
|
| 88 |
+
(in that case, it is possible that not all images from M will be part of the grid).
|
| 89 |
+
normalize: bool(default=False)
|
| 90 |
+
whether to normalize the pixel values of each image independently
|
| 91 |
+
by min and max. if False, clip the values of pixels to 0 and 1
|
| 92 |
+
without normalizing.
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
|
| 97 |
+
3D numpy array of shape (h, w, 3)
|
| 98 |
+
(with a color channel regardless of whether the original images were grayscale or colored)
|
| 99 |
+
"""
|
| 100 |
+
if len(M.shape) == 3:
|
| 101 |
+
M = M[:, :, :, np.newaxis]
|
| 102 |
+
if M.shape[-1] not in (1, 3):
|
| 103 |
+
M = M.transpose((0, 2, 3, 1))
|
| 104 |
+
if M.shape[-1] == 1:
|
| 105 |
+
M = np.ones((1, 1, 1, 3)) * M
|
| 106 |
+
bordercolor = np.array(bordercolor)[None, None, :]
|
| 107 |
+
numimages = len(M)
|
| 108 |
+
M = M.copy()
|
| 109 |
+
|
| 110 |
+
if normalize:
|
| 111 |
+
for i in range(M.shape[0]):
|
| 112 |
+
M[i] -= M[i].flatten().min()
|
| 113 |
+
M[i] /= M[i].flatten().max()
|
| 114 |
+
else:
|
| 115 |
+
M = np.clip(M, 0, 1)
|
| 116 |
+
height, width, color = M[0].shape
|
| 117 |
+
assert color == 3, 'Nb of color channels are {}'.format(color)
|
| 118 |
+
if shape is None:
|
| 119 |
+
n0 = np.int(np.ceil(np.sqrt(numimages)))
|
| 120 |
+
n1 = np.int(np.ceil(np.sqrt(numimages)))
|
| 121 |
+
else:
|
| 122 |
+
n0 = shape[0]
|
| 123 |
+
n1 = shape[1]
|
| 124 |
+
|
| 125 |
+
im = np.array(bordercolor) * np.ones(
|
| 126 |
+
((height + border) * n1 + border, (width + border) * n0 + border, 1), dtype='<f8')
|
| 127 |
+
# shape = (n0, n1)
|
| 128 |
+
# j corresponds to rows in the grid, n1 should correspond to nb of rows
|
| 129 |
+
# i corresponds to columns in the grid, n0 should correspond to nb of cols
|
| 130 |
+
# M should be such that the first n1 examples correspond to row 1,
|
| 131 |
+
# next n1 examples correspond to row 2, etc. that is, M first axis
|
| 132 |
+
# can be reshaped to (n1, n0)
|
| 133 |
+
for i in range(n0):
|
| 134 |
+
for j in range(n1):
|
| 135 |
+
if i * n1 + j < numimages:
|
| 136 |
+
im[j * (height + border) + border:(j + 1) * (height + border) + border,
|
| 137 |
+
i * (width + border) + border:(i + 1) * (width + border) + border, :] = np.concatenate((
|
| 138 |
+
np.concatenate((M[i * n1 + j, :, :, :],
|
| 139 |
+
bordercolor * np.ones((height, border, 3), dtype=float)), 1),
|
| 140 |
+
bordercolor * np.ones((border, width + border, 3), dtype=float)
|
| 141 |
+
), 0)
|
| 142 |
+
return im
|
| 143 |
+
|
| 144 |
+
grid_of_images_default = partial(grid_of_images, border=1, bordercolor=(0.3, 0, 0))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def reshape_to_images(x, input_shape=None):
|
| 148 |
+
"""
|
| 149 |
+
a function that takes a numpy array and try to
|
| 150 |
+
reshape it to an array of images that would
|
| 151 |
+
be compatible with the function grid_of_images.
|
| 152 |
+
Two cases are considered.
|
| 153 |
+
|
| 154 |
+
if x is a 2D numpy array, it uses input_shape:
|
| 155 |
+
- x can either be (nb_examples, nb_features) or (nb_features, nb_examples)
|
| 156 |
+
- nb_features should be prod(input_shape)
|
| 157 |
+
- the nb_features dim is then expanded to have :
|
| 158 |
+
(nb_examples, h, w, nb_channels), sorted input_shape shoud
|
| 159 |
+
be (h, w, nb_channels).
|
| 160 |
+
|
| 161 |
+
if x is a 4D numpy array:
|
| 162 |
+
- if the first tensor dim is 1 or 3 like e.g. (1, a, b, c), then assume it is
|
| 163 |
+
color channel and transform to (a, 1, b, c)
|
| 164 |
+
- if the second tensor dim is 1 or 3, leave x it as it is
|
| 165 |
+
- if the third tensor dim is 1 or 3, like e.g. (a, b, 1, c), then assume it is
|
| 166 |
+
color channel and transform to (c, 1, a, b)
|
| 167 |
+
- if the fourth tensor dim is 1 or 3, like e.g. (a, b, c, 1), then assume it is
|
| 168 |
+
color channel and transform to (c, 1, a, b)
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
|
| 172 |
+
x : numpy array
|
| 173 |
+
input to be reshape
|
| 174 |
+
input_shape : tuple needed only when x is 2D numpy array
|
| 175 |
+
"""
|
| 176 |
+
if len(x.shape) == 2:
|
| 177 |
+
assert input_shape is not None
|
| 178 |
+
if x.shape[0] == np.prod(input_shape):
|
| 179 |
+
x = x.T
|
| 180 |
+
x = x.reshape((x.shape[0],) + input_shape)
|
| 181 |
+
x = x.transpose((0, 2, 3, 1))
|
| 182 |
+
return x
|
| 183 |
+
elif x.shape[1] == np.prod(input_shape):
|
| 184 |
+
x = x.reshape((x.shape[0],) + input_shape)
|
| 185 |
+
x = x.transpose((0, 2, 3, 1))
|
| 186 |
+
return x
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError('Cant recognize this shape : {}'.format(x.shape))
|
| 189 |
+
elif len(x.shape) == 4:
|
| 190 |
+
if x.shape[0] in (1, 3):
|
| 191 |
+
x = x.transpose((1, 0, 2, 3))
|
| 192 |
+
return x
|
| 193 |
+
elif x.shape[1] in (1, 3):
|
| 194 |
+
return x
|
| 195 |
+
elif x.shape[2] in (1, 3):
|
| 196 |
+
x = x.transpose((3, 2, 0, 1))
|
| 197 |
+
return x
|
| 198 |
+
elif x.shape[3] in (1, 3):
|
| 199 |
+
x = x.transpose((2, 3, 0, 1))
|
| 200 |
+
return x
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError('Cant recognize a shape of size : {}'.format(len(x.shape)))
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError('Cant recognize a shape of size : {}'.format(len(x.shape)))
|