Improve UX a bit and switch back to Whisper large v2
Browse files
app.py
CHANGED
|
@@ -19,7 +19,6 @@ import datetime
|
|
| 19 |
|
| 20 |
from scipy.io.wavfile import write
|
| 21 |
from pydub import AudioSegment
|
| 22 |
-
import ffmpeg
|
| 23 |
|
| 24 |
import re
|
| 25 |
import io, wave
|
|
@@ -57,7 +56,7 @@ model.load_checkpoint(
|
|
| 57 |
checkpoint_path=os.path.join(model_path, "model.pth"),
|
| 58 |
vocab_path=os.path.join(model_path, "vocab.json"),
|
| 59 |
eval=True,
|
| 60 |
-
use_deepspeed=True
|
| 61 |
)
|
| 62 |
model.cuda()
|
| 63 |
print("Done loading TTS")
|
|
@@ -113,10 +112,7 @@ from gradio_client import Client
|
|
| 113 |
from huggingface_hub import InferenceClient
|
| 114 |
|
| 115 |
WHISPER_TIMEOUT = int(os.environ.get("WHISPER_TIMEOUT", 30))
|
| 116 |
-
|
| 117 |
-
# whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
|
| 118 |
-
# Replacement whisper client, it may be time limited
|
| 119 |
-
whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space")
|
| 120 |
text_client = InferenceClient(
|
| 121 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
| 122 |
timeout=WHISPER_TIMEOUT,
|
|
@@ -203,13 +199,12 @@ def generate(
|
|
| 203 |
|
| 204 |
def transcribe(wav_path):
|
| 205 |
try:
|
| 206 |
-
# get
|
| 207 |
return whisper_client.predict(
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
)[0].strip()
|
| 213 |
except:
|
| 214 |
gr.Warning("There was a problem with Whisper endpoint, telling a joke for you.")
|
| 215 |
return "There was a problem with my voice, tell me joke"
|
|
@@ -242,8 +237,8 @@ def add_file(history, file):
|
|
| 242 |
|
| 243 |
##NOTE: not using this as it yields a chacter each time while we need to feed history to TTS
|
| 244 |
def bot(history, system_prompt=""):
|
| 245 |
-
history = [] if history is None else history
|
| 246 |
-
|
| 247 |
if system_prompt == "":
|
| 248 |
system_prompt = system_message
|
| 249 |
|
|
@@ -267,21 +262,6 @@ latent_map = {}
|
|
| 267 |
latent_map["Female_Voice"] = get_latents("examples/female.wav")
|
| 268 |
|
| 269 |
|
| 270 |
-
def get_voice(prompt, language, latent_tuple, suffix="0"):
|
| 271 |
-
gpt_cond_latent, diffusion_conditioning, speaker_embedding = latent_tuple
|
| 272 |
-
# Direct version
|
| 273 |
-
t0 = time.time()
|
| 274 |
-
out = model.inference(
|
| 275 |
-
prompt, language, gpt_cond_latent, speaker_embedding, diffusion_conditioning
|
| 276 |
-
)
|
| 277 |
-
inference_time = time.time() - t0
|
| 278 |
-
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
|
| 279 |
-
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
|
| 280 |
-
print(f"Real-time factor (RTF): {real_time_factor}")
|
| 281 |
-
wav_filename = f"output_{suffix}.wav"
|
| 282 |
-
torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
| 283 |
-
return wav_filename
|
| 284 |
-
|
| 285 |
|
| 286 |
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
|
| 287 |
# This will create a wave header then append the frame input
|
|
@@ -333,7 +313,7 @@ def get_voice_streaming(prompt, language, latent_tuple, suffix="0"):
|
|
| 333 |
if "device-side assert" in str(e):
|
| 334 |
# cannot do anything on cuda device side error, need tor estart
|
| 335 |
print(
|
| 336 |
-
f"Exit due to: Unrecoverable exception caused by prompt:{
|
| 337 |
flush=True,
|
| 338 |
)
|
| 339 |
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
|
@@ -353,10 +333,12 @@ def get_voice_streaming(prompt, language, latent_tuple, suffix="0"):
|
|
| 353 |
|
| 354 |
def get_sentence(history, system_prompt=""):
|
| 355 |
history = [["", None]] if history is None else history
|
| 356 |
-
|
| 357 |
if system_prompt == "":
|
| 358 |
system_prompt = system_message
|
| 359 |
|
|
|
|
|
|
|
| 360 |
mistral_start = time.time()
|
| 361 |
print("Mistral start")
|
| 362 |
sentence_list = []
|
|
@@ -422,8 +404,8 @@ def generate_speech(history):
|
|
| 422 |
try:
|
| 423 |
# generate speech using precomputed latents
|
| 424 |
# This is not streaming but it will be fast
|
| 425 |
-
# wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=len(wav_list))
|
| 426 |
if len(sentence) > 250:
|
|
|
|
| 427 |
# should not generate voice it will hit token limit
|
| 428 |
# It should not generate audio for it
|
| 429 |
audio_stream = None
|
|
@@ -520,6 +502,7 @@ with gr.Blocks(title=title) as demo:
|
|
| 520 |
show_label=False,
|
| 521 |
placeholder="Enter text and press enter, or speak to your microphone",
|
| 522 |
container=False,
|
|
|
|
| 523 |
)
|
| 524 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
| 525 |
btn = gr.Audio(source="microphone", type="filepath", scale=4)
|
|
@@ -536,7 +519,7 @@ with gr.Blocks(title=title) as demo:
|
|
| 536 |
# final_audio = gr.Audio(label="Final audio response", streaming=False, autoplay=False, interactive=False,show_label=True, visible=False)
|
| 537 |
|
| 538 |
clear_btn = gr.ClearButton([chatbot, audio])
|
| 539 |
-
|
| 540 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 541 |
generate_speech, chatbot, [audio, chatbot]
|
| 542 |
)
|
|
@@ -553,13 +536,13 @@ with gr.Blocks(title=title) as demo:
|
|
| 553 |
add_file, [chatbot, btn], [chatbot, txt], queue=False
|
| 554 |
).then(generate_speech, chatbot, [audio, chatbot])
|
| 555 |
|
| 556 |
-
file_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
|
| 557 |
|
| 558 |
gr.Markdown(
|
| 559 |
"""
|
| 560 |
This Space demonstrates how to speak to a chatbot, based solely on open-source models.
|
| 561 |
It relies on 3 models:
|
| 562 |
-
1. [Whisper-large-v2](https://
|
| 563 |
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
|
| 564 |
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
|
| 565 |
|
|
@@ -567,4 +550,4 @@ Note:
|
|
| 567 |
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml"""
|
| 568 |
)
|
| 569 |
demo.queue()
|
| 570 |
-
demo.launch(debug=True
|
|
|
|
| 19 |
|
| 20 |
from scipy.io.wavfile import write
|
| 21 |
from pydub import AudioSegment
|
|
|
|
| 22 |
|
| 23 |
import re
|
| 24 |
import io, wave
|
|
|
|
| 56 |
checkpoint_path=os.path.join(model_path, "model.pth"),
|
| 57 |
vocab_path=os.path.join(model_path, "vocab.json"),
|
| 58 |
eval=True,
|
| 59 |
+
use_deepspeed=False, # TODO: replace by True
|
| 60 |
)
|
| 61 |
model.cuda()
|
| 62 |
print("Done loading TTS")
|
|
|
|
| 112 |
from huggingface_hub import InferenceClient
|
| 113 |
|
| 114 |
WHISPER_TIMEOUT = int(os.environ.get("WHISPER_TIMEOUT", 30))
|
| 115 |
+
whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
|
|
|
|
|
|
|
|
|
|
| 116 |
text_client = InferenceClient(
|
| 117 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
| 118 |
timeout=WHISPER_TIMEOUT,
|
|
|
|
| 199 |
|
| 200 |
def transcribe(wav_path):
|
| 201 |
try:
|
| 202 |
+
# get result from whisper and strip it to delete begin and end space
|
| 203 |
return whisper_client.predict(
|
| 204 |
+
wav_path, # str (filepath or URL to file) in 'inputs' Audio component
|
| 205 |
+
"transcribe", # str in 'Task' Radio component
|
| 206 |
+
api_name="/predict"
|
| 207 |
+
).strip()
|
|
|
|
| 208 |
except:
|
| 209 |
gr.Warning("There was a problem with Whisper endpoint, telling a joke for you.")
|
| 210 |
return "There was a problem with my voice, tell me joke"
|
|
|
|
| 237 |
|
| 238 |
##NOTE: not using this as it yields a chacter each time while we need to feed history to TTS
|
| 239 |
def bot(history, system_prompt=""):
|
| 240 |
+
history = [["", None]] if history is None else history
|
| 241 |
+
|
| 242 |
if system_prompt == "":
|
| 243 |
system_prompt = system_message
|
| 244 |
|
|
|
|
| 262 |
latent_map["Female_Voice"] = get_latents("examples/female.wav")
|
| 263 |
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
|
| 267 |
# This will create a wave header then append the frame input
|
|
|
|
| 313 |
if "device-side assert" in str(e):
|
| 314 |
# cannot do anything on cuda device side error, need tor estart
|
| 315 |
print(
|
| 316 |
+
f"Exit due to: Unrecoverable exception caused by prompt:{prompt}",
|
| 317 |
flush=True,
|
| 318 |
)
|
| 319 |
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
|
|
|
| 333 |
|
| 334 |
def get_sentence(history, system_prompt=""):
|
| 335 |
history = [["", None]] if history is None else history
|
| 336 |
+
|
| 337 |
if system_prompt == "":
|
| 338 |
system_prompt = system_message
|
| 339 |
|
| 340 |
+
history[-1][1] = ""
|
| 341 |
+
|
| 342 |
mistral_start = time.time()
|
| 343 |
print("Mistral start")
|
| 344 |
sentence_list = []
|
|
|
|
| 404 |
try:
|
| 405 |
# generate speech using precomputed latents
|
| 406 |
# This is not streaming but it will be fast
|
|
|
|
| 407 |
if len(sentence) > 250:
|
| 408 |
+
gr.Warning("There was a problem with the last sentence, which was too long, so it won't be spoken.")
|
| 409 |
# should not generate voice it will hit token limit
|
| 410 |
# It should not generate audio for it
|
| 411 |
audio_stream = None
|
|
|
|
| 502 |
show_label=False,
|
| 503 |
placeholder="Enter text and press enter, or speak to your microphone",
|
| 504 |
container=False,
|
| 505 |
+
interactive=True,
|
| 506 |
)
|
| 507 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
| 508 |
btn = gr.Audio(source="microphone", type="filepath", scale=4)
|
|
|
|
| 519 |
# final_audio = gr.Audio(label="Final audio response", streaming=False, autoplay=False, interactive=False,show_label=True, visible=False)
|
| 520 |
|
| 521 |
clear_btn = gr.ClearButton([chatbot, audio])
|
| 522 |
+
|
| 523 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 524 |
generate_speech, chatbot, [audio, chatbot]
|
| 525 |
)
|
|
|
|
| 536 |
add_file, [chatbot, btn], [chatbot, txt], queue=False
|
| 537 |
).then(generate_speech, chatbot, [audio, chatbot])
|
| 538 |
|
| 539 |
+
file_msg.then(lambda: (gr.update(interactive=True),gr.update(interactive=True,value=None)), None, [txt, btn], queue=False)
|
| 540 |
|
| 541 |
gr.Markdown(
|
| 542 |
"""
|
| 543 |
This Space demonstrates how to speak to a chatbot, based solely on open-source models.
|
| 544 |
It relies on 3 models:
|
| 545 |
+
1. [Whisper-large-v2](https://sanchit-gandhi-whisper-large-v2.hf.space/) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
|
| 546 |
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
|
| 547 |
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
|
| 548 |
|
|
|
|
| 550 |
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml"""
|
| 551 |
)
|
| 552 |
demo.queue()
|
| 553 |
+
demo.launch(debug=True)
|