Spaces:
Sleeping
feat: Update Dockerfile and requirements.txt to resolve PyAudio build issues
Browse filesThe main changes are:
1. Added `build-essential` and `libasound2-dev` to the system dependencies in the Dockerfile to ensure the necessary build tools are available.
2. Removed PyAudio from the `requirements.txt` file to avoid the pip installation issues.
3. Added a separate `RUN pip install PyAudio==0.2.14` command in the Dockerfile to install PyAudio manually.
These changes should resolve the build issues with PyAudio on the CUDA server.
Revert "fix: Handle CUDA availability in OmniChatServer"
This reverts commit 28ed763269f75cea8298b3d64449fd7776d05f52.
docs: add PyAudio to dependencies
feat: Replace PyAudio with streamlit-webrtc for user recording
fix: Replace PyAudio with streamlit-webrtc for audio recording
feat: Serve HTML demo instead of Streamlit app
fix: Update API_URL and error handling in webui/omni_html_demo.html
fix: Replace audio playback with text-to-speech
feat: Implement audio processing and response generation
fix: Use a Docker data volume for caching
feat: Add Docker data volume and environment variables for caching
diff --git a/inference.py b/inference.py
index 4d4d4d1..d4d4d1a 100644
--- a/inference.py
+++ b/inference.py
@@ -1,6 +1,7 @@
def download_model(ckpt_dir):
repo_id = "gpt-omni/mini-omni"
- snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
+ cache_dir = os.environ.get('XDG_CACHE_HOME', '/tmp')
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main", cache_dir=cache_dir)
rm
fix: Remove cache-related code and update Dockerfile
fix: Add Docker volume and set permissions for model download
fix: Set correct permissions for checkpoint directory
feat: Use DATA volume to store model checkpoint
fix: Set permissions and create necessary subdirectories in the DATA volume
fix: Implement error handling and CUDA Tensor Cores optimization in serve_html.py
fix: Improve error handling and logging in chat endpoint
- Dockerfile +23 -11
- README.md +125 -124
- inference.py +7 -5
- requirements.txt +8 -2
- serve_html.py +70 -0
- server.py +4 -7
- webui/index.html +0 -258
- webui/omni_html_demo.html +13 -8
- webui/omni_streamlit.py +134 -257
|
@@ -7,7 +7,6 @@ WORKDIR /app
|
|
| 7 |
# Install system dependencies
|
| 8 |
RUN apt-get update && apt-get install -y \
|
| 9 |
ffmpeg \
|
| 10 |
-
portaudio19-dev \
|
| 11 |
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
|
| 13 |
# Copy the current directory contents into the container at /app
|
|
@@ -16,20 +15,33 @@ COPY . /app
|
|
| 16 |
# Install any needed packages specified in requirements.txt
|
| 17 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# Make ports 7860 and 60808 available to the world outside this container
|
| 23 |
-
EXPOSE 7860 60808
|
| 24 |
|
| 25 |
# Set environment variable for API_URL
|
| 26 |
-
ENV API_URL=http://0.0.0.0:
|
| 27 |
|
| 28 |
# Set PYTHONPATH
|
| 29 |
ENV PYTHONPATH=./
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
# Run
|
| 35 |
-
CMD ["
|
|
|
|
| 7 |
# Install system dependencies
|
| 8 |
RUN apt-get update && apt-get install -y \
|
| 9 |
ffmpeg \
|
|
|
|
| 10 |
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
|
| 12 |
# Copy the current directory contents into the container at /app
|
|
|
|
| 15 |
# Install any needed packages specified in requirements.txt
|
| 16 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
|
| 18 |
+
# Make port 7860 available to the world outside this container
|
| 19 |
+
EXPOSE 7860
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Set environment variable for API_URL
|
| 22 |
+
ENV API_URL=http://0.0.0.0:7860/chat
|
| 23 |
|
| 24 |
# Set PYTHONPATH
|
| 25 |
ENV PYTHONPATH=./
|
| 26 |
|
| 27 |
+
# Set environment variables
|
| 28 |
+
ENV MPLCONFIGDIR=/tmp/matplotlib
|
| 29 |
+
ENV HF_HOME=/data/huggingface
|
| 30 |
+
ENV XDG_CACHE_HOME=/data/cache
|
| 31 |
+
|
| 32 |
+
# Create a volume for data
|
| 33 |
+
VOLUME /data
|
| 34 |
+
|
| 35 |
+
# Set permissions for the /data directory and create necessary subdirectories
|
| 36 |
+
RUN mkdir -p /data/checkpoint /data/cache /data/huggingface && \
|
| 37 |
+
chown -R 1000:1000 /data && \
|
| 38 |
+
chmod -R 777 /data
|
| 39 |
+
|
| 40 |
+
# Install Flask
|
| 41 |
+
RUN pip install flask
|
| 42 |
+
|
| 43 |
+
# Copy the HTML demo file
|
| 44 |
+
COPY webui/omni_html_demo.html .
|
| 45 |
|
| 46 |
+
# Run the Flask app to serve the HTML demo
|
| 47 |
+
CMD ["python", "serve_html.py"]
|
|
@@ -1,124 +1,125 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Omni Docker
|
| 3 |
-
emoji: 🦀
|
| 4 |
-
colorFrom: green
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
---
|
| 9 |
-
|
| 10 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 11 |
-
|
| 12 |
-
# Mini-Omni
|
| 13 |
-
|
| 14 |
-
<p align="center"><strong style="font-size: 18px;">
|
| 15 |
-
Mini-Omni: Language Models Can Hear, Talk While Thinking in Streaming
|
| 16 |
-
</strong>
|
| 17 |
-
</p>
|
| 18 |
-
|
| 19 |
-
<p align="center">
|
| 20 |
-
🤗 <a href="https://huggingface.co/gpt-omni/mini-omni">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni">Github</a>
|
| 21 |
-
| 📑 <a href="https://arxiv.org/abs/2408.16725">Technical report</a>
|
| 22 |
-
</p>
|
| 23 |
-
|
| 24 |
-
Mini-Omni is an open-source multimodal large language model that can **hear, talk while thinking**. Featuring real-time end-to-end speech input and **streaming audio output** conversational capabilities.
|
| 25 |
-
|
| 26 |
-
<p align="center">
|
| 27 |
-
<img src="data/figures/frameworkv3.jpg" width="100%"/>
|
| 28 |
-
</p>
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
## Features
|
| 32 |
-
|
| 33 |
-
✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required.
|
| 34 |
-
|
| 35 |
-
✅ **Talking while thinking**, with the ability to generate text and audio at the same time.
|
| 36 |
-
|
| 37 |
-
✅ **Streaming audio output** capabilities.
|
| 38 |
-
|
| 39 |
-
✅ With "Audio-to-Text" and "Audio-to-Audio" **batch inference** to further boost the performance.
|
| 40 |
-
|
| 41 |
-
## Demo
|
| 42 |
-
|
| 43 |
-
NOTE: need to unmute first.
|
| 44 |
-
|
| 45 |
-
https://github.com/user-attachments/assets/03bdde05-9514-4748-b527-003bea57f118
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
## Install
|
| 49 |
-
|
| 50 |
-
Create a new conda environment and install the required packages:
|
| 51 |
-
|
| 52 |
-
```sh
|
| 53 |
-
conda create -n omni python=3.10
|
| 54 |
-
conda activate omni
|
| 55 |
-
|
| 56 |
-
git clone https://github.com/gpt-omni/mini-omni.git
|
| 57 |
-
cd mini-omni
|
| 58 |
-
pip install -r requirements.txt
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
- [
|
| 117 |
-
- [
|
| 118 |
-
- [
|
| 119 |
-
- [
|
| 120 |
-
- [
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Omni Docker
|
| 3 |
+
emoji: 🦀
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 11 |
+
|
| 12 |
+
# Mini-Omni
|
| 13 |
+
|
| 14 |
+
<p align="center"><strong style="font-size: 18px;">
|
| 15 |
+
Mini-Omni: Language Models Can Hear, Talk While Thinking in Streaming
|
| 16 |
+
</strong>
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
<p align="center">
|
| 20 |
+
🤗 <a href="https://huggingface.co/gpt-omni/mini-omni">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni">Github</a>
|
| 21 |
+
| 📑 <a href="https://arxiv.org/abs/2408.16725">Technical report</a>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
Mini-Omni is an open-source multimodal large language model that can **hear, talk while thinking**. Featuring real-time end-to-end speech input and **streaming audio output** conversational capabilities.
|
| 25 |
+
|
| 26 |
+
<p align="center">
|
| 27 |
+
<img src="data/figures/frameworkv3.jpg" width="100%"/>
|
| 28 |
+
</p>
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
## Features
|
| 32 |
+
|
| 33 |
+
✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required.
|
| 34 |
+
|
| 35 |
+
✅ **Talking while thinking**, with the ability to generate text and audio at the same time.
|
| 36 |
+
|
| 37 |
+
✅ **Streaming audio output** capabilities.
|
| 38 |
+
|
| 39 |
+
✅ With "Audio-to-Text" and "Audio-to-Audio" **batch inference** to further boost the performance.
|
| 40 |
+
|
| 41 |
+
## Demo
|
| 42 |
+
|
| 43 |
+
NOTE: need to unmute first.
|
| 44 |
+
|
| 45 |
+
https://github.com/user-attachments/assets/03bdde05-9514-4748-b527-003bea57f118
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## Install
|
| 49 |
+
|
| 50 |
+
Create a new conda environment and install the required packages:
|
| 51 |
+
|
| 52 |
+
```sh
|
| 53 |
+
conda create -n omni python=3.10
|
| 54 |
+
conda activate omni
|
| 55 |
+
|
| 56 |
+
git clone https://github.com/gpt-omni/mini-omni.git
|
| 57 |
+
cd mini-omni
|
| 58 |
+
pip install -r requirements.txt
|
| 59 |
+
pip install PyAudio==0.2.14
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Quick start
|
| 63 |
+
|
| 64 |
+
**Interactive demo**
|
| 65 |
+
|
| 66 |
+
- start server
|
| 67 |
+
|
| 68 |
+
NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
|
| 69 |
+
|
| 70 |
+
```sh
|
| 71 |
+
sudo apt-get install ffmpeg
|
| 72 |
+
conda activate omni
|
| 73 |
+
cd mini-omni
|
| 74 |
+
python3 server.py --ip '0.0.0.0' --port 60808
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
- run streamlit demo
|
| 79 |
+
|
| 80 |
+
NOTE: you need to run streamlit locally with PyAudio installed. For error: `ModuleNotFoundError: No module named 'utils.vad'`, please run `export PYTHONPATH=./` first.
|
| 81 |
+
|
| 82 |
+
```sh
|
| 83 |
+
pip install PyAudio==0.2.14
|
| 84 |
+
API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
- run gradio demo
|
| 88 |
+
```sh
|
| 89 |
+
API_URL=http://0.0.0.0:60808/chat python3 webui/omni_gradio.py
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
example:
|
| 93 |
+
|
| 94 |
+
NOTE: need to unmute first. Gradio seems can not play audio stream instantly, so the latency feels a bit longer.
|
| 95 |
+
|
| 96 |
+
https://github.com/user-attachments/assets/29187680-4c42-47ff-b352-f0ea333496d9
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
**Local test**
|
| 100 |
+
|
| 101 |
+
```sh
|
| 102 |
+
conda activate omni
|
| 103 |
+
cd mini-omni
|
| 104 |
+
# test run the preset audio samples and questions
|
| 105 |
+
python inference.py
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Common issues
|
| 109 |
+
|
| 110 |
+
- Error: `ModuleNotFoundError: No module named 'utils.xxxx'`
|
| 111 |
+
|
| 112 |
+
Answer: run `export PYTHONPATH=./` first.
|
| 113 |
+
|
| 114 |
+
## Acknowledgements
|
| 115 |
+
|
| 116 |
+
- [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
|
| 117 |
+
- [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
|
| 118 |
+
- [whisper](https://github.com/openai/whisper/) for audio encoding.
|
| 119 |
+
- [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
|
| 120 |
+
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
|
| 121 |
+
- [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
|
| 122 |
+
|
| 123 |
+
## Star History
|
| 124 |
+
|
| 125 |
+
[](https://star-history.com/#gpt-omni/mini-omni&Date)
|
|
@@ -7,6 +7,8 @@ from litgpt import Tokenizer
|
|
| 7 |
from litgpt.utils import (
|
| 8 |
num_parameters,
|
| 9 |
)
|
|
|
|
|
|
|
| 10 |
from litgpt.generate.base import (
|
| 11 |
generate_AA,
|
| 12 |
generate_ASR,
|
|
@@ -347,8 +349,8 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
|
| 347 |
|
| 348 |
|
| 349 |
def load_model(ckpt_dir, device):
|
| 350 |
-
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
| 351 |
-
whispermodel = whisper.load_model("small").to(device)
|
| 352 |
text_tokenizer = Tokenizer(ckpt_dir)
|
| 353 |
fabric = L.Fabric(devices=1, strategy="auto")
|
| 354 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
|
@@ -367,12 +369,12 @@ def load_model(ckpt_dir, device):
|
|
| 367 |
|
| 368 |
def download_model(ckpt_dir):
|
| 369 |
repo_id = "gpt-omni/mini-omni"
|
| 370 |
-
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
| 371 |
|
| 372 |
|
| 373 |
class OmniInference:
|
| 374 |
|
| 375 |
-
def __init__(self, ckpt_dir='
|
| 376 |
self.device = device
|
| 377 |
if not os.path.exists(ckpt_dir):
|
| 378 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
|
@@ -508,7 +510,7 @@ class OmniInference:
|
|
| 508 |
def test_infer():
|
| 509 |
device = "cuda:0"
|
| 510 |
out_dir = f"./output/{get_time_str()}"
|
| 511 |
-
ckpt_dir = f"
|
| 512 |
if not os.path.exists(ckpt_dir):
|
| 513 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
| 514 |
download_model(ckpt_dir)
|
|
|
|
| 7 |
from litgpt.utils import (
|
| 8 |
num_parameters,
|
| 9 |
)
|
| 10 |
+
import matplotlib
|
| 11 |
+
matplotlib.use('Agg') # Use a non-GUI backend
|
| 12 |
from litgpt.generate.base import (
|
| 13 |
generate_AA,
|
| 14 |
generate_ASR,
|
|
|
|
| 349 |
|
| 350 |
|
| 351 |
def load_model(ckpt_dir, device):
|
| 352 |
+
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz", cache_dir="/data/cache/snac").eval().to(device)
|
| 353 |
+
whispermodel = whisper.load_model("small", download_root="/data/cache/whisper").to(device)
|
| 354 |
text_tokenizer = Tokenizer(ckpt_dir)
|
| 355 |
fabric = L.Fabric(devices=1, strategy="auto")
|
| 356 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
|
|
|
| 369 |
|
| 370 |
def download_model(ckpt_dir):
|
| 371 |
repo_id = "gpt-omni/mini-omni"
|
| 372 |
+
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main", cache_dir="/data/huggingface")
|
| 373 |
|
| 374 |
|
| 375 |
class OmniInference:
|
| 376 |
|
| 377 |
+
def __init__(self, ckpt_dir='/data/checkpoint', device='cuda:0'):
|
| 378 |
self.device = device
|
| 379 |
if not os.path.exists(ckpt_dir):
|
| 380 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
|
|
|
| 510 |
def test_infer():
|
| 511 |
device = "cuda:0"
|
| 512 |
out_dir = f"./output/{get_time_str()}"
|
| 513 |
+
ckpt_dir = f"/data/checkpoint"
|
| 514 |
if not os.path.exists(ckpt_dir):
|
| 515 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
| 516 |
download_model(ckpt_dir)
|
|
@@ -6,8 +6,14 @@ snac==1.2.0
|
|
| 6 |
soundfile==0.12.1
|
| 7 |
openai-whisper==20231117
|
| 8 |
tokenizers==0.15.2
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
pydub==0.25.1
|
| 12 |
onnxruntime==1.17.1
|
| 13 |
numpy==1.26.4
|
|
|
|
| 6 |
soundfile==0.12.1
|
| 7 |
openai-whisper==20231117
|
| 8 |
tokenizers==0.15.2
|
| 9 |
+
torch==2.2.1
|
| 10 |
+
torchvision==0.17.1
|
| 11 |
+
torchaudio==2.2.1
|
| 12 |
+
litgpt==0.4.3
|
| 13 |
+
snac==1.2.0
|
| 14 |
+
soundfile==0.12.1
|
| 15 |
+
openai-whisper==20231117
|
| 16 |
+
tokenizers==0.15.2
|
| 17 |
pydub==0.25.1
|
| 18 |
onnxruntime==1.17.1
|
| 19 |
numpy==1.26.4
|
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
torch.set_float32_matmul_precision('high')
|
| 3 |
+
|
| 4 |
+
from flask import Flask, send_from_directory, request, Response
|
| 5 |
+
import os
|
| 6 |
+
import base64
|
| 7 |
+
import numpy as np
|
| 8 |
+
from inference import OmniInference
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
app = Flask(__name__)
|
| 12 |
+
|
| 13 |
+
# Initialize OmniInference
|
| 14 |
+
try:
|
| 15 |
+
print("Initializing OmniInference...")
|
| 16 |
+
omni = OmniInference()
|
| 17 |
+
print("OmniInference initialized successfully.")
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Error initializing OmniInference: {str(e)}")
|
| 20 |
+
raise
|
| 21 |
+
|
| 22 |
+
@app.route('/')
|
| 23 |
+
def serve_html():
|
| 24 |
+
return send_from_directory('.', 'webui/omni_html_demo.html')
|
| 25 |
+
|
| 26 |
+
@app.route('/chat', methods=['POST'])
|
| 27 |
+
def chat():
|
| 28 |
+
try:
|
| 29 |
+
audio_data = request.json['audio']
|
| 30 |
+
if not audio_data:
|
| 31 |
+
return "No audio data received", 400
|
| 32 |
+
|
| 33 |
+
# Check if the audio_data contains the expected base64 prefix
|
| 34 |
+
if ',' in audio_data:
|
| 35 |
+
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
| 36 |
+
else:
|
| 37 |
+
audio_bytes = base64.b64decode(audio_data)
|
| 38 |
+
|
| 39 |
+
# Save audio to a temporary file
|
| 40 |
+
temp_audio_path = 'temp_audio.wav'
|
| 41 |
+
with open(temp_audio_path, 'wb') as f:
|
| 42 |
+
f.write(audio_bytes)
|
| 43 |
+
|
| 44 |
+
# Generate response using OmniInference
|
| 45 |
+
try:
|
| 46 |
+
response_generator = omni.run_AT_batch_stream(temp_audio_path)
|
| 47 |
+
|
| 48 |
+
# Concatenate all audio chunks
|
| 49 |
+
all_audio = b''
|
| 50 |
+
for audio_chunk in response_generator:
|
| 51 |
+
all_audio += audio_chunk
|
| 52 |
+
|
| 53 |
+
# Clean up temporary file
|
| 54 |
+
os.remove(temp_audio_path)
|
| 55 |
+
|
| 56 |
+
return Response(all_audio, mimetype='audio/wav')
|
| 57 |
+
except Exception as inner_e:
|
| 58 |
+
print(f"Error in OmniInference processing: {str(inner_e)}")
|
| 59 |
+
return f"An error occurred during audio processing: {str(inner_e)}", 500
|
| 60 |
+
finally:
|
| 61 |
+
# Ensure temporary file is removed even if an error occurs
|
| 62 |
+
if os.path.exists(temp_audio_path):
|
| 63 |
+
os.remove(temp_audio_path)
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error in chat endpoint: {str(e)}")
|
| 67 |
+
return f"An error occurred: {str(e)}", 500
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
app.run(host='0.0.0.0', port=7860)
|
|
@@ -2,21 +2,17 @@ import flask
|
|
| 2 |
import base64
|
| 3 |
import tempfile
|
| 4 |
import traceback
|
| 5 |
-
import torch
|
| 6 |
from flask import Flask, Response, stream_with_context
|
| 7 |
from inference import OmniInference
|
| 8 |
|
| 9 |
|
| 10 |
class OmniChatServer(object):
|
| 11 |
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
|
| 12 |
-
ckpt_dir='./checkpoint', device=
|
| 13 |
server = Flask(__name__)
|
| 14 |
# CORS(server, resources=r"/*")
|
| 15 |
# server.config["JSON_AS_ASCII"] = False
|
| 16 |
|
| 17 |
-
if device is None:
|
| 18 |
-
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 19 |
-
|
| 20 |
self.client = OmniInference(ckpt_dir, device)
|
| 21 |
self.client.warm_up()
|
| 22 |
|
|
@@ -50,8 +46,9 @@ def create_app():
|
|
| 50 |
return server.server
|
| 51 |
|
| 52 |
|
| 53 |
-
def serve(ip='0.0.0.0', port=60808, device=
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import base64
|
| 3 |
import tempfile
|
| 4 |
import traceback
|
|
|
|
| 5 |
from flask import Flask, Response, stream_with_context
|
| 6 |
from inference import OmniInference
|
| 7 |
|
| 8 |
|
| 9 |
class OmniChatServer(object):
|
| 10 |
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
|
| 11 |
+
ckpt_dir='./checkpoint', device='cuda:0') -> None:
|
| 12 |
server = Flask(__name__)
|
| 13 |
# CORS(server, resources=r"/*")
|
| 14 |
# server.config["JSON_AS_ASCII"] = False
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
self.client = OmniInference(ckpt_dir, device)
|
| 17 |
self.client.warm_up()
|
| 18 |
|
|
|
|
| 46 |
return server.server
|
| 47 |
|
| 48 |
|
| 49 |
+
def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
|
| 50 |
+
|
| 51 |
+
OmniChatServer(ip, port=port,run_app=True, device=device)
|
| 52 |
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
|
@@ -1,258 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset="UTF-8">
|
| 5 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
-
<title>Mini-Omni Chat Demo</title>
|
| 7 |
-
<style>
|
| 8 |
-
body {
|
| 9 |
-
background-color: black;
|
| 10 |
-
color: white;
|
| 11 |
-
font-family: Arial, sans-serif;
|
| 12 |
-
}
|
| 13 |
-
#chat-container {
|
| 14 |
-
height: 300px;
|
| 15 |
-
overflow-y: auto;
|
| 16 |
-
border: 1px solid #444;
|
| 17 |
-
padding: 10px;
|
| 18 |
-
margin-bottom: 10px;
|
| 19 |
-
}
|
| 20 |
-
#status-message {
|
| 21 |
-
margin-bottom: 10px;
|
| 22 |
-
}
|
| 23 |
-
button {
|
| 24 |
-
margin-right: 10px;
|
| 25 |
-
}
|
| 26 |
-
</style>
|
| 27 |
-
</head>
|
| 28 |
-
<body>
|
| 29 |
-
<div id="svg-container"></div>
|
| 30 |
-
<div id="chat-container"></div>
|
| 31 |
-
<div id="status-message">Current status: idle</div>
|
| 32 |
-
<button id="start-button">Start</button>
|
| 33 |
-
<button id="stop-button" disabled>Stop</button>
|
| 34 |
-
<main>
|
| 35 |
-
<p id="current-status">Current status: idle</p>
|
| 36 |
-
</main>
|
| 37 |
-
</body>
|
| 38 |
-
<script>
|
| 39 |
-
// Load the SVG
|
| 40 |
-
const svgContainer = document.getElementById('svg-container');
|
| 41 |
-
const svgContent = `
|
| 42 |
-
<svg width="800" height="600" viewBox="0 0 800 600" xmlns="http://www.w3.org/2000/svg">
|
| 43 |
-
<ellipse id="left-eye" cx="340" cy="200" rx="20" ry="20" fill="white"/>
|
| 44 |
-
<circle id="left-pupil" cx="340" cy="200" r="8" fill="black"/>
|
| 45 |
-
<ellipse id="right-eye" cx="460" cy="200" rx="20" ry="20" fill="white"/>
|
| 46 |
-
<circle id="right-pupil" cx="460" cy="200" r="8" fill="black"/>
|
| 47 |
-
<path id="upper-lip" d="M 300 300 C 350 284, 450 284, 500 300" stroke="white" stroke-width="10" fill="none"/>
|
| 48 |
-
<path id="lower-lip" d="M 300 300 C 350 316, 450 316, 500 300" stroke="white" stroke-width="10" fill="none"/>
|
| 49 |
-
</svg>`;
|
| 50 |
-
svgContainer.innerHTML = svgContent;
|
| 51 |
-
|
| 52 |
-
// Set up audio context
|
| 53 |
-
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
| 54 |
-
const analyser = audioContext.createAnalyser();
|
| 55 |
-
analyser.fftSize = 256;
|
| 56 |
-
|
| 57 |
-
// Animation variables
|
| 58 |
-
let isAudioPlaying = false;
|
| 59 |
-
let lastBlinkTime = 0;
|
| 60 |
-
let eyeMovementOffset = { x: 0, y: 0 };
|
| 61 |
-
|
| 62 |
-
// Chat variables
|
| 63 |
-
let mediaRecorder;
|
| 64 |
-
let audioChunks = [];
|
| 65 |
-
let isRecording = false;
|
| 66 |
-
const API_URL = 'http://127.0.0.1:60808/chat';
|
| 67 |
-
|
| 68 |
-
// Idle eye animation function
|
| 69 |
-
function animateIdleEyes(timestamp) {
|
| 70 |
-
const leftEye = document.getElementById('left-eye');
|
| 71 |
-
const rightEye = document.getElementById('right-eye');
|
| 72 |
-
const leftPupil = document.getElementById('left-pupil');
|
| 73 |
-
const rightPupil = document.getElementById('right-pupil');
|
| 74 |
-
const baseEyeX = { left: 340, right: 460 };
|
| 75 |
-
const baseEyeY = 200;
|
| 76 |
-
|
| 77 |
-
// Blink effect
|
| 78 |
-
const blinkInterval = 4000 + Math.random() * 2000; // Random blink interval between 4-6 seconds
|
| 79 |
-
if (timestamp - lastBlinkTime > blinkInterval) {
|
| 80 |
-
leftEye.setAttribute('ry', '2');
|
| 81 |
-
rightEye.setAttribute('ry', '2');
|
| 82 |
-
leftPupil.setAttribute('ry', '0.8');
|
| 83 |
-
rightPupil.setAttribute('ry', '0.8');
|
| 84 |
-
setTimeout(() => {
|
| 85 |
-
leftEye.setAttribute('ry', '20');
|
| 86 |
-
rightEye.setAttribute('ry', '20');
|
| 87 |
-
leftPupil.setAttribute('ry', '8');
|
| 88 |
-
rightPupil.setAttribute('ry', '8');
|
| 89 |
-
}, 150);
|
| 90 |
-
lastBlinkTime = timestamp;
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
// Subtle eye movement
|
| 94 |
-
const movementSpeed = 0.001;
|
| 95 |
-
eyeMovementOffset.x = Math.sin(timestamp * movementSpeed) * 6;
|
| 96 |
-
eyeMovementOffset.y = Math.cos(timestamp * movementSpeed * 1.3) * 1; // Reduced vertical movement
|
| 97 |
-
|
| 98 |
-
leftEye.setAttribute('cx', baseEyeX.left + eyeMovementOffset.x);
|
| 99 |
-
leftEye.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
| 100 |
-
rightEye.setAttribute('cx', baseEyeX.right + eyeMovementOffset.x);
|
| 101 |
-
rightEye.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
| 102 |
-
leftPupil.setAttribute('cx', baseEyeX.left + eyeMovementOffset.x);
|
| 103 |
-
leftPupil.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
| 104 |
-
rightPupil.setAttribute('cx', baseEyeX.right + eyeMovementOffset.x);
|
| 105 |
-
rightPupil.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
// Main animation function
|
| 109 |
-
function animate(timestamp) {
|
| 110 |
-
if (isAudioPlaying) {
|
| 111 |
-
const dataArray = new Uint8Array(analyser.frequencyBinCount);
|
| 112 |
-
analyser.getByteFrequencyData(dataArray);
|
| 113 |
-
|
| 114 |
-
// Calculate the average amplitude in the speech frequency range
|
| 115 |
-
const speechRange = dataArray.slice(5, 80); // Adjust based on your needs
|
| 116 |
-
const averageAmplitude = speechRange.reduce((a, b) => a + b) / speechRange.length;
|
| 117 |
-
|
| 118 |
-
// Normalize the amplitude (0-1 range)
|
| 119 |
-
const normalizedAmplitude = averageAmplitude / 255;
|
| 120 |
-
|
| 121 |
-
// Animate mouth
|
| 122 |
-
const upperLip = document.getElementById('upper-lip');
|
| 123 |
-
const lowerLip = document.getElementById('lower-lip');
|
| 124 |
-
const baseY = 300;
|
| 125 |
-
const maxMovement = 60;
|
| 126 |
-
const newUpperY = baseY - normalizedAmplitude * maxMovement;
|
| 127 |
-
const newLowerY = baseY + normalizedAmplitude * maxMovement;
|
| 128 |
-
|
| 129 |
-
// Adjust control points for more natural movement
|
| 130 |
-
const upperControlY1 = newUpperY - 8;
|
| 131 |
-
const upperControlY2 = newUpperY - 8;
|
| 132 |
-
const lowerControlY1 = newLowerY + 8;
|
| 133 |
-
const lowerControlY2 = newLowerY + 8;
|
| 134 |
-
|
| 135 |
-
upperLip.setAttribute('d', `M 300 ${baseY} C 350 ${upperControlY1}, 450 ${upperControlY2}, 500 ${baseY}`);
|
| 136 |
-
lowerLip.setAttribute('d', `M 300 ${baseY} C 350 ${lowerControlY1}, 450 ${lowerControlY2}, 500 ${baseY}`);
|
| 137 |
-
|
| 138 |
-
// Animate eyes
|
| 139 |
-
const leftEye = document.getElementById('left-eye');
|
| 140 |
-
const rightEye = document.getElementById('right-eye');
|
| 141 |
-
const leftPupil = document.getElementById('left-pupil');
|
| 142 |
-
const rightPupil = document.getElementById('right-pupil');
|
| 143 |
-
const baseEyeY = 200;
|
| 144 |
-
const maxEyeMovement = 10;
|
| 145 |
-
const newEyeY = baseEyeY - normalizedAmplitude * maxEyeMovement;
|
| 146 |
-
|
| 147 |
-
leftEye.setAttribute('cy', newEyeY);
|
| 148 |
-
rightEye.setAttribute('cy', newEyeY);
|
| 149 |
-
leftPupil.setAttribute('cy', newEyeY);
|
| 150 |
-
rightPupil.setAttribute('cy', newEyeY);
|
| 151 |
-
} else {
|
| 152 |
-
animateIdleEyes(timestamp);
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
requestAnimationFrame(animate);
|
| 156 |
-
}
|
| 157 |
-
|
| 158 |
-
// Start animation
|
| 159 |
-
animate();
|
| 160 |
-
|
| 161 |
-
// Chat functions
|
| 162 |
-
function startRecording() {
|
| 163 |
-
navigator.mediaDevices.getUserMedia({ audio: true })
|
| 164 |
-
.then(stream => {
|
| 165 |
-
mediaRecorder = new MediaRecorder(stream);
|
| 166 |
-
mediaRecorder.ondataavailable = event => {
|
| 167 |
-
audioChunks.push(event.data);
|
| 168 |
-
};
|
| 169 |
-
mediaRecorder.onstop = sendAudioToServer;
|
| 170 |
-
mediaRecorder.start();
|
| 171 |
-
isRecording = true;
|
| 172 |
-
updateStatus('Recording...');
|
| 173 |
-
document.getElementById('start-button').disabled = true;
|
| 174 |
-
document.getElementById('stop-button').disabled = false;
|
| 175 |
-
})
|
| 176 |
-
.catch(error => {
|
| 177 |
-
console.error('Error accessing microphone:', error);
|
| 178 |
-
updateStatus('Error: ' + error.message);
|
| 179 |
-
});
|
| 180 |
-
}
|
| 181 |
-
|
| 182 |
-
function stopRecording() {
|
| 183 |
-
if (mediaRecorder && isRecording) {
|
| 184 |
-
mediaRecorder.stop();
|
| 185 |
-
isRecording = false;
|
| 186 |
-
updateStatus('Processing...');
|
| 187 |
-
document.getElementById('start-button').disabled = false;
|
| 188 |
-
document.getElementById('stop-button').disabled = true;
|
| 189 |
-
}
|
| 190 |
-
}
|
| 191 |
-
|
| 192 |
-
function sendAudioToServer() {
|
| 193 |
-
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
| 194 |
-
const reader = new FileReader();
|
| 195 |
-
reader.readAsDataURL(audioBlob);
|
| 196 |
-
reader.onloadend = function() {
|
| 197 |
-
const base64Audio = reader.result.split(',')[1];
|
| 198 |
-
fetch(API_URL, {
|
| 199 |
-
method: 'POST',
|
| 200 |
-
headers: {
|
| 201 |
-
'Content-Type': 'application/json',
|
| 202 |
-
},
|
| 203 |
-
body: JSON.stringify({ audio: base64Audio }),
|
| 204 |
-
})
|
| 205 |
-
.then(response => response.blob())
|
| 206 |
-
.then(blob => {
|
| 207 |
-
const audioUrl = URL.createObjectURL(blob);
|
| 208 |
-
playResponseAudio(audioUrl);
|
| 209 |
-
updateChatHistory('User', 'Audio message sent');
|
| 210 |
-
updateChatHistory('Assistant', 'Audio response received');
|
| 211 |
-
})
|
| 212 |
-
.catch(error => {
|
| 213 |
-
console.error('Error:', error);
|
| 214 |
-
updateStatus('Error: ' + error.message);
|
| 215 |
-
});
|
| 216 |
-
};
|
| 217 |
-
audioChunks = [];
|
| 218 |
-
}
|
| 219 |
-
|
| 220 |
-
function playResponseAudio(audioUrl) {
|
| 221 |
-
const audio = new Audio(audioUrl);
|
| 222 |
-
audio.onloadedmetadata = () => {
|
| 223 |
-
const source = audioContext.createMediaElementSource(audio);
|
| 224 |
-
source.connect(analyser);
|
| 225 |
-
analyser.connect(audioContext.destination);
|
| 226 |
-
};
|
| 227 |
-
audio.onplay = () => {
|
| 228 |
-
isAudioPlaying = true;
|
| 229 |
-
updateStatus('Playing response...');
|
| 230 |
-
};
|
| 231 |
-
audio.onended = () => {
|
| 232 |
-
isAudioPlaying = false;
|
| 233 |
-
updateStatus('Idle');
|
| 234 |
-
};
|
| 235 |
-
audio.play();
|
| 236 |
-
}
|
| 237 |
-
|
| 238 |
-
function updateChatHistory(role, message) {
|
| 239 |
-
const chatContainer = document.getElementById('chat-container');
|
| 240 |
-
const messageElement = document.createElement('p');
|
| 241 |
-
messageElement.textContent = `${role}: ${message}`;
|
| 242 |
-
chatContainer.appendChild(messageElement);
|
| 243 |
-
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 244 |
-
}
|
| 245 |
-
|
| 246 |
-
function updateStatus(status) {
|
| 247 |
-
document.getElementById('status-message').textContent = status;
|
| 248 |
-
document.getElementById('current-status').textContent = 'Current status: ' + status;
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
// Event listeners
|
| 252 |
-
document.getElementById('start-button').addEventListener('click', startRecording);
|
| 253 |
-
document.getElementById('stop-button').addEventListener('click', stopRecording);
|
| 254 |
-
|
| 255 |
-
// Initialize
|
| 256 |
-
updateStatus('Idle');
|
| 257 |
-
</script>
|
| 258 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -21,7 +21,7 @@
|
|
| 21 |
<audio id="audioPlayback" controls style="display:none;"></audio>
|
| 22 |
|
| 23 |
<script>
|
| 24 |
-
const API_URL = '
|
| 25 |
const recordButton = document.getElementById('recordButton');
|
| 26 |
const chatHistory = document.getElementById('chatHistory');
|
| 27 |
const audioPlayback = document.getElementById('audioPlayback');
|
|
@@ -86,12 +86,13 @@
|
|
| 86 |
}
|
| 87 |
});
|
| 88 |
|
| 89 |
-
const
|
| 90 |
-
const
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
} else {
|
| 96 |
console.error('API response not ok:', response.status);
|
| 97 |
updateChatHistory('AI', 'Error in API response');
|
|
@@ -99,7 +100,11 @@
|
|
| 99 |
};
|
| 100 |
} catch (error) {
|
| 101 |
console.error('Error sending audio to API:', error);
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
}
|
| 104 |
}
|
| 105 |
|
|
|
|
| 21 |
<audio id="audioPlayback" controls style="display:none;"></audio>
|
| 22 |
|
| 23 |
<script>
|
| 24 |
+
const API_URL = '/chat';
|
| 25 |
const recordButton = document.getElementById('recordButton');
|
| 26 |
const chatHistory = document.getElementById('chatHistory');
|
| 27 |
const audioPlayback = document.getElementById('audioPlayback');
|
|
|
|
| 86 |
}
|
| 87 |
});
|
| 88 |
|
| 89 |
+
const responseBlob = await new Response(stream).blob();
|
| 90 |
+
const audioUrl = URL.createObjectURL(responseBlob);
|
| 91 |
+
updateChatHistory('AI', audioUrl);
|
| 92 |
+
|
| 93 |
+
// Play the audio response
|
| 94 |
+
const audio = new Audio(audioUrl);
|
| 95 |
+
audio.play();
|
| 96 |
} else {
|
| 97 |
console.error('API response not ok:', response.status);
|
| 98 |
updateChatHistory('AI', 'Error in API response');
|
|
|
|
| 100 |
};
|
| 101 |
} catch (error) {
|
| 102 |
console.error('Error sending audio to API:', error);
|
| 103 |
+
if (error.name === 'TypeError' && error.message === 'Failed to fetch') {
|
| 104 |
+
updateChatHistory('AI', 'Error: Unable to connect to the server. Please ensure the server is running and accessible.');
|
| 105 |
+
} else {
|
| 106 |
+
updateChatHistory('AI', 'Error communicating with the server: ' + error.message);
|
| 107 |
+
}
|
| 108 |
}
|
| 109 |
}
|
| 110 |
|
|
@@ -1,257 +1,134 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
import
|
| 10 |
-
import
|
| 11 |
-
import
|
| 12 |
-
import
|
| 13 |
-
import
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
stream.write(audio_data)
|
| 136 |
-
except Exception as e:
|
| 137 |
-
st.error(f"Error during audio streaming: {e}")
|
| 138 |
-
|
| 139 |
-
out_file = save_tmp_audio(output_audio_bytes)
|
| 140 |
-
with st.chat_message("assistant"):
|
| 141 |
-
st.audio(out_file, format="audio/wav", loop=False, autoplay=False)
|
| 142 |
-
st.session_state.messages.append(
|
| 143 |
-
{"role": "assistant", "content": out_file, "type": "audio"}
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
wf.close()
|
| 147 |
-
# Close PyAudio stream and terminate PyAudio
|
| 148 |
-
stream.stop_stream()
|
| 149 |
-
stream.close()
|
| 150 |
-
p.terminate()
|
| 151 |
-
st.session_state.speaking = False
|
| 152 |
-
st.session_state.recording = True
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def recording(status):
|
| 156 |
-
audio = pyaudio.PyAudio()
|
| 157 |
-
|
| 158 |
-
stream = audio.open(
|
| 159 |
-
format=IN_FORMAT,
|
| 160 |
-
channels=IN_CHANNELS,
|
| 161 |
-
rate=IN_RATE,
|
| 162 |
-
input=True,
|
| 163 |
-
frames_per_buffer=IN_CHUNK,
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
temp_audio = b""
|
| 167 |
-
vad_audio = b""
|
| 168 |
-
|
| 169 |
-
start_talking = False
|
| 170 |
-
last_temp_audio = None
|
| 171 |
-
st.session_state.frames = []
|
| 172 |
-
|
| 173 |
-
while st.session_state.recording:
|
| 174 |
-
status.success("Listening...")
|
| 175 |
-
audio_bytes = stream.read(IN_CHUNK)
|
| 176 |
-
temp_audio += audio_bytes
|
| 177 |
-
|
| 178 |
-
if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE:
|
| 179 |
-
dur_vad, vad_audio_bytes, time_vad = run_vad(temp_audio, IN_RATE)
|
| 180 |
-
|
| 181 |
-
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
|
| 182 |
-
|
| 183 |
-
if dur_vad > 0.2 and not start_talking:
|
| 184 |
-
if last_temp_audio is not None:
|
| 185 |
-
st.session_state.frames.append(last_temp_audio)
|
| 186 |
-
start_talking = True
|
| 187 |
-
if start_talking:
|
| 188 |
-
st.session_state.frames.append(temp_audio)
|
| 189 |
-
if dur_vad < 0.1 and start_talking:
|
| 190 |
-
st.session_state.recording = False
|
| 191 |
-
print(f"speech end detected. excit")
|
| 192 |
-
last_temp_audio = temp_audio
|
| 193 |
-
temp_audio = b""
|
| 194 |
-
|
| 195 |
-
stream.stop_stream()
|
| 196 |
-
stream.close()
|
| 197 |
-
|
| 198 |
-
audio.terminate()
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def main():
|
| 202 |
-
|
| 203 |
-
st.title("Chat Mini-Omni Demo")
|
| 204 |
-
status = st.empty()
|
| 205 |
-
|
| 206 |
-
if "warm_up" not in st.session_state:
|
| 207 |
-
warm_up()
|
| 208 |
-
st.session_state.warm_up = True
|
| 209 |
-
if "start" not in st.session_state:
|
| 210 |
-
st.session_state.start = False
|
| 211 |
-
if "recording" not in st.session_state:
|
| 212 |
-
st.session_state.recording = False
|
| 213 |
-
if "speaking" not in st.session_state:
|
| 214 |
-
st.session_state.speaking = False
|
| 215 |
-
if "frames" not in st.session_state:
|
| 216 |
-
st.session_state.frames = []
|
| 217 |
-
|
| 218 |
-
if not st.session_state.start:
|
| 219 |
-
status.warning("Click Start to chat")
|
| 220 |
-
|
| 221 |
-
start_col, stop_col, _ = st.columns([0.2, 0.2, 0.6])
|
| 222 |
-
start_button = start_col.button("Start", key="start_button")
|
| 223 |
-
# stop_button = stop_col.button("Stop", key="stop_button")
|
| 224 |
-
if start_button:
|
| 225 |
-
time.sleep(1)
|
| 226 |
-
st.session_state.recording = True
|
| 227 |
-
st.session_state.start = True
|
| 228 |
-
|
| 229 |
-
for message in st.session_state.messages:
|
| 230 |
-
with st.chat_message(message["role"]):
|
| 231 |
-
if message["type"] == "msg":
|
| 232 |
-
st.markdown(message["content"])
|
| 233 |
-
elif message["type"] == "img":
|
| 234 |
-
st.image(message["content"], width=300)
|
| 235 |
-
elif message["type"] == "audio":
|
| 236 |
-
st.audio(
|
| 237 |
-
message["content"], format="audio/wav", loop=False, autoplay=False
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
while st.session_state.start:
|
| 241 |
-
if st.session_state.recording:
|
| 242 |
-
recording(status)
|
| 243 |
-
|
| 244 |
-
if not st.session_state.recording and st.session_state.start:
|
| 245 |
-
st.session_state.speaking = True
|
| 246 |
-
speaking(status)
|
| 247 |
-
|
| 248 |
-
# if stop_button:
|
| 249 |
-
# status.warning("Stopped, click Start to chat")
|
| 250 |
-
# st.session_state.start = False
|
| 251 |
-
# st.session_state.recording = False
|
| 252 |
-
# st.session_state.frames = []
|
| 253 |
-
# break
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
if __name__ == "__main__":
|
| 257 |
-
main()
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import requests
|
| 4 |
+
import base64
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import traceback
|
| 9 |
+
import librosa
|
| 10 |
+
from pydub import AudioSegment
|
| 11 |
+
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
| 12 |
+
import av
|
| 13 |
+
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
|
| 14 |
+
|
| 15 |
+
API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")
|
| 16 |
+
|
| 17 |
+
# Initialize chat history
|
| 18 |
+
if "messages" not in st.session_state:
|
| 19 |
+
st.session_state.messages = []
|
| 20 |
+
|
| 21 |
+
def run_vad(audio, sr):
|
| 22 |
+
_st = time.time()
|
| 23 |
+
try:
|
| 24 |
+
audio = audio.astype(np.float32) / 32768.0
|
| 25 |
+
sampling_rate = 16000
|
| 26 |
+
if sr != sampling_rate:
|
| 27 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
|
| 28 |
+
|
| 29 |
+
vad_parameters = {}
|
| 30 |
+
vad_parameters = VadOptions(**vad_parameters)
|
| 31 |
+
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
| 32 |
+
audio = collect_chunks(audio, speech_chunks)
|
| 33 |
+
duration_after_vad = audio.shape[0] / sampling_rate
|
| 34 |
+
|
| 35 |
+
if sr != sampling_rate:
|
| 36 |
+
# resample to original sampling rate
|
| 37 |
+
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
|
| 38 |
+
else:
|
| 39 |
+
vad_audio = audio
|
| 40 |
+
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
|
| 41 |
+
vad_audio_bytes = vad_audio.tobytes()
|
| 42 |
+
|
| 43 |
+
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
msg = f"[asr vad error] audio_len: {len(audio)/(sr):.3f} s, trace: {traceback.format_exc()}"
|
| 46 |
+
print(msg)
|
| 47 |
+
return -1, audio.tobytes(), round(time.time() - _st, 4)
|
| 48 |
+
|
| 49 |
+
def save_tmp_audio(audio_bytes):
|
| 50 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
|
| 51 |
+
file_name = tmpfile.name
|
| 52 |
+
audio = AudioSegment(
|
| 53 |
+
data=audio_bytes,
|
| 54 |
+
sample_width=2,
|
| 55 |
+
frame_rate=16000,
|
| 56 |
+
channels=1,
|
| 57 |
+
)
|
| 58 |
+
audio.export(file_name, format="wav")
|
| 59 |
+
return file_name
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
st.title("Chat Mini-Omni Demo")
|
| 63 |
+
status = st.empty()
|
| 64 |
+
|
| 65 |
+
if "audio_buffer" not in st.session_state:
|
| 66 |
+
st.session_state.audio_buffer = []
|
| 67 |
+
|
| 68 |
+
webrtc_ctx = webrtc_streamer(
|
| 69 |
+
key="speech-to-text",
|
| 70 |
+
mode=WebRtcMode.SENDONLY,
|
| 71 |
+
audio_receiver_size=1024,
|
| 72 |
+
rtc_configuration=RTCConfiguration(
|
| 73 |
+
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
| 74 |
+
),
|
| 75 |
+
media_stream_constraints={"video": False, "audio": True},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if webrtc_ctx.audio_receiver:
|
| 79 |
+
while True:
|
| 80 |
+
try:
|
| 81 |
+
audio_frame = webrtc_ctx.audio_receiver.get_frame(timeout=1)
|
| 82 |
+
sound_chunk = np.frombuffer(audio_frame.to_ndarray(), dtype="int16")
|
| 83 |
+
st.session_state.audio_buffer.extend(sound_chunk)
|
| 84 |
+
|
| 85 |
+
if len(st.session_state.audio_buffer) >= 16000:
|
| 86 |
+
duration_after_vad, vad_audio_bytes, vad_time = run_vad(
|
| 87 |
+
np.array(st.session_state.audio_buffer), 16000
|
| 88 |
+
)
|
| 89 |
+
st.session_state.audio_buffer = []
|
| 90 |
+
if duration_after_vad > 0:
|
| 91 |
+
st.session_state.messages.append(
|
| 92 |
+
{"role": "user", "content": "User audio"}
|
| 93 |
+
)
|
| 94 |
+
file_name = save_tmp_audio(vad_audio_bytes)
|
| 95 |
+
st.audio(file_name, format="audio/wav")
|
| 96 |
+
|
| 97 |
+
response = requests.post(API_URL, data=vad_audio_bytes)
|
| 98 |
+
assistant_audio_bytes = response.content
|
| 99 |
+
assistant_file_name = save_tmp_audio(assistant_audio_bytes)
|
| 100 |
+
st.audio(assistant_file_name, format="audio/wav")
|
| 101 |
+
st.session_state.messages.append(
|
| 102 |
+
{"role": "assistant", "content": "Assistant response"}
|
| 103 |
+
)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error in audio processing: {e}")
|
| 106 |
+
break
|
| 107 |
+
|
| 108 |
+
if st.button("Process Audio"):
|
| 109 |
+
if st.session_state.audio_buffer:
|
| 110 |
+
duration_after_vad, vad_audio_bytes, vad_time = run_vad(
|
| 111 |
+
np.array(st.session_state.audio_buffer), 16000
|
| 112 |
+
)
|
| 113 |
+
st.session_state.messages.append({"role": "user", "content": "User audio"})
|
| 114 |
+
file_name = save_tmp_audio(vad_audio_bytes)
|
| 115 |
+
st.audio(file_name, format="audio/wav")
|
| 116 |
+
|
| 117 |
+
response = requests.post(API_URL, data=vad_audio_bytes)
|
| 118 |
+
assistant_audio_bytes = response.content
|
| 119 |
+
assistant_file_name = save_tmp_audio(assistant_audio_bytes)
|
| 120 |
+
st.audio(assistant_file_name, format="audio/wav")
|
| 121 |
+
st.session_state.messages.append(
|
| 122 |
+
{"role": "assistant", "content": "Assistant response"}
|
| 123 |
+
)
|
| 124 |
+
st.session_state.audio_buffer = []
|
| 125 |
+
|
| 126 |
+
if st.session_state.messages:
|
| 127 |
+
for message in st.session_state.messages:
|
| 128 |
+
if message["role"] == "user":
|
| 129 |
+
st.write(f"User: {message['content']}")
|
| 130 |
+
else:
|
| 131 |
+
st.write(f"Assistant: {message['content']}")
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|