|
|
|
|
|
import socket |
|
|
import struct |
|
|
import json |
|
|
import msgpack |
|
|
|
|
|
import zlib |
|
|
import re |
|
|
from util import calculate_duration_from_bytes, update_motion_generator_duration,load_yaml |
|
|
|
|
|
from typing import Dict, Any, List, Tuple |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from aligner import align_words, setup_aligner |
|
|
|
|
|
|
|
|
|
|
|
config = load_yaml() |
|
|
HOST = config["HOST"] |
|
|
PORT = config["PORT"] |
|
|
|
|
|
|
|
|
print(f"Connecting to {HOST}:{PORT}") |
|
|
MAGIC = 0x2333 |
|
|
|
|
|
def patch_socket_keepalive(sock: socket.socket) -> None: |
|
|
"""Set keepalive + long timeout to prevent halts on idle.""" |
|
|
sock.settimeout(None) |
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) |
|
|
|
|
|
|
|
|
if hasattr(socket, 'TCP_KEEPIDLE'): |
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 10) |
|
|
if hasattr(socket, 'TCP_KEEPINTVL'): |
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) |
|
|
if hasattr(socket, 'TCP_KEEPCNT'): |
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) |
|
|
|
|
|
def recv_exact(sock: socket.socket, n: int) -> bytes: |
|
|
buf = bytearray() |
|
|
while len(buf) < n: |
|
|
chunk = sock.recv(n - len(buf)) |
|
|
if not chunk: |
|
|
raise EOFError("Connection closed prematurely") |
|
|
buf.extend(chunk) |
|
|
return bytes(buf) |
|
|
|
|
|
def send_frame(sock: socket.socket, event: str, payload: Any) -> None: |
|
|
|
|
|
raw = msgpack.packb({"event": event, "payload": payload}, use_bin_type=True) |
|
|
comp = zlib.compress(raw) |
|
|
|
|
|
|
|
|
header = struct.pack("<III", MAGIC, len(raw), len(comp)) |
|
|
sock.sendall(header + comp) |
|
|
|
|
|
MAGIC_JSON = 0xDEADBEEF |
|
|
|
|
|
def recv_frame(sock: socket.socket) -> Dict[str, Any]: |
|
|
header = recv_exact(sock, 12) |
|
|
magic, raw_len, comp_len = struct.unpack("<III", header) |
|
|
if magic != MAGIC_JSON: |
|
|
raise RuntimeError("Bad magic number – protocol mismatch") |
|
|
comp_bytes = recv_exact(sock, comp_len) |
|
|
raw_bytes = zlib.decompress(comp_bytes) |
|
|
return json.loads(raw_bytes.decode()) |
|
|
|
|
|
def strip_tags(text: str) -> str: |
|
|
no_tags = re.sub(r"<[^>]+>", "", text) |
|
|
words = re.findall(r"\b[a-zA-Z']+\b", no_tags) |
|
|
return " ".join(words).strip() |
|
|
|
|
|
def align_audio(audio_bytes: bytes, scene_text: str) -> Tuple: |
|
|
""" |
|
|
Helper function that runs both TTS and alignment for a single scene. |
|
|
This entire function will be executed in a parallel thread. |
|
|
""" |
|
|
|
|
|
""" |
|
|
dummy_path = "output_0.wav" |
|
|
if not os.path.exists(dummy_path): |
|
|
raise FileNotFoundError("Dummy file 'output_0.wav' not found.") |
|
|
|
|
|
# Read dummy WAV file as bytes |
|
|
with open(dummy_path, "rb") as f: |
|
|
audio_bytes = f.read() |
|
|
|
|
|
# Strip tags from text (optional) |
|
|
spoken_text = strip_tags(scene_text) |
|
|
""" |
|
|
|
|
|
alignment = align_words(audio_bytes, scene_text) |
|
|
|
|
|
return alignment |
|
|
|
|
|
def generate_audio(scene: Dict[str, Any]) -> Tuple[bytes, str]: |
|
|
""" |
|
|
audio_bytes, audio_base64 = synthesize_for_scene( |
|
|
prompt=scene["txt"], |
|
|
voice=scene.get("voice", "miko"), |
|
|
temperature=scene.get("temperature", 0.6), |
|
|
top_p=scene.get("top_p", 0.8), |
|
|
repetition_penalty=scene.get("repetition_penalty", 1.3), |
|
|
max_tokens=scene.get("max_tokens", 1200), |
|
|
)""" |
|
|
|
|
|
|
|
|
""" |
|
|
dummy_path = "output_0.wav" |
|
|
if not os.path.exists(dummy_path): |
|
|
raise FileNotFoundError("Dummy file 'output_0.wav' not found.") |
|
|
|
|
|
with open(dummy_path, "rb") as f: |
|
|
audio_bytes = f.read() |
|
|
|
|
|
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")""" |
|
|
return audio_bytes, audio_base64 |
|
|
|
|
|
def handle_connection(sock: socket.socket) -> None: |
|
|
send_frame(sock, "hello", {"role": "tts"}) |
|
|
print("→ hello (role=tts) sent") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
frame = recv_frame(sock) |
|
|
except EOFError: |
|
|
print('[ "Connection closed by the other side" ]') |
|
|
break |
|
|
|
|
|
event = frame.get("event") |
|
|
payload = frame.get("payload") |
|
|
|
|
|
if event != "generate-voice": |
|
|
print(f"⚠️ unknown event {event}, ignored") |
|
|
continue |
|
|
|
|
|
scenes: List[dict] = payload.get("scenes", []) |
|
|
|
|
|
|
|
|
|
|
|
generated_audio_data = [] |
|
|
print("") |
|
|
print("--- Generating Audios Thread ---") |
|
|
with ThreadPoolExecutor(max_workers=10) as executor: |
|
|
|
|
|
future_to_scene = { |
|
|
executor.submit(generate_audio, scene): scene |
|
|
for scene in scenes if scene.get("txt") |
|
|
} |
|
|
|
|
|
|
|
|
for future in as_completed(future_to_scene): |
|
|
scene = future_to_scene[future] |
|
|
try: |
|
|
scene_id = scene["sceneId"] |
|
|
motion_index = scene.get("motionIndex", 0) |
|
|
|
|
|
audio_bytes, audio_base64 = future.result() |
|
|
print("") |
|
|
print(f'[ "Generated Audio {scene_id}, Motion: {motion_index}" ]') |
|
|
|
|
|
|
|
|
duration = calculate_duration_from_bytes(audio_bytes) |
|
|
|
|
|
|
|
|
if duration > 0: |
|
|
update_motion_generator_duration(scene["sceneId"], scene.get("motionIndex", 0), duration) |
|
|
|
|
|
|
|
|
generated_audio_data.append({ |
|
|
"scene": scene, |
|
|
"audio_bytes": audio_bytes, |
|
|
"audio_base64": audio_base64 |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error during audio generation for {scene['sceneId']}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
response_by_scene: Dict[str, Any] = {} |
|
|
print("") |
|
|
print("--- Word Alignments Thread ---") |
|
|
with ThreadPoolExecutor(max_workers=10) as executor: |
|
|
|
|
|
|
|
|
future_to_data = { |
|
|
executor.submit(align_words, data["audio_bytes"], strip_tags(data["scene"]["txt"])): data |
|
|
for data in generated_audio_data |
|
|
} |
|
|
|
|
|
|
|
|
for future in as_completed(future_to_data): |
|
|
data = future_to_data[future] |
|
|
scene = data["scene"] |
|
|
scene_id = scene["sceneId"] |
|
|
motion_index = scene.get("motionIndex", 0) |
|
|
|
|
|
try: |
|
|
|
|
|
alignment = future.result() |
|
|
print("") |
|
|
print(f'[ "Aligned {scene_id}, Motion: {motion_index}" ]') |
|
|
|
|
|
|
|
|
voice_audio = { |
|
|
"motion": motion_index, |
|
|
"audio_base64": data["audio_base64"], |
|
|
"alignment": alignment, |
|
|
} |
|
|
if scene_id not in response_by_scene: |
|
|
response_by_scene[scene_id] = {"sceneId": scene_id, "audioEvents": []} |
|
|
response_by_scene[scene_id]["audioEvents"].append(voice_audio) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during alignment for scene {scene_id}: {e}") |
|
|
|
|
|
if response_by_scene: |
|
|
send_frame(sock, "voice-generated", list(response_by_scene.values())) |
|
|
print("") |
|
|
print(f"[ ← Audios ({len(response_by_scene)}) sent ]") |
|
|
|
|
|
def main() -> None: |
|
|
|
|
|
|
|
|
|
|
|
setup_aligner() |
|
|
|
|
|
while True: |
|
|
try: |
|
|
with socket.create_connection((HOST, PORT), timeout=60) as sock: |
|
|
patch_socket_keepalive(sock) |
|
|
print(f'["Connected to server at {HOST}:{PORT}"]') |
|
|
handle_connection(sock) |
|
|
except (ConnectionRefusedError, OSError) as e: |
|
|
print(f"Connection error: {e}, retrying in 5s") |
|
|
except Exception as e: |
|
|
print(f"Unhandled error: {e}, reconnecting in 5s") |
|
|
finally: |
|
|
import time |
|
|
time.sleep(5) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|