Spaces:
Running
Running
File size: 10,223 Bytes
53ea588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""Pipeline runner for ACE."""
import asyncio
import gc
from dataclasses import dataclass
from fastapi import WebSocket
from loguru import logger
from pipecat.pipeline.task import PipelineTask
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from starlette.websockets import WebSocketState
@dataclass
class PipelineMetadata:
"""Metadata for managing pipeline state and connections.
This class holds the necessary information to track and manage a pipeline instance,
including its stream ID, websocket connection, RTSP URL, and associated tasks.
Attributes:
stream_id: Unique identifier for the pipeline stream
websocket: Optional WebSocket connection for the pipeline
rtsp_url: RTSP URL for video/audio streaming
pipeline_task: Task object representing the pipeline process
runner_task: Asyncio task managing the pipeline execution
"""
stream_id: str
websocket: WebSocket | None = None
rtsp_url: str = ""
pipeline_task: PipelineTask | None = None
runner_task: asyncio.Task | None = None
class ACEPipelineRunner:
"""Singleton class for managing ACE pipelines.
This class provides a singleton interface for managing multiple ACE pipelines,
including the addition of new pipelines, connection of websockets, and removal of pipelines.
Attributes:
_pipelines: Dictionary storing pipeline metadata for each stream ID
_enable_rtsp: Boolean flag indicating if RTSP is enabled
_pipelines_callback: Callback function for pipeline creation
__instance: Singleton instance of the class
"""
__instance = None
def __init__(self, pipeline_callback: callable, enable_rtsp: bool = False):
"""Initialize the ACEPipelineRunner singleton.
Args:
pipeline_callback: Callback function for pipeline creation
enable_rtsp: Boolean flag indicating if RTSP is enabled
"""
if ACEPipelineRunner.__instance is not None:
raise Exception("This class is a singleton!")
self._pipelines = {}
self._enable_rtsp = enable_rtsp
self._pipelines_callback = pipeline_callback
ACEPipelineRunner.__instance = self
@staticmethod
def create_instance(pipeline_callback: callable, enable_rtsp: bool = False):
"""Create an instance of the ACEPipelineRunner.
Args:
pipeline_callback: Callback function for pipeline creation
enable_rtsp: Boolean flag indicating if RTSP is enabled
"""
if ACEPipelineRunner.__instance is not None:
return ACEPipelineRunner.__instance
else:
ACEPipelineRunner.__instance = ACEPipelineRunner(pipeline_callback, enable_rtsp)
return ACEPipelineRunner.__instance
@staticmethod
def get_instance():
"""Get the singleton instance of the ACEPipelineRunner.
Returns:
ACEPipelineRunner: The singleton instance of the class
Raises:
Exception: If the class is not initialized
"""
if ACEPipelineRunner.__instance is None:
raise Exception("This class is a singleton!, Please create an instance first.")
return ACEPipelineRunner.__instance
async def add_pipeline(self, stream_id: str, rtsp_url: str):
"""Add a new pipeline to the runner.
Args:
stream_id: Unique identifier for the pipeline stream
rtsp_url: RTSP URL string for video/audio streaming
"""
logger.debug(f"Found {len(self._pipelines)} active pipelines with stream IDs: {self._pipelines.keys()}")
with logger.contextualize(stream_id=stream_id):
if stream_id in self._pipelines:
raise ValueError(f"Pipeline for Stream ID {stream_id} already exists")
self._pipelines[stream_id] = PipelineMetadata(stream_id, rtsp_url=rtsp_url)
logger.info(f"Received add pipeline request, Running pipeline for stream {stream_id}")
try:
await self._run_pipeline(stream_id)
except Exception as e:
logger.error(f"Error while creating pipeline: {e}")
await self.remove_pipeline(stream_id)
raise ValueError(f"Error while creating pipeline: {e}") from e
async def connect_websocket(self, stream_id: str, websocket: WebSocket):
"""Connect a websocket.
Connects a websocket to the running pipeline or creates a new pipeline if
it is not running for the given stream id. The method will wait until the websocket
connection is closed by the client and only then return.
Args:
stream_id: Unique identifier for the pipeline stream
websocket: WebSocket connection for the pipeline
"""
with logger.contextualize(stream_id=stream_id):
# First check if pipeline exists
if self._enable_rtsp and stream_id not in self._pipelines:
raise ValueError(f"Pipeline for Stream ID {stream_id} does not exist")
elif not self._enable_rtsp and stream_id not in self._pipelines:
self._pipelines[stream_id] = PipelineMetadata(stream_id, websocket=websocket)
await self._run_pipeline(stream_id)
await self._wait_for_websocket_close(stream_id)
await self.remove_pipeline(stream_id)
elif self._pipelines[stream_id].runner_task and self._pipelines[stream_id].runner_task.done():
raise ValueError(f"Pipeline for Stream ID {stream_id} is already terminated")
else:
await self._update_websocket(stream_id, websocket)
await self._wait_for_websocket_close(stream_id)
async def remove_pipeline(self, stream_id: str):
"""Remove a pipeline from the runner.
Args:
stream_id: Unique identifier for the pipeline stream
"""
with logger.contextualize(stream_id=stream_id):
if stream_id in self._pipelines:
if self._pipelines[stream_id].pipeline_task is not None:
try:
# Signal shutdown
await self._pipelines[stream_id].pipeline_task.stop_when_done()
except Exception as e:
logger.error(f"Error while removing Pipeline: {e}")
if self._pipelines[stream_id].runner_task and not self._pipelines[stream_id].runner_task.done():
logger.info("Waiting for pipeline runner task to finish ...")
await self._pipelines[stream_id].runner_task
logger.info(f"Pipeline for Stream ID {stream_id} removed")
async def _cleanup_pipeline(self, stream_id: str):
"""Cleanup a pipeline.
Args:
stream_id: Unique identifier for the pipeline stream
"""
try:
if (
stream_id in self._pipelines
and self._pipelines[stream_id].websocket
and self._pipelines[stream_id].websocket.client_state == WebSocketState.CONNECTED
):
await self._pipelines[stream_id].websocket.close()
except Exception as e:
logger.error(f"Error while closing websocket: {e}")
del self._pipelines[stream_id]
gc.collect()
logger.info(f"Pipeline for Stream ID {stream_id} deleted")
async def _run_pipeline(self, stream_id: str):
"""Run a pipeline in background.
Args:
stream_id: Unique identifier for the pipeline stream
"""
try:
self._pipelines[stream_id].pipeline_task = await self._pipelines_callback(self._pipelines[stream_id])
self._pipelines[stream_id].pipeline_task.set_event_loop(asyncio.get_event_loop())
self._pipelines[stream_id].runner_task = asyncio.create_task(self._pipelines[stream_id].pipeline_task.run())
self._pipelines[stream_id].runner_task.add_done_callback(
lambda _: asyncio.create_task(self._cleanup_pipeline(stream_id))
)
logger.info(f"Pipeline started successfully for stream {stream_id}")
except Exception as e:
logger.error(f"Error while creating pipeline task: {e}")
raise
async def _update_websocket(self, stream_id: str, websocket: WebSocket):
"""Update the websocket for a pipeline.
Args:
stream_id: Unique identifier for the pipeline stream
websocket: WebSocket connection for the pipeline
"""
self._pipelines[stream_id].websocket = websocket
pipeline = self._pipelines[stream_id].pipeline_task._pipeline
for component in pipeline._processors:
if isinstance(component, BaseInputTransport | BaseOutputTransport):
if hasattr(component._transport, "update_websocket"):
await component._transport.update_websocket(websocket)
logger.info(f"Websocket for Stream ID {stream_id} updated")
return
else:
raise ValueError(f"Component {component.__class__.__name__} doesn't support updating websocket.")
async def _wait_for_websocket_close(self, stream_id: str):
"""Wait for the websocket to close. This is used to keep connection alive.
Args:
stream_id: Unique identifier for the pipeline stream
"""
try:
# Wait until websocket is closed
while (
stream_id in self._pipelines
and self._pipelines[stream_id].websocket.client_state == WebSocketState.CONNECTED
):
await asyncio.sleep(0.1)
logger.info(f"Websocket for Stream ID {stream_id} closed")
except Exception as e:
raise ValueError(f"Error while waiting for websocket close: {e}") from e
|