Spaces:
Runtime error
Runtime error
lmoss
commited on
Commit
·
fb968f9
1
Parent(s):
42ff5a0
added a dummy to not have to write pythreejs html to disk
Browse files
app.py
CHANGED
|
@@ -7,16 +7,25 @@ import requests
|
|
| 7 |
import numpy as np
|
| 8 |
import numpy.typing as npt
|
| 9 |
from dcgan import DCGAN3D_G
|
| 10 |
-
import os
|
| 11 |
import pathlib
|
| 12 |
pv.start_xvfb()
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static'
|
| 15 |
|
| 16 |
DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads")
|
| 17 |
if not DOWNLOADS_PATH.is_dir():
|
| 18 |
DOWNLOADS_PATH.mkdir()
|
| 19 |
|
|
|
|
| 20 |
def download_checkpoint(url: str, path: str) -> None:
|
| 21 |
resp = requests.get(url)
|
| 22 |
|
|
@@ -24,15 +33,22 @@ def download_checkpoint(url: str, path: str) -> None:
|
|
| 24 |
f.write(resp.content)
|
| 25 |
|
| 26 |
|
| 27 |
-
|
|
|
|
| 28 |
image_size: int = 64,
|
| 29 |
z_dim: int = 512,
|
| 30 |
n_channels: int = 1,
|
| 31 |
n_features: int = 32,
|
| 32 |
-
ngpu: int = 1,
|
| 33 |
-
latent_size: int = 3) -> npt.ArrayLike:
|
| 34 |
netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
|
| 35 |
netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
|
| 37 |
with torch.no_grad():
|
| 38 |
X = netG(z)
|
|
@@ -68,6 +84,7 @@ def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int):
|
|
| 68 |
a.set_axis_off()
|
| 69 |
return fig
|
| 70 |
|
|
|
|
| 71 |
def main():
|
| 72 |
st.title("Generating Porous Media with GANs")
|
| 73 |
|
|
@@ -106,34 +123,37 @@ def main():
|
|
| 106 |
|
| 107 |
if not (DOWNLOADS_PATH / model_fname).exists():
|
| 108 |
download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname))
|
|
|
|
| 109 |
|
| 110 |
latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
|
| 111 |
-
img = generate_image(
|
| 112 |
slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
|
| 113 |
|
| 114 |
pv.set_plot_theme("document")
|
| 115 |
pl = pv.Plotter(shape=(1, 1),
|
| 116 |
window_size=(view_width, view_height))
|
| 117 |
_ = pl.add_mesh(slices, cmap="gray")
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
|
| 120 |
pl = pv.Plotter(shape=(1, 1),
|
| 121 |
window_size=(view_width, view_height))
|
| 122 |
_ = pl.add_mesh(mesh, scalars=dist)
|
| 123 |
-
|
|
|
|
| 124 |
|
| 125 |
st.header("2D Cross-Section of Generated Volume")
|
| 126 |
fig = create_matplotlib_figure(img, img.shape[0]//2)
|
| 127 |
st.pyplot(fig=fig)
|
| 128 |
|
| 129 |
-
|
| 130 |
-
source_code = HtmlFile.read()
|
| 131 |
st.header("3D Intersections")
|
| 132 |
components.html(source_code, width=view_width, height=view_height)
|
| 133 |
st.markdown("_Click and drag to spin, right click to shift._")
|
| 134 |
|
| 135 |
-
|
| 136 |
-
source_code =
|
| 137 |
st.header("3D Pore Space Mesh")
|
| 138 |
components.html(source_code, width=view_width, height=view_height)
|
| 139 |
st.markdown("_Click and drag to spin, right click to shift._")
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import numpy.typing as npt
|
| 9 |
from dcgan import DCGAN3D_G
|
|
|
|
| 10 |
import pathlib
|
| 11 |
pv.start_xvfb()
|
| 12 |
|
| 13 |
+
|
| 14 |
+
class DummyWriteable(object):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.html = None
|
| 17 |
+
|
| 18 |
+
def write(self, html):
|
| 19 |
+
self.html = html
|
| 20 |
+
|
| 21 |
+
|
| 22 |
STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static'
|
| 23 |
|
| 24 |
DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads")
|
| 25 |
if not DOWNLOADS_PATH.is_dir():
|
| 26 |
DOWNLOADS_PATH.mkdir()
|
| 27 |
|
| 28 |
+
|
| 29 |
def download_checkpoint(url: str, path: str) -> None:
|
| 30 |
resp = requests.get(url)
|
| 31 |
|
|
|
|
| 33 |
f.write(resp.content)
|
| 34 |
|
| 35 |
|
| 36 |
+
@st.cache(persist=True, allow_output_mutation=True)
|
| 37 |
+
def load_model(path: str,
|
| 38 |
image_size: int = 64,
|
| 39 |
z_dim: int = 512,
|
| 40 |
n_channels: int = 1,
|
| 41 |
n_features: int = 32,
|
| 42 |
+
ngpu: int = 1,) -> torch.nn.Module:
|
|
|
|
| 43 |
netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
|
| 44 |
netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
| 45 |
+
return netG
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@st.cache()
|
| 49 |
+
def generate_image(netG: torch.nn.Module,
|
| 50 |
+
z_dim: int = 512,
|
| 51 |
+
latent_size: int = 3) -> npt.ArrayLike:
|
| 52 |
z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
|
| 53 |
with torch.no_grad():
|
| 54 |
X = netG(z)
|
|
|
|
| 84 |
a.set_axis_off()
|
| 85 |
return fig
|
| 86 |
|
| 87 |
+
|
| 88 |
def main():
|
| 89 |
st.title("Generating Porous Media with GANs")
|
| 90 |
|
|
|
|
| 123 |
|
| 124 |
if not (DOWNLOADS_PATH / model_fname).exists():
|
| 125 |
download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname))
|
| 126 |
+
netG = load_model((DOWNLOADS_PATH / model_fname))
|
| 127 |
|
| 128 |
latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
|
| 129 |
+
img = generate_image(netG, latent_size=latent_size)
|
| 130 |
slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
|
| 131 |
|
| 132 |
pv.set_plot_theme("document")
|
| 133 |
pl = pv.Plotter(shape=(1, 1),
|
| 134 |
window_size=(view_width, view_height))
|
| 135 |
_ = pl.add_mesh(slices, cmap="gray")
|
| 136 |
+
|
| 137 |
+
slices_html = DummyWriteable()
|
| 138 |
+
pl.export_html(slices_html)
|
| 139 |
|
| 140 |
pl = pv.Plotter(shape=(1, 1),
|
| 141 |
window_size=(view_width, view_height))
|
| 142 |
_ = pl.add_mesh(mesh, scalars=dist)
|
| 143 |
+
mesh_html = DummyWriteable()
|
| 144 |
+
pl.export_html(mesh_html)
|
| 145 |
|
| 146 |
st.header("2D Cross-Section of Generated Volume")
|
| 147 |
fig = create_matplotlib_figure(img, img.shape[0]//2)
|
| 148 |
st.pyplot(fig=fig)
|
| 149 |
|
| 150 |
+
source_code = slices_html.html
|
|
|
|
| 151 |
st.header("3D Intersections")
|
| 152 |
components.html(source_code, width=view_width, height=view_height)
|
| 153 |
st.markdown("_Click and drag to spin, right click to shift._")
|
| 154 |
|
| 155 |
+
|
| 156 |
+
source_code = mesh_html.html
|
| 157 |
st.header("3D Pore Space Mesh")
|
| 158 |
components.html(source_code, width=view_width, height=view_height)
|
| 159 |
st.markdown("_Click and drag to spin, right click to shift._")
|