Spaces:
Sleeping
Sleeping
| from transformers import pipeline | |
| import torch | |
| from transformers.pipelines.audio_utils import ffmpeg_microphone_live | |
| from huggingface_hub import HfFolder, InferenceClient | |
| import requests | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| from datasets import load_dataset | |
| import sounddevice as sd | |
| import sys | |
| import os | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| import warnings | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| warnings.filterwarnings("ignore", | |
| message="At least one mel filter has all zero values.*", | |
| category=UserWarning) | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| classifier = pipeline( | |
| "audio-classification", | |
| model="MIT/ast-finetuned-speech-commands-v2", | |
| device=device | |
| ) | |
| def launch_fn(wake_word="marvin", prob_threshold=0.5, chunk_length_s=2.0, stream_chunk_s=0.25, debug=False): | |
| if wake_word not in classifier.model.config.label2id.keys(): | |
| raise ValueError( | |
| f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}." | |
| ) | |
| sampling_rate = classifier.feature_extractor.sampling_rate | |
| mic = ffmpeg_microphone_live( | |
| sampling_rate=sampling_rate, | |
| chunk_length_s=chunk_length_s, | |
| stream_chunk_s=stream_chunk_s, | |
| ) | |
| print("Listening for wake word...") | |
| for prediction in classifier(mic): | |
| prediction = prediction[0] | |
| if debug: | |
| print(prediction) | |
| if prediction["label"] == wake_word: | |
| if prediction["score"] > prob_threshold: | |
| return True | |
| transcriber = pipeline( | |
| "automatic-speech-recognition", model="openai/whisper-base.en", device=device | |
| ) | |
| def transcribe(chunk_length_s=5.0, stream_chunk_s=1.0): | |
| sampling_rate = transcriber.feature_extractor.sampling_rate | |
| mic = ffmpeg_microphone_live( | |
| sampling_rate=sampling_rate, | |
| chunk_length_s=chunk_length_s, | |
| stream_chunk_s=stream_chunk_s, | |
| ) | |
| print("Start speaking...") | |
| for item in transcriber(mic, generate_kwargs={"max_new_tokens": 128}): | |
| sys.stdout.write("\033[K") | |
| print(item["text"], end="\r") | |
| if not item["partial"][0]: | |
| break | |
| return item["text"] | |
| client = InferenceClient( | |
| provider="fireworks-ai", | |
| api_key=HF_TOKEN | |
| ) | |
| def query(text, model_id="meta-llama/Llama-3.1-8B-Instruct"): | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model_id, | |
| messages=[{"role": "user", "content": text}] | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| print(f"Erreur: {str(e)}") | |
| return None | |
| processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
| def synthesise(text): | |
| input_ids = processor(text=text, return_tensors="pt")["input_ids"] | |
| try: | |
| speech = model.generate_speech( | |
| input_ids.to(device), | |
| speaker_embeddings.to(device), | |
| vocoder=vocoder | |
| ) | |
| return speech.cpu() | |
| except Exception as e: | |
| print(f"Erreur lors de la synthèse vocale : {e}") | |
| return None | |
| # launch_fn(debug=True) | |
| # transcription = transcribe() | |
| # response = query(transcription) | |
| # audio = synthesise(response) | |
| # | |
| # sd.play(audio.numpy(), 16000) | |
| # sd.wait() | |
| # Interface Gradio | |
| def assistant_vocal_interface(): | |
| launch_fn(debug=True) | |
| transcription = transcribe() | |
| response = query(transcription) | |
| audio = synthesise(response) | |
| return transcription, response, (16000, audio.numpy()) | |
| with gr.Blocks(title="Assistant Vocal") as demo: | |
| gr.Markdown("## Assistant vocal : détection, transcription, génération et synthèse") | |
| start_btn = gr.Button("Démarrer l'assistant") | |
| transcription_box = gr.Textbox(label="Transcription") | |
| response_box = gr.Textbox(label="Réponse IA") | |
| audio_output = gr.Audio(label="Synthèse vocale", type="numpy", autoplay=True) | |
| start_btn.click( | |
| assistant_vocal_interface, | |
| inputs=[], | |
| outputs=[transcription_box, response_box, audio_output] | |
| ) | |
| demo.launch(share=True) | |