KangLiao commited on
Commit
cc8a5f7
·
1 Parent(s): f18fdea
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. app.py +34 -11
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  import numpy as np
5
  import spaces # Import spaces for ZeroGPU compatibility
@@ -47,6 +48,26 @@ checkpoint_path = "checkpoints/Puffin-Base.pth"
47
  checkpoint = torch.load(checkpoint_path)
48
  info = model.load_state_dict(checkpoint, strict=False)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @torch.inference_mode()
52
  @spaces.GPU(duration=120)
@@ -88,23 +109,23 @@ def camera_understanding(image_src, question, seed, progress=gr.Progress(track_t
88
  single_batch["latitude_field"] = cam[2:].unsqueeze(0)
89
 
90
  figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
91
- imgs = []
92
  for k, fig in figs.items():
93
- fig.canvas.draw()
94
- img = np.array(fig.canvas.renderer.buffer_rgba())
95
- imgs.append(img)
 
96
  plt.close(fig)
97
- merged_imgs = np.concatenate(imgs, axis=1)
98
 
99
- return text, merged_imgs
100
 
101
 
102
  @torch.inference_mode()
103
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
104
  def generate_image(prompt_scene,
105
  seed=42,
106
- roll=3,
107
- pitch=1.0,
108
  fov=1.0,
109
  progress=gr.Progress(track_tqdm=True)):
110
  # Clear CUDA cache and avoid tracking gradients
@@ -126,6 +147,7 @@ def generate_image(prompt_scene,
126
  cam_map = cam_map / (math.pi / 2)
127
 
128
  prompt = prompt_scene + " " + prompt_camera
 
129
 
130
  bsz = 4
131
  with torch.no_grad():
@@ -167,7 +189,7 @@ with gr.Blocks(css=css) as demo:
167
  roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value")
168
  pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value")
169
  fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value")
170
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
171
 
172
  generation_button = gr.Button("Generate Images")
173
 
@@ -192,7 +214,8 @@ with gr.Blocks(css=css) as demo:
192
  understanding_button = gr.Button("Chat")
193
  understanding_output = gr.Textbox(label="Response")
194
 
195
- camera_output = gr.Gallery(label="Camera Maps", columns=1, rows=1)
 
196
 
197
  with gr.Accordion("Advanced options", open=False):
198
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
@@ -215,7 +238,7 @@ with gr.Blocks(css=css) as demo:
215
  understanding_button.click(
216
  camera_understanding,
217
  inputs=[image_input, und_seed_input],
218
- outputs=[understanding_output, camera_output]
219
  )
220
 
221
  demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ import io
4
  from PIL import Image
5
  import numpy as np
6
  import spaces # Import spaces for ZeroGPU compatibility
 
48
  checkpoint = torch.load(checkpoint_path)
49
  info = model.load_state_dict(checkpoint, strict=False)
50
 
51
+ def fig_to_image(fig):
52
+ buf = io.BytesIO()
53
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
54
+ buf.seek(0)
55
+ img = Image.open(buf).convert('RGB')
56
+ buf.close()
57
+ return img
58
+
59
+ def extract_up_lat_figs(fig_dict):
60
+ fig_up, fig_lat = None, None
61
+ others = {}
62
+ for k, fig in fig_dict.items():
63
+ if ("up_field" in k) and (fig_up is None):
64
+ fig_up = fig
65
+ elif ("latitude_field" in k) and (fig_lat is None):
66
+ fig_lat = fig
67
+ else:
68
+ others[k] = fig
69
+ return fig_up, fig_lat, others
70
+
71
 
72
  @torch.inference_mode()
73
  @spaces.GPU(duration=120)
 
109
  single_batch["latitude_field"] = cam[2:].unsqueeze(0)
110
 
111
  figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
112
+ up_img = lat_img = None
113
  for k, fig in figs.items():
114
+ if "up_field" in k:
115
+ up_img = fig_to_image(fig)
116
+ elif "latitude_field" in k:
117
+ lat_img = fig_to_image(fig)
118
  plt.close(fig)
 
119
 
120
+ return text, up_img, lat_img
121
 
122
 
123
  @torch.inference_mode()
124
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
125
  def generate_image(prompt_scene,
126
  seed=42,
127
+ roll=0.1,
128
+ pitch=0.1,
129
  fov=1.0,
130
  progress=gr.Progress(track_tqdm=True)):
131
  # Clear CUDA cache and avoid tracking gradients
 
147
  cam_map = cam_map / (math.pi / 2)
148
 
149
  prompt = prompt_scene + " " + prompt_camera
150
+ print("prompt:", prompt)
151
 
152
  bsz = 4
153
  with torch.no_grad():
 
189
  roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value")
190
  pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value")
191
  fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value")
192
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=42)
193
 
194
  generation_button = gr.Button("Generate Images")
195
 
 
214
  understanding_button = gr.Button("Chat")
215
  understanding_output = gr.Textbox(label="Response")
216
 
217
+ camera1 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
218
+ camera2 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
219
 
220
  with gr.Accordion("Advanced options", open=False):
221
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
 
238
  understanding_button.click(
239
  camera_understanding,
240
  inputs=[image_input, und_seed_input],
241
+ outputs=[understanding_output, camera1, camera2]
242
  )
243
 
244
  demo.launch(share=True)