Spaces:
Running
Running
Mathis Petrovich
commited on
Commit
·
5e4fa5e
1
Parent(s):
bdb661d
device
Browse files
app.py
CHANGED
|
@@ -56,7 +56,7 @@ EXAMPLES = [
|
|
| 56 |
"A person is taking the stairs",
|
| 57 |
"Someone is doing jumping jacks",
|
| 58 |
"The person walked forward and is picking up his toolbox",
|
| 59 |
-
"The person angrily punching the air"
|
| 60 |
]
|
| 61 |
|
| 62 |
# Show closest text in the training
|
|
@@ -94,6 +94,7 @@ CSS = """
|
|
| 94 |
|
| 95 |
DEFAULT_TEXT = "A person is "
|
| 96 |
|
|
|
|
| 97 |
def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
| 98 |
# Don't show the mirrored version of HumanMl3D
|
| 99 |
if "M" in keyid:
|
|
@@ -128,13 +129,15 @@ def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
|
| 128 |
"text": text,
|
| 129 |
"keyid": keyid,
|
| 130 |
"babel_id": babel_id,
|
| 131 |
-
"path": path
|
| 132 |
}
|
| 133 |
|
| 134 |
return data
|
| 135 |
|
| 136 |
|
| 137 |
-
def retrieve(
|
|
|
|
|
|
|
| 138 |
unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
|
| 139 |
keyids = np.concatenate([all_keyids[s] for s in splits])
|
| 140 |
|
|
@@ -169,7 +172,7 @@ def get_video_html(data, video_id, width=700, height=700):
|
|
| 169 |
path = data["path"]
|
| 170 |
|
| 171 |
trim = f"#t={start},{end}"
|
| 172 |
-
title = f
|
| 173 |
|
| 174 |
Corresponding text: {text}
|
| 175 |
|
|
@@ -177,18 +180,18 @@ HumanML3D keyid: {keyid}
|
|
| 177 |
|
| 178 |
BABEL keyid: {babel_id}
|
| 179 |
|
| 180 |
-
AMASS path: {path}
|
| 181 |
|
| 182 |
# class="wrap default svelte-gjihhp hide"
|
| 183 |
# <div class="contour_video" style="position: absolute; padding: 10px;">
|
| 184 |
# width="{width}" height="{height}"
|
| 185 |
-
video_html = f
|
| 186 |
<video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
| 187 |
autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
|
| 188 |
<source src="{url}{trim}" type="video/mp4">
|
| 189 |
Your browser does not support the video tag.
|
| 190 |
</video>
|
| 191 |
-
|
| 192 |
return video_html
|
| 193 |
|
| 194 |
|
|
@@ -208,16 +211,18 @@ def retrieve_component(retrieve_function, text, splits_choice, nvids, n_componen
|
|
| 208 |
htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
|
| 209 |
# get n_component exactly if asked less
|
| 210 |
# pad with dummy blocks
|
| 211 |
-
htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
|
| 212 |
return htmls
|
| 213 |
|
| 214 |
|
| 215 |
if not os.path.exists("data"):
|
| 216 |
-
gdown.download_folder(
|
| 217 |
-
|
|
|
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
-
device = torch.device(
|
| 221 |
|
| 222 |
# LOADING
|
| 223 |
model = load_model(device)
|
|
@@ -229,7 +234,9 @@ h3d_index = load_json("amass-annotations/humanml3d.json")
|
|
| 229 |
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
| 230 |
|
| 231 |
keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
|
| 232 |
-
retrieve_function = partial(
|
|
|
|
|
|
|
| 233 |
|
| 234 |
# DEMO
|
| 235 |
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
|
@@ -242,33 +249,48 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
|
|
| 242 |
with gr.Row():
|
| 243 |
with gr.Column(scale=3):
|
| 244 |
with gr.Column(scale=2):
|
| 245 |
-
text = gr.Textbox(
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
with gr.Column(scale=1):
|
| 248 |
-
btn = gr.Button("Retrieve", variant=
|
| 249 |
-
clear = gr.Button("Clear", variant=
|
| 250 |
|
| 251 |
with gr.Row():
|
| 252 |
with gr.Column(scale=1):
|
| 253 |
-
splits_choice = gr.Radio(
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
with gr.Column(scale=1):
|
| 258 |
# nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
|
| 259 |
-
nvideo_slider = gr.Radio(
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
with gr.Column(scale=2):
|
|
|
|
| 264 |
def retrieve_example(text, splits_choice, nvideo_slider):
|
| 265 |
return retrieve_and_show(text, splits_choice, nvideo_slider)
|
| 266 |
|
| 267 |
-
examples = gr.Examples(
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
i = -1
|
| 274 |
# should indent
|
|
@@ -294,16 +316,28 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
|
|
| 294 |
show_progress=False,
|
| 295 |
postprocess=False,
|
| 296 |
queue=False,
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
text.submit(
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
def clear_videos():
|
| 309 |
return [None for x in range(24)] + [DEFAULT_TEXT]
|
|
|
|
| 56 |
"A person is taking the stairs",
|
| 57 |
"Someone is doing jumping jacks",
|
| 58 |
"The person walked forward and is picking up his toolbox",
|
| 59 |
+
"The person angrily punching the air",
|
| 60 |
]
|
| 61 |
|
| 62 |
# Show closest text in the training
|
|
|
|
| 94 |
|
| 95 |
DEFAULT_TEXT = "A person is "
|
| 96 |
|
| 97 |
+
|
| 98 |
def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
| 99 |
# Don't show the mirrored version of HumanMl3D
|
| 100 |
if "M" in keyid:
|
|
|
|
| 129 |
"text": text,
|
| 130 |
"keyid": keyid,
|
| 131 |
"babel_id": babel_id,
|
| 132 |
+
"path": path,
|
| 133 |
}
|
| 134 |
|
| 135 |
return data
|
| 136 |
|
| 137 |
|
| 138 |
+
def retrieve(
|
| 139 |
+
model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8
|
| 140 |
+
):
|
| 141 |
unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
|
| 142 |
keyids = np.concatenate([all_keyids[s] for s in splits])
|
| 143 |
|
|
|
|
| 172 |
path = data["path"]
|
| 173 |
|
| 174 |
trim = f"#t={start},{end}"
|
| 175 |
+
title = f"""Score = {score}
|
| 176 |
|
| 177 |
Corresponding text: {text}
|
| 178 |
|
|
|
|
| 180 |
|
| 181 |
BABEL keyid: {babel_id}
|
| 182 |
|
| 183 |
+
AMASS path: {path}"""
|
| 184 |
|
| 185 |
# class="wrap default svelte-gjihhp hide"
|
| 186 |
# <div class="contour_video" style="position: absolute; padding: 10px;">
|
| 187 |
# width="{width}" height="{height}"
|
| 188 |
+
video_html = f"""
|
| 189 |
<video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
| 190 |
autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
|
| 191 |
<source src="{url}{trim}" type="video/mp4">
|
| 192 |
Your browser does not support the video tag.
|
| 193 |
</video>
|
| 194 |
+
"""
|
| 195 |
return video_html
|
| 196 |
|
| 197 |
|
|
|
|
| 211 |
htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
|
| 212 |
# get n_component exactly if asked less
|
| 213 |
# pad with dummy blocks
|
| 214 |
+
htmls = htmls + [None for _ in range(max(0, n_component - nvids))]
|
| 215 |
return htmls
|
| 216 |
|
| 217 |
|
| 218 |
if not os.path.exists("data"):
|
| 219 |
+
gdown.download_folder(
|
| 220 |
+
"https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
|
| 221 |
+
use_cookies=False,
|
| 222 |
+
)
|
| 223 |
|
| 224 |
|
| 225 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 226 |
|
| 227 |
# LOADING
|
| 228 |
model = load_model(device)
|
|
|
|
| 234 |
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
| 235 |
|
| 236 |
keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
|
| 237 |
+
retrieve_function = partial(
|
| 238 |
+
retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids
|
| 239 |
+
)
|
| 240 |
|
| 241 |
# DEMO
|
| 242 |
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
|
|
|
| 249 |
with gr.Row():
|
| 250 |
with gr.Column(scale=3):
|
| 251 |
with gr.Column(scale=2):
|
| 252 |
+
text = gr.Textbox(
|
| 253 |
+
placeholder="Type the motion you want to search with a sentence",
|
| 254 |
+
show_label=True,
|
| 255 |
+
label="Text prompt",
|
| 256 |
+
value=DEFAULT_TEXT,
|
| 257 |
+
)
|
| 258 |
with gr.Column(scale=1):
|
| 259 |
+
btn = gr.Button("Retrieve", variant="primary")
|
| 260 |
+
clear = gr.Button("Clear", variant="secondary")
|
| 261 |
|
| 262 |
with gr.Row():
|
| 263 |
with gr.Column(scale=1):
|
| 264 |
+
splits_choice = gr.Radio(
|
| 265 |
+
["All motions", "Unseen motions"],
|
| 266 |
+
label="Gallery of motion",
|
| 267 |
+
value="All motions",
|
| 268 |
+
info="The motion gallery is coming from HumanML3D",
|
| 269 |
+
)
|
| 270 |
|
| 271 |
with gr.Column(scale=1):
|
| 272 |
# nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
|
| 273 |
+
nvideo_slider = gr.Radio(
|
| 274 |
+
[4, 8, 12, 16, 24],
|
| 275 |
+
label="Videos",
|
| 276 |
+
value=8,
|
| 277 |
+
info="Number of videos to display",
|
| 278 |
+
)
|
| 279 |
|
| 280 |
with gr.Column(scale=2):
|
| 281 |
+
|
| 282 |
def retrieve_example(text, splits_choice, nvideo_slider):
|
| 283 |
return retrieve_and_show(text, splits_choice, nvideo_slider)
|
| 284 |
|
| 285 |
+
examples = gr.Examples(
|
| 286 |
+
examples=[[x, None, None] for x in EXAMPLES],
|
| 287 |
+
inputs=[text, splits_choice, nvideo_slider],
|
| 288 |
+
examples_per_page=20,
|
| 289 |
+
run_on_click=False,
|
| 290 |
+
cache_examples=False,
|
| 291 |
+
fn=retrieve_example,
|
| 292 |
+
outputs=[],
|
| 293 |
+
)
|
| 294 |
|
| 295 |
i = -1
|
| 296 |
# should indent
|
|
|
|
| 316 |
show_progress=False,
|
| 317 |
postprocess=False,
|
| 318 |
queue=False,
|
| 319 |
+
).then(fn=retrieve_example, inputs=examples.inputs, outputs=videos)
|
| 320 |
+
|
| 321 |
+
btn.click(
|
| 322 |
+
fn=retrieve_and_show,
|
| 323 |
+
inputs=[text, splits_choice, nvideo_slider],
|
| 324 |
+
outputs=videos,
|
| 325 |
+
)
|
| 326 |
+
text.submit(
|
| 327 |
+
fn=retrieve_and_show,
|
| 328 |
+
inputs=[text, splits_choice, nvideo_slider],
|
| 329 |
+
outputs=videos,
|
| 330 |
+
)
|
| 331 |
+
splits_choice.change(
|
| 332 |
+
fn=retrieve_and_show,
|
| 333 |
+
inputs=[text, splits_choice, nvideo_slider],
|
| 334 |
+
outputs=videos,
|
| 335 |
+
)
|
| 336 |
+
nvideo_slider.change(
|
| 337 |
+
fn=retrieve_and_show,
|
| 338 |
+
inputs=[text, splits_choice, nvideo_slider],
|
| 339 |
+
outputs=videos,
|
| 340 |
+
)
|
| 341 |
|
| 342 |
def clear_videos():
|
| 343 |
return [None for x in range(24)] + [DEFAULT_TEXT]
|
load.py
CHANGED
|
@@ -20,10 +20,7 @@ def load_keyids(split):
|
|
| 20 |
|
| 21 |
|
| 22 |
def load_keyids_splits(splits):
|
| 23 |
-
return {
|
| 24 |
-
split: load_keyids(split)
|
| 25 |
-
for split in splits
|
| 26 |
-
}
|
| 27 |
|
| 28 |
|
| 29 |
def load_unit_motion_embs(split, device):
|
|
@@ -33,16 +30,17 @@ def load_unit_motion_embs(split, device):
|
|
| 33 |
|
| 34 |
|
| 35 |
def load_unit_motion_embs_splits(splits, device):
|
| 36 |
-
return {
|
| 37 |
-
split: load_unit_motion_embs(split, device)
|
| 38 |
-
for split in splits
|
| 39 |
-
}
|
| 40 |
|
| 41 |
|
| 42 |
def load_model(device):
|
| 43 |
text_params = {
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
}
|
| 47 |
"unit_motion_embs"
|
| 48 |
model = TMR_textencoder(**text_params)
|
|
@@ -50,4 +48,4 @@ def load_model(device):
|
|
| 50 |
# load values for the transformer only
|
| 51 |
model.load_state_dict(state_dict, strict=False)
|
| 52 |
model = model.eval()
|
| 53 |
-
return model
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def load_keyids_splits(splits):
|
| 23 |
+
return {split: load_keyids(split) for split in splits}
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def load_unit_motion_embs(split, device):
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def load_unit_motion_embs_splits(splits, device):
|
| 33 |
+
return {split: load_unit_motion_embs(split, device) for split in splits}
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def load_model(device):
|
| 37 |
text_params = {
|
| 38 |
+
"latent_dim": 256,
|
| 39 |
+
"ff_size": 1024,
|
| 40 |
+
"num_layers": 6,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"activation": "gelu",
|
| 43 |
+
"modelpath": "distilbert-base-uncased",
|
| 44 |
}
|
| 45 |
"unit_motion_embs"
|
| 46 |
model = TMR_textencoder(**text_params)
|
|
|
|
| 48 |
# load values for the transformer only
|
| 49 |
model.load_state_dict(state_dict, strict=False)
|
| 50 |
model = model.eval()
|
| 51 |
+
return model.to(device)
|