from textwrap import indent import spaces import gradio as gr import os import json import torch import soundfile as sf import numpy as np from pathlib import Path from transformers import AutoModel from utils.llm_xiapi import get_time_info device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModel.from_pretrained("rookie9/PicoAudio2", trust_remote_code=True).to(device) def is_tdc_format_valid(tdc_str): try: for event_onset in tdc_str.split('--'): event, instance = event_onset.split('__') for start_end in instance.split('_'): start, end = start_end.split('-') return True except Exception: return False # Event handlers def clear_json(): return "{}" def convert_tdc_to_tcc(b_str): events = b_str.split('--') names = [] for e in events: if '__' not in e: continue name, _ = e.split('__', 1) name = name.replace('_', ' ') names.append(name) return ' and '.join(names) def json_to_tdc(json_str): """Convert JSON data to final internal format: event1__s1-e1_s2-e2--event2__s1-e1""" if not json_str or not json_str.strip(): return "" try: events_dict = json.loads(json_str) except json.JSONDecodeError as e: return f"Error: Invalid JSON format - {str(e)}" if not isinstance(events_dict, dict): return "Error: JSON should be a dictionary/object" # Build final internal format result = [] for event_name, time_segments in events_dict.items(): if not isinstance(time_segments, list): continue valid_segments = [] for segment in time_segments: if isinstance(segment, list) and len(segment) >= 2: start, end = segment[0], segment[1] if start is not None and end is not None: valid_segments.append(f"{start}-{end}") if valid_segments: # Convert to internal format: event_name__time1_time2 event_name_clean = event_name.strip().replace(' ', '_') times_str = '_'.join(valid_segments) result.append(f"{event_name_clean}__{times_str}") return '--'.join(result) def generate_audio(tcc, json_data, length, time_ctrl): tdc = json_to_tdc(json_data) return infer(tcc, tdc, length, time_ctrl) @spaces.GPU(duration=60) def infer(input_text, input_onset, input_length, time_control): if not input_text and input_onset and is_tdc_format_valid(input_onset): input_text = convert_tdc_to_tcc(input_onset) elif not input_text: input_text = "a dog barks" if input_onset and not is_tdc_format_valid(input_onset): input_onset = "random" if time_control: if not input_onset or not input_length: input_json = json.loads(get_time_info(input_text)) input_onset, input_length = input_json["onset"], input_json["length"] else: input_onset = input_onset if input_onset else "random" input_length = input_length if input_length else "10.0" content = { "caption": input_text, "onset": input_onset, "length": input_length } with torch.no_grad(): waveform = model(content) output_wav = "output.wav" sf.write( output_wav, waveform[0, 0].cpu().numpy(), samplerate=24000, ) return output_wav, str(input_onset) with gr.Blocks(title="PicoAudio2 Online Inference", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎵 PicoAudio2 Online Inference") gr.Markdown("""
{
"a dog barks": [
[3.0, 4.0],
[6.0, 7.0]
],
"a man speaks": [
[5.0, 6.0]
]
}
It means the event `a dog barks` happens from 3.0 to 4.0 seconds, and the event `a man speaks` happens from 5.0 to 6.0 seconds.