import gradio as gr import os from pathlib import Path from vogent_turn.inference import TurnDetector import soundfile as sf import numpy as np def get_detector(): """Lazy load the detector to avoid initialization during import""" detector = TurnDetector(compile_model=False, warmup=False) return detector # Initialize the turn detector detector = get_detector() # Get all preset names from samples folder def get_presets(): samples_dir = Path("samples") if not samples_dir.exists(): return [] presets = [d.name for d in samples_dir.iterdir() if d.is_dir()] return sorted(presets) # Load preset data def load_preset(preset_name): """Load audio and text files from the selected preset""" if not preset_name: return None, "", "" preset_dir = Path("samples") / preset_name # Load audio audio_path = preset_dir / "audio.wav" audio_file = str(audio_path) if audio_path.exists() else None # Load text files prev_text = "" curr_text = "" prev_path = preset_dir / "prev.txt" if prev_path.exists(): prev_text = prev_path.read_text().strip() text_path = preset_dir / "text.txt" if text_path.exists(): curr_text = text_path.read_text().strip() return audio_file, prev_text, curr_text # Run inference def run_inference(audio_file, prev_text, curr_text): """Run turn detection inference""" if audio_file is None: return "Error: No audio file provided" if curr_text is None: return "Error: a transcript of the audio must be provided" if prev_text is None: prev_text = "" try: # Load audio file audio, sr = sf.read(audio_file) # Convert to mono if stereo if len(audio.shape) > 1: audio = audio.mean(axis=1) # Ensure audio is float32 audio = audio.astype(np.float32) # Run prediction with context result = detector.predict( audio, prev_line=prev_text if prev_text else None, curr_line=curr_text if curr_text else None, return_probs=True, sample_rate=sr, ) # Format output is_endpoint = result['is_endpoint'] prob_endpoint = result['prob_endpoint'] prob_continue = result['prob_continue'] output = f""" **Turn Detection Result:** - **Prediction:** {'Turn Complete (Endpoint)' if is_endpoint else 'Turn Incomplete (Continue)'} - **Probability of Endpoint:** {prob_endpoint:.4f} - **Probability of Continue:** {prob_continue:.4f} """ return output except Exception as e: return f"Error during inference: {str(e)}" # Get default preset and load its data presets = get_presets() default_preset = presets[0] if presets else None default_audio, default_prev_text, default_curr_text = load_preset(default_preset) if default_preset else (None, "", "") # Create Gradio interface with gr.Blocks(title="Vogent Turn Demo") as demo: gr.Markdown("# Vogent Turn Demo") gr.Markdown("Multimodal turn detection using audio and text context") gr.Markdown(""" [GitHub](https://github.com/vogent/vogent-turn) | [Technical Report](https://blog.vogent.ai/posts/voturn-80m-state-of-the-art-turn-detection-for-voice-agents) | [Model Weights](http://huggingface.co/vogent/Vogent-Turn-80M) """) with gr.Row(): with gr.Column(): # Preset selector preset_dropdown = gr.Dropdown( choices=presets, label="Preset Samples", info="Select a preset to auto-fill the fields", value=default_preset ) # Input fields prev_text_input = gr.Textbox( label="Previous Line (The previous line spoken in the dialog)", placeholder="Enter the previous line of dialog...", lines=2, value=default_prev_text ) curr_text_input = gr.Textbox( label="Current Line (The transcript of the below audio, omit punctuation)", placeholder="Enter the current line being spoken...", lines=2, value=default_curr_text ) audio_input = gr.Audio( label="Audio", type="filepath", value=default_audio ) # Inference button inference_btn = gr.Button("Run Inference", variant="primary") with gr.Column(): # Output output_text = gr.Markdown(label="Results") # Connect preset dropdown to load function preset_dropdown.change( fn=load_preset, inputs=[preset_dropdown], outputs=[audio_input, prev_text_input, curr_text_input] ) # Connect inference button inference_btn.click( fn=run_inference, inputs=[audio_input, prev_text_input, curr_text_input], outputs=[output_text] ) if __name__ == "__main__": demo.launch()