|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Callable, Dict, List, Tuple |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MMArg: |
|
|
vid: Any |
|
|
txt: Any |
|
|
|
|
|
|
|
|
def get_args(key: str, args: List[Any]) -> List[Any]: |
|
|
return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] |
|
|
|
|
|
|
|
|
def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} |
|
|
|
|
|
|
|
|
class MMModule(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
module: Callable[..., nn.Module], |
|
|
*args, |
|
|
shared_weights: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.shared_weights = shared_weights |
|
|
if self.shared_weights: |
|
|
assert get_args("vid", args) == get_args("txt", args) |
|
|
assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) |
|
|
self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) |
|
|
else: |
|
|
self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) |
|
|
self.txt = module(*get_args("txt", args), **get_kwargs("txt", kwargs)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
vid: torch.FloatTensor, |
|
|
txt: torch.FloatTensor, |
|
|
*args, |
|
|
**kwargs, |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.FloatTensor, |
|
|
]: |
|
|
vid_module = self.vid if not self.shared_weights else self.all |
|
|
txt_module = self.txt if not self.shared_weights else self.all |
|
|
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) |
|
|
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) |
|
|
return vid, txt |
|
|
|