miko-tts / orpheus_engine.py
zirobtc's picture
Uploading DART folder into model repo
478eeb0 verified
# tts_engine.py
import wave
import asyncio
import uuid # Import uuid to generate unique IDs
import threading
import queue
import base64
from io import BytesIO
from util import load_yaml
from orpheus_tts.engine_class import OrpheusModel
from vllm.outputs import RequestOutput
from vllm import SamplingParams
# --- Background loop to keep vLLM stable across requests ---
# This class is correct and does not need changes.
class BackgroundEventLoop:
def __init__(self):
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self):
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
def run_generator(self, async_gen):
q = queue.Queue()
sentinel = object()
async def producer():
try:
async for item in async_gen:
q.put(item)
except Exception as e:
q.put(e)
finally:
q.put(sentinel)
asyncio.run_coroutine_threadsafe(producer(), self._loop)
while True:
item = q.get()
if item is sentinel:
break
if isinstance(item, Exception):
raise item
yield item
# --- Patched Orpheus model using background loop ---
tts_event_loop = BackgroundEventLoop()
class PatchedOrpheusModel(OrpheusModel):
# THE FIX IS HERE
def generate_tokens_sync(self, prompt, voice=None, request_id=None, temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3):
# If no request_id is provided, generate a new unique one.
# This solves the "id already running" error.
if request_id is None:
request_id = str(uuid.uuid4())
prompt_string = self._format_prompt(prompt, voice)
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
repetition_penalty=repetition_penalty,
)
async_gen = self.engine.generate(
prompt=prompt_string,
sampling_params=sampling_params,
request_id=request_id # Use the unique ID
)
for result in tts_event_loop.run_generator(async_gen):
if not isinstance(result, RequestOutput):
raise TypeError(f"Unexpected result type: {type(result)}")
yield result.outputs[0].text
# --- Persistent global model ---
# This section is correct and does not need changes.
model = None
def setup_model():
global model
if model is None:
print("Loading TTS model...")
config = load_yaml()
model = PatchedOrpheusModel(model_name=config["tts"]["model_name"])
print("✅ Model loaded and ready.")
def synthesize_for_scene(
prompt: str,
voice: str = "miko",
temperature: float = 0.6,
top_p: float = 0.9,
repetition_penalty: float = 1.3,
max_tokens: int = 1200,
):
global model
# This function now works correctly in parallel because each call
# will trigger a unique request_id in the PatchedOrpheusModel above.
chunks = bytearray()
for chunk in model.generate_speech(
prompt=prompt,
voice=voice,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
):
chunks.extend(chunk)
buffer = BytesIO()
with wave.open(buffer, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(24000)
wf.writeframes(chunks)
audio_bytes = buffer.getvalue()
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
return audio_bytes, audio_base64