fciannella's picture
Working with service run on 7860
53ea588
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""NVIDIA RAG bot."""
import os
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from nvidia_pipecat.pipeline.ace_pipeline_runner import ACEPipelineRunner, PipelineMetadata
from nvidia_pipecat.processors.nvidia_context_aggregator import (
# NvidiaTTSResponseCacher, # Uncomment to enable speculative speech processing
create_nvidia_context_aggregator,
)
from nvidia_pipecat.processors.transcript_synchronization import (
BotTranscriptSynchronization,
UserTranscriptSynchronization,
)
from nvidia_pipecat.services.nvidia_rag import NvidiaRAGService
from nvidia_pipecat.services.riva_speech import RivaASRService, RivaTTSService
from nvidia_pipecat.transports.network.ace_fastapi_websocket import (
ACETransport,
ACETransportParams,
)
from nvidia_pipecat.transports.services.ace_controller.routers.websocket_router import router as websocket_router
from nvidia_pipecat.utils.logging import setup_default_ace_logging
load_dotenv(override=True)
setup_default_ace_logging(level="INFO")
async def create_pipeline_task(pipeline_metadata: PipelineMetadata):
"""Create the pipeline to be run.
Args:
pipeline_metadata (PipelineMetadata): Metadata containing websocket and other pipeline configuration.
Returns:
PipelineTask: The configured pipeline task for handling NVIDIA RAG.
"""
transport = ACETransport(
websocket=pipeline_metadata.websocket,
params=ACETransportParams(
vad_analyzer=SileroVADAnalyzer(),
),
)
# Please set your nvidia rag collection name here
rag = NvidiaRAGService(collection_name="nvidia_blogs")
stt = RivaASRService(
server="localhost:50051",
api_key=os.getenv("NVIDIA_API_KEY"),
language="en-US",
sample_rate=16000,
model="parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble",
)
tts = RivaTTSService(
server="localhost:50051",
api_key=os.getenv("NVIDIA_API_KEY"),
voice_id="English-US.Female-1",
language="en-US",
zero_shot_quality=20,
sample_rate=16000,
model="fastpitch-hifigan-tts",
)
messages = [
{
"role": "system",
"content": "You are a helpful Large Language Model. "
"Your goal is to demonstrate your capabilities in a succinct way. "
"Your output will be converted to audio so don't include special characters in your answers. "
"Respond to what the user said in a creative and helpful way.",
}
]
context = OpenAILLMContext(messages)
# Required components for Speculative Speech Processing
# - Nvidia Context aggregator: Handles interim transcripts and early response generation
# send_interims=False: Only process final transcripts
# Set send_interims=True to process interim transcripts when enabling speculative speech processing
nvidia_context_aggregator = create_nvidia_context_aggregator(context, send_interims=False)
# - TTS response cacher: Manages response timing and delivery for natural conversation flow
# nvidia_tts_response_cacher = NvidiaTTSResponseCacher() # Uncomment to enable speculative speech processing
# Used to synchronize the user and bot transcripts in the UI
stt_transcript_synchronization = UserTranscriptSynchronization()
tts_transcript_synchronization = BotTranscriptSynchronization()
pipeline = Pipeline(
[
transport.input(), # Websocket input from client
stt, # Speech-To-Text
stt_transcript_synchronization,
nvidia_context_aggregator.user(),
rag, # NVIDIA RAG
tts, # Text-To-Speech
# Caches TTS responses for coordinated delivery in speculative
# speech processing
# nvidia_tts_response_cacher, # Uncomment to enable speculative speech processing
tts_transcript_synchronization,
transport.output(), # Websocket output to client
nvidia_context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
send_initial_empty_metrics=True,
report_only_initial_ttfb=True,
start_metadata={"stream_id": pipeline_metadata.stream_id},
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
# Kick off the conversation.
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])
return task
app = FastAPI()
app.include_router(websocket_router)
runner = ACEPipelineRunner.create_instance(pipeline_callback=create_pipeline_task)
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "../static")), name="static")
if __name__ == "__main__":
uvicorn.run("bot:app", host="0.0.0.0", port=8100, workers=1)