File size: 9,110 Bytes
478eeb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5dcccd
 
478eeb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4d3a23
478eeb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# main.py

import socket
import struct
import json
import msgpack

import zlib
import re
from util import calculate_duration_from_bytes, update_motion_generator_duration,load_yaml

import base64
from typing import Dict, Any, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from aligner import align_words, setup_aligner

from orpheus_engine import synthesize_for_scene, setup_model


config = load_yaml()
HOST = config["HOST"]
PORT = config["PORT"]

print(f"Connecting to {HOST}:{PORT}")
MAGIC = 0x2333

MAGIC_JSON = config["MAGIC"]

def patch_socket_keepalive(sock: socket.socket) -> None:
    """Set keepalive + long timeout to prevent halts on idle."""
    sock.settimeout(None)  # Never timeout on recv
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

    # Platform-specific tuning
    if hasattr(socket, 'TCP_KEEPIDLE'):
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 10)
    if hasattr(socket, 'TCP_KEEPINTVL'):
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5)
    if hasattr(socket, 'TCP_KEEPCNT'):
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)

def recv_exact(sock: socket.socket, n: int) -> bytes:
    buf = bytearray()
    while len(buf) < n:
        chunk = sock.recv(n - len(buf))
        if not chunk:
            raise EOFError("Connection closed prematurely")
        buf.extend(chunk)
    return bytes(buf)

def send_frame(sock: socket.socket, event: str, payload: Any) -> None:
    # Use msgpack instead of JSON
    raw = msgpack.packb({"event": event, "payload": payload}, use_bin_type=True)
    comp = zlib.compress(raw)

    # <MAGIC><raw_len><comp_len>
    header = struct.pack("<III", MAGIC, len(raw), len(comp))
    sock.sendall(header + comp)


def recv_frame(sock: socket.socket) -> Dict[str, Any]:
    header = recv_exact(sock, 12)
    magic, raw_len, comp_len = struct.unpack("<III", header)
    if magic != MAGIC_JSON:
        raise RuntimeError("Bad magic number – protocol mismatch")
    comp_bytes = recv_exact(sock, comp_len)
    raw_bytes = zlib.decompress(comp_bytes)
    return json.loads(raw_bytes.decode())

def strip_tags(text: str) -> str:
    no_tags = re.sub(r"<[^>]+>", "", text)
    words = re.findall(r"\b[a-zA-Z']+\b", no_tags)
    return " ".join(words).strip()

def align_audio(audio_bytes: bytes, scene_text: str) -> Tuple:
    """
    Helper function that runs both TTS and alignment for a single scene.
    This entire function will be executed in a parallel thread.
    """

    """
    dummy_path = "output_0.wav"
    if not os.path.exists(dummy_path):
        raise FileNotFoundError("Dummy file 'output_0.wav' not found.")

    # Read dummy WAV file as bytes
    with open(dummy_path, "rb") as f:
        audio_bytes = f.read()

    # Strip tags from text (optional)
    spoken_text = strip_tags(scene_text)
    """
    # Align
    alignment = align_words(audio_bytes, scene_text)

    return alignment

def generate_audio(scene: Dict[str, Any]) -> Tuple[bytes, str]:
    audio_bytes, audio_base64 = synthesize_for_scene(
        prompt=scene["txt"],
        voice=scene.get("voice", "miko"),
        temperature=scene.get("temperature", 0.6),
        top_p=scene.get("top_p", 0.8),
        repetition_penalty=scene.get("repetition_penalty", 1.3),
        max_tokens=scene.get("max_tokens", 1200),
    )

    # In a real scenario, this would call your TTS engine.
    """
    dummy_path = "output_0.wav"
    if not os.path.exists(dummy_path):
        raise FileNotFoundError("Dummy file 'output_0.wav' not found.")

    with open(dummy_path, "rb") as f:
        audio_bytes = f.read()

    audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")"""
    return audio_bytes, audio_base64

