Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import time | |
| import inflect | |
| import torch | |
| import torch.nn.functional as F | |
| from torchaudio.transforms import Resample | |
| from torch import Tensor | |
| from torch.nn.attention import SDPBackend, sdpa_kernel | |
| from vui.model import Vui | |
| from vui.sampling import multinomial, sample_top_k, sample_top_p, sample_top_p_top_k | |
| resample = Resample(22050, 16000).cuda() | |
| def ensure_spaces_around_tags(text: str): | |
| # Add space before '[' if not preceded by space, '<', or '[' | |
| text = re.sub( | |
| r"(?<![<\[\s])(\[)", | |
| lambda m: ( | |
| f"\n{m.group(1)}" | |
| if m.start() > 0 and text[m.start() - 1] == "\n" | |
| else f" {m.group(1)}" | |
| ), | |
| text, | |
| ) | |
| # Add space after ']' if not preceded by digit+']' and not followed by space, '>', or ']' | |
| text = re.sub( | |
| r"(?<!\d\])(\])(?![>\]\s])", | |
| lambda m: ( | |
| f"{m.group(1)}\n" | |
| if m.end() < len(text) and text[m.end()] == "\n" | |
| else f"{m.group(1)} " | |
| ), | |
| text, | |
| ) | |
| text = text.strip() | |
| return text | |
| REPLACE = [ | |
| ("—", ","), | |
| ("'", "'"), | |
| (":", ","), | |
| (";", ","), | |
| ] | |
| engine = None | |
| wm = None | |
| def asr(chunk, model=None, prefix=None): | |
| import whisper | |
| global wm | |
| if model is not None: | |
| wm = model | |
| elif wm is None: | |
| wm = whisper.load_model("turbo", "cuda") | |
| """Process audio with VAD and transcribe""" | |
| chunk = whisper.pad_or_trim(chunk) | |
| mel = whisper.log_mel_spectrogram(chunk, n_mels=wm.dims.n_mels).to(wm.device) | |
| options = whisper.DecodingOptions( | |
| language="en", without_timestamps=True, prefix=prefix | |
| ) | |
| result = whisper.decode(wm, mel[None], options) | |
| return result[0].text | |
| def replace_numbers_with_words(text): | |
| global engine | |
| if engine is None: | |
| engine = inflect.engine() | |
| # Function to convert a number match to words | |
| def number_to_words(match): | |
| number = match.group() | |
| return engine.number_to_words(number) + " " | |
| # Replace digits with their word equivalents | |
| return re.sub(r"\d+", number_to_words, text) | |
| valid_non_speech = ["breath", "sigh", "laugh", "tut", "hesitate", "clearthroat"] | |
| valid_non_speech = [f"[{v}]" for v in valid_non_speech] | |
| def remove_all_invalid_non_speech(txt): | |
| """ | |
| Remove all non-speech markers that are not in the valid_non_speech list. | |
| Only keeps valid non-speech markers like [breath], [sigh], etc. | |
| """ | |
| # Find all text within square brackets | |
| bracket_pattern = r"\[([^\]]+)\]" | |
| brackets = re.findall(bracket_pattern, txt) | |
| # For each bracketed text, check if it's in our valid list | |
| for bracket in brackets: | |
| bracket_with_brackets = f"[{bracket}]" | |
| if bracket_with_brackets not in valid_non_speech and bracket != "pause": | |
| # If not valid, remove it from the text | |
| txt = txt.replace(bracket_with_brackets, "") | |
| return txt | |
| def simple_clean(text): | |
| text = re.sub(r"(\d+)am", r"\1 AM", text) | |
| text = re.sub(r"(\d+)pm", r"\1 PM", text) | |
| text = replace_numbers_with_words(text) | |
| text = ensure_spaces_around_tags(text) | |
| text = remove_all_invalid_non_speech(text) | |
| text = text.replace('"', "") | |
| text = text.replace("”", "") | |
| text = text.replace("“", "") | |
| text = text.replace("’", "'") | |
| text = text.replace("%", " percent") | |
| text = text.replace("*", "") | |
| text = text.replace("(", "") | |
| text = text.replace(")", "") | |
| text = text.replace(";", "") | |
| text = text.replace("–", " ") | |
| text = text.replace("—", "") | |
| text = text.replace(":", "") | |
| text = text.replace("…", "...") | |
| text = text.replace("s...", "s") | |
| # replace repeating \n with just one \n | |
| text = re.sub(r"\n+", "\n", text) | |
| ntxt = re.sub(r" +", " ", text) | |
| # Ensure that ntxt ends with . or ? | |
| ntxt = ntxt.strip() | |
| if not ntxt.endswith(".") or ntxt.endswith("?"): | |
| ntxt += "." | |
| ntxt += " [pause]" | |
| return ntxt | |
| def generate( | |
| self: Vui, | |
| text: str, | |
| prompt_codes: Tensor | None = None, | |
| temperature: float = 0.5, | |
| top_k: int | None = 150, | |
| top_p: float | None = None, | |
| max_gen_len: int = int(120 * 21.53), | |
| ): | |
| text = simple_clean(text) | |
| with ( | |
| torch.autocast("cuda", torch.bfloat16, True), | |
| sdpa_kernel([SDPBackend.MATH]), | |
| ): | |
| t1 = time.perf_counter() | |
| batch_size = 1 | |
| device = self.device | |
| self.dtype | |
| self.decoder.allocate_inference_cache(batch_size, device, torch.bfloat16) | |
| texts = [text] | |
| encoded = self.tokenizer( | |
| texts, | |
| padding="longest", | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoded.input_ids.to(device) | |
| text_embeddings = self.token_emb(input_ids) | |
| B = batch_size | |
| Q = self.config.model.n_quantizers | |
| if prompt_codes is None: | |
| prompt_codes = torch.zeros( | |
| (batch_size, Q, 0), dtype=torch.int64, device=device | |
| ) | |
| else: | |
| prompt_codes = prompt_codes[:, :Q].repeat(batch_size, 1, 1) | |
| start_offset = prompt_codes.size(-1) | |
| pattern = self.pattern_provider.get_pattern(max_gen_len) | |
| # this token is used as default value for codes that are not generated yet | |
| unknown_token = -1 | |
| special_token_id = self.config.model.special_token_id | |
| # we generate codes up to the max_gen_len that will be mapped to the pattern sequence | |
| codes = torch.full( | |
| (B, Q, max_gen_len), unknown_token, dtype=torch.int64, device=device | |
| ) | |
| codes[:, :, :start_offset] = prompt_codes | |
| sequence, indexes, mask = pattern.build_pattern_sequence( | |
| codes, special_token_id | |
| ) | |
| # retrieve the start_offset in the sequence: | |
| # it is the first sequence step that contains the `start_offset` timestep | |
| start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) | |
| assert start_offset_sequence is not None | |
| prev_offset = 0 | |
| S = sequence.size(-1) | |
| do_prefill = True | |
| eos = self.config.model.audio_eos_id | |
| for offset in range(start_offset_sequence, S): | |
| # print(f"{prev_offset}:{offset}") | |
| curr_sequence = sequence[..., prev_offset:offset] | |
| audio_embeddings = ( | |
| sum([self.audio_embeddings[q](curr_sequence[:, q]) for q in range(Q)]) | |
| / Q | |
| ) | |
| if do_prefill: | |
| embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1) | |
| T = embeddings.size(1) | |
| input_pos = torch.arange(0, T, device=device) | |
| do_prefill = False | |
| else: | |
| embeddings = audio_embeddings | |
| input_pos = torch.tensor([T], device=device) | |
| T += 1 | |
| out = self.decoder(embeddings, input_pos) | |
| if offset == 15: | |
| print("TTFB", time.perf_counter() - t1) | |
| logits = torch.stack( | |
| [self.audio_heads[q](out[:, -1]) for q in range(Q)], dim=1 | |
| ) | |
| repetition_penalty = 1.4 | |
| history_window = 12 | |
| # Get the history of generated tokens for each quantizer | |
| for q in range(Q): | |
| # Extract the history window for this quantizer | |
| history_start = max(0, offset - history_window) | |
| token_history = sequence[0, q, history_start:offset] | |
| # Only apply penalty to tokens that appear in the history | |
| unique_tokens = torch.unique(token_history) | |
| unique_tokens = unique_tokens[unique_tokens != special_token_id] | |
| unique_tokens = unique_tokens[unique_tokens != eos] | |
| unique_tokens = unique_tokens[unique_tokens != unknown_token] | |
| if len(unique_tokens) > 0: | |
| # Apply penalty by dividing the logits for tokens that have appeared recently | |
| logits[0, q, unique_tokens] = ( | |
| logits[0, q, unique_tokens] / repetition_penalty | |
| ) | |
| if offset < 24.53 * 4: | |
| logits[..., eos] = -float("inf") | |
| probs = F.softmax(logits / temperature, dim=-1) | |
| # print(probs.shape) | |
| if top_p is not None and top_k is not None: | |
| next_codes = sample_top_p_top_k(probs, top_p, top_k) | |
| elif top_p is not None and top_p > 0: | |
| next_codes = sample_top_p(probs, top_p) | |
| elif top_k is not None and top_k > 0: | |
| next_codes = sample_top_k(probs, top_k) | |
| else: | |
| next_codes = multinomial(probs, num_samples=1) | |
| next_codes = next_codes.repeat(batch_size, 1, 1) | |
| if (probs[..., eos] > 0.95).any(): | |
| print("breaking at", offset) | |
| break | |
| valid_mask = mask[..., offset : offset + 1].expand(B, -1, -1) | |
| next_codes[~valid_mask] = special_token_id | |
| sequence[..., offset : offset + 1] = torch.where( | |
| sequence[..., offset : offset + 1] == unknown_token, | |
| next_codes, | |
| sequence[..., offset : offset + 1], | |
| ) | |
| prev_offset = offset | |
| # print(sequence.shape) | |
| out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence( | |
| sequence, special_token=unknown_token | |
| ) | |
| # sanity checks over the returned codes and corresponding masks | |
| # assert (out_codes[..., :max_gen_len] != unknown_token).all() | |
| # assert (out_mask[..., :max_gen_len] == 1).all() | |
| out_codes = out_codes[..., prompt_codes.shape[-1] : offset] | |
| return out_codes[[0]] | |
| def render( | |
| self: Vui, | |
| text: str, | |
| prompt_codes: Tensor | None = None, | |
| temperature: float = 0.5, | |
| top_k: int | None = 100, | |
| top_p: float | None = None, | |
| max_secs: int = 100, | |
| ): | |
| """ | |
| Render audio from text. Uses generate for text < 1000 characters, | |
| otherwise breaks text into sections and uses chunking with context. | |
| """ | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| text = remove_all_invalid_non_speech(text) | |
| text = simple_clean(text) | |
| SR = self.codec.config.sample_rate | |
| HZ = self.codec.hz | |
| max_gen_len = int(HZ * max_secs) | |
| t1 = time.perf_counter() | |
| if len(text) < 1400: | |
| codes = generate( | |
| self, text, prompt_codes, temperature, top_k, top_p, max_gen_len | |
| ) | |
| codes = codes[..., :-10] | |
| audio = self.codec.from_indices(codes) | |
| print("RTF", (audio.numel()/SR)/(time.perf_counter() - t1)) | |
| return audio.cpu() | |
| # Otherwise we have to do some clever chaining! | |
| orig_codes = prompt_codes | |
| lines = text.split("\n") | |
| audios = [] | |
| prev_codes = prompt_codes | |
| prev_text = "" | |
| for i, line in enumerate(lines): | |
| run = True | |
| while run: | |
| current_text = prev_text + "\n" + line if prev_text else line | |
| current_text = current_text.strip() | |
| current_text = current_text.replace("...", "") | |
| current_text = current_text + " [pause]" | |
| # Calculate max length based on text length | |
| maxlen = int(HZ * int(60 * len(current_text) / 500)) | |
| try: | |
| print("rendering", current_text) | |
| codes = generate( | |
| self, | |
| current_text, | |
| prompt_codes=prev_codes, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| max_gen_len=maxlen, | |
| ) | |
| codes = codes[..., :-10] | |
| paudio = self.codec.from_indices(codes) | |
| prev_text = line | |
| prev_codes = codes | |
| audios.append(paudio) | |
| except KeyboardInterrupt: | |
| break | |
| except RuntimeError as e: | |
| prev_codes = orig_codes | |
| prev_text = "" | |
| print(e) | |
| return torch.cat(audios, dim=-1) | |