PicoAudio2 / models /autoencoder /autoencoder_base.py
rookie9's picture
Upload 77 files
f582ec6 verified
raw
history blame
695 Bytes
from abc import abstractmethod, ABC
from typing import Sequence
import torch
import torch.nn as nn
class AutoEncoderBase(ABC):
def __init__(
self, downsampling_ratio: int, sample_rate: int,
latent_shape: Sequence[int | None]
):
self.downsampling_ratio = downsampling_ratio
self.sample_rate = sample_rate
self.latent_token_rate = sample_rate // downsampling_ratio
self.latent_shape = latent_shape
self.time_dim = latent_shape.index(None) + 1 # the first dim is batch
@abstractmethod
def encode(
self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
...