add ksparse
Browse files- .gitattributes +1 -0
- app.py +19 -7
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
convae.th filter=lfs diff=lfs merge=lfs -text
|
| 36 |
deep_convae.th filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
convae.th filter=lfs diff=lfs merge=lfs -text
|
| 36 |
deep_convae.th filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
fc_sparse.th filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torchvision
|
| 3 |
import gradio as gr
|
|
@@ -5,13 +6,16 @@ from PIL import Image
|
|
| 5 |
from cli import iterative_refinement
|
| 6 |
from viz import grid_of_images_default
|
| 7 |
models = {
|
| 8 |
-
"
|
| 9 |
-
"
|
|
|
|
| 10 |
}
|
| 11 |
-
def gen(md,
|
| 12 |
torch.manual_seed(int(seed))
|
| 13 |
bs = 64
|
| 14 |
-
model = models[
|
|
|
|
|
|
|
| 15 |
samples = iterative_refinement(
|
| 16 |
model,
|
| 17 |
nb_iter=int(nb_iter),
|
|
@@ -19,20 +23,28 @@ def gen(md, model, seed, nb_iter, nb_samples, width, height):
|
|
| 19 |
w=int(width), h=int(height), c=1,
|
| 20 |
batch_size=bs,
|
| 21 |
)
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
grid = (grid*255).astype("uint8")
|
| 24 |
return Image.fromarray(grid)
|
| 25 |
|
| 26 |
text = """
|
| 27 |
-
Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`)
|
| 28 |
|
| 29 |
These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
|
|
|
|
|
|
|
| 30 |
"""
|
| 31 |
iface = gr.Interface(
|
| 32 |
fn=gen,
|
| 33 |
inputs=[
|
| 34 |
gr.Markdown(text),
|
| 35 |
-
gr.Dropdown(list(models.keys()), value="
|
| 36 |
],
|
| 37 |
outputs="image"
|
| 38 |
)
|
|
|
|
| 1 |
+
import math
|
| 2 |
import torch
|
| 3 |
import torchvision
|
| 4 |
import gradio as gr
|
|
|
|
| 6 |
from cli import iterative_refinement
|
| 7 |
from viz import grid_of_images_default
|
| 8 |
models = {
|
| 9 |
+
"ConvAE": torch.load("convae.th", map_location="cpu"),
|
| 10 |
+
"Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
|
| 11 |
+
"Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
|
| 12 |
}
|
| 13 |
+
def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg):
|
| 14 |
torch.manual_seed(int(seed))
|
| 15 |
bs = 64
|
| 16 |
+
model = models[model_name]
|
| 17 |
+
if model == "Dense K-Sparse":
|
| 18 |
+
model.nb_active = nb_active
|
| 19 |
samples = iterative_refinement(
|
| 20 |
model,
|
| 21 |
nb_iter=int(nb_iter),
|
|
|
|
| 23 |
w=int(width), h=int(height), c=1,
|
| 24 |
batch_size=bs,
|
| 25 |
)
|
| 26 |
+
if only_last:
|
| 27 |
+
s = int(math.sqrt((nb_samples)))
|
| 28 |
+
grid = grid_of_images_default(samples[-1].numpy(), shape=(s, s))
|
| 29 |
+
else:
|
| 30 |
+
grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1]))
|
| 31 |
+
if not black_bg:
|
| 32 |
+
grid = 1 - grid
|
| 33 |
grid = (grid*255).astype("uint8")
|
| 34 |
return Image.fromarray(grid)
|
| 35 |
|
| 36 |
text = """
|
| 37 |
+
Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`), Dense K-Sparse model (from [here](https://openreview.net/forum?id=r1QXQkSYg))
|
| 38 |
|
| 39 |
These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
|
| 40 |
+
|
| 41 |
+
NB: `nb_active` is only used for the Dense K-Sparse, specifying nb of activations to keep in the last layer.
|
| 42 |
"""
|
| 43 |
iface = gr.Interface(
|
| 44 |
fn=gen,
|
| 45 |
inputs=[
|
| 46 |
gr.Markdown(text),
|
| 47 |
+
gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background")
|
| 48 |
],
|
| 49 |
outputs="image"
|
| 50 |
)
|