File size: 3,883 Bytes
478eeb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# tts_engine.py

import wave
import asyncio
import uuid  # Import uuid to generate unique IDs
import threading
import queue
import base64
from io import BytesIO
from util import load_yaml

from orpheus_tts.engine_class import OrpheusModel
from vllm.outputs import RequestOutput
from vllm import SamplingParams

# --- Background loop to keep vLLM stable across requests ---
# This class is correct and does not need changes.
class BackgroundEventLoop:
    def __init__(self):
        self._loop = asyncio.new_event_loop()
        self._thread = threading.Thread(target=self._run_loop, daemon=True)
        self._thread.start()

    def _run_loop(self):
        asyncio.set_event_loop(self._loop)
        self._loop.run_forever()

    def run_generator(self, async_gen):
        q = queue.Queue()
        sentinel = object()

        async def producer():
            try:
                async for item in async_gen:
                    q.put(item)
            except Exception as e:
                q.put(e)
            finally:
                q.put(sentinel)

        asyncio.run_coroutine_threadsafe(producer(), self._loop)

        while True:
            item = q.get()
            if item is sentinel:
                break
            if isinstance(item, Exception):
                raise item
            yield item

# --- Patched Orpheus model using background loop ---
tts_event_loop = BackgroundEventLoop()

class PatchedOrpheusModel(OrpheusModel):
    # THE FIX IS HERE
    def generate_tokens_sync(self, prompt, voice=None, request_id=None, temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3):
        
        # If no request_id is provided, generate a new unique one.
        # This solves the "id already running" error.
        if request_id is None:
            request_id = str(uuid.uuid4())
            
        prompt_string = self._format_prompt(prompt, voice)
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            stop_token_ids=stop_token_ids,
            repetition_penalty=repetition_penalty,
        )
        async_gen = self.engine.generate(
            prompt=prompt_string,
            sampling_params=sampling_params,
            request_id=request_id  # Use the unique ID
        )
        for result in tts_event_loop.run_generator(async_gen):
            if not isinstance(result, RequestOutput):
                raise TypeError(f"Unexpected result type: {type(result)}")
            yield result.outputs[0].text

# --- Persistent global model ---
# This section is correct and does not need changes.
model = None



def setup_model():
    global model
    if model is None:
        print("Loading TTS model...")
        config = load_yaml()
        model = PatchedOrpheusModel(model_name=config["tts"]["model_name"])
        print("✅ Model loaded and ready.")

def synthesize_for_scene(
    prompt: str,
    voice: str = "miko",
    temperature: float = 0.6,
    top_p: float = 0.9,
    repetition_penalty: float = 1.3,
    max_tokens: int = 1200,
):
    global model

    # This function now works correctly in parallel because each call
    # will trigger a unique request_id in the PatchedOrpheusModel above.
    chunks = bytearray()
    for chunk in model.generate_speech(
        prompt=prompt,
        voice=voice,
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        repetition_penalty=repetition_penalty,
    ):
        chunks.extend(chunk)

    buffer = BytesIO()
    with wave.open(buffer, "wb") as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(24000)
        wf.writeframes(chunks)

    audio_bytes = buffer.getvalue()
    audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
    return audio_bytes, audio_base64