Spaces:
Paused
Paused
| from huggingface_hub import InferenceClient | |
| from gradio_client import Client | |
| import torch | |
| import nltk # we'll use this to split into sentences | |
| import numpy as np | |
| from transformers import BarkModel, AutoProcessor | |
| nltk.download('punkt') | |
| import gradio as gr | |
| import os | |
| def _grab_best_device(use_gpu=True): | |
| if torch.cuda.device_count() > 0 and use_gpu: | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| return device | |
| device = _grab_best_device() | |
| SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines: | |
| - Keep your sentences short, concise and easy to understand. | |
| - There should be only the narrator speaking. If there are dialogues, they should be indirect.""" | |
| #story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson." | |
| story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom." | |
| temperature = 0.9 | |
| top_p = 0.6 | |
| repetition_penalty = 1.2 | |
| TIMEOUT = int(os.environ.get("TIMEOUT", 45)) | |
| temperature = 0.9 | |
| top_p = 0.6 | |
| repetition_penalty = 1.2 | |
| # TODO: requirements: accelerate optimum | |
| text_client = InferenceClient( | |
| "mistralai/Mistral-7B-Instruct-v0.1", | |
| timeout=TIMEOUT, | |
| ) | |
| image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545bst2bq/") | |
| image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry" | |
| image_positive_prompt = "" | |
| image_seed = 6 | |
| processor = AutoProcessor.from_pretrained("suno/bark") | |
| def format_speaker_key(key): | |
| key = key.replace("v2/", "").split("_") | |
| return f"Speaker {key[2]} ({key[0]})" | |
| voice_presets = [key for key in processor.speaker_embeddings.keys() if "v2/en" in key] | |
| voice_presets_dict = { | |
| format_speaker_key(key): key for key in voice_presets | |
| } | |
| model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) | |
| sampling_rate = model.generation_config.sample_rate | |
| silence = np.zeros(int(0.25 * sampling_rate)) # quarter second of silence | |
| voice_preset = "v2/en_speaker_6" | |
| BATCH_SIZE = 32 | |
| # enable CPU offload | |
| model.enable_cpu_offload() | |
| # MISTRAL ONLY | |
| default_system_understand_message = ( | |
| "I understand, I am a Mistral chatbot." | |
| ) | |
| system_understand_message = os.environ.get( | |
| "SYSTEM_UNDERSTAND_MESSAGE", default_system_understand_message | |
| ) | |
| # Mistral formatter | |
| def format_prompt(message): | |
| prompt = ( | |
| "<s>[INST]" + SYST_PROMPT + "[/INST]" + system_understand_message + "</s>" | |
| ) | |
| prompt += f"[INST] {message} [/INST]" | |
| return prompt | |
| def generate_story( | |
| story_prompt, | |
| temperature=0.9, | |
| max_new_tokens=1024, | |
| top_p=0.95, | |
| repetition_penalty=1.0,): | |
| temperature = float(temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| generate_kwargs = dict( | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| seed=42, | |
| ) | |
| try: | |
| output = text_client.text_generation( | |
| format_prompt(story_prompt), | |
| **generate_kwargs, | |
| details=False, | |
| return_full_text=False, | |
| ) | |
| except Exception as e: | |
| if "Too Many Requests" in str(e): | |
| print("ERROR: Too many requests on mistral client") | |
| gr.Warning("Unfortunately Mistral is unable to process") | |
| output = "Unfortuanately I am not able to process your request now, too many people are asking me !" | |
| elif "Model not loaded on the server" in str(e): | |
| print("ERROR: Mistral server down") | |
| gr.Warning("Unfortunately Mistral LLM is unable to process") | |
| output = "Unfortuanately I am not able to process your request now, I have problem with Mistral!" | |
| else: | |
| print("Unhandled Exception: ", str(e)) | |
| gr.Warning("Unfortunately Mistral is unable to process") | |
| output = "I do not know what happened but I could not understand you." | |
| return output | |
| return output | |
| def generate_audio_and_image(story_prompt, voice_preset="Speaker 3 (en)"): | |
| story = generate_story(story_prompt) | |
| print(story) | |
| model_input = story.replace("\n", " ").strip() | |
| model_input = nltk.sent_tokenize(model_input) | |
| print("text generated - now calling for image") | |
| job_img = image_client.submit( | |
| story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component | |
| image_negative_prompt, # str in 'parameter_12' Textbox component | |
| 25, | |
| 7, | |
| 1024, | |
| 1024, | |
| image_seed, | |
| fn_index=0, | |
| ) | |
| print("image called - now generating audio") | |
| pieces = [] | |
| for i in range(0, len(model_input), BATCH_SIZE): | |
| inputs = model_input[i:min(i + BATCH_SIZE, len(model_input))] | |
| if len(inputs) != 0: | |
| inputs = processor(inputs, voice_preset=voice_presets_dict[voice_preset]) | |
| speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2) | |
| speech_output = [output[:length].cpu().numpy() for (output,length) in zip(speech_output, output_lengths)] | |
| print(f"{i}-th part generated") | |
| pieces += [*speech_output, silence.copy()] | |
| print("Calling image") | |
| try: | |
| img = job_img.result() | |
| except Exception as e: | |
| print("Unhandled Exception: ", str(e)) | |
| gr.Warning("Unfortunately there was an issue when generating the image with SDXL.") | |
| img = None | |
| return story, (sampling_rate, np.concatenate(pieces)), img | |
| # Gradio blocks demo | |
| with gr.Blocks() as demo_blocks: | |
| gr.Markdown("""<h1 align="center">🐶Children story</h1>""") | |
| gr.HTML("""<h3 style="text-align:center;">Let Mistral tell you a story</h3>""") | |
| with gr.Group(): | |
| with gr.Row(): | |
| inp_text = gr.Textbox(label="Story prompt", info="Enter text here") | |
| with gr.Row(): | |
| with gr.Accordion("Advanced settings", open=False): | |
| voice_preset = gr.Dropdown( | |
| voice_presets_dict, | |
| value="Speaker 6 (en)", | |
| label="Available speakers", | |
| ) | |
| with gr.Row(): | |
| btn = gr.Button("Create a story") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_output = gr.Image(elem_id="gallery") | |
| with gr.Row(): | |
| out_audio = gr.Audio( | |
| streaming=False, autoplay=True) # needed to stream output audio | |
| out_text = gr.Text() | |
| btn.click(generate_audio_and_image, [inp_text, voice_preset], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count]) | |
| with gr.Row(): | |
| gr.Examples( | |
| [ | |
| "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson.", | |
| "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom.", | |
| "Tell me about the wonders of the world.", | |
| ], | |
| [inp_text], | |
| [out_text, out_audio, image_output], | |
| generate_audio_and_image, | |
| cache_examples=True, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| This Space uses **[Bark](https://huggingface.co/docs/transformers/main/en/model_doc/bark)**, [Mistral-7b-instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [Fast SD-XL](https://huggingface.co/spaces/openskyml/fast-sdxl-stable-diffusion-xl)! | |
| """ | |
| ) | |
| demo_blocks.queue().launch(debug=True) |