Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline, Wav2Vec2ProcessorWithLM | |
| from pyannote.audio import Pipeline | |
| from librosa import load, resample | |
| from rpunct import RestorePuncts | |
| # Audio components | |
| asr_model = 'patrickvonplaten/wav2vec2-base-960h-4-gram' | |
| processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model) | |
| asr = pipeline('automatic-speech-recognition', model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder) | |
| speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation") | |
| rpunct = RestorePuncts() | |
| # Text components | |
| sentiment_pipeline = pipeline('text-classification', model="distilbert-base-uncased-finetuned-sst-2-english") | |
| sentiment_threshold = 0.75 | |
| EXAMPLES = ["example_audio.wav"] | |
| def speech_to_text(speech): | |
| speaker_output = speaker_segmentation(speech) | |
| speech, sampling_rate = load(speech) | |
| if sampling_rate != 16000: | |
| speech = resample(speech, sampling_rate, 16000) | |
| text = asr(speech, return_timestamps="word") | |
| full_text = text['text'].lower() | |
| chunks = text['chunks'] | |
| diarized_output = [] | |
| i = 0 | |
| speaker_counter = 0 | |
| # New iteration every time the speaker changes | |
| for turn, _, _ in speaker_output.itertracks(yield_label=True): | |
| speaker = "Speaker 0" if speaker_counter % 2 == 0 else "Speaker 1" | |
| diarized = "" | |
| while i < len(chunks) and chunks[i]['timestamp'][1] <= turn.end: | |
| diarized += chunks[i]['text'].lower() + ' ' | |
| i += 1 | |
| if diarized != "": | |
| diarized = rpunct.punctuate(diarized) | |
| diarized_output.extend([(diarized, speaker), ('from {:.2f}-{:.2f}'.format(turn.start, turn.end), None)]) | |
| speaker_counter += 1 | |
| return diarized_output, full_text | |
| def sentiment(checked_options, diarized): | |
| customer_id = checked_options | |
| customer_sentiments = [] | |
| for transcript in diarized: | |
| speaker_speech, speaker_id = transcript | |
| if speaker_id == customer_id: | |
| output = sentiment_pipeline(speaker_speech)[0] | |
| if output["label"] != "neutral" and output["score"] > sentiment_threshold: | |
| customer_sentiments.append((speaker_speech, output["label"])) | |
| else: | |
| customer_sentiments.append(speaker_speech, None) | |
| return customer_sentiments | |
| demo = gr.Blocks(enable_queue=True) | |
| demo.encrypt = False | |
| with demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio = gr.Audio(label="Audio file", type='filepath') | |
| with gr.Row(): | |
| btn = gr.Button("Transcribe") | |
| with gr.Row(): | |
| examples = gr.components.Dataset(components=[audio], samples=[EXAMPLES], type="index") | |
| with gr.Column(): | |
| gr.Markdown("**Diarized Output:**") | |
| diarized = gr.HighlightedText(lines=5, label="Diarized Output") | |
| full = gr.Textbox(lines=4, label="Full Transcript") | |
| check = gr.Radio(["Speaker 0", "Speaker 1"], label='Choose speaker for sentiment analysis') | |
| analyzed = gr.HighlightedText(label="Customer Sentiment") | |
| btn.click(speech_to_text, audio, [diarized, full]) | |
| check.change(sentiment, [check, diarized], analyzed) | |
| def cache_example(example): | |
| processed_examples = audio.preprocess_example(example) | |
| diarized_output, full_text = speech_to_text(example) | |
| return processed_examples, diarized_output, full_text | |
| cache = [cache_example(e) for e in EXAMPLES] | |
| def load_example(example_id): | |
| return cache[example_id] | |
| examples._click_no_postprocess(load_example, inputs=[examples], outputs=[audio, diarized, full], queue=False) | |
| demo.launch() |