DiartDiarization now uses SpeakerSegment
Browse files
src/diarization/diarization_online.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
| 6 |
from diart import SpeakerDiarization
|
| 7 |
from diart.inference import StreamingInference
|
| 8 |
from diart.sources import AudioSource
|
| 9 |
-
|
| 10 |
|
| 11 |
def extract_number(s: str) -> int:
|
| 12 |
m = re.search(r'\d+', s)
|
|
@@ -58,15 +58,15 @@ class DiartDiarization:
|
|
| 58 |
annotation, audio = result
|
| 59 |
if annotation._labels:
|
| 60 |
for speaker, label in annotation._labels.items():
|
| 61 |
-
|
| 62 |
end = label.segments_boundaries_[-1]
|
| 63 |
if end > self.processed_time:
|
| 64 |
self.processed_time = end
|
| 65 |
-
asyncio.create_task(self.speakers_queue.put(
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
else:
|
| 71 |
dur = audio.extent.end
|
| 72 |
if dur > self.processed_time:
|
|
@@ -84,7 +84,7 @@ class DiartDiarization:
|
|
| 84 |
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
| 85 |
for token in tokens:
|
| 86 |
for segment in self.segment_speakers:
|
| 87 |
-
if not (segment
|
| 88 |
-
token.speaker = extract_number(segment
|
| 89 |
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
| 90 |
return end_attributed_speaker
|
|
|
|
| 6 |
from diart import SpeakerDiarization
|
| 7 |
from diart.inference import StreamingInference
|
| 8 |
from diart.sources import AudioSource
|
| 9 |
+
from src.whisper_streaming.timed_objects import SpeakerSegment
|
| 10 |
|
| 11 |
def extract_number(s: str) -> int:
|
| 12 |
m = re.search(r'\d+', s)
|
|
|
|
| 58 |
annotation, audio = result
|
| 59 |
if annotation._labels:
|
| 60 |
for speaker, label in annotation._labels.items():
|
| 61 |
+
start = label.segments_boundaries_[0]
|
| 62 |
end = label.segments_boundaries_[-1]
|
| 63 |
if end > self.processed_time:
|
| 64 |
self.processed_time = end
|
| 65 |
+
asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
|
| 66 |
+
speaker=speaker,
|
| 67 |
+
start=start,
|
| 68 |
+
end=end,
|
| 69 |
+
)))
|
| 70 |
else:
|
| 71 |
dur = audio.extent.end
|
| 72 |
if dur > self.processed_time:
|
|
|
|
| 84 |
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
| 85 |
for token in tokens:
|
| 86 |
for segment in self.segment_speakers:
|
| 87 |
+
if not (segment.end <= token.start or segment.start >= token.end):
|
| 88 |
+
token.speaker = extract_number(segment.speaker) + 1
|
| 89 |
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
| 90 |
return end_attributed_speaker
|