| from abc import ABC, abstractmethod | |
| from typing import Generic, TypeVar | |
| from jaxtyping import Float | |
| from torch import Tensor, nn | |
| from src.dataset.types import BatchedViews | |
| T = TypeVar("T") | |
| class Backbone(nn.Module, ABC, Generic[T]): | |
| cfg: T | |
| def __init__(self, cfg: T) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| def forward( | |
| self, | |
| context: BatchedViews, | |
| ) -> Float[Tensor, "batch view d_out height width"]: | |
| pass | |
| def d_out(self) -> int: | |
| pass | |