fciannella's picture
Working with service run on 7860
53ea588
raw
history blame
10.2 kB
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""Pipeline runner for ACE."""
import asyncio
import gc
from dataclasses import dataclass
from fastapi import WebSocket
from loguru import logger
from pipecat.pipeline.task import PipelineTask
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from starlette.websockets import WebSocketState
@dataclass
class PipelineMetadata:
"""Metadata for managing pipeline state and connections.
This class holds the necessary information to track and manage a pipeline instance,
including its stream ID, websocket connection, RTSP URL, and associated tasks.
Attributes:
stream_id: Unique identifier for the pipeline stream
websocket: Optional WebSocket connection for the pipeline
rtsp_url: RTSP URL for video/audio streaming
pipeline_task: Task object representing the pipeline process
runner_task: Asyncio task managing the pipeline execution
"""
stream_id: str
websocket: WebSocket | None = None
rtsp_url: str = ""
pipeline_task: PipelineTask | None = None
runner_task: asyncio.Task | None = None
class ACEPipelineRunner:
"""Singleton class for managing ACE pipelines.
This class provides a singleton interface for managing multiple ACE pipelines,
including the addition of new pipelines, connection of websockets, and removal of pipelines.
Attributes:
_pipelines: Dictionary storing pipeline metadata for each stream ID
_enable_rtsp: Boolean flag indicating if RTSP is enabled
_pipelines_callback: Callback function for pipeline creation
__instance: Singleton instance of the class
"""
__instance = None
def __init__(self, pipeline_callback: callable, enable_rtsp: bool = False):
"""Initialize the ACEPipelineRunner singleton.
Args:
pipeline_callback: Callback function for pipeline creation
enable_rtsp: Boolean flag indicating if RTSP is enabled
"""
if ACEPipelineRunner.__instance is not None:
raise Exception("This class is a singleton!")
self._pipelines = {}
self._enable_rtsp = enable_rtsp
self._pipelines_callback = pipeline_callback
ACEPipelineRunner.__instance = self
@staticmethod
def create_instance(pipeline_callback: callable, enable_rtsp: bool = False):
"""Create an instance of the ACEPipelineRunner.
Args:
pipeline_callback: Callback function for pipeline creation
enable_rtsp: Boolean flag indicating if RTSP is enabled
"""
if ACEPipelineRunner.__instance is not None:
return ACEPipelineRunner.__instance
else:
ACEPipelineRunner.__instance = ACEPipelineRunner(pipeline_callback, enable_rtsp)
return ACEPipelineRunner.__instance
@staticmethod
def get_instance():
"""Get the singleton instance of the ACEPipelineRunner.
Returns:
ACEPipelineRunner: The singleton instance of the class
Raises:
Exception: If the class is not initialized
"""
if ACEPipelineRunner.__instance is None:
raise Exception("This class is a singleton!, Please create an instance first.")
return ACEPipelineRunner.__instance
async def add_pipeline(self, stream_id: str, rtsp_url: str):
"""Add a new pipeline to the runner.
Args:
stream_id: Unique identifier for the pipeline stream
rtsp_url: RTSP URL string for video/audio streaming
"""
logger.debug(f"Found {len(self._pipelines)} active pipelines with stream IDs: {self._pipelines.keys()}")
with logger.contextualize(stream_id=stream_id):
if stream_id in self._pipelines:
raise ValueError(f"Pipeline for Stream ID {stream_id} already exists")
self._pipelines[stream_id] = PipelineMetadata(stream_id, rtsp_url=rtsp_url)
logger.info(f"Received add pipeline request, Running pipeline for stream {stream_id}")
try:
await self._run_pipeline(stream_id)
except Exception as e:
logger.error(f"Error while creating pipeline: {e}")
await self.remove_pipeline(stream_id)
raise ValueError(f"Error while creating pipeline: {e}") from e
async def connect_websocket(self, stream_id: str, websocket: WebSocket):
"""Connect a websocket.
Connects a websocket to the running pipeline or creates a new pipeline if
it is not running for the given stream id. The method will wait until the websocket
connection is closed by the client and only then return.
Args:
stream_id: Unique identifier for the pipeline stream
websocket: WebSocket connection for the pipeline
"""
with logger.contextualize(stream_id=stream_id):
# First check if pipeline exists
if self._enable_rtsp and stream_id not in self._pipelines:
raise ValueError(f"Pipeline for Stream ID {stream_id} does not exist")
elif not self._enable_rtsp and stream_id not in self._pipelines:
self._pipelines[stream_id] = PipelineMetadata(stream_id, websocket=websocket)
await self._run_pipeline(stream_id)
await self._wait_for_websocket_close(stream_id)
await self.remove_pipeline(stream_id)
elif self._pipelines[stream_id].runner_task and self._pipelines[stream_id].runner_task.done():
raise ValueError(f"Pipeline for Stream ID {stream_id} is already terminated")
else:
await self._update_websocket(stream_id, websocket)
await self._wait_for_websocket_close(stream_id)
async def remove_pipeline(self, stream_id: str):
"""Remove a pipeline from the runner.
Args:
stream_id: Unique identifier for the pipeline stream
"""
with logger.contextualize(stream_id=stream_id):
if stream_id in self._pipelines:
if self._pipelines[stream_id].pipeline_task is not None:
try:
# Signal shutdown
await self._pipelines[stream_id].pipeline_task.stop_when_done()
except Exception as e:
logger.error(f"Error while removing Pipeline: {e}")
if self._pipelines[stream_id].runner_task and not self._pipelines[stream_id].runner_task.done():
logger.info("Waiting for pipeline runner task to finish ...")
await self._pipelines[stream_id].runner_task
logger.info(f"Pipeline for Stream ID {stream_id} removed")
async def _cleanup_pipeline(self, stream_id: str):
"""Cleanup a pipeline.
Args:
stream_id: Unique identifier for the pipeline stream
"""
try:
if (
stream_id in self._pipelines
and self._pipelines[stream_id].websocket
and self._pipelines[stream_id].websocket.client_state == WebSocketState.CONNECTED
):
await self._pipelines[stream_id].websocket.close()
except Exception as e:
logger.error(f"Error while closing websocket: {e}")
del self._pipelines[stream_id]
gc.collect()
logger.info(f"Pipeline for Stream ID {stream_id} deleted")
async def _run_pipeline(self, stream_id: str):
"""Run a pipeline in background.
Args:
stream_id: Unique identifier for the pipeline stream
"""
try:
self._pipelines[stream_id].pipeline_task = await self._pipelines_callback(self._pipelines[stream_id])
self._pipelines[stream_id].pipeline_task.set_event_loop(asyncio.get_event_loop())
self._pipelines[stream_id].runner_task = asyncio.create_task(self._pipelines[stream_id].pipeline_task.run())
self._pipelines[stream_id].runner_task.add_done_callback(
lambda _: asyncio.create_task(self._cleanup_pipeline(stream_id))
)
logger.info(f"Pipeline started successfully for stream {stream_id}")
except Exception as e:
logger.error(f"Error while creating pipeline task: {e}")
raise
async def _update_websocket(self, stream_id: str, websocket: WebSocket):
"""Update the websocket for a pipeline.
Args:
stream_id: Unique identifier for the pipeline stream
websocket: WebSocket connection for the pipeline
"""
self._pipelines[stream_id].websocket = websocket
pipeline = self._pipelines[stream_id].pipeline_task._pipeline
for component in pipeline._processors:
if isinstance(component, BaseInputTransport | BaseOutputTransport):
if hasattr(component._transport, "update_websocket"):
await component._transport.update_websocket(websocket)
logger.info(f"Websocket for Stream ID {stream_id} updated")
return
else:
raise ValueError(f"Component {component.__class__.__name__} doesn't support updating websocket.")
async def _wait_for_websocket_close(self, stream_id: str):
"""Wait for the websocket to close. This is used to keep connection alive.
Args:
stream_id: Unique identifier for the pipeline stream
"""
try:
# Wait until websocket is closed
while (
stream_id in self._pipelines
and self._pipelines[stream_id].websocket.client_state == WebSocketState.CONNECTED
):
await asyncio.sleep(0.1)
logger.info(f"Websocket for Stream ID {stream_id} closed")
except Exception as e:
raise ValueError(f"Error while waiting for websocket close: {e}") from e