Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import matplotlib.pyplot as plt | |
| import pyvista as pv | |
| import torch | |
| import requests | |
| import numpy as np | |
| import numpy.typing as npt | |
| from dcgan import DCGAN3D_G | |
| import pathlib | |
| import time | |
| pv.start_xvfb() | |
| class DummyWriteable(object): | |
| def __init__(self): | |
| self.html = None | |
| def write(self, html): | |
| self.html = html | |
| STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static' | |
| DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads") | |
| if not DOWNLOADS_PATH.is_dir(): | |
| DOWNLOADS_PATH.mkdir() | |
| def download_checkpoint(url: str, path: str) -> None: | |
| resp = requests.get(url) | |
| with open(path, 'wb') as f: | |
| f.write(resp.content) | |
| def load_model(path: str, | |
| image_size: int = 64, | |
| z_dim: int = 512, | |
| n_channels: int = 1, | |
| n_features: int = 32, | |
| ngpu: int = 1,) -> torch.nn.Module: | |
| netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu) | |
| netG.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) | |
| return netG | |
| def generate_image(netG: torch.nn.Module, | |
| z_dim: int = 512, | |
| latent_size: int = 3) -> npt.ArrayLike: | |
| z = torch.randn(1, z_dim, latent_size, latent_size, latent_size) | |
| with torch.no_grad(): | |
| X = netG(z) | |
| img = 1 - (X[0, 0].numpy() + 1) / 2 | |
| return img | |
| def create_uniform_mesh_marching_cubes(img: npt.ArrayLike): | |
| grid = pv.UniformGrid( | |
| dims=img.shape, | |
| spacing=(1, 1, 1), | |
| origin=(0, 0, 0), | |
| ) | |
| values = img.flatten() | |
| grid.point_data['my_array'] = values | |
| slices = grid.slice_orthogonal() | |
| mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points") | |
| dist = np.linalg.norm(mesh.points, axis=1) | |
| return slices, mesh, dist | |
| def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int): | |
| fig, ax = plt.subplots(1, 3, figsize=(18, 6)) | |
| ax[0].imshow(img[midpoint], cmap="gray", vmin=0, vmax=1) | |
| ax[1].imshow(img[:, midpoint], cmap="gray", vmin=0, vmax=1) | |
| ax[2].imshow(img[..., midpoint], cmap="gray", vmin=0, vmax=1) | |
| for a, title in zip(ax, ["Front", "Right", "Top"]): | |
| a.set_title(title, fontsize=18) | |
| for a in ax: | |
| a.set_axis_off() | |
| return fig | |
| def main(): | |
| st.title("Generating Porous Media with GANs") | |
| st.markdown( | |
| """ | |
| ### Author | |
| _[Lukas Mosser](https://scholar.google.com/citations?user=y0R9snMAAAAJ&hl=en&oi=ao) (2022)_ - :bird:[porestar](https://twitter.com/porestar) | |
| ## Description | |
| This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan) | |
| published in Physical Review E ([Mosser et. al 2017](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.96.043309)) | |
| The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights. | |
| ## Intent | |
| I hope this encourages others to create interactive demos of their research for knowledge sharing and validation. | |
| ## The Demo | |
| Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/) | |
| The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance. | |
| Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible. | |
| ### Interactive Model Parameters | |
| The GAN used here in this study is fully convolutional "_Look Ma' no MLP's_": Changing the spatial extent of the latent space vector _z_ | |
| allows one to generate larger synthetic images. | |
| """ | |
| , unsafe_allow_html=True) | |
| view_width = 400 | |
| view_height = 400 | |
| model_fname = "berea_generator_epoch_24.pth" | |
| checkpoint_url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/{0:}?raw=true".format(model_fname) | |
| if not (DOWNLOADS_PATH / model_fname).exists(): | |
| download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname)) | |
| netG = load_model((DOWNLOADS_PATH / model_fname)) | |
| latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1) | |
| img = generate_image(netG, latent_size=latent_size) | |
| slices, mesh, dist = create_uniform_mesh_marching_cubes(img) | |
| pv.set_plot_theme("document") | |
| pl1 = pv.Plotter(shape=(1, 1), | |
| window_size=(view_width, view_height)) | |
| _ = pl1.add_mesh(slices, cmap="gray") | |
| slices_html = DummyWriteable() | |
| try: | |
| pl1.export_html(slices_html) | |
| except RuntimeError as e: | |
| print(e) | |
| pl2 = pv.Plotter(shape=(1, 1), | |
| window_size=(view_width, view_height)) | |
| _ = pl2.add_mesh(mesh, scalars=dist) | |
| mesh_html = DummyWriteable() | |
| try: | |
| pl2.export_html(mesh_html) | |
| except RuntimeError as e: | |
| print(e) | |
| st.header("2D Cross-Section of Generated Volume") | |
| fig = create_matplotlib_figure(img, img.shape[0]//2) | |
| st.pyplot(fig=fig) | |
| st.header("3D Intersections") | |
| components.html(slices_html.html, width=view_width, height=view_height) | |
| st.markdown("_Click and drag to spin, right click to shift._") | |
| st.header("3D Pore Space Mesh") | |
| components.html(mesh_html.html, width=view_width, height=view_height) | |
| st.markdown("_Click and drag to spin, right click to shift._") | |
| st.markdown(""" | |
| ## Citation | |
| If you use our code for your own research, we would be grateful if you cite our publication: | |
| ``` | |
| @article{pmgan2017, | |
| title={Reconstruction of three-dimensional porous media using generative adversarial neural networks}, | |
| author={Mosser, Lukas and Dubrule, Olivier and Blunt, Martin J.}, | |
| journal={arXiv preprint arXiv:1704.03225}, | |
| year={2017} | |
| }``` | |
| """) | |
| if __name__ == "__main__": | |
| main() | |