Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Modified From https://github.com/XXXXRT666/GPT-SoVITS | |
| """ | |
| import os | |
| import time | |
| import traceback | |
| from typing import Dict, List, Optional, Tuple | |
| import flash_attn # type: ignore | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from AR.models.embedding import ( | |
| SinePositionalEmbeddingNested as SinePositionalEmbedding, | |
| ) | |
| from AR.models.embedding import TokenEmbedding | |
| from AR.models.structs import T2SRequest, T2SResult, T2SSession | |
| from AR.models.t2s_model_abc import ( | |
| AttentionABC, | |
| FeedForward, | |
| KVCacheABC, | |
| KVCacheNHD, | |
| T2SDecoderABC, | |
| TorchProfiler, | |
| TransformerBlockABC, | |
| TransformerDecoderABC, | |
| ) | |
| Tensor = torch.Tensor | |
| class Attention(AttentionABC): | |
| def __init__(self, n_head: int, hidden_dim: int): | |
| super().__init__() | |
| self.n_head = n_head | |
| self.hidden_dim = hidden_dim | |
| assert hidden_dim % n_head == 0 | |
| self.head_dim = hidden_dim // n_head | |
| self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True) | |
| self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True) | |
| def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor: | |
| bsz, seqlen, _ = x.shape | |
| q, k, v = self.in_proj.forward(x).chunk(3, dim=-1) | |
| q = q.view(bsz, seqlen, self.n_head, self.head_dim) | |
| k = k.view(bsz, seqlen, self.n_head, self.head_dim) | |
| v = v.view(bsz, seqlen, self.n_head, self.head_dim) | |
| attn: Tensor = flash_attn.flash_attn_with_kvcache( | |
| q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1 | |
| ) # type: ignore | |
| attn = self.dropout.forward(attn) | |
| attn = attn.view(bsz, seqlen, self.hidden_dim) | |
| attn = self.out_proj.forward(attn) | |
| return attn | |
| class TransformerBlock(TransformerBlockABC): | |
| def __init__(self, n_head, ffn_dim, hidden_dim) -> None: | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.attention = Attention(n_head, hidden_dim) | |
| self.feed_forward = FeedForward(hidden_dim, ffn_dim) | |
| self.attention_norm = nn.LayerNorm([self.hidden_dim]) | |
| self.ffn_norm = nn.LayerNorm([self.hidden_dim]) | |
| class TransformerDecoder(TransformerDecoderABC): | |
| def __init__( | |
| self, | |
| hidden_dim, | |
| n_layer, | |
| n_head, | |
| ffn_dim, | |
| vocab_size, | |
| max_seq_length, | |
| max_batch_size, | |
| ) -> None: | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.n_head = n_head | |
| assert hidden_dim % n_head == 0 | |
| self.head_dim = hidden_dim // n_head | |
| self.vocab_size = vocab_size | |
| self.n_layer = n_layer | |
| self.layers = nn.ModuleList( # type: ignore | |
| TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer) | |
| ) | |
| self.max_seq_length: int = max_seq_length | |
| self.max_batch_size: int = max_batch_size | |
| self.setup_caches(self.max_batch_size, self.max_seq_length) | |
| def setup_caches(self, max_batch_size=10, max_seq_length=2500): | |
| self.max_seq_length = max_seq_length | |
| self.max_batch_size = max_batch_size | |
| class T2SDecoder(T2SDecoderABC): | |
| def __init__( | |
| self, | |
| config, | |
| *args, | |
| norm_first=False, | |
| max_seq_length=2500, | |
| max_batch_size=10, | |
| **kwds, | |
| ) -> None: | |
| super().__init__() | |
| hidden_dim = config["model"]["hidden_dim"] | |
| embedding_dim = config["model"]["embedding_dim"] | |
| n_head = config["model"]["head"] | |
| n_layer = config["model"]["n_layer"] | |
| vocab_size = config["model"]["vocab_size"] | |
| phoneme_vocab_size = config["model"]["phoneme_vocab_size"] | |
| p_dropout = config["model"]["dropout"] | |
| EOS = config["model"]["EOS"] | |
| ffn_dim = hidden_dim * 4 | |
| self.norm_first = norm_first | |
| self.n_layer = n_layer | |
| self.hidden_dim = hidden_dim | |
| self.n_head = n_head | |
| assert hidden_dim % n_head == 0 | |
| self.head_dim = hidden_dim // n_head | |
| self.embedding_dim = embedding_dim | |
| self.vocab_size = vocab_size | |
| self.phoneme_vocab_size = phoneme_vocab_size | |
| self.p_dropout = p_dropout | |
| self.max_seq_length = max_seq_length | |
| self.max_batch_size = max_batch_size | |
| self.EOS = EOS | |
| assert self.EOS == self.vocab_size - 1 | |
| self.bert_proj = nn.Linear(1024, self.embedding_dim) | |
| self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) | |
| self.ar_text_position = SinePositionalEmbedding( | |
| self.embedding_dim, | |
| dropout=0.1, | |
| scale=False, | |
| alpha=True, | |
| max_batch_size=max_batch_size, | |
| max_seq_len=max_seq_length, | |
| ) | |
| self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout) | |
| self.ar_audio_position = SinePositionalEmbedding( | |
| self.embedding_dim, | |
| dropout=0.1, | |
| scale=False, | |
| alpha=True, | |
| max_batch_size=max_batch_size, | |
| max_seq_len=max_seq_length, | |
| ) | |
| self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False) | |
| self.h: TransformerDecoderABC = TransformerDecoder( | |
| hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size | |
| ) | |
| self.kv_class = KVCacheNHD | |
| self._register_load_state_dict_pre_hook(self.load_hook) | |
| def embed( | |
| self, | |
| x: List[torch.Tensor], | |
| y: torch.Tensor, | |
| bert_features: List[torch.Tensor], | |
| ): | |
| x_nested = torch.nested.nested_tensor(x) | |
| assert x_nested.size(0) <= self.max_batch_size | |
| bert_features_nested = torch.nested.nested_tensor(list(map(lambda x: x.transpose(0, 1), bert_features))) | |
| x_emb = self.ar_text_embedding.forward(x_nested) | |
| bert = self.bert_proj.forward(bert_features_nested) | |
| x_emb = x_emb + bert | |
| x_pos = self.ar_text_position.prefill(x_emb) | |
| y_nested = torch.nested.nested_tensor(list(y.unbind(0))) | |
| y_emb = self.ar_audio_embedding.forward(y_nested) | |
| y_pos = self.ar_audio_position.prefill(y_emb) | |
| xy_pos = torch.nested.nested_tensor([torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))]) | |
| return xy_pos | |
| def post_forward(self, idx: int, session: T2SSession) -> None: | |
| pass | |
| def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]: | |
| return list(), dict() | |
| class CUDAGraphRunner: | |
| def __init__( | |
| self, | |
| decoder_model: T2SDecoderABC, | |
| device: torch.device = torch.device("cpu"), | |
| dtype: torch.dtype = torch.float32, | |
| ) -> None: | |
| assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"} | |
| assert dtype in {torch.float16, torch.bfloat16, torch.float32} | |
| self.device = device | |
| self.dtype = dtype | |
| self.decoder_path: os.PathLike | |
| self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype) | |
| self.graph: Optional[torch.cuda.CUDAGraph] = None | |
| self.xy_pos_ = torch.rand((1, 1, decoder_model.embedding_dim), device=device).to(dtype) | |
| self.xy_dec_ = torch.rand((1, 1, decoder_model.embedding_dim), device=device).to(dtype) | |
| self.kv_cache = decoder_model.init_cache(1) | |
| self.input_pos = torch.tensor([10]).int().cuda() | |
| def _handle_request(self, request: T2SRequest): | |
| with self.device: | |
| for i in self.kv_cache: | |
| i.empty() | |
| decoder = self.decoder_model | |
| session = T2SSession(decoder, request, device=self.device, dtype=self.dtype) | |
| self.input_pos.copy_(session.input_pos) | |
| t1 = 0.0 | |
| infer_speed = 0.0 | |
| y = session.y | |
| bsz = y.size(0) | |
| torch_profiler = TorchProfiler(request.debug) | |
| with torch_profiler.profiler(): | |
| for idx in tqdm(range(1500)): | |
| if idx == 0: | |
| xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, self.kv_cache) | |
| xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()]) | |
| else: | |
| if request.use_cuda_graph and self.graph is None and torch.cuda.is_available(): | |
| self.xy_pos_.copy_(session.xy_pos) | |
| args, kwds = decoder.pre_forward(session) | |
| self.graph = decoder.capture( | |
| self.input_pos, | |
| self.xy_pos_, | |
| self.xy_dec_, | |
| kv_caches=self.kv_cache, | |
| *args, | |
| **kwds, | |
| ) | |
| with torch_profiler.record("AR"): | |
| if self.graph: | |
| self.xy_pos_.copy_(session.xy_pos) | |
| self.graph.replay() | |
| xy_dec = self.xy_dec_.clone() | |
| else: | |
| args, kwds = decoder.pre_forward(session) | |
| xy_dec = decoder.h.forward( | |
| self.input_pos, | |
| session.xy_pos, | |
| self.kv_cache, | |
| *args, | |
| **kwds, | |
| ) | |
| decoder.post_forward(idx, session) | |
| logits = decoder.ar_predict_layer(xy_dec[:, -1]) | |
| self.input_pos.add_(1) | |
| if idx == 0: | |
| logits[:, -1] = float("-inf") | |
| with torch_profiler.record("Sampling"): | |
| samples = session.sampler.sample( | |
| logits=logits, | |
| previous_tokens=session.y, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| repetition_penalty=request.repetition_penalty, | |
| temperature=request.temperature, | |
| ) | |
| session.y = torch.cat([session.y, samples], dim=1) | |
| with torch_profiler.record("EOS"): | |
| argmax_token = torch.argmax(logits, dim=-1) | |
| sample_token = samples.squeeze(1) | |
| EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS) | |
| newly_done_mask = EOS_mask & (~session.completed) | |
| newly_done_indices = newly_done_mask.nonzero() | |
| if newly_done_indices.numel() > 0: | |
| session.y_results[newly_done_indices[0]] = session.y[ | |
| newly_done_indices[0], session.y_len : -1 | |
| ].squeeze(0) | |
| session.completed[newly_done_indices] = True | |
| if torch.all(session.completed).item(): | |
| if session.y.size(1) == 0: | |
| session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1) | |
| tqdm.write("Bad Zero Prediction") | |
| else: | |
| tqdm.write( | |
| f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}" | |
| ) | |
| tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s") | |
| infer_speed = (idx - 1) / (time.perf_counter() - t1) | |
| break | |
| if ( | |
| request.early_stop_num != -1 | |
| and (session.y.size(1) - session.y_len) > request.early_stop_num | |
| ) or idx == 1499: | |
| for i in range(bsz): | |
| if not session.completed[i].item(): | |
| session.y_results[i] = session.y[i, session.y_len :] | |
| session.completed[i] = True | |
| break | |
| with torch_profiler.record("NextPos"): | |
| y_emb = decoder.ar_audio_embedding(session.y[:, -1:]) | |
| session.xy_pos = decoder.ar_audio_position.forward(self.input_pos - session.x_lens, y_emb) | |
| if idx == 2: | |
| torch_profiler.start() | |
| t1 = time.perf_counter() | |
| if idx == 51: | |
| torch_profiler.end() | |
| if idx % 100 == 0: | |
| match session.device.type: | |
| case "cuda": | |
| torch.cuda.empty_cache() | |
| case "mps": | |
| torch.mps.empty_cache() | |
| case "xpu": | |
| torch.xpu.empty_cache() | |
| case "mtia": | |
| torch.mtia.empty_cache() | |
| match session.device.type: | |
| case "cuda": | |
| torch.cuda.empty_cache() | |
| case "mps": | |
| torch.mps.empty_cache() | |
| case "xpu": | |
| torch.xpu.empty_cache() | |
| case "mtia": | |
| torch.mtia.empty_cache() | |
| torch_profiler.end() | |
| return session.y_results[: request.valid_length], infer_speed | |
| def generate(self, request: T2SRequest): | |
| try: | |
| result, infer_speed = self._handle_request(request) | |
| t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success") | |
| except Exception as e: | |
| t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc()) | |
| return t2s_result | |
| def load_decoder(weights_path: os.PathLike, implement: str = "flash_attn"): | |
| print(f"Loading Text2Semantic Weights from {weights_path} with {implement.replace('_', ' ').title()} Implement") | |
| module_path = f"AR.models.t2s_model_{implement.lower()}" | |
| cls_name = "T2SDecoder" | |
| mod = __import__(module_path, fromlist=[cls_name]) | |
| decoder_cls: T2SDecoderABC = getattr(mod, cls_name) | |
| dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True) | |
| config = dict_s1["config"] | |
| decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=1) | |
| state_dict = dict_s1["weight"] | |
| decoder.load_state_dict(state_dict) | |
| return decoder.eval() | |