Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: BSD 2-Clause License | |
| """NVIDIA RAG bot.""" | |
| import os | |
| import uvicorn | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from pipecat.audio.vad.silero import SileroVADAnalyzer | |
| from pipecat.frames.frames import LLMMessagesFrame | |
| from pipecat.pipeline.pipeline import Pipeline | |
| from pipecat.pipeline.task import PipelineParams, PipelineTask | |
| from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext | |
| from nvidia_pipecat.pipeline.ace_pipeline_runner import ACEPipelineRunner, PipelineMetadata | |
| from nvidia_pipecat.processors.nvidia_context_aggregator import ( | |
| # NvidiaTTSResponseCacher, # Uncomment to enable speculative speech processing | |
| create_nvidia_context_aggregator, | |
| ) | |
| from nvidia_pipecat.processors.transcript_synchronization import ( | |
| BotTranscriptSynchronization, | |
| UserTranscriptSynchronization, | |
| ) | |
| from nvidia_pipecat.services.nvidia_rag import NvidiaRAGService | |
| from nvidia_pipecat.services.riva_speech import RivaASRService, RivaTTSService | |
| from nvidia_pipecat.transports.network.ace_fastapi_websocket import ( | |
| ACETransport, | |
| ACETransportParams, | |
| ) | |
| from nvidia_pipecat.transports.services.ace_controller.routers.websocket_router import router as websocket_router | |
| from nvidia_pipecat.utils.logging import setup_default_ace_logging | |
| load_dotenv(override=True) | |
| setup_default_ace_logging(level="INFO") | |
| async def create_pipeline_task(pipeline_metadata: PipelineMetadata): | |
| """Create the pipeline to be run. | |
| Args: | |
| pipeline_metadata (PipelineMetadata): Metadata containing websocket and other pipeline configuration. | |
| Returns: | |
| PipelineTask: The configured pipeline task for handling NVIDIA RAG. | |
| """ | |
| transport = ACETransport( | |
| websocket=pipeline_metadata.websocket, | |
| params=ACETransportParams( | |
| vad_analyzer=SileroVADAnalyzer(), | |
| ), | |
| ) | |
| # Please set your nvidia rag collection name here | |
| rag = NvidiaRAGService(collection_name="nvidia_blogs") | |
| stt = RivaASRService( | |
| server="localhost:50051", | |
| api_key=os.getenv("NVIDIA_API_KEY"), | |
| language="en-US", | |
| sample_rate=16000, | |
| model="parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble", | |
| ) | |
| tts = RivaTTSService( | |
| server="localhost:50051", | |
| api_key=os.getenv("NVIDIA_API_KEY"), | |
| voice_id="English-US.Female-1", | |
| language="en-US", | |
| zero_shot_quality=20, | |
| sample_rate=16000, | |
| model="fastpitch-hifigan-tts", | |
| ) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful Large Language Model. " | |
| "Your goal is to demonstrate your capabilities in a succinct way. " | |
| "Your output will be converted to audio so don't include special characters in your answers. " | |
| "Respond to what the user said in a creative and helpful way.", | |
| } | |
| ] | |
| context = OpenAILLMContext(messages) | |
| # Required components for Speculative Speech Processing | |
| # - Nvidia Context aggregator: Handles interim transcripts and early response generation | |
| # send_interims=False: Only process final transcripts | |
| # Set send_interims=True to process interim transcripts when enabling speculative speech processing | |
| nvidia_context_aggregator = create_nvidia_context_aggregator(context, send_interims=False) | |
| # - TTS response cacher: Manages response timing and delivery for natural conversation flow | |
| # nvidia_tts_response_cacher = NvidiaTTSResponseCacher() # Uncomment to enable speculative speech processing | |
| # Used to synchronize the user and bot transcripts in the UI | |
| stt_transcript_synchronization = UserTranscriptSynchronization() | |
| tts_transcript_synchronization = BotTranscriptSynchronization() | |
| pipeline = Pipeline( | |
| [ | |
| transport.input(), # Websocket input from client | |
| stt, # Speech-To-Text | |
| stt_transcript_synchronization, | |
| nvidia_context_aggregator.user(), | |
| rag, # NVIDIA RAG | |
| tts, # Text-To-Speech | |
| # Caches TTS responses for coordinated delivery in speculative | |
| # speech processing | |
| # nvidia_tts_response_cacher, # Uncomment to enable speculative speech processing | |
| tts_transcript_synchronization, | |
| transport.output(), # Websocket output to client | |
| nvidia_context_aggregator.assistant(), | |
| ] | |
| ) | |
| task = PipelineTask( | |
| pipeline, | |
| params=PipelineParams( | |
| allow_interruptions=True, | |
| enable_metrics=True, | |
| enable_usage_metrics=True, | |
| send_initial_empty_metrics=True, | |
| report_only_initial_ttfb=True, | |
| start_metadata={"stream_id": pipeline_metadata.stream_id}, | |
| ), | |
| ) | |
| async def on_client_connected(transport, client): | |
| # Kick off the conversation. | |
| messages.append({"role": "user", "content": "Please introduce yourself to the user."}) | |
| await task.queue_frames([LLMMessagesFrame(messages)]) | |
| return task | |
| app = FastAPI() | |
| app.include_router(websocket_router) | |
| runner = ACEPipelineRunner.create_instance(pipeline_callback=create_pipeline_task) | |
| app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "../static")), name="static") | |
| if __name__ == "__main__": | |
| uvicorn.run("bot:app", host="0.0.0.0", port=8100, workers=1) | |