Number of speakers not anymore limited to 10; a speaker has been created for "being processed" (-1), and another one for no" speaker detected" (-2)
Browse files
src/diarization/diarization_online.py
CHANGED
|
@@ -5,6 +5,11 @@ from rx.subject import Subject
|
|
| 5 |
import threading
|
| 6 |
import numpy as np
|
| 7 |
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class WebSocketAudioSource(AudioSource):
|
| 10 |
"""
|
|
@@ -44,37 +49,48 @@ def create_pipeline(SAMPLE_RATE):
|
|
| 44 |
return inference, ws_source
|
| 45 |
|
| 46 |
|
| 47 |
-
def init_diart(SAMPLE_RATE):
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def diar_hook(result):
|
| 51 |
"""
|
| 52 |
Hook called each time Diart processes a chunk.
|
| 53 |
result is (annotation, audio).
|
| 54 |
-
|
| 55 |
"""
|
| 56 |
-
global l_speakers
|
| 57 |
-
l_speakers = []
|
| 58 |
annotation, audio = result
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
l_speakers_queue = asyncio.Queue()
|
| 67 |
inference.attach_hooks(diar_hook)
|
| 68 |
-
|
| 69 |
-
# Launch Diart in a background thread
|
| 70 |
loop = asyncio.get_event_loop()
|
| 71 |
diar_future = loop.run_in_executor(None, inference)
|
| 72 |
return inference, l_speakers_queue, ws_source
|
| 73 |
|
| 74 |
-
|
| 75 |
-
class DiartDiarization():
|
| 76 |
def __init__(self, SAMPLE_RATE):
|
| 77 |
-
self.
|
|
|
|
| 78 |
self.segment_speakers = []
|
| 79 |
|
| 80 |
async def diarize(self, pcm_array):
|
|
@@ -82,20 +98,21 @@ class DiartDiarization():
|
|
| 82 |
self.segment_speakers = []
|
| 83 |
while not self.l_speakers_queue.empty():
|
| 84 |
self.segment_speakers.append(await self.l_speakers_queue.get())
|
| 85 |
-
|
| 86 |
def close(self):
|
| 87 |
self.ws_source.close()
|
| 88 |
|
| 89 |
-
|
| 90 |
def assign_speakers_to_chunks(self, chunks):
|
| 91 |
"""
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
"""
|
| 97 |
-
|
| 98 |
-
|
| 99 |
|
| 100 |
for segment in self.segment_speakers:
|
| 101 |
seg_beg = segment["beg"]
|
|
@@ -104,7 +121,10 @@ class DiartDiarization():
|
|
| 104 |
for ch in chunks:
|
| 105 |
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
|
| 106 |
continue
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
return chunks
|
|
|
|
| 5 |
import threading
|
| 6 |
import numpy as np
|
| 7 |
import asyncio
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
def extract_number(s):
|
| 11 |
+
match = re.search(r'\d+', s)
|
| 12 |
+
return int(match.group()) if match else None
|
| 13 |
|
| 14 |
class WebSocketAudioSource(AudioSource):
|
| 15 |
"""
|
|
|
|
| 49 |
return inference, ws_source
|
| 50 |
|
| 51 |
|
| 52 |
+
def init_diart(SAMPLE_RATE, diar_instance):
|
| 53 |
+
diar_pipeline = SpeakerDiarization()
|
| 54 |
+
ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
|
| 55 |
+
inference = StreamingInference(
|
| 56 |
+
pipeline=diar_pipeline,
|
| 57 |
+
source=ws_source,
|
| 58 |
+
do_plot=False,
|
| 59 |
+
show_progress=False,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
l_speakers_queue = asyncio.Queue()
|
| 63 |
|
| 64 |
def diar_hook(result):
|
| 65 |
"""
|
| 66 |
Hook called each time Diart processes a chunk.
|
| 67 |
result is (annotation, audio).
|
| 68 |
+
For each detected speaker segment, push its info to the queue and update processed_time.
|
| 69 |
"""
|
|
|
|
|
|
|
| 70 |
annotation, audio = result
|
| 71 |
+
if annotation._labels:
|
| 72 |
+
for speaker in annotation._labels:
|
| 73 |
+
segments_beg = annotation._labels[speaker].segments_boundaries_[0]
|
| 74 |
+
segments_end = annotation._labels[speaker].segments_boundaries_[-1]
|
| 75 |
+
if segments_end > diar_instance.processed_time:
|
| 76 |
+
diar_instance.processed_time = segments_end
|
| 77 |
+
asyncio.create_task(
|
| 78 |
+
l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
audio_duration = audio.extent.end
|
| 82 |
+
if audio_duration > diar_instance.processed_time:
|
| 83 |
+
diar_instance.processed_time = audio_duration
|
| 84 |
|
|
|
|
| 85 |
inference.attach_hooks(diar_hook)
|
|
|
|
|
|
|
| 86 |
loop = asyncio.get_event_loop()
|
| 87 |
diar_future = loop.run_in_executor(None, inference)
|
| 88 |
return inference, l_speakers_queue, ws_source
|
| 89 |
|
| 90 |
+
class DiartDiarization:
|
|
|
|
| 91 |
def __init__(self, SAMPLE_RATE):
|
| 92 |
+
self.processed_time = 0
|
| 93 |
+
self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE, self)
|
| 94 |
self.segment_speakers = []
|
| 95 |
|
| 96 |
async def diarize(self, pcm_array):
|
|
|
|
| 98 |
self.segment_speakers = []
|
| 99 |
while not self.l_speakers_queue.empty():
|
| 100 |
self.segment_speakers.append(await self.l_speakers_queue.get())
|
| 101 |
+
|
| 102 |
def close(self):
|
| 103 |
self.ws_source.close()
|
| 104 |
|
|
|
|
| 105 |
def assign_speakers_to_chunks(self, chunks):
|
| 106 |
"""
|
| 107 |
+
For each chunk (a dict with keys "beg" and "end"), assign a speaker label.
|
| 108 |
+
|
| 109 |
+
- If a chunk overlaps with a detected speaker segment, assign that label.
|
| 110 |
+
- If the chunk's end time is within the processed time and no speaker was assigned,
|
| 111 |
+
mark it as "No speaker".
|
| 112 |
+
- If the chunk's time hasn't been fully processed yet, leave it (or mark as "Processing").
|
| 113 |
"""
|
| 114 |
+
for ch in chunks:
|
| 115 |
+
ch["speaker"] = ch.get("speaker", -1)
|
| 116 |
|
| 117 |
for segment in self.segment_speakers:
|
| 118 |
seg_beg = segment["beg"]
|
|
|
|
| 121 |
for ch in chunks:
|
| 122 |
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
|
| 123 |
continue
|
| 124 |
+
ch["speaker"] = extract_number(speaker) + 1
|
| 125 |
+
if self.processed_time > 0:
|
| 126 |
+
for ch in chunks:
|
| 127 |
+
if ch["end"] <= self.processed_time and ch["speaker"] == -1:
|
| 128 |
+
ch["speaker"] = -2
|
| 129 |
|
| 130 |
+
return chunks
|