Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from typing import Optional, List, Tuple | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| from .config import GenerationConfig | |
| from .presets import PRESETS | |
| from .audio import normalize_audio, save_wav, shorten_long_silences | |
| from .description import build_advanced_description | |
| logger = logging.getLogger(__name__) | |
| class ParlerVoiceInference: | |
| """ParlerVoice inference engine with enhanced generation options.""" | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| base_model_path: str = "parler-tts/parler-tts-mini-v1.1", | |
| device: Optional[str] = None, | |
| ) -> None: | |
| self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu") | |
| logger.info("Using device: %s", self.device) | |
| logger.info("Loading model from %s", checkpoint_path) | |
| self.model = ParlerTTSForConditionalGeneration.from_pretrained(checkpoint_path).to( | |
| self.device | |
| ) | |
| self.model.eval() | |
| logger.info("Loading tokenizers from %s", base_model_path) | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model_path) | |
| self.description_tokenizer = AutoTokenizer.from_pretrained( | |
| self.model.config.text_encoder._name_or_path | |
| ) | |
| self.sampling_rate = int(self.model.config.sampling_rate) | |
| logger.info("Model loaded. Sampling rate: %d Hz", self.sampling_rate) | |
| def build_advanced_description( | |
| self, | |
| speaker: str, | |
| pace: str = "moderate speed", | |
| noise: str = "very clear", | |
| reverberation: str = "very close-sounding", | |
| monotony: str = "expressive and animated", | |
| pitch: str = "moderate pitch", | |
| emotion: str = "neutral", | |
| tone: str = "neutral", | |
| add_context: bool = True, | |
| ) -> str: | |
| return build_advanced_description( | |
| speaker=speaker, | |
| pace=pace, | |
| noise=noise, | |
| reverberation=reverberation, | |
| monotony=monotony, | |
| pitch=pitch, | |
| emotion=emotion, | |
| tone=tone, | |
| add_context=add_context, | |
| ) | |
| def generate_audio( | |
| self, | |
| prompt: str, | |
| description: str, | |
| config: Optional[GenerationConfig] = None, | |
| output_path: Optional[str] = None, | |
| ) -> Tuple[np.ndarray, str]: | |
| if config is None: | |
| config = GenerationConfig() | |
| input_ids = self.description_tokenizer( | |
| description, return_tensors="pt", padding=True, truncation=True | |
| ).input_ids.to(self.device) | |
| prompt_input_ids = self.tokenizer( | |
| prompt, return_tensors="pt", padding=True, truncation=True | |
| ).input_ids.to(self.device) | |
| with torch.no_grad(): | |
| generation_output = self.model.generate( | |
| input_ids=input_ids, | |
| prompt_input_ids=prompt_input_ids, | |
| temperature=config.temperature, | |
| do_sample=config.do_sample, | |
| top_k=config.top_k, | |
| top_p=config.top_p, | |
| repetition_penalty=config.repetition_penalty, | |
| max_length=config.max_length, | |
| min_length=config.min_length, | |
| num_beams=config.num_beams, | |
| early_stopping=config.early_stopping, | |
| ) | |
| audio_array = generation_output.cpu().numpy().squeeze() | |
| audio_array = normalize_audio(audio_array) | |
| # Post-process: collapse long silences (>2s) down to 800ms | |
| audio_array = shorten_long_silences( | |
| audio_array, | |
| samplerate=self.sampling_rate, | |
| silence_threshold_db=-40.0, | |
| max_silence_ms=800, | |
| collapse_trigger_ms=2000, | |
| ) | |
| if output_path: | |
| save_wav(output_path, audio_array, samplerate=self.sampling_rate) | |
| logger.info("Audio saved to: %s", output_path) | |
| else: | |
| output_path = "output.wav" | |
| return audio_array, output_path | |
| def generate_with_speaker_preset( | |
| self, | |
| prompt: str, | |
| speaker: str, | |
| preset: str = "natural", | |
| config: Optional[GenerationConfig] = None, | |
| output_path: Optional[str] = None, | |
| ) -> Tuple[np.ndarray, str]: | |
| if preset not in PRESETS: | |
| logger.warning("Unknown preset '%s', using 'natural'", preset) | |
| preset = "natural" | |
| preset_config = PRESETS[preset] | |
| description = self.build_advanced_description(speaker=speaker, **preset_config) | |
| return self.generate_audio(prompt, description, config, output_path) | |
| def batch_generate( | |
| self, | |
| prompts: List[str], | |
| descriptions: List[str], | |
| config: Optional[GenerationConfig] = None, | |
| output_dir: str = "outputs", | |
| ) -> List[Tuple[np.ndarray, str]]: | |
| import os | |
| os.makedirs(output_dir, exist_ok=True) | |
| results: List[Tuple[np.ndarray, str]] = [] | |
| for idx, (prompt, description) in enumerate(zip(prompts, descriptions)): | |
| output_path = os.path.join(output_dir, f"output_{idx:03d}.wav") | |
| audio_array, saved_path = self.generate_audio( | |
| prompt, description, config, output_path | |
| ) | |
| results.append((audio_array, saved_path)) | |
| logger.info("Batch generation complete. Generated %d audio files.", len(results)) | |
| return results | |