Spaces:
Build error
Build error
| import spaces | |
| import torch | |
| import gradio as gr | |
| import whisperx | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| import tempfile | |
| import gc | |
| import os | |
| # Constants | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| BATCH_SIZE = 4 # reduce if low on GPU mem | |
| COMPUTE_TYPE = "float32" # change to "int8" if low on GPU mem | |
| FILE_LIMIT_MB = 1000 | |
| def transcribe_audio(inputs, task): | |
| if inputs is None: | |
| raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
| try: | |
| # Load audio | |
| if isinstance(inputs, str): | |
| # For file path input | |
| audio = whisperx.load_audio(inputs) | |
| else: | |
| # For microphone input (needs conversion) | |
| audio = whisperx.load_audio(inputs) | |
| # 1. Transcribe with base Whisper model | |
| model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE) | |
| result = model.transcribe(audio, batch_size=BATCH_SIZE) | |
| # Clear GPU memory | |
| del model | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # 2. Align whisper output | |
| model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE) | |
| result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False) | |
| # Clear GPU memory again | |
| del model_a | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # 3. Diarize audio | |
| diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ["HF_TOKEN"], device=DEVICE) | |
| diarize_segments = diarize_model(audio) | |
| # 4. Assign speaker labels | |
| result = whisperx.assign_word_speakers(diarize_segments, result) | |
| # Format output | |
| output_text = "" | |
| for segment in result['segments']: | |
| speaker = segment.get('speaker', 'Unknown Speaker') | |
| text = segment['text'] | |
| output_text += f"{speaker}: {text}\n" | |
| return output_text | |
| except Exception as e: | |
| raise gr.Error(f"Error processing audio: {str(e)}") | |
| finally: | |
| # Final cleanup | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Create Gradio interface | |
| demo = gr.Blocks(theme=gr.themes.Ocean()) | |
| with demo: | |
| gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Audio Input (Microphone or File Upload)" | |
| ) | |
| task = gr.Radio( | |
| ["transcribe", "translate"], | |
| label="Task", | |
| value="transcribe" | |
| ) | |
| submit_button = gr.Button("Process Audio") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Transcription with Speaker Diarization", | |
| lines=10, | |
| placeholder="Transcribed text will appear here..." | |
| ) | |
| gr.Markdown(""" | |
| ### Features: | |
| - High-accuracy transcription using WhisperX | |
| - Automatic speaker diarization | |
| - Support for both microphone recording and file upload | |
| - GPU-accelerated processing | |
| ### Note: | |
| Processing may take a few moments depending on the audio length and system resources. | |
| """) | |
| submit_button.click( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, task], | |
| outputs=output_text | |
| ) | |
| demo.queue().launch(ssr_mode=False) |