Spaces:
Running
Running
File size: 5,081 Bytes
53ea588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""Riva speech langchain bot."""
import os
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from pipecat.processors.frameworks.langchain import LangchainProcessor
from nvidia_pipecat.pipeline.ace_pipeline_runner import ACEPipelineRunner, PipelineMetadata
from nvidia_pipecat.services.riva_speech import (
RivaASRService,
RivaTTSService,
)
from nvidia_pipecat.transports.network.ace_fastapi_websocket import (
ACETransport,
ACETransportParams,
)
from nvidia_pipecat.transports.services.ace_controller.routers.websocket_router import router as websocket_router
from nvidia_pipecat.utils.logging import setup_default_ace_logging
load_dotenv(override=True)
setup_default_ace_logging(level="INFO")
message_store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
"""Get the session history."""
if session_id not in message_store:
message_store[session_id] = ChatMessageHistory()
return message_store[session_id]
async def create_pipeline_task(pipeline_metadata: PipelineMetadata):
"""Create the pipeline to be run.
Args:
pipeline_metadata (PipelineMetadata): Metadata containing websocket and other pipeline configuration.
Returns:
PipelineTask: The configured pipeline task for handling speech-to-speech conversation.
"""
transport = ACETransport(
websocket=pipeline_metadata.websocket,
params=ACETransportParams(
vad_analyzer=SileroVADAnalyzer(),
),
)
stt = RivaASRService(
server="localhost:50051",
api_key=os.getenv("NVIDIA_API_KEY"),
language="en-US",
sample_rate=16000,
model="parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble",
)
tts = RivaTTSService(
server="localhost:50051",
api_key=os.getenv("NVIDIA_API_KEY"),
voice_id="English-US.Female-1",
language="en-US",
zero_shot_quality=20,
sample_rate=16000,
model="fastpitch-hifigan-tts",
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Be nice and helpful. Answer very briefly and without special characters like `#` or `*`. "
"Your response will be synthesized to voice and those characters will create unnatural sounds.",
),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
chain = prompt | ChatOpenAI(model="gpt-4o", temperature=0.7)
history_chain = RunnableWithMessageHistory(
chain,
get_session_history,
history_messages_key="chat_history",
input_messages_key="input",
)
lc = LangchainProcessor(history_chain)
tma_in = LLMUserResponseAggregator()
tma_out = LLMAssistantResponseAggregator()
pipeline = Pipeline(
[
transport.input(), # Websocket input from client
stt, # Speech-To-Text
tma_in, # User responses
lc, # Langchain processor
tts, # Text-To-Speech
transport.output(), # Websocket output to client
tma_out, # LLM responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
send_initial_empty_metrics=True,
report_only_initial_ttfb=True,
start_metadata={"stream_id": pipeline_metadata.stream_id},
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
# Kick off the conversation.
messages = [({"content": "Please briefly introduce yourself to the user."})]
await task.queue_frames([LLMMessagesFrame(messages)])
return task
app = FastAPI()
app.include_router(websocket_router)
runner = ACEPipelineRunner.create_instance(pipeline_callback=create_pipeline_task)
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "../static")), name="static")
if __name__ == "__main__":
uvicorn.run("bot:app", host="0.0.0.0", port=8100, workers=1)
|