Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
| from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer | |
| from omni_speech.constants import IGNORE_INDEX | |
| torch.autograd.set_detect_anomaly(True) | |
| try: | |
| import sys | |
| sys.path.append('/mnt/lzy/LLaMA-Omni/CosyVoice/') | |
| from cosyvoice.cli.cosyvoice import CosyVoice | |
| except: | |
| print('CosyVoice not found') | |
| import os | |
| if 'SPEECH_GEN_CONV_KERNEL' in os.environ: | |
| SPEECH_GEN_CONV_KERNEL = int(os.environ['SPEECH_GEN_CONV_KERNEL']) | |
| print(f'Using SPEECH_GEN_CONV_KERNEL={SPEECH_GEN_CONV_KERNEL}') | |
| else: | |
| SPEECH_GEN_CONV_KERNEL = -1 | |
| if 'DISTILL_EMBEDDING' in os.environ: | |
| DISTILL_EMBEDDING = True | |
| print(f'DISTILL_EMBEDDING is set.') | |
| else: | |
| DISTILL_EMBEDDING = False | |
| def lengths_to_padding_mask(lens): | |
| bsz, max_lens = lens.size(0), torch.max(lens).item() | |
| mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) | |
| mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) | |
| return mask | |
| def _uniform_assignment(src_lens, tgt_lens): | |
| tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device) | |
| ratio = tgt_lens / src_lens | |
| index_t = (tgt_indices / ratio.view(-1, 1)).long() | |
| return index_t | |
| class SpeechGeneratorCTC(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| n_layers, n_dims, n_heads, n_inter_dims = list(map(int, config.ctc_decoder_config[1:-1].split(","))) | |
| _config = copy.deepcopy(config) | |
| _config.hidden_size = n_dims | |
| _config.num_hidden_layers = n_layers | |
| _config.num_attention_heads = n_heads | |
| _config.num_key_value_heads = n_heads | |
| _config.intermediate_size = n_inter_dims | |
| _config._attn_implementation = "flash_attention_2" | |
| self.upsample_factor = config.ctc_upsample_factor | |
| self.input_proj = nn.Linear(config.hidden_size, n_dims) | |
| self.layers = nn.ModuleList( | |
| [LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] | |
| ) | |
| self.unit_vocab_size = config.unit_vocab_size | |
| self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) | |
| def upsample(self, reps, tgt_units=None): | |
| src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) | |
| up_lens = src_lens * self.upsample_factor | |
| if tgt_units is not None: | |
| tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
| up_lens = torch.max(up_lens, tgt_lens) | |
| reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) | |
| padding_mask = lengths_to_padding_mask(up_lens) | |
| mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( | |
| padding_mask, 0 | |
| ) | |
| copied_reps = torch.gather( | |
| reps, | |
| 1, | |
| mapped_inputs.unsqueeze(-1).expand( | |
| *mapped_inputs.size(), reps.size(-1) | |
| ), | |
| ) | |
| copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) | |
| position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) | |
| return copied_reps, ~padding_mask, position_ids | |
| def forward(self, tgt_reps, labels, tgt_units): | |
| tgt_label_reps = [] | |
| for tgt_rep, label in zip(tgt_reps, labels): | |
| tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
| hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) | |
| hidden_states = self.input_proj(hidden_states) | |
| for layer in self.layers: | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| ctc_logits = self.output_proj(hidden_states) | |
| ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
| ctc_lens = attention_mask.long().sum(dim=-1) | |
| ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
| ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) | |
| ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) | |
| ctc_loss = F.ctc_loss( | |
| ctc_lprobs.transpose(0, 1), | |
| ctc_tgt_flat, | |
| ctc_lens, | |
| ctc_tgt_lens, | |
| reduction="sum", | |
| zero_infinity=True, | |
| blank=self.unit_vocab_size | |
| ) | |
| ctc_loss /= ctc_tgt_lens.sum().item() | |
| return ctc_loss | |
| def predict(self, tgt_reps): | |
| hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) | |
| hidden_states = self.input_proj(hidden_states) | |
| for layer in self.layers: | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| ctc_logits = self.output_proj(hidden_states) | |
| ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
| ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) | |
| return ctc_pred | |
| class SpeechGeneratorCTCQwen(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) | |
| _config = copy.deepcopy(config) | |
| _config.hidden_size = n_dims | |
| _config.num_hidden_layers = n_layers | |
| _config.num_attention_heads = n_heads | |
| _config.num_key_value_heads = n_kv_heads | |
| _config.intermediate_size = n_inter_dims | |
| _config._attn_implementation = "flash_attention_2" | |
| self.upsample_factor = config.ctc_upsample_factor | |
| self.input_proj = nn.Linear(config.hidden_size, n_dims) | |
| self.layers = nn.ModuleList( | |
| [Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] | |
| ) | |
| self.unit_vocab_size = config.unit_vocab_size | |
| self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) | |
| if SPEECH_GEN_CONV_KERNEL > 0: | |
| self.temporal_conv = nn.Conv1d(n_dims, n_dims, SPEECH_GEN_CONV_KERNEL, padding=0) | |
| self.learnable_pad_left = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) | |
| self.learnable_pad_right = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) | |
| # self.conv_layer_id = n_layers // 2 # Insert temporal conv layer in the middle of the decoder layers | |
| def upsample(self, reps, tgt_units=None): | |
| src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) | |
| up_lens = src_lens * self.upsample_factor | |
| if tgt_units is not None: | |
| tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
| up_lens = torch.max(up_lens, tgt_lens) | |
| reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) | |
| padding_mask = lengths_to_padding_mask(up_lens) | |
| mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( | |
| padding_mask, 0 | |
| ) | |
| copied_reps = torch.gather( | |
| reps, | |
| 1, | |
| mapped_inputs.unsqueeze(-1).expand( | |
| *mapped_inputs.size(), reps.size(-1) | |
| ), | |
| ) | |
| copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) | |
| position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) | |
| return copied_reps, ~padding_mask, position_ids | |
| def forward(self, tgt_reps, labels, tgt_units): | |
| tgt_label_reps = [] | |
| for tgt_rep, label in zip(tgt_reps, labels): | |
| if SPEECH_GEN_CONV_KERNEL > 0: | |
| now_rep = tgt_rep[label != IGNORE_INDEX] | |
| now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) | |
| now_rep = self.input_proj(now_rep)[None] | |
| now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] | |
| tgt_label_reps.append(now_rep) | |
| else: | |
| tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
| hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) | |
| if SPEECH_GEN_CONV_KERNEL < 0: | |
| hidden_states = self.input_proj(hidden_states) | |
| for layer_id, layer in enumerate(self.layers): | |
| # if SPEECH_GEN_CONV_KERNEL: | |
| # if layer_id == self.conv_layer_id: | |
| # hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| ctc_logits = self.output_proj(hidden_states) | |
| ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
| ctc_lens = attention_mask.long().sum(dim=-1) | |
| ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
| ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) | |
| ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) | |
| ctc_loss = F.ctc_loss( | |
| ctc_lprobs.transpose(0, 1), | |
| ctc_tgt_flat, | |
| ctc_lens, | |
| ctc_tgt_lens, | |
| reduction="sum", | |
| zero_infinity=True, | |
| blank=self.unit_vocab_size | |
| ) | |
| ctc_loss /= ctc_tgt_lens.sum().item() | |
| return ctc_loss | |
| def predict(self, tgt_reps): | |
| hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) | |
| hidden_states = self.input_proj(hidden_states) | |
| for layer in self.layers: | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| ctc_logits = self.output_proj(hidden_states) | |
| ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
| ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) | |
| return ctc_pred | |
| class SpeechGeneratorCEQwen(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) | |
| _config = copy.deepcopy(config) | |
| _config.hidden_size = n_dims | |
| _config.num_hidden_layers = n_layers | |
| _config.num_attention_heads = n_heads | |
| _config.num_key_value_heads = n_kv_heads | |
| _config.intermediate_size = n_inter_dims | |
| _config._attn_implementation = "flash_attention_2" | |
| self.upsample_factor = 1 | |
| self.input_proj = nn.Linear(config.hidden_size, n_dims) | |
| self.layers = nn.ModuleList( | |
| [Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] | |
| ) | |
| self.unit_vocab_size = config.unit_vocab_size | |
| self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) | |
| def upsample(self, reps, tgt_units=None): | |
| src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) | |
| up_lens = src_lens * self.upsample_factor | |
| if tgt_units is not None: | |
| tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
| up_lens = torch.max(up_lens, tgt_lens) | |
| reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) | |
| padding_mask = lengths_to_padding_mask(up_lens) | |
| mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( | |
| padding_mask, 0 | |
| ) | |
| copied_reps = torch.gather( | |
| reps, | |
| 1, | |
| mapped_inputs.unsqueeze(-1).expand( | |
| *mapped_inputs.size(), reps.size(-1) | |
| ), | |
| ) | |
| copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) | |
| position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) | |
| return copied_reps, ~padding_mask, position_ids | |
| def forward(self, tgt_reps, labels, tgt_units): | |
| tgt_label_reps = [] | |
| for tgt_rep, label in zip(tgt_reps, labels): | |
| # if SPEECH_GEN_CONV_KERNEL > 0: | |
| # now_rep = tgt_rep[label != IGNORE_INDEX] | |
| # now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) | |
| # now_rep = self.input_proj(now_rep)[None] | |
| # now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] | |
| # tgt_label_reps.append(now_rep) | |
| # else: | |
| tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
| hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) | |
| # if SPEECH_GEN_CONV_KERNEL < 0: | |
| hidden_states = self.input_proj(hidden_states) | |
| for layer_id, layer in enumerate(self.layers): | |
| # if SPEECH_GEN_CONV_KERNEL: | |
| # if layer_id == self.conv_layer_id: | |
| # hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_states.size(-1)) | |
| logits = self.output_proj(shift_hidden_states) | |
| shift_labels = tgt_units[..., 1:].contiguous().reshape(-1) | |
| assert shift_labels.size(0) == shift_hidden_states.size(0) | |
| loss_fct = nn.CrossEntropyLoss() | |
| logits = logits.float() | |
| loss = loss_fct(logits, shift_labels) | |
| # loss = (loss / 1.0).sum().item() | |
| # loss = loss.sum().item() | |
| return loss | |
| # def predict(self, tgt_reps): | |
| # hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) | |
| # hidden_states = self.input_proj(hidden_states) | |
| # for layer in self.layers: | |
| # layer_outputs = layer( | |
| # hidden_states, | |
| # attention_mask=attention_mask, | |
| # position_ids=position_ids, | |
| # ) | |
| # hidden_states = layer_outputs[0] | |
| # ctc_logits = self.output_proj(hidden_states) | |
| # ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
| # ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) | |
| # return ctc_pred | |
| # class SpeechGeneratorCosyVoice(nn.Module): | |
| # def __init__(self, config): | |
| # super().__init__() | |
| # self.input_proj = nn.Sequential( | |
| # nn.Linear(config.hidden_size, 1024), | |
| # nn.GELU(), | |
| # nn.Linear(1024, 512) | |
| # ) | |
| # self.cosyvoice1 = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_onnx=False, fp16=False) | |
| # self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) | |
| # self.llm = self.cosyvoice1.model.llm | |
| # if DISTILL_EMBEDDING: | |
| # self.criterion = nn.CosineEmbeddingLoss() | |
| # def forward(self, tgt_reps, labels, answer): | |
| # tgt_label_reps = [] | |
| # batch_speech_tokens = [] | |
| # embeddings = [] | |
| # target_embeddings = [] | |
| # if DISTILL_EMBEDDING: | |
| # for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
| # # make all label id in [151644,151645,198] to IGNORE_INDEX | |
| # label[label == 151644] = IGNORE_INDEX | |
| # label[label == 151645] = IGNORE_INDEX | |
| # label[label == 198] = IGNORE_INDEX | |
| # tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
| # normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) | |
| # tts_text_token_all = [] | |
| # for norm_text in normalized_text: | |
| # tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) | |
| # tts_text_token_all.append(tts_text_token) | |
| # tts_text_token_all = torch.cat(tts_text_token_all, dim=0) | |
| # target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) | |
| # target_embeddings.append(target_embedding) | |
| # import pdb;pdb.set_trace() | |
| # tgt_label_reps = torch.stack(tgt_label_reps) | |
| # target_embeddings = torch.stack(target_embeddings).squeeze(1) | |
| # hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) | |
| # target_embeddings = target_embeddings.reshape(-1, 512) | |
| # loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) | |
| # else: | |
| # for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
| # # make all label id in [151644,151645,198] to IGNORE_INDEX | |
| # label[label == 151644] = IGNORE_INDEX | |
| # label[label == 151645] = IGNORE_INDEX | |
| # label[label == 198] = IGNORE_INDEX | |
| # tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
| # speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False) | |
| # speech_tokens = [] | |
| # for i,j in enumerate(speech_token): | |
| # speech_tokens.append(j['tts_speech_token'].squeeze(0)) | |
| # speech_tokens.append(torch.tensor([0])) | |
| # speech_tokens = torch.cat(speech_tokens, dim=0) | |
| # if speech_tokens.size(0) > 1: | |
| # speech_tokens = speech_tokens[:-1] | |
| # batch_speech_tokens.append(speech_tokens) | |
| # embedding = self.cosyvoice.frontend.frontend_embedding('英文女') | |
| # embeddings.append(embedding['llm_embedding'].squeeze(0)) | |
| # tgt_label_reps = torch.stack(tgt_label_reps) | |
| # batch_speech_token = torch.stack(batch_speech_tokens) | |
| # embeddings = torch.stack(embeddings) | |
| # hidden_states = self.input_proj(tgt_label_reps) | |
| # batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), | |
| # 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), | |
| # 'embedding': embeddings} | |
| # output = self.llm.forward_ours(batch, 'cuda') | |
| # loss = output['loss'] | |
| # return loss | |
| class SpeechGeneratorCosyVoice(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) | |
| def forward(self, tgt_reps, labels, answer): | |
| tgt_label_reps = [] | |
| batch_speech_tokens = [] | |
| embeddings = [] | |
| target_embeddings = [] | |
| if DISTILL_EMBEDDING: | |
| for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
| # make all label id in [151644,151645,198] to IGNORE_INDEX | |
| label[label == 151644] = IGNORE_INDEX | |
| label[label == 151645] = IGNORE_INDEX | |
| label[label == 198] = IGNORE_INDEX | |
| tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
| normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) | |
| tts_text_token_all = [] | |
| for norm_text in normalized_text: | |
| tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) | |
| tts_text_token_all.append(tts_text_token) | |
| tts_text_token_all = torch.cat(tts_text_token_all, dim=0) | |
| target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) | |
| target_embeddings.append(target_embedding) | |
| import pdb;pdb.set_trace() | |
| tgt_label_reps = torch.stack(tgt_label_reps) | |
| target_embeddings = torch.stack(target_embeddings).squeeze(1) | |
| hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) | |
| target_embeddings = target_embeddings.reshape(-1, 512) | |
| loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) | |
| else: | |
| for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
| # make all label id in [151644,151645,198] to IGNORE_INDEX | |
| label[label == 151644] = IGNORE_INDEX | |
| label[label == 151645] = IGNORE_INDEX | |
| label[label == 198] = IGNORE_INDEX | |
| tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
| speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False) | |
| speech_tokens = [] | |
| for i,j in enumerate(speech_token): | |
| speech_tokens.append(j['tts_speech_token'].squeeze(0)) | |
| speech_tokens.append(torch.tensor([0])) | |
| speech_tokens = torch.cat(speech_tokens, dim=0) | |
| if speech_tokens.size(0) > 1: | |
| speech_tokens = speech_tokens[:-1] | |
| batch_speech_tokens.append(speech_tokens) | |
| embedding = self.cosyvoice.frontend.frontend_embedding('英文女') | |
| embeddings.append(embedding['llm_embedding'].squeeze(0)) | |
| tgt_label_reps = torch.stack(tgt_label_reps) | |
| batch_speech_token = torch.stack(batch_speech_tokens) | |
| embeddings = torch.stack(embeddings) | |
| hidden_states = self.input_proj(tgt_label_reps) | |
| batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), | |
| 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), | |
| 'embedding': embeddings} | |
| output = self.llm.forward_ours(batch, 'cuda') | |
| loss = output['loss'] | |
| return loss |