|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from itertools import chain |
|
|
from typing import Callable, Dict, List, Tuple |
|
|
import einops |
|
|
import torch |
|
|
|
|
|
|
|
|
def flatten( |
|
|
hid: List[torch.FloatTensor], |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.LongTensor, |
|
|
]: |
|
|
assert len(hid) > 0 |
|
|
shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) |
|
|
hid = torch.cat([x.flatten(0, -2) for x in hid]) |
|
|
return hid, shape |
|
|
|
|
|
|
|
|
def unflatten( |
|
|
hid: torch.FloatTensor, |
|
|
hid_shape: torch.LongTensor, |
|
|
) -> List[torch.Tensor]: |
|
|
hid_len = hid_shape.prod(-1) |
|
|
hid = hid.split(hid_len.tolist()) |
|
|
hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] |
|
|
return hid |
|
|
|
|
|
|
|
|
def concat( |
|
|
vid: torch.FloatTensor, |
|
|
txt: torch.FloatTensor, |
|
|
vid_len: torch.LongTensor, |
|
|
txt_len: torch.LongTensor, |
|
|
) -> torch.FloatTensor: |
|
|
vid = torch.split(vid, vid_len.tolist()) |
|
|
txt = torch.split(txt, txt_len.tolist()) |
|
|
return torch.cat(list(chain(*zip(vid, txt)))) |
|
|
|
|
|
|
|
|
def concat_idx( |
|
|
vid_len: torch.LongTensor, |
|
|
txt_len: torch.LongTensor, |
|
|
) -> Tuple[ |
|
|
Callable, |
|
|
Callable, |
|
|
]: |
|
|
device = vid_len.device |
|
|
vid_idx = torch.arange(vid_len.sum(), device=device) |
|
|
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) |
|
|
tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) |
|
|
src_idx = torch.argsort(tgt_idx) |
|
|
return ( |
|
|
lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), |
|
|
lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), |
|
|
) |
|
|
|
|
|
|
|
|
def unconcat( |
|
|
all: torch.FloatTensor, |
|
|
vid_len: torch.LongTensor, |
|
|
txt_len: torch.LongTensor, |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.FloatTensor, |
|
|
]: |
|
|
interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) |
|
|
all = all.split(interleave_len) |
|
|
vid = torch.cat(all[0::2]) |
|
|
txt = torch.cat(all[1::2]) |
|
|
return vid, txt |
|
|
|
|
|
|
|
|
def repeat_concat( |
|
|
vid: torch.FloatTensor, |
|
|
txt: torch.FloatTensor, |
|
|
vid_len: torch.LongTensor, |
|
|
txt_len: torch.LongTensor, |
|
|
txt_repeat: List, |
|
|
) -> torch.FloatTensor: |
|
|
vid = torch.split(vid, vid_len.tolist()) |
|
|
txt = torch.split(txt, txt_len.tolist()) |
|
|
txt = [[x] * n for x, n in zip(txt, txt_repeat)] |
|
|
txt = list(chain(*txt)) |
|
|
return torch.cat(list(chain(*zip(vid, txt)))) |
|
|
|
|
|
|
|
|
def repeat_concat_idx( |
|
|
vid_len: torch.LongTensor, |
|
|
txt_len: torch.LongTensor, |
|
|
txt_repeat: torch.LongTensor, |
|
|
) -> Tuple[ |
|
|
Callable, |
|
|
Callable, |
|
|
]: |
|
|
device = vid_len.device |
|
|
vid_idx = torch.arange(vid_len.sum(), device=device) |
|
|
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) |
|
|
txt_repeat_list = txt_repeat.tolist() |
|
|
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) |
|
|
src_idx = torch.argsort(tgt_idx) |
|
|
txt_idx_len = len(tgt_idx) - len(vid_idx) |
|
|
repeat_txt_len = (txt_len * txt_repeat).tolist() |
|
|
|
|
|
def unconcat_coalesce(all): |
|
|
""" |
|
|
Un-concat vid & txt, and coalesce the repeated txt. |
|
|
e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] |
|
|
txt [9 10] |
|
|
repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] |
|
|
1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] |
|
|
split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] |
|
|
2. reshape & mean for each sample to coalesce the repeated txt. |
|
|
""" |
|
|
vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) |
|
|
txt_out_coalesced = [] |
|
|
for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): |
|
|
txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) |
|
|
txt_out_coalesced.append(txt) |
|
|
return vid_out, torch.cat(txt_out_coalesced) |
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
lambda vid, txt: torch.cat([vid, txt])[tgt_idx], |
|
|
lambda all: unconcat_coalesce(all), |
|
|
) |
|
|
|
|
|
|
|
|
def rearrange( |
|
|
hid: torch.FloatTensor, |
|
|
hid_shape: torch.LongTensor, |
|
|
pattern: str, |
|
|
**kwargs: Dict[str, int], |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.LongTensor, |
|
|
]: |
|
|
return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) |
|
|
|
|
|
|
|
|
def rearrange_idx( |
|
|
hid_shape: torch.LongTensor, |
|
|
pattern: str, |
|
|
**kwargs: Dict[str, int], |
|
|
) -> Tuple[Callable, Callable, torch.LongTensor]: |
|
|
hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) |
|
|
tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) |
|
|
tgt_idx = tgt_idx.squeeze(-1) |
|
|
src_idx = torch.argsort(tgt_idx) |
|
|
return ( |
|
|
lambda hid: torch.index_select(hid, 0, tgt_idx), |
|
|
lambda hid: torch.index_select(hid, 0, src_idx), |
|
|
tgt_shape, |
|
|
) |
|
|
|
|
|
|
|
|
def repeat( |
|
|
hid: torch.FloatTensor, |
|
|
hid_shape: torch.LongTensor, |
|
|
pattern: str, |
|
|
**kwargs: Dict[str, torch.LongTensor], |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.LongTensor, |
|
|
]: |
|
|
hid = unflatten(hid, hid_shape) |
|
|
kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] |
|
|
return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) |
|
|
|
|
|
|
|
|
def pack( |
|
|
samples: List[torch.Tensor], |
|
|
) -> Tuple[ |
|
|
List[torch.Tensor], |
|
|
List[List[int]], |
|
|
]: |
|
|
batches = {} |
|
|
indices = {} |
|
|
for i, sample in enumerate(samples): |
|
|
shape = sample.shape |
|
|
batches[shape] = batches.get(shape, []) |
|
|
indices[shape] = indices.get(shape, []) |
|
|
batches[shape].append(sample) |
|
|
indices[shape].append(i) |
|
|
|
|
|
batches = list(map(torch.stack, batches.values())) |
|
|
indices = list(indices.values()) |
|
|
return batches, indices |
|
|
|
|
|
|
|
|
def unpack( |
|
|
batches: List[torch.Tensor], |
|
|
indices: List[List[int]], |
|
|
) -> List[torch.Tensor]: |
|
|
samples = [None] * (max(chain(*indices)) + 1) |
|
|
for batch, index in zip(batches, indices): |
|
|
for sample, i in zip(batch.unbind(), index): |
|
|
samples[i] = sample |
|
|
return samples |
|
|
|
|
|
|
|
|
def window( |
|
|
hid: torch.FloatTensor, |
|
|
hid_shape: torch.LongTensor, |
|
|
window_fn: Callable[[torch.Tensor], List[torch.Tensor]], |
|
|
): |
|
|
hid = unflatten(hid, hid_shape) |
|
|
hid = list(map(window_fn, hid)) |
|
|
hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) |
|
|
hid, hid_shape = flatten(list(chain(*hid))) |
|
|
return hid, hid_shape, hid_windows |
|
|
|
|
|
|
|
|
def window_idx( |
|
|
hid_shape: torch.LongTensor, |
|
|
window_fn: Callable[[torch.Tensor], List[torch.Tensor]], |
|
|
): |
|
|
hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) |
|
|
tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) |
|
|
tgt_idx = tgt_idx.squeeze(-1) |
|
|
src_idx = torch.argsort(tgt_idx) |
|
|
return ( |
|
|
lambda hid: torch.index_select(hid, 0, tgt_idx), |
|
|
lambda hid: torch.index_select(hid, 0, src_idx), |
|
|
tgt_shape, |
|
|
tgt_windows, |
|
|
) |
|
|
|