Spaces:
Running
Running
| from .model import KModel | |
| from dataclasses import dataclass | |
| from huggingface_hub import hf_hub_download | |
| from loguru import logger | |
| from misaki import en, espeak | |
| from typing import Callable, Generator, List, Optional, Tuple, Union | |
| import re | |
| import torch | |
| import os | |
| ALIASES = { | |
| 'en-us': 'a', | |
| 'en-gb': 'b', | |
| 'es': 'e', | |
| 'fr-fr': 'f', | |
| 'hi': 'h', | |
| 'it': 'i', | |
| 'pt-br': 'p', | |
| 'ja': 'j', | |
| 'zh': 'z', | |
| } | |
| LANG_CODES = dict( | |
| # pip install misaki[en] | |
| a='American English', | |
| b='British English', | |
| # espeak-ng | |
| e='es', | |
| f='fr-fr', | |
| h='hi', | |
| i='it', | |
| p='pt-br', | |
| # pip install misaki[ja] | |
| j='Japanese', | |
| # pip install misaki[zh] | |
| z='Mandarin Chinese', | |
| ) | |
| class KPipeline: | |
| ''' | |
| KPipeline is a language-aware support class with 2 main responsibilities: | |
| 1. Perform language-specific G2P, mapping (and chunking) text -> phonemes | |
| 2. Manage and store voices, lazily downloaded from HF if needed | |
| You are expected to have one KPipeline per language. If you have multiple | |
| KPipelines, you should reuse one KModel instance across all of them. | |
| KPipeline is designed to work with a KModel, but this is not required. | |
| There are 2 ways to pass an existing model into a pipeline: | |
| 1. On init: us_pipeline = KPipeline(lang_code='a', model=model) | |
| 2. On call: us_pipeline(text, voice, model=model) | |
| By default, KPipeline will automatically initialize its own KModel. To | |
| suppress this, construct a "quiet" KPipeline with model=False. | |
| A "quiet" KPipeline yields (graphemes, phonemes, None) without generating | |
| any audio. You can use this to phonemize and chunk your text in advance. | |
| A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio). | |
| ''' | |
| def __init__( | |
| self, | |
| lang_code: str, | |
| repo_id: Optional[str] = None, | |
| model: Union[KModel, bool] = True, | |
| trf: bool = False, | |
| en_callable: Optional[Callable[[str], str]] = None, | |
| device: Optional[str] = None | |
| ): | |
| """Initialize a KPipeline. | |
| Args: | |
| lang_code: Language code for G2P processing | |
| model: KModel instance, True to create new model, False for no model | |
| trf: Whether to use transformer-based G2P | |
| device: Override default device selection ('cuda' or 'cpu', or None for auto) | |
| If None, will auto-select cuda if available | |
| If 'cuda' and not available, will explicitly raise an error | |
| """ | |
| if repo_id is None: | |
| repo_id = 'hexgrad/Kokoro-82M' | |
| print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.") | |
| config=None | |
| else: | |
| config = os.path.join(repo_id, 'config.json') | |
| self.repo_id = repo_id | |
| lang_code = lang_code.lower() | |
| lang_code = ALIASES.get(lang_code, lang_code) | |
| assert lang_code in LANG_CODES, (lang_code, LANG_CODES) | |
| self.lang_code = lang_code | |
| self.model = None | |
| if isinstance(model, KModel): | |
| self.model = model | |
| elif model: | |
| if device == 'cuda' and not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA requested but not available") | |
| if device == 'mps' and not torch.backends.mps.is_available(): | |
| raise RuntimeError("MPS requested but not available") | |
| if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1': | |
| raise RuntimeError("MPS requested but fallback not enabled") | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| device = 'cuda' | |
| elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available(): | |
| device = 'mps' | |
| else: | |
| device = 'cpu' | |
| try: | |
| self.model = KModel(repo_id=repo_id, config=config).to(device).eval() | |
| except RuntimeError as e: | |
| if device == 'cuda': | |
| raise RuntimeError(f"""Failed to initialize model on CUDA: {e}. | |
| Try setting device='cpu' or check CUDA installation.""") | |
| raise | |
| self.voices = {} | |
| if lang_code in 'ab': | |
| try: | |
| fallback = espeak.EspeakFallback(british=lang_code=='b') | |
| except Exception as e: | |
| logger.warning("EspeakFallback not Enabled: OOD words will be skipped") | |
| logger.warning({str(e)}) | |
| fallback = None | |
| self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='') | |
| elif lang_code == 'j': | |
| try: | |
| from misaki import ja | |
| self.g2p = ja.JAG2P() | |
| except ImportError: | |
| logger.error("You need to `pip install misaki[ja]` to use lang_code='j'") | |
| raise | |
| elif lang_code == 'z': | |
| try: | |
| from misaki import zh | |
| self.g2p = zh.ZHG2P( | |
| version=None if repo_id.endswith('/Kokoro-82M') else '1.1', | |
| en_callable=en_callable | |
| ) | |
| except ImportError: | |
| logger.error("You need to `pip install misaki[zh]` to use lang_code='z'") | |
| raise | |
| else: | |
| language = LANG_CODES[lang_code] | |
| logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") | |
| self.g2p = espeak.EspeakG2P(language=language) | |
| def load_single_voice(self, voice: str): | |
| if voice in self.voices: | |
| return self.voices[voice] | |
| if voice.endswith('.pt'): | |
| f = voice | |
| else: | |
| f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt') | |
| if not voice.startswith(self.lang_code): | |
| v = LANG_CODES.get(voice, voice) | |
| p = LANG_CODES.get(self.lang_code, self.lang_code) | |
| logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.') | |
| pack = torch.load(f, weights_only=True) | |
| self.voices[voice] = pack | |
| return pack | |
| """ | |
| load_voice is a helper function that lazily downloads and loads a voice: | |
| Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica'). | |
| If multiple voices are requested, they are averaged. | |
| Delimiter is optional and defaults to ','. | |
| """ | |
| def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor: | |
| if isinstance(voice, torch.FloatTensor): | |
| return voice | |
| if voice in self.voices: | |
| return self.voices[voice] | |
| logger.debug(f"Loading voice: {voice}") | |
| packs = [self.load_single_voice(v) for v in voice.split(delimiter)] | |
| if len(packs) == 1: | |
| return packs[0] | |
| self.voices[voice] = torch.mean(torch.stack(packs), dim=0) | |
| return self.voices[voice] | |
| def tokens_to_ps(tokens: List[en.MToken]) -> str: | |
| return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip() | |
| def waterfall_last( | |
| tokens: List[en.MToken], | |
| next_count: int, | |
| waterfall: List[str] = ['!.?…', ':;', ',—'], | |
| bumps: List[str] = [')', '”'] | |
| ) -> int: | |
| for w in waterfall: | |
| z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None) | |
| if z is None: | |
| continue | |
| z += 1 | |
| if z < len(tokens) and tokens[z].phonemes in bumps: | |
| z += 1 | |
| if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510: | |
| return z | |
| return len(tokens) | |
| def tokens_to_text(tokens: List[en.MToken]) -> str: | |
| return ''.join(t.text + t.whitespace for t in tokens).strip() | |
| def en_tokenize( | |
| self, | |
| tokens: List[en.MToken] | |
| ) -> Generator[Tuple[str, str, List[en.MToken]], None, None]: | |
| tks = [] | |
| pcount = 0 | |
| for t in tokens: | |
| # American English: ɾ => T | |
| t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T') | |
| next_ps = t.phonemes + (' ' if t.whitespace else '') | |
| next_pcount = pcount + len(next_ps.rstrip()) | |
| if next_pcount > 510: | |
| z = KPipeline.waterfall_last(tks, next_pcount) | |
| text = KPipeline.tokens_to_text(tks[:z]) | |
| logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'") | |
| ps = KPipeline.tokens_to_ps(tks[:z]) | |
| yield text, ps, tks[:z] | |
| tks = tks[z:] | |
| pcount = len(KPipeline.tokens_to_ps(tks)) | |
| if not tks: | |
| next_ps = next_ps.lstrip() | |
| tks.append(t) | |
| pcount += len(next_ps) | |
| if tks: | |
| text = KPipeline.tokens_to_text(tks) | |
| ps = KPipeline.tokens_to_ps(tks) | |
| yield ''.join(text).strip(), ''.join(ps).strip(), tks | |
| def infer( | |
| model: KModel, | |
| ps: str, | |
| pack: torch.FloatTensor, | |
| speed: Union[float, Callable[[int], float]] = 1 | |
| ) -> KModel.Output: | |
| if callable(speed): | |
| speed = speed(len(ps)) | |
| return model(ps, pack[len(ps)-1], speed, return_output=True) | |
| def generate_from_tokens( | |
| self, | |
| tokens: Union[str, List[en.MToken]], | |
| voice: str, | |
| speed: float = 1, | |
| model: Optional[KModel] = None | |
| ) -> Generator['KPipeline.Result', None, None]: | |
| """Generate audio from either raw phonemes or pre-processed tokens. | |
| Args: | |
| tokens: Either a phoneme string or list of pre-processed MTokens | |
| voice: The voice to use for synthesis | |
| speed: Speech speed modifier (default: 1) | |
| model: Optional KModel instance (uses pipeline's model if not provided) | |
| Yields: | |
| KPipeline.Result containing the input tokens and generated audio | |
| Raises: | |
| ValueError: If no voice is provided or token sequence exceeds model limits | |
| """ | |
| model = model or self.model | |
| if model and voice is None: | |
| raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")') | |
| pack = self.load_voice(voice).to(model.device) if model else None | |
| # Handle raw phoneme string | |
| if isinstance(tokens, str): | |
| logger.debug("Processing phonemes from raw string") | |
| if len(tokens) > 510: | |
| raise ValueError(f'Phoneme string too long: {len(tokens)} > 510') | |
| output = KPipeline.infer(model, tokens, pack, speed) if model else None | |
| yield self.Result(graphemes='', phonemes=tokens, output=output) | |
| return | |
| logger.debug("Processing MTokens") | |
| # Handle pre-processed tokens | |
| for gs, ps, tks in self.en_tokenize(tokens): | |
| if not ps: | |
| continue | |
| elif len(ps) > 510: | |
| logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") | |
| logger.warning("Truncating to 510 characters") | |
| ps = ps[:510] | |
| output = KPipeline.infer(model, ps, pack, speed) if model else None | |
| if output is not None and output.pred_dur is not None: | |
| KPipeline.join_timestamps(tks, output.pred_dur) | |
| yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output) | |
| def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor): | |
| # Multiply by 600 to go from pred_dur frames to sample_rate 24000 | |
| # Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds | |
| # We will count nice round half-frames, so the divisor is 80 | |
| MAGIC_DIVISOR = 80 | |
| if not tokens or len(pred_dur) < 3: | |
| # We expect at least 3: <bos>, token, <eos> | |
| return | |
| # We track 2 counts, measured in half-frames: (left, right) | |
| # This way we can cut space characters in half | |
| # TODO: Is -3 an appropriate offset? | |
| left = right = 2 * max(0, pred_dur[0].item() - 3) | |
| # Updates: | |
| # left = right + (2 * token_dur) + space_dur | |
| # right = left + space_dur | |
| i = 1 | |
| for t in tokens: | |
| if i >= len(pred_dur)-1: | |
| break | |
| if not t.phonemes: | |
| if t.whitespace: | |
| i += 1 | |
| left = right + pred_dur[i].item() | |
| right = left + pred_dur[i].item() | |
| i += 1 | |
| continue | |
| j = i + len(t.phonemes) | |
| if j >= len(pred_dur): | |
| break | |
| t.start_ts = left / MAGIC_DIVISOR | |
| token_dur = pred_dur[i: j].sum().item() | |
| space_dur = pred_dur[j].item() if t.whitespace else 0 | |
| left = right + (2 * token_dur) + space_dur | |
| t.end_ts = left / MAGIC_DIVISOR | |
| right = left + space_dur | |
| i = j + (1 if t.whitespace else 0) | |
| class Result: | |
| graphemes: str | |
| phonemes: str | |
| tokens: Optional[List[en.MToken]] = None | |
| output: Optional[KModel.Output] = None | |
| text_index: Optional[int] = None | |
| def audio(self) -> Optional[torch.FloatTensor]: | |
| return None if self.output is None else self.output.audio | |
| def pred_dur(self) -> Optional[torch.LongTensor]: | |
| return None if self.output is None else self.output.pred_dur | |
| ### MARK: BEGIN BACKWARD COMPAT ### | |
| def __iter__(self): | |
| yield self.graphemes | |
| yield self.phonemes | |
| yield self.audio | |
| def __getitem__(self, index): | |
| return [self.graphemes, self.phonemes, self.audio][index] | |
| def __len__(self): | |
| return 3 | |
| #### MARK: END BACKWARD COMPAT #### | |
| def __call__( | |
| self, | |
| text: Union[str, List[str]], | |
| voice: Optional[str] = None, | |
| speed: Union[float, Callable[[int], float]] = 1, | |
| split_pattern: Optional[str] = r'\n+', | |
| model: Optional[KModel] = None | |
| ) -> Generator['KPipeline.Result', None, None]: | |
| model = model or self.model | |
| if model and voice is None: | |
| raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")') | |
| pack = self.load_voice(voice).to(model.device) if model else None | |
| # Convert input to list of segments | |
| if isinstance(text, str): | |
| text = re.split(split_pattern, text.strip()) if split_pattern else [text] | |
| # Process each segment | |
| for graphemes_index, graphemes in enumerate(text): | |
| if not graphemes.strip(): # Skip empty segments | |
| continue | |
| # English processing (unchanged) | |
| if self.lang_code in 'ab': | |
| logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}") | |
| _, tokens = self.g2p(graphemes) | |
| for gs, ps, tks in self.en_tokenize(tokens): | |
| if not ps: | |
| continue | |
| elif len(ps) > 510: | |
| logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") | |
| ps = ps[:510] | |
| output = KPipeline.infer(model, ps, pack, speed) if model else None | |
| if output is not None and output.pred_dur is not None: | |
| KPipeline.join_timestamps(tks, output.pred_dur) | |
| yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index) | |
| # Non-English processing with chunking | |
| else: | |
| # Split long text into smaller chunks (roughly 400 characters each) | |
| # Using sentence boundaries when possible | |
| chunk_size = 400 | |
| chunks = [] | |
| # Try to split on sentence boundaries first | |
| sentences = re.split(r'([.!?]+)', graphemes) | |
| current_chunk = "" | |
| for i in range(0, len(sentences), 2): | |
| sentence = sentences[i] | |
| # Add the punctuation back if it exists | |
| if i + 1 < len(sentences): | |
| sentence += sentences[i + 1] | |
| if len(current_chunk) + len(sentence) <= chunk_size: | |
| current_chunk += sentence | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| # If no chunks were created (no sentence boundaries), fall back to character-based chunking | |
| if not chunks: | |
| chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)] | |
| # Process each chunk | |
| for chunk in chunks: | |
| if not chunk.strip(): | |
| continue | |
| ps, _ = self.g2p(chunk) | |
| if not ps: | |
| continue | |
| elif len(ps) > 510: | |
| logger.warning(f'Truncating len(ps) == {len(ps)} > 510') | |
| ps = ps[:510] | |
| output = KPipeline.infer(model, ps, pack, speed) if model else None | |
| yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index) | |