Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model | |
| from efficientat.models.preprocess import AugmentMelSTFT | |
| from efficientat.helpers.utils import NAME_TO_WIDTH, labels | |
| from torch import autocast | |
| from contextlib import nullcontext | |
| from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate | |
| from langchain.chains.conversation.memory import ConversationalBufferWindowMemory | |
| MODEL_NAME = "mn40_as" | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME) | |
| model.to(device) | |
| model.eval() | |
| cached_audio_class = "c" | |
| template = None | |
| prompt = None | |
| chain = None | |
| def audio_tag( | |
| audio_path, | |
| sample_rate=32000, | |
| window_size=800, | |
| hop_size=320, | |
| n_mels=128, | |
| cuda=True, | |
| ): | |
| (waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True) | |
| mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size) | |
| mel.to(device) | |
| mel.eval() | |
| waveform = torch.from_numpy(waveform[None, :]).to(device) | |
| # our models are trained in half precision mode (torch.float16) | |
| # run on cuda with torch.float16 to get the best performance | |
| # running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse | |
| with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext(): | |
| spec = mel(waveform) | |
| preds, features = model(spec.unsqueeze(0)) | |
| preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() | |
| sorted_indexes = np.argsort(preds)[::-1] | |
| output = {} | |
| # Print audio tagging top probabilities | |
| label = labels[sorted_indexes[0]] | |
| return formatted_message(label) | |
| def formatted_message(audio_class): | |
| if cached_audio_class != audio_class: | |
| cached_audio_class = audio_class | |
| prefix = '''You are going to act as a magical tool that allows for humans to communicate with non-human entities like | |
| rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you a data string which | |
| represents the audio input, the source of the audio, and the human's text input for the conversation. | |
| The goal is for you to embody the source of the audio, and use the length and variance in the signal data to produce | |
| plausible responses to the humans input. Remember to embody the the source data. When we start the conversation, | |
| you should generate a "personality profile" for the source and utilize that personality profile in your responses. | |
| Let's begin:''' | |
| suffix = f'''Source: {audio_class} | |
| Length of Audio in Seconds: {audio_length} | |
| Human Input: {userText} | |
| {audio_class} Response:''' | |
| template = prefix + suffix | |
| prompt = PromptTemplate( | |
| input_variables=["history", "human_input"], | |
| template=template | |
| ) | |
| chatgpt_chain = LLMChain( | |
| llm=OpenAI(temperature=.5, openai_api_key=session_token), | |
| prompt=prompt, | |
| verbose=True, | |
| memory=ConversationalBufferWindowMemory(k=2), | |
| ) | |
| output = chatgpt_chain.predict(human_input=message) | |
| return output | |
| demo = gr.Interface( | |
| audio_tag, | |
| gr.Audio(source="upload", type="filepath", label="Your audio"), | |
| gr.Textbox(), | |
| examples=[["metro_station-paris.wav"]] | |
| ).launch(debug=True) | |