Spaces:
Build error
Build error
Jiatao Gu
commited on
Commit
·
368dc9b
1
Parent(s):
22790a0
fix some errors. update code
Browse files- .gitignore +3 -0
- app.py +31 -35
- gradio_queue.db +0 -0
.gitignore
CHANGED
|
@@ -23,3 +23,6 @@ scripts/research/
|
|
| 23 |
.ipynb_checkpoints/
|
| 24 |
_screenshots/
|
| 25 |
flagged
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
.ipynb_checkpoints/
|
| 24 |
_screenshots/
|
| 25 |
flagged
|
| 26 |
+
|
| 27 |
+
*.db
|
| 28 |
+
gradio_queue.db
|
app.py
CHANGED
|
@@ -20,13 +20,20 @@ from huggingface_hub import hf_hub_download
|
|
| 20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def set_random_seed(seed):
|
| 25 |
torch.manual_seed(seed)
|
| 26 |
np.random.seed(seed)
|
| 27 |
|
| 28 |
|
| 29 |
-
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=
|
| 30 |
gen = model.synthesis
|
| 31 |
range_u, range_v = gen.C.range_u, gen.C.range_v
|
| 32 |
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
|
|
@@ -41,22 +48,10 @@ def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512
|
|
| 41 |
return cam
|
| 42 |
|
| 43 |
|
| 44 |
-
def check_name(model_name
|
| 45 |
"""Gets model by name."""
|
| 46 |
-
if model_name
|
| 47 |
-
network_pkl = hf_hub_download(
|
| 48 |
-
|
| 49 |
-
# TODO: checkpoint to be updated!
|
| 50 |
-
# elif model_name == 'FFHQ512v2':
|
| 51 |
-
# network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
|
| 52 |
-
# elif model_name == 'AFHQ512':
|
| 53 |
-
# network_pkl = "./pretrained/afhq_512.pkl"
|
| 54 |
-
# elif model_name == 'MetFaces512':
|
| 55 |
-
# network_pkl = "./pretrained/metfaces_512.pkl"
|
| 56 |
-
# elif model_name == 'CompCars256':
|
| 57 |
-
# network_pkl = "./pretrained/cars_256.pkl"
|
| 58 |
-
# elif model_name == 'FFHQ1024':
|
| 59 |
-
# network_pkl = "./pretrained/ffhq_1024.pkl"
|
| 60 |
else:
|
| 61 |
if os.path.isdir(model_name):
|
| 62 |
network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
|
|
@@ -85,7 +80,7 @@ def get_model(network_pkl, render_option=None):
|
|
| 85 |
return G2, res, imgs
|
| 86 |
|
| 87 |
|
| 88 |
-
global_states = list(get_model(check_name()))
|
| 89 |
wss = [None, None]
|
| 90 |
|
| 91 |
def proc_seed(history, seed):
|
|
@@ -98,7 +93,8 @@ def proc_seed(history, seed):
|
|
| 98 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
| 99 |
history = history or {}
|
| 100 |
seeds = []
|
| 101 |
-
|
|
|
|
| 102 |
if model_find != "":
|
| 103 |
model_name = model_find
|
| 104 |
|
|
@@ -124,7 +120,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
| 124 |
set_random_seed(seed)
|
| 125 |
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
|
| 126 |
ws = model.mapping(z=z, c=None, truncation_psi=trunc)
|
| 127 |
-
img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0), render_option=render_option)
|
| 128 |
ws = ws.detach().cpu().numpy()
|
| 129 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
| 130 |
|
|
@@ -178,26 +174,26 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
| 178 |
image = (image * 255).astype('uint8')
|
| 179 |
return image, history
|
| 180 |
|
| 181 |
-
model_name = gr.inputs.Dropdown(
|
| 182 |
-
model_find = gr.inputs.Textbox(label="
|
| 183 |
-
render_option = gr.inputs.Textbox(label="rendering options", default='steps:
|
| 184 |
-
trunc = gr.inputs.Slider(default=
|
| 185 |
-
seed1 = gr.inputs.Number(default=1, label="seed1")
|
| 186 |
-
seed2 = gr.inputs.Number(default=9, label="seed2")
|
| 187 |
-
mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="
|
| 188 |
-
mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="
|
| 189 |
-
early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='
|
| 190 |
-
yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="
|
| 191 |
-
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="
|
| 192 |
-
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="
|
| 193 |
-
fov = gr.inputs.Slider(minimum=
|
| 194 |
css = ".output-image, .input-image, .image-preview {height: 600px !important} "
|
| 195 |
|
| 196 |
gr.Interface(fn=f_synthesis,
|
| 197 |
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
|
| 198 |
-
title="
|
| 199 |
-
description="
|
| 200 |
outputs=["image", "state"],
|
| 201 |
layout='unaligned',
|
| 202 |
-
css=css, theme='dark-
|
| 203 |
live=True).launch(enable_queue=True)
|
|
|
|
| 20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
|
| 22 |
|
| 23 |
+
model_lists = {
|
| 24 |
+
'ffhq-512x512-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl'),
|
| 25 |
+
'ffhq-256x256-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_256.pkl'),
|
| 26 |
+
'ffhq-1024x1024-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_1024.pkl'),
|
| 27 |
+
}
|
| 28 |
+
model_names = [name for name in model_lists]
|
| 29 |
+
|
| 30 |
|
| 31 |
def set_random_seed(seed):
|
| 32 |
torch.manual_seed(seed)
|
| 33 |
np.random.seed(seed)
|
| 34 |
|
| 35 |
|
| 36 |
+
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=None):
|
| 37 |
gen = model.synthesis
|
| 38 |
range_u, range_v = gen.C.range_u, gen.C.range_v
|
| 39 |
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
|
|
|
|
| 48 |
return cam
|
| 49 |
|
| 50 |
|
| 51 |
+
def check_name(model_name):
|
| 52 |
"""Gets model by name."""
|
| 53 |
+
if model_name in model_lists:
|
| 54 |
+
network_pkl = hf_hub_download(**model_lists[model_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
else:
|
| 56 |
if os.path.isdir(model_name):
|
| 57 |
network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
|
|
|
|
| 80 |
return G2, res, imgs
|
| 81 |
|
| 82 |
|
| 83 |
+
global_states = list(get_model(check_name(model_names[0])))
|
| 84 |
wss = [None, None]
|
| 85 |
|
| 86 |
def proc_seed(history, seed):
|
|
|
|
| 93 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
| 94 |
history = history or {}
|
| 95 |
seeds = []
|
| 96 |
+
trunc = trunc / 100
|
| 97 |
+
|
| 98 |
if model_find != "":
|
| 99 |
model_name = model_find
|
| 100 |
|
|
|
|
| 120 |
set_random_seed(seed)
|
| 121 |
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
|
| 122 |
ws = model.mapping(z=z, c=None, truncation_psi=trunc)
|
| 123 |
+
img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0, model_name=model_name), render_option=render_option)
|
| 124 |
ws = ws.detach().cpu().numpy()
|
| 125 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
| 126 |
|
|
|
|
| 174 |
image = (image * 255).astype('uint8')
|
| 175 |
return image, history
|
| 176 |
|
| 177 |
+
model_name = gr.inputs.Dropdown(model_names)
|
| 178 |
+
model_find = gr.inputs.Textbox(label="Checkpoint path (folder or .pkl file)", default="")
|
| 179 |
+
render_option = gr.inputs.Textbox(label="Additional rendering options", default='freeze_bg,steps:50')
|
| 180 |
+
trunc = gr.inputs.Slider(default=70, maximum=100, minimum=0, label='Truncation trick (%)')
|
| 181 |
+
seed1 = gr.inputs.Number(default=1, label="Random seed1")
|
| 182 |
+
seed2 = gr.inputs.Number(default=9, label="Random seed2")
|
| 183 |
+
mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (geometry)")
|
| 184 |
+
mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (apparence)")
|
| 185 |
+
early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='Intermedia output')
|
| 186 |
+
yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Yaw")
|
| 187 |
+
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Pitch")
|
| 188 |
+
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Roll (optional, not suggested for basic config)")
|
| 189 |
+
fov = gr.inputs.Slider(minimum=10, maximum=14, default=12, label="Fov")
|
| 190 |
css = ".output-image, .input-image, .image-preview {height: 600px !important} "
|
| 191 |
|
| 192 |
gr.Interface(fn=f_synthesis,
|
| 193 |
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
|
| 194 |
+
title="Interactive Web Demo for StyleNeRF (ICLR 2022)",
|
| 195 |
+
description="StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.",
|
| 196 |
outputs=["image", "state"],
|
| 197 |
layout='unaligned',
|
| 198 |
+
css=css, theme='dark-seafoam',
|
| 199 |
live=True).launch(enable_queue=True)
|
gradio_queue.db
CHANGED
|
Binary files a/gradio_queue.db and b/gradio_queue.db differ
|
|
|