miko-tts / temp.py
zirobtc's picture
Uploading DART folder into model repo
f4d3a23 verified
import socket
import struct
import json
import msgpack
import zlib
import re
from util import calculate_duration_from_bytes, update_motion_generator_duration,load_yaml
from typing import Dict, Any, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from aligner import align_words, setup_aligner
config = load_yaml()
HOST = config["HOST"]
PORT = config["PORT"]
print(f"Connecting to {HOST}:{PORT}")
MAGIC = 0x2333
def patch_socket_keepalive(sock: socket.socket) -> None:
"""Set keepalive + long timeout to prevent halts on idle."""
sock.settimeout(None) # Never timeout on recv
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# Platform-specific tuning
if hasattr(socket, 'TCP_KEEPIDLE'):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 10)
if hasattr(socket, 'TCP_KEEPINTVL'):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5)
if hasattr(socket, 'TCP_KEEPCNT'):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)
def recv_exact(sock: socket.socket, n: int) -> bytes:
buf = bytearray()
while len(buf) < n:
chunk = sock.recv(n - len(buf))
if not chunk:
raise EOFError("Connection closed prematurely")
buf.extend(chunk)
return bytes(buf)
def send_frame(sock: socket.socket, event: str, payload: Any) -> None:
# Use msgpack instead of JSON
raw = msgpack.packb({"event": event, "payload": payload}, use_bin_type=True)
comp = zlib.compress(raw)
# <MAGIC><raw_len><comp_len>
header = struct.pack("<III", MAGIC, len(raw), len(comp))
sock.sendall(header + comp)
MAGIC_JSON = 0xDEADBEEF
def recv_frame(sock: socket.socket) -> Dict[str, Any]:
header = recv_exact(sock, 12)
magic, raw_len, comp_len = struct.unpack("<III", header)
if magic != MAGIC_JSON:
raise RuntimeError("Bad magic number – protocol mismatch")
comp_bytes = recv_exact(sock, comp_len)
raw_bytes = zlib.decompress(comp_bytes)
return json.loads(raw_bytes.decode())
def strip_tags(text: str) -> str:
no_tags = re.sub(r"<[^>]+>", "", text)
words = re.findall(r"\b[a-zA-Z']+\b", no_tags)
return " ".join(words).strip()
def align_audio(audio_bytes: bytes, scene_text: str) -> Tuple:
"""
Helper function that runs both TTS and alignment for a single scene.
This entire function will be executed in a parallel thread.
"""
"""
dummy_path = "output_0.wav"
if not os.path.exists(dummy_path):
raise FileNotFoundError("Dummy file 'output_0.wav' not found.")
# Read dummy WAV file as bytes
with open(dummy_path, "rb") as f:
audio_bytes = f.read()
# Strip tags from text (optional)
spoken_text = strip_tags(scene_text)
"""
# Align
alignment = align_words(audio_bytes, scene_text)
return alignment
def generate_audio(scene: Dict[str, Any]) -> Tuple[bytes, str]:
"""
audio_bytes, audio_base64 = synthesize_for_scene(
prompt=scene["txt"],
voice=scene.get("voice", "miko"),
temperature=scene.get("temperature", 0.6),
top_p=scene.get("top_p", 0.8),
repetition_penalty=scene.get("repetition_penalty", 1.3),
max_tokens=scene.get("max_tokens", 1200),
)"""
# In a real scenario, this would call your TTS engine.
"""
dummy_path = "output_0.wav"
if not os.path.exists(dummy_path):
raise FileNotFoundError("Dummy file 'output_0.wav' not found.")
with open(dummy_path, "rb") as f:
audio_bytes = f.read()
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")"""
return audio_bytes, audio_base64
def handle_connection(sock: socket.socket) -> None:
send_frame(sock, "hello", {"role": "tts"})
print("→ hello (role=tts) sent")
while True:
try:
frame = recv_frame(sock)
except EOFError:
print('[ "Connection closed by the other side" ]')
break
event = frame.get("event")
payload = frame.get("payload")
if event != "generate-voice":
print(f"⚠️ unknown event {event}, ignored")
continue
scenes: List[dict] = payload.get("scenes", [])
# --- STAGE 1: FAST Audio Generation & Duration Notification ---
# The goal here is to get durations to the motion generator ASAP.
generated_audio_data = []
print("")
print("--- Generating Audios Thread ---")
with ThreadPoolExecutor(max_workers=10) as executor:
# Submit all the FAST audio generation tasks
future_to_scene = {
executor.submit(generate_audio, scene): scene
for scene in scenes if scene.get("txt")
}
# As each FAST audio generation task completes...
for future in as_completed(future_to_scene):
scene = future_to_scene[future]
try:
scene_id = scene["sceneId"]
motion_index = scene.get("motionIndex", 0)
# 1. Get the generated audio
audio_bytes, audio_base64 = future.result()
print("")
print(f'[ "Generated Audio {scene_id}, Motion: {motion_index}" ]')
# 2. Calculate duration instantly
duration = calculate_duration_from_bytes(audio_bytes)
# 3. Notify motion generator IMMEDIATELY
if duration > 0:
update_motion_generator_duration(scene["sceneId"], scene.get("motionIndex", 0), duration)
# 4. Store the results to be used in the next (slow) stage
generated_audio_data.append({
"scene": scene,
"audio_bytes": audio_bytes,
"audio_base64": audio_base64
})
except Exception as e:
print(f"Error during audio generation for {scene['sceneId']}: {e}")
# --- STAGE 2: SLOW Word Alignment in Parallel ---
# Now that all notifications are sent, we can perform the slow alignment work.
response_by_scene: Dict[str, Any] = {}
print("")
print("--- Word Alignments Thread ---")
with ThreadPoolExecutor(max_workers=10) as executor:
# Use the data from Stage 1 to submit SLOW alignment tasks.
# We call `align_words` directly (your `align_audio` function is not needed).
future_to_data = {
executor.submit(align_words, data["audio_bytes"], strip_tags(data["scene"]["txt"])): data
for data in generated_audio_data
}
# As each SLOW alignment task completes...
for future in as_completed(future_to_data):
data = future_to_data[future]
scene = data["scene"]
scene_id = scene["sceneId"]
motion_index = scene.get("motionIndex", 0)
try:
# 1. Get the alignment result
alignment = future.result()
print("")
print(f'[ "Aligned {scene_id}, Motion: {motion_index}" ]')
# 2. Now, build the final response object with all the data
voice_audio = {
"motion": motion_index,
"audio_base64": data["audio_base64"], # From Stage 1
"alignment": alignment, # From Stage 2
}
if scene_id not in response_by_scene:
response_by_scene[scene_id] = {"sceneId": scene_id, "audioEvents": []}
response_by_scene[scene_id]["audioEvents"].append(voice_audio)
except Exception as e:
print(f"Error during alignment for scene {scene_id}: {e}")
if response_by_scene:
send_frame(sock, "voice-generated", list(response_by_scene.values()))
print("")
print(f"[ ← Audios ({len(response_by_scene)}) sent ]")
def main() -> None:
# Setup the Orpheus TTS model on startup.
#setup_model()
# Setup the aligner (does nothing for aeneas, but keeps pattern consistent)
setup_aligner()
while True:
try:
with socket.create_connection((HOST, PORT), timeout=60) as sock:
patch_socket_keepalive(sock)
print(f'["Connected to server at {HOST}:{PORT}"]')
handle_connection(sock)
except (ConnectionRefusedError, OSError) as e:
print(f"Connection error: {e}, retrying in 5s")
except Exception as e:
print(f"Unhandled error: {e}, reconnecting in 5s")
finally:
import time
time.sleep(5)
if __name__ == "__main__":
main()