|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
|
|
|
from .perceiver import Perceiver |
|
|
from .t3_config import T3Config |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class T3Cond: |
|
|
""" |
|
|
Dataclass container for most / all conditioning info. |
|
|
TODO: serialization methods aren't used, keeping them around for convenience |
|
|
""" |
|
|
|
|
|
speaker_emb: Tensor |
|
|
clap_emb: Optional[Tensor] = None |
|
|
cond_prompt_speech_tokens: Optional[Tensor] = None |
|
|
cond_prompt_speech_emb: Optional[Tensor] = None |
|
|
emotion_adv: Optional[Tensor] = 0.5 |
|
|
|
|
|
def to(self, *, device=None, dtype=None): |
|
|
"Cast to a device and dtype. Dtype casting is ignored for long/int tensors." |
|
|
for k, v in self.__dict__.items(): |
|
|
if torch.is_tensor(v): |
|
|
is_fp = type(v.view(-1)[0].item()) is not int |
|
|
setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None)) |
|
|
return self |
|
|
|
|
|
def save(self, fpath): |
|
|
torch.save(self.__dict__, fpath) |
|
|
|
|
|
@staticmethod |
|
|
def load(fpath, map_location="cpu"): |
|
|
kwargs = torch.load(fpath, map_location=map_location, weights_only=True) |
|
|
return T3Cond(**kwargs) |
|
|
|
|
|
|
|
|
class T3CondEnc(nn.Module): |
|
|
""" |
|
|
Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc. |
|
|
""" |
|
|
|
|
|
def __init__(self, hp: T3Config): |
|
|
super().__init__() |
|
|
self.hp = hp |
|
|
if hp.encoder_type == "voice_encoder": |
|
|
self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels) |
|
|
else: |
|
|
raise NotImplementedError(str(hp.encoder_type)) |
|
|
|
|
|
|
|
|
self.emotion_adv_fc = None |
|
|
if hp.emotion_adv: |
|
|
self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False) |
|
|
|
|
|
|
|
|
self.perceiver = None |
|
|
if hp.use_perceiver_resampler: |
|
|
self.perceiver = Perceiver() |
|
|
|
|
|
def forward(self, cond: T3Cond): |
|
|
|
|
|
assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \ |
|
|
"no embeddings for cond_prompt_speech_tokens" |
|
|
|
|
|
|
|
|
cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] |
|
|
empty = torch.zeros_like(cond_spkr[:, :0]) |
|
|
|
|
|
|
|
|
assert cond.clap_emb is None, "clap_embed not implemented" |
|
|
cond_clap = empty |
|
|
|
|
|
|
|
|
cond_prompt_speech_emb = cond.cond_prompt_speech_emb |
|
|
if cond_prompt_speech_emb is None: |
|
|
cond_prompt_speech_emb = empty |
|
|
elif self.hp.use_perceiver_resampler: |
|
|
cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb) |
|
|
|
|
|
|
|
|
cond_emotion_adv = empty |
|
|
if self.hp.emotion_adv: |
|
|
assert cond.emotion_adv is not None |
|
|
cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1)) |
|
|
|
|
|
|
|
|
cond_embeds = torch.cat(( |
|
|
cond_spkr, |
|
|
cond_clap, |
|
|
cond_prompt_speech_emb, |
|
|
cond_emotion_adv, |
|
|
), dim=1) |
|
|
return cond_embeds |
|
|
|