| import torch | |
| from typeguard import check_argument_types | |
| class ForwardAdaptor(torch.nn.Module): | |
| """Wrapped module to parallelize specified method | |
| torch.nn.DataParallel parallelizes only "forward()" | |
| and, maybe, the method having the other name can't be applied | |
| except for wrapping the module just like this class. | |
| Examples: | |
| >>> class A(torch.nn.Module): | |
| ... def foo(self, x): | |
| ... ... | |
| >>> model = A() | |
| >>> model = ForwardAdaptor(model, "foo") | |
| >>> model = torch.nn.DataParallel(model, device_ids=[0, 1]) | |
| >>> x = torch.randn(2, 10) | |
| >>> model(x) | |
| """ | |
| def __init__(self, module: torch.nn.Module, name: str): | |
| assert check_argument_types() | |
| super().__init__() | |
| self.module = module | |
| self.name = name | |
| if not hasattr(module, name): | |
| raise ValueError(f"{module} doesn't have {name}") | |
| def forward(self, *args, **kwargs): | |
| func = getattr(self.module, self.name) | |
| return func(*args, **kwargs) | |