|
|
|
|
|
|
|
|
import wave |
|
|
import asyncio |
|
|
import uuid |
|
|
import threading |
|
|
import queue |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from util import load_yaml |
|
|
|
|
|
from orpheus_tts.engine_class import OrpheusModel |
|
|
from vllm.outputs import RequestOutput |
|
|
from vllm import SamplingParams |
|
|
|
|
|
|
|
|
|
|
|
class BackgroundEventLoop: |
|
|
def __init__(self): |
|
|
self._loop = asyncio.new_event_loop() |
|
|
self._thread = threading.Thread(target=self._run_loop, daemon=True) |
|
|
self._thread.start() |
|
|
|
|
|
def _run_loop(self): |
|
|
asyncio.set_event_loop(self._loop) |
|
|
self._loop.run_forever() |
|
|
|
|
|
def run_generator(self, async_gen): |
|
|
q = queue.Queue() |
|
|
sentinel = object() |
|
|
|
|
|
async def producer(): |
|
|
try: |
|
|
async for item in async_gen: |
|
|
q.put(item) |
|
|
except Exception as e: |
|
|
q.put(e) |
|
|
finally: |
|
|
q.put(sentinel) |
|
|
|
|
|
asyncio.run_coroutine_threadsafe(producer(), self._loop) |
|
|
|
|
|
while True: |
|
|
item = q.get() |
|
|
if item is sentinel: |
|
|
break |
|
|
if isinstance(item, Exception): |
|
|
raise item |
|
|
yield item |
|
|
|
|
|
|
|
|
tts_event_loop = BackgroundEventLoop() |
|
|
|
|
|
class PatchedOrpheusModel(OrpheusModel): |
|
|
|
|
|
def generate_tokens_sync(self, prompt, voice=None, request_id=None, temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3): |
|
|
|
|
|
|
|
|
|
|
|
if request_id is None: |
|
|
request_id = str(uuid.uuid4()) |
|
|
|
|
|
prompt_string = self._format_prompt(prompt, voice) |
|
|
sampling_params = SamplingParams( |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_tokens=max_tokens, |
|
|
stop_token_ids=stop_token_ids, |
|
|
repetition_penalty=repetition_penalty, |
|
|
) |
|
|
async_gen = self.engine.generate( |
|
|
prompt=prompt_string, |
|
|
sampling_params=sampling_params, |
|
|
request_id=request_id |
|
|
) |
|
|
for result in tts_event_loop.run_generator(async_gen): |
|
|
if not isinstance(result, RequestOutput): |
|
|
raise TypeError(f"Unexpected result type: {type(result)}") |
|
|
yield result.outputs[0].text |
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
|
|
|
|
|
|
def setup_model(): |
|
|
global model |
|
|
if model is None: |
|
|
print("Loading TTS model...") |
|
|
config = load_yaml() |
|
|
model = PatchedOrpheusModel(model_name=config["tts"]["model_name"]) |
|
|
print("✅ Model loaded and ready.") |
|
|
|
|
|
def synthesize_for_scene( |
|
|
prompt: str, |
|
|
voice: str = "miko", |
|
|
temperature: float = 0.6, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.3, |
|
|
max_tokens: int = 1200, |
|
|
): |
|
|
global model |
|
|
|
|
|
|
|
|
|
|
|
chunks = bytearray() |
|
|
for chunk in model.generate_speech( |
|
|
prompt=prompt, |
|
|
voice=voice, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_tokens=max_tokens, |
|
|
repetition_penalty=repetition_penalty, |
|
|
): |
|
|
chunks.extend(chunk) |
|
|
|
|
|
buffer = BytesIO() |
|
|
with wave.open(buffer, "wb") as wf: |
|
|
wf.setnchannels(1) |
|
|
wf.setsampwidth(2) |
|
|
wf.setframerate(24000) |
|
|
wf.writeframes(chunks) |
|
|
|
|
|
audio_bytes = buffer.getvalue() |
|
|
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") |
|
|
return audio_bytes, audio_base64 |