def handle_connection(sock: socket.socket) -> None:
    send_frame(sock, "hello", {"role": "tts"})
    print("→ hello (role=tts) sent")

    while True:
        try:
            frame = recv_frame(sock)
        except EOFError:
            print('[ "Connection closed by the other side" ]')
            break
            
        event = frame.get("event")
        payload = frame.get("payload")

        if event != "generate-voice":
            print(f"⚠️  unknown event {event}, ignored")
            continue

        scenes: List[dict] = payload.get("scenes", [])

        # --- STAGE 1: FAST Audio Generation & Duration Notification ---
        # The goal here is to get durations to the motion generator ASAP.
        generated_audio_data = []
        print("")
        print("--- Generating Audios Thread ---")
        with ThreadPoolExecutor(max_workers=10) as executor:
            # Submit all the FAST audio generation tasks
            future_to_scene = {
                executor.submit(generate_audio, scene): scene
                for scene in scenes if scene.get("txt")
            }
            
            # As each FAST audio generation task completes...
            for future in as_completed(future_to_scene):
                scene = future_to_scene[future]
                try:
                    scene_id = scene["sceneId"]
                    motion_index = scene.get("motionIndex", 0)
                    # 1. Get the generated audio
                    audio_bytes, audio_base64 = future.result()
                    print("")
                    print(f'[ "Generated Audio {scene_id}, Motion: {motion_index}" ]')

                    # 2. Calculate duration instantly
                    duration = calculate_duration_from_bytes(audio_bytes)
                    
                    # 3. Notify motion generator IMMEDIATELY
                    if duration > 0:
                        update_motion_generator_duration(scene["sceneId"], scene.get("motionIndex", 0), duration)
                    
                    # 4. Store the results to be used in the next (slow) stage
                    generated_audio_data.append({
                        "scene": scene,
                        "audio_bytes": audio_bytes,
                        "audio_base64": audio_base64
                    })
                except Exception as e:
                    print(f'[ Error during audio generation for {scene["sceneId"]}: {e} ]')

        # --- STAGE 2: SLOW Word Alignment in Parallel ---
        # Now that all notifications are sent, we can perform the slow alignment work.
        response_by_scene: Dict[str, Any] = {}
        print("")
        print("--- Word Alignments Thread ---")
        with ThreadPoolExecutor(max_workers=10) as executor:
            # Use the data from Stage 1 to submit SLOW alignment tasks.
            # We call `align_words` directly (your `align_audio` function is not needed).
            future_to_data = {
                executor.submit(align_words, data["audio_bytes"], strip_tags(data["scene"]["txt"])): data
                for data in generated_audio_data
            }
            
            # As each SLOW alignment task completes...
            for future in as_completed(future_to_data):
                data = future_to_data[future]
                scene = data["scene"]
                scene_id = scene["sceneId"]
                motion_index = scene.get("motionIndex", 0)
                
                try:
                    # 1. Get the alignment result
                    alignment = future.result()
                    print("")
                    print(f'[ "Aligned {scene_id}, Motion: {motion_index}" ]')

                    # 2. Now, build the final response object with all the data
                    voice_audio = {
                        "motion": motion_index,
                        "audio_base64": data["audio_base64"], # From Stage 1
                        "alignment": alignment,               # From Stage 2
                    }
                    if scene_id not in response_by_scene:
                        response_by_scene[scene_id] = {"sceneId": scene_id, "audioEvents": []}
                    response_by_scene[scene_id]["audioEvents"].append(voice_audio)

                except Exception as e:
                    print(f"Error during alignment for scene {scene_id}: {e}")

        if response_by_scene:
            send_frame(sock, "voice-generated", list(response_by_scene.values()))
            print("")
            print(f"[ ← Audios ({len(response_by_scene)}) sent ]")

def main() -> None:
    # Setup the Orpheus TTS model on startup.
    setup_model()
    # Setup the aligner (does nothing for aeneas, but keeps pattern consistent)
    setup_aligner()

    while True:
        try:
            with socket.create_connection((HOST, PORT), timeout=60) as sock:
                patch_socket_keepalive(sock)
                print(f'["Connected to server at {HOST}:{PORT}"]')
                handle_connection(sock)
        except (ConnectionRefusedError, OSError) as e:
            print(f"Connection error: {e}, retrying in 5s")
        except Exception as e:
            print(f"Unhandled error: {e}, reconnecting in 5s")
        finally:
            import time
            time.sleep(5)

if __name__ == "__main__":
    main()