Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import torch | |
| from typing import Optional, Union, Tuple | |
| # @torch.jit.script | |
| def get_similarity(mk: torch.Tensor, | |
| ms: torch.Tensor, | |
| qk: torch.Tensor, | |
| qe: torch.Tensor, | |
| add_batch_dim: bool = False) -> torch.Tensor: | |
| # used for training/inference and memory reading/memory potentiation | |
| # mk: B x CK x [N] - Memory keys | |
| # ms: B x 1 x [N] - Memory shrinkage | |
| # qk: B x CK x [HW/P] - Query keys | |
| # qe: B x CK x [HW/P] - Query selection | |
| # Dimensions in [] are flattened | |
| if add_batch_dim: | |
| mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) | |
| qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) | |
| CK = mk.shape[1] | |
| mk = mk.flatten(start_dim=2) | |
| ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None | |
| qk = qk.flatten(start_dim=2) | |
| qe = qe.flatten(start_dim=2) if qe is not None else None | |
| if qe is not None: | |
| # See XMem's appendix for derivation | |
| mk = mk.transpose(1, 2) | |
| a_sq = (mk.pow(2) @ qe) | |
| two_ab = 2 * (mk @ (qk * qe)) | |
| b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) | |
| similarity = (-a_sq + two_ab - b_sq) | |
| else: | |
| # similar to STCN if we don't have the selection term | |
| a_sq = mk.pow(2).sum(1).unsqueeze(2) | |
| two_ab = 2 * (mk.transpose(1, 2) @ qk) | |
| similarity = (-a_sq + two_ab) | |
| if ms is not None: | |
| similarity = similarity * ms / math.sqrt(CK) # B*N*HW | |
| else: | |
| similarity = similarity / math.sqrt(CK) # B*N*HW | |
| return similarity | |
| def do_softmax( | |
| similarity: torch.Tensor, | |
| top_k: Optional[int] = None, | |
| inplace: bool = False, | |
| return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| # normalize similarity with top-k softmax | |
| # similarity: B x N x [HW/P] | |
| # use inplace with care | |
| if top_k is not None: | |
| values, indices = torch.topk(similarity, k=top_k, dim=1) | |
| x_exp = values.exp_() | |
| x_exp /= torch.sum(x_exp, dim=1, keepdim=True) | |
| if inplace: | |
| similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW | |
| affinity = similarity | |
| else: | |
| affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW | |
| else: | |
| maxes = torch.max(similarity, dim=1, keepdim=True)[0] | |
| x_exp = torch.exp(similarity - maxes) | |
| x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) | |
| affinity = x_exp / x_exp_sum | |
| indices = None | |
| if return_usage: | |
| return affinity, affinity.sum(dim=2) | |
| return affinity | |
| def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, | |
| qe: torch.Tensor) -> torch.Tensor: | |
| # shorthand used in training with no top-k | |
| similarity = get_similarity(mk, ms, qk, qe) | |
| affinity = do_softmax(similarity) | |
| return affinity | |
| def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor: | |
| B, CV, T, H, W = mv.shape | |
| mo = mv.view(B, CV, T * H * W) | |
| mem = torch.bmm(mo, affinity) | |
| mem = mem.view(B, CV, H, W) | |
| return mem | |