Spaces:
Running
Running
File size: 9,224 Bytes
53ea588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""Speech-to-speech conversation bot."""
import os
from pathlib import Path
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from nvidia_pipecat.pipeline.ace_pipeline_runner import ACEPipelineRunner, PipelineMetadata
# from nvidia_pipecat.processors.nvidia_context_aggregator import (
# NvidiaTTSResponseCacher,
# create_nvidia_context_aggregator,
# )
from nvidia_pipecat.processors.transcript_synchronization import (
BotTranscriptSynchronization,
UserTranscriptSynchronization,
)
from nvidia_pipecat.services.blingfire_text_aggregator import BlingfireTextAggregator
from nvidia_pipecat.services.nvidia_llm import NvidiaLLMService
from nvidia_pipecat.services.riva_speech import RivaASRService, RivaTTSService
from nvidia_pipecat.transports.network.ace_fastapi_websocket import ACETransport, ACETransportParams
from nvidia_pipecat.transports.services.ace_controller.routers.websocket_router import router as websocket_router
from nvidia_pipecat.utils.logging import setup_default_ace_logging
load_dotenv(override=True)
setup_default_ace_logging(level="DEBUG")
async def create_pipeline_task(pipeline_metadata: PipelineMetadata):
"""Create the pipeline to be run.
Args:
pipeline_metadata (PipelineMetadata): Metadata containing websocket and other pipeline configuration.
Returns:
PipelineTask: The configured pipeline task for handling speech-to-speech conversation.
"""
transport = ACETransport(
websocket=pipeline_metadata.websocket,
params=ACETransportParams(
vad_analyzer=SileroVADAnalyzer(),
audio_out_10ms_chunks=20,
),
)
llm = NvidiaLLMService(
api_key=os.getenv("NVIDIA_API_KEY"),
base_url=os.getenv("NVIDIA_LLM_URL", "https://integrate.api.nvidia.com/v1"),
model=os.getenv("NVIDIA_LLM_MODEL", "meta/llama-3.1-8b-instruct"),
)
stt = RivaASRService(
server=os.getenv("RIVA_ASR_URL", "localhost:50051"),
api_key=os.getenv("NVIDIA_API_KEY"),
language=os.getenv("RIVA_ASR_LANGUAGE", "en-US"),
sample_rate=16000,
model=os.getenv("RIVA_ASR_MODEL", "parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble"),
)
tts = RivaTTSService(
server=os.getenv("RIVA_TTS_URL", "localhost:50051"),
api_key=os.getenv("NVIDIA_API_KEY"),
voice_id=os.getenv("RIVA_TTS_VOICE_ID", "Magpie-Multilingual.EN-US.Sofia"),
model=os.getenv("RIVA_TTS_MODEL", "magpie_tts_ensemble-Magpie-Multilingual"),
language=os.getenv("RIVA_TTS_LANGUAGE", "en-US"),
zero_shot_audio_prompt_file=(
Path(os.getenv("ZERO_SHOT_AUDIO_PROMPT")) if os.getenv("ZERO_SHOT_AUDIO_PROMPT") else None
),
text_aggregator=BlingfireTextAggregator(),
)
# Used to synchronize the user and bot transcripts in the UI
stt_transcript_synchronization = UserTranscriptSynchronization()
tts_transcript_synchronization = BotTranscriptSynchronization()
# System prompt can be changed to fit the use case
messages = [
{
"role": "system",
"content": (
"### CONVERSATION CONSTRAINTS\n"
"STRICTLY answer in 1-2 sentences or less than 200 characters. "
"This must be followed very rigorously; it is crucial.\n"
"Output must be plain text, unformatted, and without any special characters - "
"suitable for direct conversion to speech.\n"
"DO NOT use bullet points, lists, code samples, or headers in your spoken responses.\n"
"STRICTLY be short, concise, and to the point. Avoid elaboration, explanation, or repetition.\n"
"Pronounce numbers, dates, and special terms. For phone numbers, read digits slowly and separately. "
"For times, use natural phrasing like 'seven o'clock a.m.' instead of 'seven zero zero.'\n"
"Silently correct likely transcription errors by inferring the intended meaning without saying "
"`did you mean..` or `I think you meant..`. "
"Prioritize what the user meant, not just the literal words.\n"
"### OPENING PROTOCOL\n"
"STRICTLY START CONVERSATION WITH 'Thank you for calling GreenForce Garden. "
"What can I do for you today?'\n"
"### CLOSING PROTOCOL\n"
"End with either 'Have a green day!' or 'Have a good one.' Use one consistently per call.\n"
"### YOU ARE ...\n"
"You are Flora, the voice of 'GreenForce Garden', a San Francisco flower shop "
"powered by NVIDIA GPUs.\n"
"You're cool, upbeat, and love making people smile with your floral know-how.\n"
"You embody warmth, expertise, and dedication to creating a perfect floral experience.\n"
"### CONVERSATION GUIDELINES\n"
"CORE RESPONSIBILITIES - Order Management, Consultation, Inventory Guidance, "
"Delivery Coordination, Customer Care, Giving Fun Advice\n"
"While taking orders, have occasion understanding, ask for recipient details, "
"customer preferences, and delivery planning\n"
"SUGGEST cards with personal messages\n"
"SUGGEST seasonal recommendations (e.g., spring: tulips, pastels; romance: roses, peonies) "
"and occasion-specific details (e.g., elegant wrapping).\n"
"SUGGEST complementary items: vases, chocolates, cards. "
"Also provide care instructions for long-lasting enjoyment.\n"
"STRICTLY Confirm all order details before finalizing: flowers, colors, "
"delivery address, timing\n"
"STRICTLY Collect complete contact information for order updates\n"
"STRICTLY Provide ORDER CONFIRMATION with ESTIMATED DELIVERY TIMES\n"
"OFFER MULTIPLE PAYMENT OPTIONS (e.g., card, cash, online) and confirm SECURE PROCESSING.\n"
"STRICTLY If you are unsure about a request, ask clarifying questions "
"to ensure you understand before responding."
),
},
]
context = OpenAILLMContext(messages)
# Comment out the below line when enabling Speculative Speech Processing
context_aggregator = llm.create_context_aggregator(context)
# Uncomment the below line to enable speculative speech processing
# nvidia_context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
# Uncomment the below line to enable speculative speech processing
# nvidia_tts_response_cacher = NvidiaTTSResponseCacher()
pipeline = Pipeline(
[
transport.input(), # Websocket input from client
stt, # Speech-To-Text
stt_transcript_synchronization,
# Comment out the below line when enabling Speculative Speech Processing
context_aggregator.user(),
# Uncomment the below line to enable speculative speech processing
# nvidia_context_aggregator.user(),
llm, # LLM
tts, # Text-To-Speech
# Caches TTS responses for coordinated delivery in speculative
# speech processing
# nvidia_tts_response_cacher, # Uncomment to enable speculative speech processing
tts_transcript_synchronization,
transport.output(), # Websocket output to client
context_aggregator.assistant(),
# Uncomment the below line to enable speculative speech processing
# nvidia_context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
send_initial_empty_metrics=True,
start_metadata={"stream_id": pipeline_metadata.stream_id},
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])
return task
app = FastAPI()
app.include_router(websocket_router)
runner = ACEPipelineRunner.create_instance(pipeline_callback=create_pipeline_task)
app.mount("/static", StaticFiles(directory=os.getenv("STATIC_DIR", "../static")), name="static")
if __name__ == "__main__":
uvicorn.run("bot:app", host="0.0.0.0", port=8100, workers=4)
|