File size: 10,223 Bytes
53ea588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License

"""Pipeline runner for ACE."""

import asyncio
import gc
from dataclasses import dataclass

from fastapi import WebSocket
from loguru import logger
from pipecat.pipeline.task import PipelineTask
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from starlette.websockets import WebSocketState


@dataclass
class PipelineMetadata:
    """Metadata for managing pipeline state and connections.

    This class holds the necessary information to track and manage a pipeline instance,
    including its stream ID, websocket connection, RTSP URL, and associated tasks.

    Attributes:
        stream_id: Unique identifier for the pipeline stream
        websocket: Optional WebSocket connection for the pipeline
        rtsp_url: RTSP URL for video/audio streaming
        pipeline_task: Task object representing the pipeline process
        runner_task: Asyncio task managing the pipeline execution
    """

    stream_id: str
    websocket: WebSocket | None = None
    rtsp_url: str = ""
    pipeline_task: PipelineTask | None = None
    runner_task: asyncio.Task | None = None


class ACEPipelineRunner:
    """Singleton class for managing ACE pipelines.

    This class provides a singleton interface for managing multiple ACE pipelines,
    including the addition of new pipelines, connection of websockets, and removal of pipelines.

    Attributes:
        _pipelines: Dictionary storing pipeline metadata for each stream ID
        _enable_rtsp: Boolean flag indicating if RTSP is enabled
        _pipelines_callback: Callback function for pipeline creation
        __instance: Singleton instance of the class
    """

    __instance = None

    def __init__(self, pipeline_callback: callable, enable_rtsp: bool = False):
        """Initialize the ACEPipelineRunner singleton.

        Args:
            pipeline_callback: Callback function for pipeline creation
            enable_rtsp: Boolean flag indicating if RTSP is enabled
        """
        if ACEPipelineRunner.__instance is not None:
            raise Exception("This class is a singleton!")
        self._pipelines = {}
        self._enable_rtsp = enable_rtsp
        self._pipelines_callback = pipeline_callback
        ACEPipelineRunner.__instance = self

    @staticmethod
    def create_instance(pipeline_callback: callable, enable_rtsp: bool = False):
        """Create an instance of the ACEPipelineRunner.

        Args:
            pipeline_callback: Callback function for pipeline creation
            enable_rtsp: Boolean flag indicating if RTSP is enabled
        """
        if ACEPipelineRunner.__instance is not None:
            return ACEPipelineRunner.__instance
        else:
            ACEPipelineRunner.__instance = ACEPipelineRunner(pipeline_callback, enable_rtsp)
            return ACEPipelineRunner.__instance

    @staticmethod
    def get_instance():
        """Get the singleton instance of the ACEPipelineRunner.

        Returns:
            ACEPipelineRunner: The singleton instance of the class

        Raises:
            Exception: If the class is not initialized
        """
        if ACEPipelineRunner.__instance is None:
            raise Exception("This class is a singleton!, Please create an instance first.")
        return ACEPipelineRunner.__instance

    async def add_pipeline(self, stream_id: str, rtsp_url: str):
        """Add a new pipeline to the runner.

        Args:
            stream_id: Unique identifier for the pipeline stream
            rtsp_url: RTSP URL string for video/audio streaming
        """
        logger.debug(f"Found {len(self._pipelines)} active pipelines with stream IDs: {self._pipelines.keys()}")
        with logger.contextualize(stream_id=stream_id):
            if stream_id in self._pipelines:
                raise ValueError(f"Pipeline for Stream ID {stream_id} already exists")

            self._pipelines[stream_id] = PipelineMetadata(stream_id, rtsp_url=rtsp_url)
            logger.info(f"Received add pipeline request, Running pipeline for stream {stream_id}")
            try:
                await self._run_pipeline(stream_id)
            except Exception as e:
                logger.error(f"Error while creating pipeline: {e}")
                await self.remove_pipeline(stream_id)
                raise ValueError(f"Error while creating pipeline: {e}") from e

    async def connect_websocket(self, stream_id: str, websocket: WebSocket):
        """Connect a websocket.

        Connects a websocket to the running pipeline or creates a new pipeline if
        it is not running for the given stream id. The method will wait until the websocket
        connection is closed by the client and only then return.

        Args:
            stream_id: Unique identifier for the pipeline stream
            websocket: WebSocket connection for the pipeline
        """
        with logger.contextualize(stream_id=stream_id):
            # First check if pipeline exists
            if self._enable_rtsp and stream_id not in self._pipelines:
                raise ValueError(f"Pipeline for Stream ID {stream_id} does not exist")
            elif not self._enable_rtsp and stream_id not in self._pipelines:
                self._pipelines[stream_id] = PipelineMetadata(stream_id, websocket=websocket)
                await self._run_pipeline(stream_id)
                await self._wait_for_websocket_close(stream_id)
                await self.remove_pipeline(stream_id)
            elif self._pipelines[stream_id].runner_task and self._pipelines[stream_id].runner_task.done():
                raise ValueError(f"Pipeline for Stream ID {stream_id} is already terminated")
            else:
                await self._update_websocket(stream_id, websocket)
                await self._wait_for_websocket_close(stream_id)

    async def remove_pipeline(self, stream_id: str):
        """Remove a pipeline from the runner.

        Args:
            stream_id: Unique identifier for the pipeline stream
        """
        with logger.contextualize(stream_id=stream_id):
            if stream_id in self._pipelines:
                if self._pipelines[stream_id].pipeline_task is not None:
                    try:
                        # Signal shutdown
                        await self._pipelines[stream_id].pipeline_task.stop_when_done()
                    except Exception as e:
                        logger.error(f"Error while removing Pipeline: {e}")

                    if self._pipelines[stream_id].runner_task and not self._pipelines[stream_id].runner_task.done():
                        logger.info("Waiting for pipeline runner task to finish ...")
                        await self._pipelines[stream_id].runner_task
                logger.info(f"Pipeline for Stream ID {stream_id} removed")

    async def _cleanup_pipeline(self, stream_id: str):
        """Cleanup a pipeline.

        Args:
            stream_id: Unique identifier for the pipeline stream
        """
        try:
            if (
                stream_id in self._pipelines
                and self._pipelines[stream_id].websocket
                and self._pipelines[stream_id].websocket.client_state == WebSocketState.CONNECTED
            ):
                await self._pipelines[stream_id].websocket.close()
        except Exception as e:
            logger.error(f"Error while closing websocket: {e}")
        del self._pipelines[stream_id]
        gc.collect()
        logger.info(f"Pipeline for Stream ID {stream_id} deleted")

    async def _run_pipeline(self, stream_id: str):
        """Run a pipeline in background.

        Args:
            stream_id: Unique identifier for the pipeline stream
        """
        try:
            self._pipelines[stream_id].pipeline_task = await self._pipelines_callback(self._pipelines[stream_id])
            self._pipelines[stream_id].pipeline_task.set_event_loop(asyncio.get_event_loop())
            self._pipelines[stream_id].runner_task = asyncio.create_task(self._pipelines[stream_id].pipeline_task.run())
            self._pipelines[stream_id].runner_task.add_done_callback(
                lambda _: asyncio.create_task(self._cleanup_pipeline(stream_id))
            )
            logger.info(f"Pipeline started successfully for stream {stream_id}")
        except Exception as e:
            logger.error(f"Error while creating pipeline task: {e}")
            raise

    async def _update_websocket(self, stream_id: str, websocket: WebSocket):
        """Update the websocket for a pipeline.

        Args:
            stream_id: Unique identifier for the pipeline stream
            websocket: WebSocket connection for the pipeline
        """
        self._pipelines[stream_id].websocket = websocket
        pipeline = self._pipelines[stream_id].pipeline_task._pipeline
        for component in pipeline._processors:
            if isinstance(component, BaseInputTransport | BaseOutputTransport):
                if hasattr(component._transport, "update_websocket"):
                    await component._transport.update_websocket(websocket)
                    logger.info(f"Websocket for Stream ID {stream_id} updated")
                    return
                else:
                    raise ValueError(f"Component {component.__class__.__name__} doesn't support updating websocket.")

    async def _wait_for_websocket_close(self, stream_id: str):
        """Wait for the websocket to close. This is used to keep connection alive.

        Args:
            stream_id: Unique identifier for the pipeline stream
        """
        try:
            # Wait until websocket is closed
            while (
                stream_id in self._pipelines
                and self._pipelines[stream_id].websocket.client_state == WebSocketState.CONNECTED
            ):
                await asyncio.sleep(0.1)
            logger.info(f"Websocket for Stream ID {stream_id} closed")
        except Exception as e:
            raise ValueError(f"Error while waiting for websocket close: {e}") from e