PicoAudio2 / app.py
wsntxxn's picture
Update app.py
67e83bd verified
from textwrap import indent
import spaces
import gradio as gr
import os
import json
import torch
import soundfile as sf
import numpy as np
from pathlib import Path
from transformers import AutoModel
from utils.llm_xiapi import get_time_info
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained("rookie9/PicoAudio2", trust_remote_code=True).to(device)
def is_tdc_format_valid(tdc_str):
try:
for event_onset in tdc_str.split('--'):
event, instance = event_onset.split('__')
for start_end in instance.split('_'):
start, end = start_end.split('-')
return True
except Exception:
return False
# Event handlers
def clear_json():
return "{}"
def convert_tdc_to_tcc(b_str):
events = b_str.split('--')
names = []
for e in events:
if '__' not in e:
continue
name, _ = e.split('__', 1)
name = name.replace('_', ' ')
names.append(name)
return ' and '.join(names)
def json_to_tdc(json_str):
"""Convert JSON data to final internal format: event1__s1-e1_s2-e2--event2__s1-e1"""
if not json_str or not json_str.strip():
return ""
try:
events_dict = json.loads(json_str)
except json.JSONDecodeError as e:
return f"Error: Invalid JSON format - {str(e)}"
if not isinstance(events_dict, dict):
return "Error: JSON should be a dictionary/object"
# Build final internal format
result = []
for event_name, time_segments in events_dict.items():
if not isinstance(time_segments, list):
continue
valid_segments = []
for segment in time_segments:
if isinstance(segment, list) and len(segment) >= 2:
start, end = segment[0], segment[1]
if start is not None and end is not None:
valid_segments.append(f"{start}-{end}")
if valid_segments:
# Convert to internal format: event_name__time1_time2
event_name_clean = event_name.strip().replace(' ', '_')
times_str = '_'.join(valid_segments)
result.append(f"{event_name_clean}__{times_str}")
return '--'.join(result)
def generate_audio(tcc, json_data, length, time_ctrl):
tdc = json_to_tdc(json_data)
return infer(tcc, tdc, length, time_ctrl)
@spaces.GPU(duration=60)
def infer(input_text, input_onset, input_length, time_control):
if not input_text and input_onset and is_tdc_format_valid(input_onset):
input_text = convert_tdc_to_tcc(input_onset)
elif not input_text:
input_text = "a dog barks"
if input_onset and not is_tdc_format_valid(input_onset):
input_onset = "random"
if time_control:
if not input_onset or not input_length:
input_json = json.loads(get_time_info(input_text))
input_onset, input_length = input_json["onset"], input_json["length"]
else:
input_onset = input_onset if input_onset else "random"
input_length = input_length if input_length else "10.0"
content = {
"caption": input_text,
"onset": input_onset,
"length": input_length
}
with torch.no_grad():
waveform = model(content)
output_wav = "output.wav"
sf.write(
output_wav,
waveform[0, 0].cpu().numpy(),
samplerate=24000,
)
return output_wav, str(input_onset)
with gr.Blocks(title="PicoAudio2 Online Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎡 PicoAudio2 Online Inference")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
## πŸ“– Definition
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 8px; margin: 10px 0;">
**TCC (Temporal Coarse Caption)**
A brief text description for the overall audio scene.
> *Example*: `a dog barks`
**TDC (Temporal Detailed Caption)**
**Event descriptions with timestamps**. It allows precise temporal control over when events happen in the generated audio.
> *Example*: See the format below.
</div>
## πŸ“Š TDC Event Timestamp JSON (Optional)
<div style="background-color: #fffef0; padding: 12px; border-radius: 6px; border-left: 4px solid #ffa500;">
**πŸ’‘ JSON Format:**
- Each **event description** is a key in the JSON object
- The value is an **array of [start, end] timestamp pairs** (in seconds)
**Example:**
<pre style="background-color: #f5f5f5; padding: 8px; border-radius: 4px; margin: 5px 0; font-size: 13px;">
{
"a dog barks": [
[3.0, 4.0],
[6.0, 7.0]
],
"a man speaks": [
[5.0, 6.0]
]
}</pre>
It means the event `a dog barks` happens from 3.0 to 4.0 seconds, and the event `a man speaks` happens from 5.0 to 6.0 seconds.
</div>
---
</div>
""")
with gr.Row():
with gr.Column():
tcc_input = gr.Textbox(
label="🎯 TCC (Temporal Coarse Caption) - Required",
value="a dog barks and a man speaks",
placeholder="e.g., a dog barks and a man speaks",
lines=2
)
event_json = gr.Code(
label="πŸ“‹ TDC (Event Timestamp JSON) - Optional",
value="""{
"a dog barks": [
[3.0, 4.0],
[6.0, 7.0]
],
"a man speaks": [
[8.0, 9.5]
]
}""",
language="json",
lines=10,
interactive=True,
)
clear_btn = gr.Button("πŸ—‘οΈ Clear JSON", size="sm")
gr.Markdown("---")
with gr.Row():
length_input = gr.Textbox(
label="⏱️ Length (seconds)",
value="10.0",
placeholder="e.g., 10.0 (optional but recommended)",
scale=2)
time_control = gr.Checkbox(
label="βš™οΈ Enable Time Control",
value=True,
scale=1,
)
generate_btn = gr.Button("🎡 Generate Audio", variant="primary", size="lg")
with gr.Column():
audio_output = gr.Audio(label="πŸ”Š Generated Audio")
tdc_used = gr.Textbox(label="πŸ“‹ Final TDC Used (Internal Format)", lines=3)
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
---
## πŸ“ Input Requirements
<div style="background-color: #fff5e6; padding: 15px; border-radius: 8px; margin: 10px 0;">
1. **TCC** is **required** for audio generation.
2. **TDC (JSON)** is **optional** for precise temporal control of events.
3. **Length** (in seconds) is optional, but recommended for temporal control. Defaults to 10.0 seconds.
4. **Enable Time Control**: Tick to use TDC and length for precise event timing.
</div>
---
## πŸ’‘ Notes
<div style="background-color: #f0fff0; padding: 15px; border-radius: 8px; margin: 10px 0;">
1. Currently events with overlapped timestamps will not get good results, so we recommend generating audio **without temporal overlaps between events**.
2. If TDC format is incorrect or length is missing, the model will generate audio **without precise temporal control**.
3. For general audio generation without precise timing, you can leave the JSON empty.
4. You may leave TDC blank to let the LLM generate timestamps automatically (subject to API quota).
</div>
</div>
""")
clear_btn.click(
fn=clear_json,
inputs=[],
outputs=[event_json]
)
generate_btn.click(
fn=generate_audio,
inputs=[tcc_input, event_json, length_input, time_control],
outputs=[audio_output, tdc_used]
)
# Examples
gr.Markdown("## 🎯 Quick Examples")
gr.Examples(
examples=[
[
"a dog barks",
"""{
"a dog barks": [
[3.0, 4.0],
[6.0, 7.0]
]
}""",
"8.0",
True
],
[
"door closes then car engine starts",
"""{
"door closes": [
[1.0, 1.5]
],
"car engine starts": [
[2.0, 7.0]
]
}""",
"8.0",
True
],
[
"birds chirping and water flowing",
"""{
"birds chirping": [
[0.0, 5.0]
],
"water flowing": [
[6.0, 9.8]
]
}""",
"10.0",
True
],
[
"heavy rain is falling with thunder",
"",
"15.0",
False
],
[
"a gun shoots twice then a man speaks",
"",
"",
True
]
],
inputs=[tcc_input, event_json, length_input, time_control],
outputs=[audio_output, tdc_used],
fn=generate_audio,
cache_examples=False,
label="Click examples below to try"
)
if __name__ == "__main__":
demo.launch()