Seokju Cho
commited on
Commit
·
6b9382c
1
Parent(s):
e11cc45
improve speed
Browse files- .gitattributes +2 -0
- app.py +6 -2
- locotrack_pytorch/models/cmdtop.py +4 -2
- locotrack_pytorch/models/locotrack_model.py +46 -81
- locotrack_pytorch/models/utils.py +1 -44
- requirements.txt +1 -1
- weights/locotrack_base.ckpt +3 -0
- weights/locotrack_small.ckpt +3 -0
.gitattributes
CHANGED
|
@@ -42,3 +42,5 @@ examples/libby.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
| 42 |
examples/motocross-jump.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
examples/bmx-trees.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
examples/parkour.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 42 |
examples/motocross-jump.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
examples/bmx-trees.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
examples/parkour.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
weights/locotrack_base.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
weights/locotrack_small.ckpt filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -19,6 +19,10 @@ PREVIEW_WIDTH = 768 # Width of the preview video
|
|
| 19 |
VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
|
| 20 |
POINT_SIZE = 4 # Size of the query point in the preview video
|
| 21 |
FRAME_LIMIT = 300 # Limit the number of frames to process
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
|
|
@@ -120,7 +124,7 @@ def extract_feature(video_input, model_size="small"):
|
|
| 120 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 121 |
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
| 122 |
|
| 123 |
-
model = load_model(model_size=model_size).to(device)
|
| 124 |
|
| 125 |
video_input = (video_input / 255.0) * 2 - 1
|
| 126 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
|
@@ -223,7 +227,7 @@ def track(
|
|
| 223 |
video_input = (video_input / 255.0) * 2 - 1
|
| 224 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
| 225 |
|
| 226 |
-
model = load_model(model_size=model_size).to(device)
|
| 227 |
with torch.autocast(device_type=device, dtype=dtype):
|
| 228 |
with torch.no_grad():
|
| 229 |
output = model(video_input, query_points_tensor, feature_grids=video_feature)
|
|
|
|
| 19 |
VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
|
| 20 |
POINT_SIZE = 4 # Size of the query point in the preview video
|
| 21 |
FRAME_LIMIT = 300 # Limit the number of frames to process
|
| 22 |
+
WEIGHTS_PATH = {
|
| 23 |
+
"small": "./weights/locotrack_small.ckpt",
|
| 24 |
+
"base": "./weights/locotrack_base.ckpt",
|
| 25 |
+
}
|
| 26 |
|
| 27 |
|
| 28 |
def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
|
|
|
|
| 124 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 125 |
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
| 126 |
|
| 127 |
+
model = load_model(WEIGHTS_PATH[model_size], model_size=model_size).to(device)
|
| 128 |
|
| 129 |
video_input = (video_input / 255.0) * 2 - 1
|
| 130 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
|
|
|
| 227 |
video_input = (video_input / 255.0) * 2 - 1
|
| 228 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
| 229 |
|
| 230 |
+
model = load_model(WEIGHTS_PATH[model_size], model_size=model_size).to(device)
|
| 231 |
with torch.autocast(device_type=device, dtype=dtype):
|
| 232 |
with torch.no_grad():
|
| 233 |
output = model(video_input, query_points_tensor, feature_grids=video_feature)
|
locotrack_pytorch/models/cmdtop.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
|
|
|
| 4 |
from models import utils
|
| 5 |
|
| 6 |
|
|
@@ -29,8 +31,8 @@ class CMDTop(nn.Module):
|
|
| 29 |
"""
|
| 30 |
x: (b, h, w, i, j)
|
| 31 |
"""
|
| 32 |
-
out1 =
|
| 33 |
-
out2 =
|
| 34 |
|
| 35 |
for i in range(len(self.out_channels)):
|
| 36 |
out1 = self.conv[i](out1)
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
from models import utils
|
| 7 |
|
| 8 |
|
|
|
|
| 31 |
"""
|
| 32 |
x: (b, h, w, i, j)
|
| 33 |
"""
|
| 34 |
+
out1 = rearrange(x, 'b h w i j -> b (i j) h w')
|
| 35 |
+
out2 = rearrange(x, 'b h w i j -> b (h w) i j')
|
| 36 |
|
| 37 |
for i in range(len(self.out_channels)):
|
| 38 |
out1 = self.conv[i](out1)
|
locotrack_pytorch/models/locotrack_model.py
CHANGED
|
@@ -22,6 +22,7 @@ import torch
|
|
| 22 |
from torch import nn
|
| 23 |
import torch.nn.functional as F
|
| 24 |
import numpy as np
|
|
|
|
| 25 |
|
| 26 |
from models import nets, utils
|
| 27 |
from models.cmdtop import CMDTop
|
|
@@ -57,15 +58,15 @@ def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
|
|
| 57 |
return torch.cat([x] + [four_feat], dim=-1)
|
| 58 |
|
| 59 |
|
| 60 |
-
def get_relative_positions(seq_len, reverse=False):
|
| 61 |
-
x = torch.arange(seq_len)[None, :]
|
| 62 |
-
y = torch.arange(seq_len)[:, None]
|
| 63 |
return torch.tril(x - y) if not reverse else torch.triu(y - x)
|
| 64 |
|
| 65 |
|
| 66 |
-
def get_alibi_slope(num_heads):
|
| 67 |
x = (24) ** (1 / num_heads)
|
| 68 |
-
return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], dtype=torch.float32).view(-1, 1, 1)
|
| 69 |
|
| 70 |
|
| 71 |
class MultiHeadAttention(nn.Module):
|
|
@@ -92,31 +93,22 @@ class MultiHeadAttention(nn.Module):
|
|
| 92 |
key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
|
| 93 |
value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
|
| 94 |
|
| 95 |
-
|
|
|
|
| 96 |
bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
| 97 |
-
bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length, reverse=True)
|
| 98 |
bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
| 99 |
-
attn_bias = torch.cat([bias_forward, bias_backward], dim=0)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
if mask is not None:
|
| 105 |
-
if mask.ndim != attn_logits.ndim:
|
| 106 |
-
raise ValueError(f"Mask dimensionality {mask.ndim} must match logits dimensionality {attn_logits.ndim}.")
|
| 107 |
-
attn_logits = torch.where(mask, attn_logits, torch.tensor(-1e30))
|
| 108 |
-
|
| 109 |
-
attn_weights = F.softmax(attn_logits, dim=-1) # [H, T', T]
|
| 110 |
-
|
| 111 |
-
attn = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
|
| 112 |
-
attn = attn.reshape(batch_size, sequence_length, -1) # [T', H*V]
|
| 113 |
|
| 114 |
return self.final_proj(attn) # [T', D']
|
| 115 |
|
| 116 |
def _linear_projection(self, x, head_size, proj_layer):
|
| 117 |
y = proj_layer(x)
|
| 118 |
-
|
| 119 |
-
return y.reshape((
|
| 120 |
|
| 121 |
|
| 122 |
class Transformer(nn.Module):
|
|
@@ -495,25 +487,25 @@ class LocoTrack(nn.Module):
|
|
| 495 |
ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
|
| 496 |
|
| 497 |
position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
|
| 498 |
-
position_support =
|
| 499 |
interp_supp = utils.map_coordinates_3d(
|
| 500 |
feature_grid[i], position_support
|
| 501 |
)
|
| 502 |
-
interp_supp =
|
| 503 |
|
| 504 |
position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
|
| 505 |
-
position_support_hires =
|
| 506 |
hires_interp_supp = utils.map_coordinates_3d(
|
| 507 |
hires_feats[i], position_support_hires
|
| 508 |
)
|
| 509 |
-
hires_interp_supp =
|
| 510 |
|
| 511 |
position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
|
| 512 |
-
position_support_highest =
|
| 513 |
highest_interp_supp = utils.map_coordinates_3d(
|
| 514 |
highest_feats[i], position_support_highest
|
| 515 |
)
|
| 516 |
-
highest_interp_supp =
|
| 517 |
|
| 518 |
interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
|
| 519 |
hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
|
|
@@ -559,7 +551,7 @@ class LocoTrack(nn.Module):
|
|
| 559 |
video.shape[2:4], self.initial_resolution
|
| 560 |
)
|
| 561 |
|
| 562 |
-
all_required_resolutions = [
|
| 563 |
all_required_resolutions.extend(refinement_resolutions)
|
| 564 |
|
| 565 |
feature_grid = []
|
|
@@ -715,30 +707,14 @@ class LocoTrack(nn.Module):
|
|
| 715 |
)
|
| 716 |
|
| 717 |
num_queries = query_features.lowres[0].shape[1]
|
| 718 |
-
if causal_context is None:
|
| 719 |
-
perm = torch.randperm(num_queries)
|
| 720 |
-
else:
|
| 721 |
-
perm = torch.arange(num_queries)
|
| 722 |
-
|
| 723 |
-
inv_perm = torch.zeros_like(perm)
|
| 724 |
-
inv_perm[perm] = torch.arange(num_queries)
|
| 725 |
|
| 726 |
for ch in range(0, num_queries, query_chunk_size):
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
chunk_hires = query_features.hires[0][:, perm_chunk]
|
| 730 |
-
|
| 731 |
-
cc_chunk = []
|
| 732 |
-
if causal_context is not None:
|
| 733 |
-
for d in range(len(causal_context)):
|
| 734 |
-
tmp_dict = {}
|
| 735 |
-
for k, v in causal_context[d].items():
|
| 736 |
-
tmp_dict[k] = v[:, perm_chunk]
|
| 737 |
-
cc_chunk.append(tmp_dict)
|
| 738 |
|
| 739 |
if query_points_in_video is not None:
|
| 740 |
infer_query_points = query_points_in_video[
|
| 741 |
-
:,
|
| 742 |
]
|
| 743 |
num_frames = feature_grids.lowres[0].shape[1]
|
| 744 |
infer_query_points = utils.convert_grid_coordinates(
|
|
@@ -765,14 +741,14 @@ class LocoTrack(nn.Module):
|
|
| 765 |
for i in range(num_iters):
|
| 766 |
feature_level = -1
|
| 767 |
queries = [
|
| 768 |
-
query_features.hires[feature_level][:,
|
| 769 |
-
query_features.lowres[feature_level][:,
|
| 770 |
-
query_features.highest[feature_level][:,
|
| 771 |
]
|
| 772 |
supports = [
|
| 773 |
-
query_features.hires_supp[feature_level][:,
|
| 774 |
-
query_features.lowres_supp[feature_level][:,
|
| 775 |
-
query_features.highest_supp[feature_level][:,
|
| 776 |
]
|
| 777 |
for _ in range(self.pyramid_level):
|
| 778 |
queries.append(queries[-1])
|
|
@@ -790,7 +766,7 @@ class LocoTrack(nn.Module):
|
|
| 790 |
padding=0,
|
| 791 |
)
|
| 792 |
)
|
| 793 |
-
|
| 794 |
refined = self.refine_pips(
|
| 795 |
queries,
|
| 796 |
supports,
|
|
@@ -803,7 +779,6 @@ class LocoTrack(nn.Module):
|
|
| 803 |
last_iter=mixer_feats,
|
| 804 |
mixer_iter=i,
|
| 805 |
resize_hw=feature_grids.resolutions[feature_level],
|
| 806 |
-
causal_context=cc,
|
| 807 |
get_causal_context=get_causal_context,
|
| 808 |
cost_volume=cost_volume
|
| 809 |
)
|
|
@@ -822,9 +797,9 @@ class LocoTrack(nn.Module):
|
|
| 822 |
points = []
|
| 823 |
expd = []
|
| 824 |
for i, _ in enumerate(occ_iters):
|
| 825 |
-
occlusion.append(torch.cat(occ_iters[i], dim=1)
|
| 826 |
-
points.append(torch.cat(pts_iters[i], dim=1)
|
| 827 |
-
expd.append(torch.cat(expd_iters[i], dim=1)
|
| 828 |
|
| 829 |
out = dict(
|
| 830 |
occlusion=occlusion,
|
|
@@ -874,11 +849,11 @@ class LocoTrack(nn.Module):
|
|
| 874 |
coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
| 875 |
neighborhood = utils.map_coordinates_2d(grid, coords2)
|
| 876 |
|
| 877 |
-
neighborhood =
|
| 878 |
patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
|
| 879 |
-
patches_input =
|
| 880 |
patches_emb = self.cmdtop[pyridx](patches_input)
|
| 881 |
-
patches =
|
| 882 |
|
| 883 |
corrs_pyr.append(patches)
|
| 884 |
corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
|
|
@@ -913,14 +888,10 @@ class LocoTrack(nn.Module):
|
|
| 913 |
mlp_input_list.append(rel_pos_emb_input)
|
| 914 |
mlp_input = torch.cat(mlp_input_list, axis=-1)
|
| 915 |
|
| 916 |
-
x =
|
| 917 |
-
|
| 918 |
-
if causal_context is not None:
|
| 919 |
-
for k, v in causal_context.items():
|
| 920 |
-
causal_context[k] = utils.einshape('bn...->(bn)...', v)
|
| 921 |
res = self.torch_pips_mixer(x)
|
| 922 |
|
| 923 |
-
res =
|
| 924 |
|
| 925 |
pos_update = utils.convert_grid_coordinates(
|
| 926 |
res[..., :2],
|
|
@@ -983,20 +954,18 @@ class LocoTrack(nn.Module):
|
|
| 983 |
shape = cost_volume.shape
|
| 984 |
batch_size, num_points = cost_volume.shape[1:3]
|
| 985 |
|
| 986 |
-
interp_cost =
|
| 987 |
interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
|
| 988 |
-
|
| 989 |
-
interp_cost = utils.einshape('(tbn)1hw->tbnhw', interp_cost, b=batch_size, n=num_points)
|
| 990 |
cost_volume_stack = torch.stack(
|
| 991 |
[
|
| 992 |
-
# jax.image.resize(cost_volume, cost_volume_hires.shape, method='bilinear'),
|
| 993 |
interp_cost,
|
| 994 |
cost_volume_hires,
|
| 995 |
], dim=-1
|
| 996 |
)
|
| 997 |
-
pos =
|
| 998 |
pos = self.cost_conv(pos)
|
| 999 |
-
pos =
|
| 1000 |
|
| 1001 |
pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
|
| 1002 |
softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
|
|
@@ -1012,14 +981,10 @@ class LocoTrack(nn.Module):
|
|
| 1012 |
], dim=-1
|
| 1013 |
)
|
| 1014 |
occlusion = self.occ_linear(occlusion)
|
| 1015 |
-
expected_dist =
|
| 1016 |
-
|
| 1017 |
-
)
|
| 1018 |
-
occlusion = utils.einshape(
|
| 1019 |
-
'tbn1->bnt', occlusion[..., 0:1]
|
| 1020 |
-
)
|
| 1021 |
|
| 1022 |
-
return points, occlusion, expected_dist,
|
| 1023 |
|
| 1024 |
def construct_initial_causal_state(self, num_points, num_resolutions=1):
|
| 1025 |
"""Construct initial causal state."""
|
|
|
|
| 22 |
from torch import nn
|
| 23 |
import torch.nn.functional as F
|
| 24 |
import numpy as np
|
| 25 |
+
from einops import rearrange
|
| 26 |
|
| 27 |
from models import nets, utils
|
| 28 |
from models.cmdtop import CMDTop
|
|
|
|
| 58 |
return torch.cat([x] + [four_feat], dim=-1)
|
| 59 |
|
| 60 |
|
| 61 |
+
def get_relative_positions(seq_len, reverse=False, device='cuda'):
|
| 62 |
+
x = torch.arange(seq_len, device=device)[None, :]
|
| 63 |
+
y = torch.arange(seq_len, device=device)[:, None]
|
| 64 |
return torch.tril(x - y) if not reverse else torch.triu(y - x)
|
| 65 |
|
| 66 |
|
| 67 |
+
def get_alibi_slope(num_heads, device='cuda'):
|
| 68 |
x = (24) ** (1 / num_heads)
|
| 69 |
+
return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], device=device, dtype=torch.float32).view(-1, 1, 1)
|
| 70 |
|
| 71 |
|
| 72 |
class MultiHeadAttention(nn.Module):
|
|
|
|
| 93 |
key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
|
| 94 |
value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
|
| 95 |
|
| 96 |
+
device = query.device
|
| 97 |
+
bias_forward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(sequence_length, device=device)
|
| 98 |
bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
| 99 |
+
bias_backward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(sequence_length, reverse=True, device=device)
|
| 100 |
bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
| 101 |
+
attn_bias = torch.cat([bias_forward, bias_backward], dim=0)
|
| 102 |
|
| 103 |
+
attn = F.scaled_dot_product_attention(query_heads, key_heads, value_heads, attn_mask=attn_bias, scale=1 / np.sqrt(self.key_size))
|
| 104 |
+
attn = attn.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
return self.final_proj(attn) # [T', D']
|
| 107 |
|
| 108 |
def _linear_projection(self, x, head_size, proj_layer):
|
| 109 |
y = proj_layer(x)
|
| 110 |
+
batch_size, sequence_length, _= x.shape
|
| 111 |
+
return y.reshape((batch_size, sequence_length, self.num_heads, head_size)).permute(0, 2, 1, 3)
|
| 112 |
|
| 113 |
|
| 114 |
class Transformer(nn.Module):
|
|
|
|
| 487 |
ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
|
| 488 |
|
| 489 |
position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
|
| 490 |
+
position_support = rearrange(position_support, 'b n s c -> b (n s) c')
|
| 491 |
interp_supp = utils.map_coordinates_3d(
|
| 492 |
feature_grid[i], position_support
|
| 493 |
)
|
| 494 |
+
interp_supp = rearrange(interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
|
| 495 |
|
| 496 |
position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
|
| 497 |
+
position_support_hires = rearrange(position_support_hires, 'b n s c -> b (n s) c')
|
| 498 |
hires_interp_supp = utils.map_coordinates_3d(
|
| 499 |
hires_feats[i], position_support_hires
|
| 500 |
)
|
| 501 |
+
hires_interp_supp = rearrange(hires_interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
|
| 502 |
|
| 503 |
position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
|
| 504 |
+
position_support_highest = rearrange(position_support_highest, 'b n s c -> b (n s) c')
|
| 505 |
highest_interp_supp = utils.map_coordinates_3d(
|
| 506 |
highest_feats[i], position_support_highest
|
| 507 |
)
|
| 508 |
+
highest_interp_supp = rearrange(highest_interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
|
| 509 |
|
| 510 |
interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
|
| 511 |
hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
|
|
|
|
| 551 |
video.shape[2:4], self.initial_resolution
|
| 552 |
)
|
| 553 |
|
| 554 |
+
all_required_resolutions = []
|
| 555 |
all_required_resolutions.extend(refinement_resolutions)
|
| 556 |
|
| 557 |
feature_grid = []
|
|
|
|
| 707 |
)
|
| 708 |
|
| 709 |
num_queries = query_features.lowres[0].shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
|
| 711 |
for ch in range(0, num_queries, query_chunk_size):
|
| 712 |
+
chunk = query_features.lowres[0][:, ch:ch + query_chunk_size]
|
| 713 |
+
chunk_hires = query_features.hires[0][:, ch:ch + query_chunk_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
if query_points_in_video is not None:
|
| 716 |
infer_query_points = query_points_in_video[
|
| 717 |
+
:, ch : ch + query_chunk_size
|
| 718 |
]
|
| 719 |
num_frames = feature_grids.lowres[0].shape[1]
|
| 720 |
infer_query_points = utils.convert_grid_coordinates(
|
|
|
|
| 741 |
for i in range(num_iters):
|
| 742 |
feature_level = -1
|
| 743 |
queries = [
|
| 744 |
+
query_features.hires[feature_level][:, ch:ch + query_chunk_size],
|
| 745 |
+
query_features.lowres[feature_level][:, ch:ch + query_chunk_size],
|
| 746 |
+
query_features.highest[feature_level][:, ch:ch + query_chunk_size],
|
| 747 |
]
|
| 748 |
supports = [
|
| 749 |
+
query_features.hires_supp[feature_level][:, ch:ch + query_chunk_size],
|
| 750 |
+
query_features.lowres_supp[feature_level][:, ch:ch + query_chunk_size],
|
| 751 |
+
query_features.highest_supp[feature_level][:, ch:ch + query_chunk_size],
|
| 752 |
]
|
| 753 |
for _ in range(self.pyramid_level):
|
| 754 |
queries.append(queries[-1])
|
|
|
|
| 766 |
padding=0,
|
| 767 |
)
|
| 768 |
)
|
| 769 |
+
|
| 770 |
refined = self.refine_pips(
|
| 771 |
queries,
|
| 772 |
supports,
|
|
|
|
| 779 |
last_iter=mixer_feats,
|
| 780 |
mixer_iter=i,
|
| 781 |
resize_hw=feature_grids.resolutions[feature_level],
|
|
|
|
| 782 |
get_causal_context=get_causal_context,
|
| 783 |
cost_volume=cost_volume
|
| 784 |
)
|
|
|
|
| 797 |
points = []
|
| 798 |
expd = []
|
| 799 |
for i, _ in enumerate(occ_iters):
|
| 800 |
+
occlusion.append(torch.cat(occ_iters[i], dim=1))
|
| 801 |
+
points.append(torch.cat(pts_iters[i], dim=1))
|
| 802 |
+
expd.append(torch.cat(expd_iters[i], dim=1))
|
| 803 |
|
| 804 |
out = dict(
|
| 805 |
occlusion=occlusion,
|
|
|
|
| 849 |
coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
| 850 |
neighborhood = utils.map_coordinates_2d(grid, coords2)
|
| 851 |
|
| 852 |
+
neighborhood = rearrange(neighborhood, 'b n t (h w) c -> b n t h w c', h=support_size, w=support_size)
|
| 853 |
patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
|
| 854 |
+
patches_input = rearrange(patches_input, 'b n t h w i j -> (b n t) h w i j')
|
| 855 |
patches_emb = self.cmdtop[pyridx](patches_input)
|
| 856 |
+
patches = rearrange(patches_emb, '(b n t) c -> b n t c', b=neighborhood.shape[0], n=neighborhood.shape[1])
|
| 857 |
|
| 858 |
corrs_pyr.append(patches)
|
| 859 |
corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
|
|
|
|
| 888 |
mlp_input_list.append(rel_pos_emb_input)
|
| 889 |
mlp_input = torch.cat(mlp_input_list, axis=-1)
|
| 890 |
|
| 891 |
+
x = rearrange(mlp_input, 'b n f c -> (b n) f c')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
res = self.torch_pips_mixer(x)
|
| 893 |
|
| 894 |
+
res = rearrange(res, '(b n) f c -> b n f c', b=mlp_input.shape[0])
|
| 895 |
|
| 896 |
pos_update = utils.convert_grid_coordinates(
|
| 897 |
res[..., :2],
|
|
|
|
| 954 |
shape = cost_volume.shape
|
| 955 |
batch_size, num_points = cost_volume.shape[1:3]
|
| 956 |
|
| 957 |
+
interp_cost = rearrange(cost_volume, 't b n h w -> (t b n) () h w')
|
| 958 |
interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
|
| 959 |
+
interp_cost = rearrange(interp_cost, '(t b n) () h w -> t b n h w', b=batch_size, n=num_points)
|
|
|
|
| 960 |
cost_volume_stack = torch.stack(
|
| 961 |
[
|
|
|
|
| 962 |
interp_cost,
|
| 963 |
cost_volume_hires,
|
| 964 |
], dim=-1
|
| 965 |
)
|
| 966 |
+
pos = rearrange(cost_volume_stack, 't b n h w c -> (t b n) c h w')
|
| 967 |
pos = self.cost_conv(pos)
|
| 968 |
+
pos = rearrange(pos, '(t b n) () h w -> b n t h w', b=batch_size, n=num_points)
|
| 969 |
|
| 970 |
pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
|
| 971 |
softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
|
|
|
|
| 981 |
], dim=-1
|
| 982 |
)
|
| 983 |
occlusion = self.occ_linear(occlusion)
|
| 984 |
+
expected_dist = rearrange(occlusion[..., 1:2], 't b n () -> b n t', t=shape[0])
|
| 985 |
+
occlusion = rearrange(occlusion[..., 0:1], 't b n () -> b n t', t=shape[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 986 |
|
| 987 |
+
return points, occlusion, expected_dist, rearrange(cost_volume, 't b n h w -> b n t h w')
|
| 988 |
|
| 989 |
def construct_initial_causal_state(self, num_points, num_resolutions=1):
|
| 990 |
"""Construct initial causal state."""
|
locotrack_pytorch/models/utils.py
CHANGED
|
@@ -16,8 +16,6 @@
|
|
| 16 |
"""Pytorch model utilities."""
|
| 17 |
import math
|
| 18 |
from typing import Any, Sequence, Union
|
| 19 |
-
from einshape.src import abstract_ops
|
| 20 |
-
from einshape.src import backend
|
| 21 |
import numpy as np
|
| 22 |
import torch
|
| 23 |
import torch.nn.functional as F
|
|
@@ -101,7 +99,7 @@ def map_coordinates_2d(
|
|
| 101 |
|
| 102 |
n, p, t, s, xy = coordinates.shape
|
| 103 |
y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
|
| 104 |
-
y = 2 * (y / h) - 1
|
| 105 |
y = torch.flip(y, dims=(-1,)).float()
|
| 106 |
|
| 107 |
out = F.grid_sample(
|
|
@@ -231,47 +229,6 @@ def convert_grid_coordinates(
|
|
| 231 |
return position_in_grid
|
| 232 |
|
| 233 |
|
| 234 |
-
class _JaxBackend(backend.Backend[torch.Tensor]):
|
| 235 |
-
"""Einshape implementation for PyTorch."""
|
| 236 |
-
|
| 237 |
-
# https://github.com/vacancy/einshape/blob/main/einshape/src/pytorch/pytorch_ops.py
|
| 238 |
-
|
| 239 |
-
def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor:
|
| 240 |
-
return x.reshape(op.shape)
|
| 241 |
-
|
| 242 |
-
def transpose(
|
| 243 |
-
self, x: torch.Tensor, op: abstract_ops.Transpose
|
| 244 |
-
) -> torch.Tensor:
|
| 245 |
-
return x.permute(op.perm)
|
| 246 |
-
|
| 247 |
-
def broadcast(
|
| 248 |
-
self, x: torch.Tensor, op: abstract_ops.Broadcast
|
| 249 |
-
) -> torch.Tensor:
|
| 250 |
-
shape = op.transform_shape(x.shape)
|
| 251 |
-
for axis_position in sorted(op.axis_sizes.keys()):
|
| 252 |
-
x = x.unsqueeze(axis_position)
|
| 253 |
-
return x.expand(shape)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def einshape(
|
| 257 |
-
equation: str, value: Union[torch.Tensor, Any], **index_sizes: int
|
| 258 |
-
) -> torch.Tensor:
|
| 259 |
-
"""Reshapes `value` according to the given Shape Equation.
|
| 260 |
-
|
| 261 |
-
Args:
|
| 262 |
-
equation: The Shape Equation specifying the index regrouping and reordering.
|
| 263 |
-
value: Input tensor, or tensor-like object.
|
| 264 |
-
**index_sizes: Sizes of indices, where they cannot be inferred from
|
| 265 |
-
`input_shape`.
|
| 266 |
-
|
| 267 |
-
Returns:
|
| 268 |
-
Tensor derived from `value` by reshaping as specified by `equation`.
|
| 269 |
-
"""
|
| 270 |
-
if not isinstance(value, torch.Tensor):
|
| 271 |
-
value = torch.tensor(value)
|
| 272 |
-
return _JaxBackend().exec(equation, value, value.shape, **index_sizes)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
def generate_default_resolutions(full_size, train_size, num_levels=None):
|
| 276 |
"""Generate a list of logarithmically-spaced resolutions.
|
| 277 |
|
|
|
|
| 16 |
"""Pytorch model utilities."""
|
| 17 |
import math
|
| 18 |
from typing import Any, Sequence, Union
|
|
|
|
|
|
|
| 19 |
import numpy as np
|
| 20 |
import torch
|
| 21 |
import torch.nn.functional as F
|
|
|
|
| 99 |
|
| 100 |
n, p, t, s, xy = coordinates.shape
|
| 101 |
y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
|
| 102 |
+
y = 2 * (y / torch.tensor([h, w], device=feats.device)) - 1
|
| 103 |
y = torch.flip(y, dims=(-1,)).float()
|
| 104 |
|
| 105 |
out = F.grid_sample(
|
|
|
|
| 229 |
return position_in_grid
|
| 230 |
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
def generate_default_resolutions(full_size, train_size, num_levels=None):
|
| 233 |
"""Generate a list of logarithmically-spaced resolutions.
|
| 234 |
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
gradio==4.40.0
|
| 3 |
mediapy==1.2.2
|
| 4 |
opencv-python==4.10.0.84
|
|
|
|
| 1 |
+
einops==0.8.0
|
| 2 |
gradio==4.40.0
|
| 3 |
mediapy==1.2.2
|
| 4 |
opencv-python==4.10.0.84
|
weights/locotrack_base.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a5adbaeb610d1f06adfbc7c9076b66f727d674c0fd1d668890201cf3339736c
|
| 3 |
+
size 46139570
|
weights/locotrack_small.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da023594e6d6c05ecad9644efc1467545481cfa899e20730bd9fdce778ffa5ac
|
| 3 |
+
size 33001026
|