Spaces:
Running
Running
| import abc | |
| import logging | |
| import torch | |
| from feature_retrieval import FaissRetrievableFeatureIndex | |
| logger = logging.getLogger(__name__) | |
| class IRetrieval(abc.ABC): | |
| def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor: | |
| raise NotImplementedError | |
| def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor: | |
| raise NotImplementedError | |
| class DummyRetrieval(IRetrieval): | |
| def retriv_whisper(self, vec: torch.FloatTensor) -> torch.FloatTensor: | |
| logger.debug("start dummy retriv whisper") | |
| return vec.clone().to(torch.device("cpu")) | |
| def retriv_hubert(self, vec: torch.FloatTensor) -> torch.FloatTensor: | |
| logger.debug("start dummy retriv hubert") | |
| return vec.clone().to(torch.device("cpu")) | |
| class FaissIndexRetrieval(IRetrieval): | |
| def __init__(self, hubert_index: FaissRetrievableFeatureIndex, whisper_index: FaissRetrievableFeatureIndex) -> None: | |
| self._hubert_index = hubert_index | |
| self._whisper_index = whisper_index | |
| def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor: | |
| logger.debug("start retriv whisper") | |
| np_vec = self._whisper_index.retriv(vec.numpy()) | |
| return torch.from_numpy(np_vec) | |
| def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor: | |
| logger.debug("start retriv hubert") | |
| np_vec = self._hubert_index.retriv(vec.numpy()) | |
| return torch.from_numpy(np_vec) | |