Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model | |
| from efficientat.models.preprocess import AugmentMelSTFT | |
| from efficientat.helpers.utils import NAME_TO_WIDTH, labels | |
| from torch import autocast | |
| from contextlib import nullcontext | |
| from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate | |
| from langchain.chains.conversation.memory import ConversationalBufferWindowMemory | |
| MODEL_NAME = "mn40_as" | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME) | |
| model.to(device) | |
| model.eval() | |
| cached_audio_class = "c" | |
| template = None | |
| prompt = None | |
| chain = None | |
| def format_classname(classname): | |
| return classname.capitalize() | |
| def audio_tag( | |
| audio_path, | |
| human_input, | |
| sample_rate=32000, | |
| window_size=800, | |
| hop_size=320, | |
| n_mels=128, | |
| cuda=True, | |
| ): | |
| (waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True) | |
| mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size) | |
| mel.to(device) | |
| mel.eval() | |
| waveform = torch.from_numpy(waveform[None, :]).to(device) | |
| # our models are trained in half precision mode (torch.float16) | |
| # run on cuda with torch.float16 to get the best performance | |
| # running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse | |
| with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext(): | |
| spec = mel(waveform) | |
| preds, features = model(spec.unsqueeze(0)) | |
| preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() | |
| sorted_indexes = np.argsort(preds)[::-1] | |
| output = {} | |
| # Print audio tagging top probabilities | |
| label = labels[sorted_indexes[0]] | |
| return formatted_message(format_classname(label), human_input) | |
| def formatted_message(audio_class, human_input): | |
| if cached_audio_class != audio_class: | |
| cached_audio_class = audio_class | |
| prefix = f"""You are going to act as a magical tool that allows for humans to communicate with non-human entities like | |
| rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you the human's text input for the conversation. | |
| The goal is for you to embody that non-human entity and converse with the human. | |
| Examples: | |
| Non-human Entity: Tree | |
| Human Input: Hello tree | |
| Tree: Hello human, I am a tree | |
| Let's begin: | |
| Non-human Entity: {audio_class}""" | |
| suffix = f'''{{history}} | |
| Human Input: {{human_input}} | |
| {audio_class}:''' | |
| suffix = f'''Source: {audio_class} | |
| Length of Audio in Seconds: {audio_length} | |
| Human Input: {userText} | |
| {audio_class} Response:''' | |
| template = prefix + suffix | |
| prompt = PromptTemplate( | |
| input_variables=["history", "human_input"], | |
| template=template | |
| ) | |
| chatgpt_chain = LLMChain( | |
| llm=OpenAI(temperature=.5, openai_api_key=session_token), | |
| prompt=prompt, | |
| verbose=True, | |
| memory=ConversationalBufferWindowMemory(k=2, ai=audio_class), | |
| ) | |
| output = chatgpt_chain.predict(human_input=message) | |
| return output | |
| demo = gr.Interface( | |
| audio_tag, | |
| [ | |
| gr.Audio(source="upload", type="filepath", label="Your audio"), | |
| gr.Textbox(), | |
| ], | |
| gr.Textbox(), | |
| title="AnyChat", | |
| description=""" | |
| <div style='text-align: center; width:100%; margin: auto;'> | |
| <img src='./logo.png' alt='anychat' width='250px' /> | |
| <h3>Non-Human entities have many things to say, listen to them!</h3> | |
| </div> | |
| """, | |
| ).launch(debug=True) | |