Spaces:
Build error
Build error
| import os | |
| from PIL import Image | |
| import random | |
| import shutil | |
| import datetime | |
| import torchvision.transforms.functional as f | |
| import torch | |
| from typing import Optional, Tuple | |
| from threading import Lock | |
| from langchain import ConversationChain | |
| from chat_anything.tts_talker.tts_edge import TTSTalker | |
| from chat_anything.sad_talker.sad_talker import SadTalker | |
| from chat_anything.chatbot.chat import load_chain | |
| from chat_anything.chatbot.select import model_selection_chain | |
| from chat_anything.chatbot.voice_select import voice_selection_chain | |
| import gradio as gr | |
| TALKING_HEAD_WIDTH = "350" | |
| sadtalker_checkpoint_path = "MODELS/SadTalker" | |
| config_path = "chat_anything/sad_talker/config" | |
| class ChatWrapper: | |
| def __init__(self): | |
| self.lock = Lock() | |
| self.sad_talker = SadTalker( | |
| sadtalker_checkpoint_path, config_path, lazy_load=True) | |
| def __call__( | |
| self, | |
| api_key: str, | |
| inp: str, | |
| history: Optional[Tuple[str, str]], | |
| chain: Optional[ConversationChain], | |
| speak_text: bool, talking_head: bool, | |
| uid: str, | |
| talker : None, | |
| fullbody : str, | |
| ): | |
| """Execute the chat functionality.""" | |
| self.lock.acquire() | |
| if chain is None: | |
| history.append((inp, "Please register with your API key first!")) | |
| else: | |
| try: | |
| print("\n==== date/time: " + str(datetime.datetime.now()) + " ====") | |
| print("inp: " + inp) | |
| print("speak_text: ", speak_text) | |
| print("talking_head: ", talking_head) | |
| history = history or [] | |
| # If chain is None, that is because no API key was provided. | |
| output = "Please paste your OpenAI key from openai.com to use this app. " + \ | |
| str(datetime.datetime.now()) | |
| output = chain.predict(input=inp).strip() | |
| output = output.replace("\n", "\n\n") | |
| text_to_display = output | |
| # #预定义一个talker | |
| # talker = MaleEn() | |
| history.append((inp, text_to_display)) | |
| html_video, temp_file, html_audio, temp_aud_file = None, None, None, None | |
| if speak_text: | |
| if talking_head: | |
| html_video, temp_file = self.do_html_video_speak( | |
| talker, output, fullbody, uid) | |
| else: | |
| html_audio, temp_aud_file = self.do_html_audio_speak( | |
| talker, output,uid) | |
| else: | |
| if talking_head: | |
| temp_file = os.path.join('tmp', uid, 'videos') | |
| html_video = create_html_video( | |
| temp_file, TALKING_HEAD_WIDTH) | |
| else: | |
| pass | |
| except Exception as e: | |
| raise e | |
| finally: | |
| self.lock.release() | |
| return history, history, html_video, temp_file, html_audio, temp_aud_file, "" | |
| def do_html_audio_speak(self,talker, words_to_speak, uid): | |
| audio_path = os.path.join('tmp', uid, 'audios') | |
| print('uid:', uid, ":", words_to_speak) | |
| audo_file_path = talker.test(text=words_to_speak, audio_path=audio_path) | |
| html_audio = '<pre>no audio</pre>' | |
| try: | |
| temp_aud_file = gr.File(audo_file_path) | |
| print("audio-----------------------------------------------------success") | |
| temp_aud_file_url = "/file=" + temp_aud_file.value['name'] | |
| html_audio = f'<audio autoplay><source src={temp_aud_file_url} type="audio/mp3"></audio>' | |
| except IOError as error: | |
| # Could not write to file, exit gracefully | |
| print(error) | |
| return None, None | |
| return html_audio, audo_file_path | |
| def do_html_video_speak(self,talker,words_to_speak,fullbody, uid): | |
| if fullbody: | |
| # preprocess='somthing' | |
| preprocess='full' | |
| else: | |
| preprocess='crop' | |
| print("success") | |
| video_path = os.path.join('tmp', uid, 'videos') | |
| if not os.path.exists(video_path): | |
| os.makedirs(video_path) | |
| video_file_path = os.path.join(video_path, 'tempfile.mp4') | |
| _, audio_path = self.do_html_audio_speak( | |
| talker,words_to_speak,uid) | |
| face_file_path = os.path.join('tmp', uid, 'images', 'test.jpg') | |
| video = self.sad_talker.test(face_file_path, audio_path,preprocess, uid=uid) #video_file_path | |
| print("---------------------------------------------------------success") | |
| print(f"moving {video} -> {video_file_path}") | |
| shutil.move(video, video_file_path) | |
| return video_file_path, video_file_path | |
| def generate_init_face_video(self,class_concept="clock", llm=None,uid=None,fullbody=None, ref_image=None, seed=None): | |
| """ | |
| """ | |
| print('generate concept of', class_concept) | |
| print("=================================================") | |
| print('fullbody:', fullbody) | |
| print('uid:', uid) | |
| print("==================================================") | |
| chain, memory, personality_text = load_chain(llm, class_concept) | |
| model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down | |
| # model_conf, selected_model = model_selection_chain(llm, personality_text, conf_file='resources/models_personality.yaml') # use class concept to choose a generating model, otherwise crack down | |
| voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') | |
| # added for safe face generation | |
| print('generate concept of', class_concept) | |
| augment_word_list = ["Female ", "female ", "beautiful ", "small ", "cute "] | |
| first_sentence = "Hello, how are you doing today?" | |
| voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') | |
| talker = TTSTalker(selected_voice=selected_voice, gender=voice_conf['gender'], language=voice_conf['language']) | |
| model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down | |
| retry_cnt = 4 | |
| if ref_image is None: | |
| face_files = os.listdir(FACE_DIR) | |
| face_img_path = os.path.join(FACE_DIR, random.choice(face_files)) | |
| ref_image = Image.open(face_img_path) | |
| print('loading face generating model') | |
| anything_facemaker = load_face_generator( | |
| model_dir=model_conf['model_dir'], | |
| lora_path=model_conf['lora_path'], | |
| prompt_template=model_conf['prompt_template'], | |
| negative_prompt=model_conf['negative_prompt'], | |
| ) | |
| retry_cnt = 0 | |
| has_face = anything_facemaker.has_face(ref_image) | |
| init_strength = 1.0 if has_face else 0.85 | |
| strength_retry_step = -0.04 if has_face else 0.04 | |
| while retry_cnt < 8: | |
| try: | |
| generate_face_image( | |
| anything_facemaker, | |
| class_concept, | |
| ref_image, | |
| uid=uid, | |
| strength=init_strength if (retry_cnt==0 and has_face) else init_strength + retry_cnt * strength_retry_step, | |
| controlnet_conditioning_scale=0.5 if retry_cnt == 8 else 0.3, | |
| seed=seed, | |
| ) | |
| self.do_html_video_speak(talker, first_sentence, fullbody, uid=uid) | |
| video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4') | |
| htm_video = create_html_video( | |
| video_file_path, TALKING_HEAD_WIDTH) | |
| break | |
| except Exception as e: | |
| retry_cnt += 1 | |
| class_concept = random.choice(augment_word_list) + class_concept | |
| print(e) | |
| # end of repeat block | |
| return chain, memory, htm_video, talker | |
| def update_talking_head(self, widget, uid, state): | |
| print("success----------------") | |
| if widget: | |
| state = widget | |
| temp_file = os.path.join('tmp', uid, 'videos') | |
| video_html_talking_head = create_html_video( | |
| temp_file, TALKING_HEAD_WIDTH) | |
| return state, video_html_talking_head | |
| else: | |
| return None, "<pre></pre>" | |
| def reset_memory(history, memory): | |
| memory.clear() | |
| history = [] | |
| return history, history, memory | |
| def create_html_video(file_name, width): | |
| return file_name | |
| def create_html_audio(file_name): | |
| if os.path.exists(file_name): | |
| tmp_audio_file = gr.File(file_name, visible=False) | |
| tmp_aud_file_url = "/file=" + tmp_audio_file.value['name'] | |
| html_audio = f'<audio><source src={tmp_aud_file_url} type="audio/mp3"></audio>' | |
| del tmp_aud_file_url | |
| else: | |
| html_audio = f'' | |
| return html_audio | |
| def update_foo(widget, state): | |
| if widget: | |
| state = widget | |
| return state | |
| # Pertains to question answering functionality | |
| def update_use_embeddings(widget, state): | |
| if widget: | |
| state = widget | |
| return state | |
| # This is the code for image generating. | |
| def load_face_generator(model_dir, lora_path, prompt_template, negative_prompt): | |
| from chat_anything.face_generator.long_prompt_control_generator import LongPromptControlGenerator | |
| # # using local | |
| model_zoo = "MODELS" | |
| face_control_dir = os.path.join( | |
| model_zoo, "Face-Landmark-ControlNet", "models_for_diffusers") | |
| face_detect_path = os.path.join( | |
| model_zoo, "SadTalker", "shape_predictor_68_face_landmarks.dat") | |
| # use remote, hugginface auto-download. | |
| # use your model path, has to be a model derived from stable diffusion v1-5 | |
| anything_facemaker = LongPromptControlGenerator( | |
| model_dir=model_dir, | |
| lora_path=lora_path, | |
| prompt_template=prompt_template, | |
| negative_prompt=negative_prompt, | |
| face_control_dir=face_control_dir, | |
| face_detect_path=face_detect_path, | |
| ) | |
| anything_facemaker.load_model(safety_checker=None) | |
| return anything_facemaker | |
| FACE_DIR="resources/images/faces" | |
| def generate_face_image( | |
| anything_facemaker, | |
| class_concept, | |
| face_img_pil, | |
| uid=None, | |
| controlnet_conditioning_scale=1.0, | |
| strength=0.95, | |
| seed=42, | |
| ): | |
| face_img_pil = f.center_crop( | |
| f.resize(face_img_pil, 512), 512).convert('RGB') | |
| prompt = anything_facemaker.prompt_template.format(class_concept) | |
| # # There are four ways to generate a image by now. | |
| # pure_generate = anything_facemaker.generate(prompt=prompt, image=face_img_pil, do_inversion=False) | |
| # inversion = anything_facemaker.generate(prompt=prompt, image=face_img_pil, strength=strength, do_inversion=True) | |
| print('USING SEED:', seed) | |
| generator = torch.Generator(device=anything_facemaker.face_control_pipe.device) | |
| generator.manual_seed(seed) | |
| if strength is None: | |
| pure_control = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=False, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator) | |
| init_face_pil = pure_control | |
| else: | |
| control_inversion = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=True, | |
| strength=strength, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator) | |
| init_face_pil = control_inversion | |
| print('succeeded generating face image') | |
| face_path = os.path.join('tmp', uid, 'images') | |
| if not os.path.exists(face_path): | |
| os.makedirs(face_path) | |
| # TODO: reproduce the images for return, shouldn't use the filesystem | |
| face_file_path = os.path.join(face_path, 'test.jpg') | |
| init_face_pil.save(face_file_path) | |
| return init_face_pil | |