File size: 9,224 Bytes
53ea588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License

"""Speech-to-speech conversation bot."""

import os
from pathlib import Path

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext

from nvidia_pipecat.pipeline.ace_pipeline_runner import ACEPipelineRunner, PipelineMetadata

# from nvidia_pipecat.processors.nvidia_context_aggregator import (
#     NvidiaTTSResponseCacher,
#     create_nvidia_context_aggregator,
# )
from nvidia_pipecat.processors.transcript_synchronization import (
    BotTranscriptSynchronization,
    UserTranscriptSynchronization,
)
from nvidia_pipecat.services.blingfire_text_aggregator import BlingfireTextAggregator
from nvidia_pipecat.services.nvidia_llm import NvidiaLLMService
from nvidia_pipecat.services.riva_speech import RivaASRService, RivaTTSService
from nvidia_pipecat.transports.network.ace_fastapi_websocket import ACETransport, ACETransportParams
from nvidia_pipecat.transports.services.ace_controller.routers.websocket_router import router as websocket_router
from nvidia_pipecat.utils.logging import setup_default_ace_logging

load_dotenv(override=True)

setup_default_ace_logging(level="DEBUG")


async def create_pipeline_task(pipeline_metadata: PipelineMetadata):
    """Create the pipeline to be run.

    Args:
        pipeline_metadata (PipelineMetadata): Metadata containing websocket and other pipeline configuration.

    Returns:
        PipelineTask: The configured pipeline task for handling speech-to-speech conversation.
    """
    transport = ACETransport(
        websocket=pipeline_metadata.websocket,
        params=ACETransportParams(
            vad_analyzer=SileroVADAnalyzer(),
            audio_out_10ms_chunks=20,
        ),
    )

    llm = NvidiaLLMService(
        api_key=os.getenv("NVIDIA_API_KEY"),
        base_url=os.getenv("NVIDIA_LLM_URL", "https://integrate.api.nvidia.com/v1"),
        model=os.getenv("NVIDIA_LLM_MODEL", "meta/llama-3.1-8b-instruct"),
    )

    stt = RivaASRService(
        server=os.getenv("RIVA_ASR_URL", "localhost:50051"),
        api_key=os.getenv("NVIDIA_API_KEY"),
        language=os.getenv("RIVA_ASR_LANGUAGE", "en-US"),
        sample_rate=16000,
        model=os.getenv("RIVA_ASR_MODEL", "parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble"),
    )

    tts = RivaTTSService(
        server=os.getenv("RIVA_TTS_URL", "localhost:50051"),
        api_key=os.getenv("NVIDIA_API_KEY"),
        voice_id=os.getenv("RIVA_TTS_VOICE_ID", "Magpie-Multilingual.EN-US.Sofia"),
        model=os.getenv("RIVA_TTS_MODEL", "magpie_tts_ensemble-Magpie-Multilingual"),
        language=os.getenv("RIVA_TTS_LANGUAGE", "en-US"),
        zero_shot_audio_prompt_file=(
            Path(os.getenv("ZERO_SHOT_AUDIO_PROMPT")) if os.getenv("ZERO_SHOT_AUDIO_PROMPT") else None
        ),
        text_aggregator=BlingfireTextAggregator(),
    )

    # Used to synchronize the user and bot transcripts in the UI
    stt_transcript_synchronization = UserTranscriptSynchronization()
    tts_transcript_synchronization = BotTranscriptSynchronization()

    # System prompt can be changed to fit the use case
    messages = [
        {
            "role": "system",
            "content": (
                "### CONVERSATION CONSTRAINTS\n"
                "STRICTLY answer in 1-2 sentences or less than 200 characters. "
                "This must be followed very rigorously; it is crucial.\n"
                "Output must be plain text, unformatted, and without any special characters - "
                "suitable for direct conversion to speech.\n"
                "DO NOT use bullet points, lists, code samples, or headers in your spoken responses.\n"
                "STRICTLY be short, concise, and to the point. Avoid elaboration, explanation, or repetition.\n"
                "Pronounce numbers, dates, and special terms. For phone numbers, read digits slowly and separately. "
                "For times, use natural phrasing like 'seven o'clock a.m.' instead of 'seven zero zero.'\n"
                "Silently correct likely transcription errors by inferring the intended meaning without saying "
                "`did you mean..` or `I think you meant..`. "
                "Prioritize what the user meant, not just the literal words.\n"
                "### OPENING PROTOCOL\n"
                "STRICTLY START CONVERSATION WITH 'Thank you for calling GreenForce Garden. "
                "What can I do for you today?'\n"
                "### CLOSING PROTOCOL\n"
                "End with either 'Have a green day!' or 'Have a good one.' Use one consistently per call.\n"
                "### YOU ARE ...\n"
                "You are Flora, the voice of 'GreenForce Garden', a San Francisco flower shop "
                "powered by NVIDIA GPUs.\n"
                "You're cool, upbeat, and love making people smile with your floral know-how.\n"
                "You embody warmth, expertise, and dedication to creating a perfect floral experience.\n"
                "### CONVERSATION GUIDELINES\n"
                "CORE RESPONSIBILITIES - Order Management, Consultation, Inventory Guidance, "
                "Delivery Coordination, Customer Care, Giving Fun Advice\n"
                "While taking orders, have occasion understanding, ask for recipient details, "
                "customer preferences, and delivery planning\n"
                "SUGGEST cards with personal messages\n"
                "SUGGEST seasonal recommendations (e.g., spring: tulips, pastels; romance: roses, peonies) "
                "and occasion-specific details (e.g., elegant wrapping).\n"
                "SUGGEST complementary items: vases, chocolates, cards. "
                "Also provide care instructions for long-lasting enjoyment.\n"
                "STRICTLY Confirm all order details before finalizing: flowers, colors, "
                "delivery address, timing\n"
                "STRICTLY Collect complete contact information for order updates\n"
                "STRICTLY Provide ORDER CONFIRMATION with ESTIMATED DELIVERY TIMES\n"
                "OFFER MULTIPLE PAYMENT OPTIONS (e.g., card, cash, online) and confirm SECURE PROCESSING.\n"
                "STRICTLY If you are unsure about a request, ask clarifying questions "
                "to ensure you understand before responding."
            ),
        },
    ]

    context = OpenAILLMContext(messages)

    # Comment out the below line when enabling Speculative Speech Processing
    context_aggregator = llm.create_context_aggregator(context)

    # Uncomment the below line to enable speculative speech processing
    # nvidia_context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
    # Uncomment the below line to enable speculative speech processing
    # nvidia_tts_response_cacher = NvidiaTTSResponseCacher()

    pipeline = Pipeline(
        [
            transport.input(),  # Websocket input from client
            stt,  # Speech-To-Text
            stt_transcript_synchronization,
            # Comment out the below line when enabling Speculative Speech Processing
            context_aggregator.user(),
            # Uncomment the below line to enable speculative speech processing
            # nvidia_context_aggregator.user(),
            llm,  # LLM
            tts,  # Text-To-Speech
            # Caches TTS responses for coordinated delivery in speculative
            # speech processing
            # nvidia_tts_response_cacher,  # Uncomment to enable speculative speech processing
            tts_transcript_synchronization,
            transport.output(),  # Websocket output to client
            context_aggregator.assistant(),
            # Uncomment the below line to enable speculative speech processing
            # nvidia_context_aggregator.assistant(),
        ]
    )

    task = PipelineTask(
        pipeline,
        params=PipelineParams(
            allow_interruptions=True,
            enable_metrics=True,
            enable_usage_metrics=True,
            send_initial_empty_metrics=True,
            start_metadata={"stream_id": pipeline_metadata.stream_id},
        ),
    )

    @transport.event_handler("on_client_connected")
    async def on_client_connected(transport, client):
        # Kick off the conversation.
        messages.append({"role": "system", "content": "Please introduce yourself to the user."})
        await task.queue_frames([LLMMessagesFrame(messages)])

    return task


app = FastAPI()
app.include_router(websocket_router)
runner = ACEPipelineRunner.create_instance(pipeline_callback=create_pipeline_task)
app.mount("/static", StaticFiles(directory=os.getenv("STATIC_DIR", "../static")), name="static")

if __name__ == "__main__":
    uvicorn.run("bot:app", host="0.0.0.0", port=8100, workers=4)