Spaces:
Running
Running
File size: 5,614 Bytes
53ea588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""NVIDIA RAG bot."""
import os
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from nvidia_pipecat.pipeline.ace_pipeline_runner import ACEPipelineRunner, PipelineMetadata
from nvidia_pipecat.processors.nvidia_context_aggregator import (
# NvidiaTTSResponseCacher, # Uncomment to enable speculative speech processing
create_nvidia_context_aggregator,
)
from nvidia_pipecat.processors.transcript_synchronization import (
BotTranscriptSynchronization,
UserTranscriptSynchronization,
)
from nvidia_pipecat.services.nvidia_rag import NvidiaRAGService
from nvidia_pipecat.services.riva_speech import RivaASRService, RivaTTSService
from nvidia_pipecat.transports.network.ace_fastapi_websocket import (
ACETransport,
ACETransportParams,
)
from nvidia_pipecat.transports.services.ace_controller.routers.websocket_router import router as websocket_router
from nvidia_pipecat.utils.logging import setup_default_ace_logging
load_dotenv(override=True)
setup_default_ace_logging(level="INFO")
async def create_pipeline_task(pipeline_metadata: PipelineMetadata):
"""Create the pipeline to be run.
Args:
pipeline_metadata (PipelineMetadata): Metadata containing websocket and other pipeline configuration.
Returns:
PipelineTask: The configured pipeline task for handling NVIDIA RAG.
"""
transport = ACETransport(
websocket=pipeline_metadata.websocket,
params=ACETransportParams(
vad_analyzer=SileroVADAnalyzer(),
),
)
# Please set your nvidia rag collection name here
rag = NvidiaRAGService(collection_name="nvidia_blogs")
stt = RivaASRService(
server="localhost:50051",
api_key=os.getenv("NVIDIA_API_KEY"),
language="en-US",
sample_rate=16000,
model="parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble",
)
tts = RivaTTSService(
server="localhost:50051",
api_key=os.getenv("NVIDIA_API_KEY"),
voice_id="English-US.Female-1",
language="en-US",
zero_shot_quality=20,
sample_rate=16000,
model="fastpitch-hifigan-tts",
)
messages = [
{
"role": "system",
"content": "You are a helpful Large Language Model. "
"Your goal is to demonstrate your capabilities in a succinct way. "
"Your output will be converted to audio so don't include special characters in your answers. "
"Respond to what the user said in a creative and helpful way.",
}
]
context = OpenAILLMContext(messages)
# Required components for Speculative Speech Processing
# - Nvidia Context aggregator: Handles interim transcripts and early response generation
# send_interims=False: Only process final transcripts
# Set send_interims=True to process interim transcripts when enabling speculative speech processing
nvidia_context_aggregator = create_nvidia_context_aggregator(context, send_interims=False)
# - TTS response cacher: Manages response timing and delivery for natural conversation flow
# nvidia_tts_response_cacher = NvidiaTTSResponseCacher() # Uncomment to enable speculative speech processing
# Used to synchronize the user and bot transcripts in the UI
stt_transcript_synchronization = UserTranscriptSynchronization()
tts_transcript_synchronization = BotTranscriptSynchronization()
pipeline = Pipeline(
[
transport.input(), # Websocket input from client
stt, # Speech-To-Text
stt_transcript_synchronization,
nvidia_context_aggregator.user(),
rag, # NVIDIA RAG
tts, # Text-To-Speech
# Caches TTS responses for coordinated delivery in speculative
# speech processing
# nvidia_tts_response_cacher, # Uncomment to enable speculative speech processing
tts_transcript_synchronization,
transport.output(), # Websocket output to client
nvidia_context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
send_initial_empty_metrics=True,
report_only_initial_ttfb=True,
start_metadata={"stream_id": pipeline_metadata.stream_id},
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
# Kick off the conversation.
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])
return task
app = FastAPI()
app.include_router(websocket_router)
runner = ACEPipelineRunner.create_instance(pipeline_callback=create_pipeline_task)
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "../static")), name="static")
if __name__ == "__main__":
uvicorn.run("bot:app", host="0.0.0.0", port=8100, workers=1)
|