Spaces:
Running
on
Zero
Running
on
Zero
| import concurrent.futures | |
| import random | |
| import gradio as gr | |
| import requests | |
| import io, base64, json, os | |
| import spaces | |
| from PIL import Image | |
| from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, DESIRED_APPEAR_MODEL_CHANCE, load_pipeline | |
| from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum | |
| from .pre_download import pre_download_all_models, pre_download_image_models_gen, pre_download_image_models_edit, pre_download_video_models_gen | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import re | |
| def debug_packages(): | |
| import pkg_resources | |
| installed_packages = pkg_resources.working_set | |
| for package in installed_packages: | |
| print(f"{package.key}=={package.version}") | |
| def fetch_unsafe_words(file_path): | |
| """ | |
| Loads unsafe words from a file and returns them as a list. | |
| """ | |
| try: | |
| with open(file_path, 'r') as file: | |
| # Read lines from file and strip any extra whitespace | |
| unsafe_words = [line.strip() for line in file.readlines()] | |
| # Remove any empty strings that may result from empty lines | |
| unsafe_words = [word for word in unsafe_words if word] | |
| return unsafe_words | |
| except Exception as e: | |
| print(f"Error loading file: {e}. Using default unsafe words.") | |
| # Default unsafe words list | |
| return [ | |
| "anal", "anus", "arse", "ass", "ballsack", "bastard", "bdsm", "bitch", "bimbo", | |
| "blow job", "blowjob", "blue waffle", "boob", "booobs", "breasts", "booty call", | |
| "boner", "bondage", "bullshit", "busty", "butthole", "cawk", "chink", "clit", | |
| "cnut", "cock", "cokmuncher", "cowgirl", "crap", "crotch", "cum", "cunt", "damn", | |
| "dick", "dildo", "dink", "deepthroat", "deep throat", "dog style", "doggie style", | |
| "doggy style", "doosh", "douche", "duche", "ejaculate", "ejaculating", | |
| "ejaculation", "ejakulate", "erotic", "erotism", "fag", "fatass", "femdom", | |
| "fingering", "footjob", "foot job", "fuck", "fcuk", "fingerfuck", "fistfuck", | |
| "fook", "fooker", "fuk", "gangbang", "gang bang", "gaysex", "handjob", | |
| "hand job", "hentai", "hooker", "hoer", "homo", "horny", "incest", "jackoff", | |
| "jack off", "jerkoff", "jerk off", "jizz", "masturbate", "mofo", "mothafuck", | |
| "motherfuck", "milf", "muff", "nigga", "nigger", "nipple", "nob", "numbnuts", | |
| "nutsack", "nude", "orgy", "orgasm", "panty", "panties", "penis", "playboy", | |
| "porn", "pussy", "pussies", "rape", "raping", "rapist", "rectum", "retard", | |
| "rimming", "sadist", "sadism", "scrotum", "sex", "semen", "shemale", "she male", | |
| "shit", "slut", "spunk", "strip club", "stripclub", "tit", "threesome", | |
| "three some", "throating", "twat", "viagra", "vagina", "wank", "whore", "whoar", | |
| "xxx" | |
| ] | |
| def check_prompt_safety(prompt, unsafe_words_file='./profanity_words.txt'): | |
| """ | |
| Checking prompt safety. Returns boolean (Not Safe = False, Safe = True) | |
| """ | |
| # Load unsafe words from the provided file or use default if loading fails | |
| unsafe_words = fetch_unsafe_words(unsafe_words_file) | |
| # Convert input string to lowercase to ensure case-insensitive matching | |
| prompt = prompt.lower() | |
| # Check if any unsafe word is in the input string | |
| for word in unsafe_words: | |
| # Use regex to match whole words only | |
| if re.search(r'\b' + re.escape(word) + r'\b', prompt): | |
| return False | |
| return True | |
| class ModelManager: | |
| def __init__(self, enable_nsfw=False, do_pre_download=False, do_debug_packages=False): | |
| self.model_ig_list = IMAGE_GENERATION_MODELS | |
| self.model_ie_list = IMAGE_EDITION_MODELS | |
| self.model_vg_list = VIDEO_GENERATION_MODELS | |
| self.excluding_model_list = MUSEUM_UNSUPPORTED_MODELS | |
| self.desired_model_list = DESIRED_APPEAR_MODEL | |
| self.enable_nsfw = enable_nsfw | |
| self.load_guard(enable_nsfw) | |
| self.loaded_models = {} | |
| if do_debug_packages: | |
| debug_packages() | |
| if do_pre_download: | |
| pre_download_all_models(include_video=False) | |
| try: | |
| self.generate_image_ig.zerogpu = True | |
| self.generate_image_ie.zerogpu = True | |
| self.generate_video_vg.zerogpu = True | |
| except: | |
| print("Failed to set zerogpu for wrapper_fns") | |
| pass | |
| def load_model_pipe(self, model_name): | |
| if not model_name in self.loaded_models: | |
| pipe = load_pipeline(model_name) | |
| self.loaded_models[model_name] = pipe | |
| else: | |
| pipe = self.loaded_models[model_name] | |
| return pipe | |
| def load_guard(self, enable_nsfw=True): | |
| model_id = "meta-llama/Llama-Guard-3-8B" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 | |
| token = os.getenv("HF_TOKEN") or os.getenv("HF_GUARD") | |
| if enable_nsfw: | |
| self.guard_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self.guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device) | |
| else: | |
| self.guard_tokenizer = None | |
| self.guard = None | |
| def NSFW_filter_simple(self, prompt): | |
| is_safe = check_prompt_safety(prompt) | |
| if is_safe: | |
| return "safe" | |
| else: | |
| return "unsafe" | |
| def NSFW_filter(self, prompt): | |
| chat = [{"role": "user", "content": prompt}] | |
| input_ids = self.guard_tokenizer.apply_chat_template(chat, return_tensors="pt").to('cuda') | |
| self.guard.cuda() | |
| if self.guard: | |
| def _generate(): | |
| return self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0) | |
| output = _generate() | |
| output = self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0) | |
| prompt_len = input_ids.shape[-1] | |
| result = self.guard_tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) | |
| return result | |
| else: | |
| # guard is disabled | |
| return "safe" | |
| def generate_image_ig(self, prompt, model_name): | |
| # if 'unsafe' not in self.NSFW_filter(prompt): | |
| print('The prompt is safe') | |
| pipe = self.load_model_pipe(model_name) | |
| result = pipe(prompt=prompt) | |
| # else: | |
| # print(f'The prompt "{prompt}" is not safe') | |
| # result = '' | |
| return result | |
| def generate_image_ig_api(self, prompt, model_name): | |
| # if 'unsafe' not in self.NSFW_filter(prompt): | |
| print('The prompt is safe') | |
| pipe = self.load_model_pipe(model_name) | |
| result = pipe(prompt=prompt) | |
| # else: | |
| # print(f'The prompt "{prompt}" is not safe') | |
| # result = '' | |
| return result | |
| def generate_image_ig_museum(self, model_name): | |
| model_name = model_name.split('_')[1] | |
| result_list = draw_from_imagen_museum("t2i", model_name) | |
| image_link = result_list[0] | |
| prompt = result_list[1] | |
| return image_link, prompt | |
| def generate_image_ig_parallel_anony(self, prompt, model_A, model_B): | |
| # Using list comprehension to get the difference between two lists | |
| picking_list = [item for item in self.model_ig_list if item not in self.excluding_model_list] | |
| if model_A == "" and model_B == "": | |
| # Filter desired_model_list to only include models that exist in picking_list | |
| valid_desired_models = [m for m in self.desired_model_list if m in picking_list] | |
| # 50% (or DESIRED_APPEAR_MODEL_CHANCE) chance to include exactly one model from valid desired_model_list | |
| if valid_desired_models and random.random() < DESIRED_APPEAR_MODEL_CHANCE: | |
| # Pick one model from valid desired list | |
| desired_model = random.choice(valid_desired_models) | |
| # Pick one model from regular list, excluding desired models | |
| regular_model = random.choice([m for m in picking_list if m not in valid_desired_models]) | |
| # Randomly determine order | |
| model_names = [desired_model, regular_model] if random.random() < 0.5 else [regular_model, desired_model] | |
| else: | |
| # Pick two models from the regular picking list | |
| model_names = random.sample([model for model in picking_list], 2) | |
| else: | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub") | |
| else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names] | |
| results = [future.result() for future in futures] | |
| return results[0], results[1], model_names[0], model_names[1] | |
| def generate_image_ig_museum_parallel_anony(self, model_A, model_B): | |
| # Using list comprehension to get the difference between two lists | |
| picking_list = [item for item in self.model_ig_list if item not in self.excluding_model_list] | |
| if model_A == "" and model_B == "": | |
| # Filter desired_model_list to only include models that exist in picking_list | |
| valid_desired_models = [m for m in self.desired_model_list if m in picking_list] | |
| # 50% (or DESIRED_APPEAR_MODEL_CHANCE) chance to include exactly one model from valid desired_model_list | |
| if valid_desired_models and random.random() < DESIRED_APPEAR_MODEL_CHANCE: | |
| # Pick one model from valid desired list | |
| desired_model = random.choice(valid_desired_models) | |
| # Pick one model from regular list, excluding desired models | |
| regular_model = random.choice([m for m in picking_list if m not in valid_desired_models]) | |
| # Randomly determine order | |
| model_names = [desired_model, regular_model] if random.random() < 0.5 else [regular_model, desired_model] | |
| else: | |
| # Pick two models from the regular picking list | |
| model_names = random.sample([model for model in picking_list], 2) | |
| else: | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| model_1 = model_names[0].split('_')[1] | |
| model_2 = model_names[1].split('_')[1] | |
| result_list = draw2_from_imagen_museum("t2i", model_1, model_2) | |
| image_links = result_list[0] | |
| prompt_list = result_list[1] | |
| return image_links[0], image_links[1], model_names[0], model_names[1], prompt_list[0] | |
| def generate_image_ig_parallel(self, prompt, model_A, model_B): | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub") | |
| else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names] | |
| results = [future.result() for future in futures] | |
| return results[0], results[1] | |
| def generate_image_ig_museum_parallel(self, model_A, model_B): | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| model_1 = model_A.split('_')[1] | |
| model_2 = model_B.split('_')[1] | |
| result_list = draw2_from_imagen_museum("t2i", model_1, model_2) | |
| image_links = result_list[0] | |
| prompt_list = result_list[1] | |
| return image_links[0], image_links[1], prompt_list[0] | |
| def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name): | |
| # if 'unsafe' not in self.NSFW_filter(" ".join([textbox_source, textbox_target, textbox_instruct])): | |
| pipe = self.load_model_pipe(model_name) | |
| result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct) | |
| # else: | |
| # result = '' | |
| return result | |
| def generate_image_ie_museum(self, model_name): | |
| model_name = model_name.split('_')[1] | |
| result_list = draw_from_imagen_museum("tie", model_name) | |
| image_links = result_list[0] | |
| prompt_list = result_list[1] | |
| # image_links = [src, model] | |
| # prompt_list = [source_caption, target_caption, instruction] | |
| return image_links[0], image_links[1], prompt_list[0], prompt_list[1], prompt_list[2] | |
| def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B): | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [ | |
| executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, | |
| model) for model in model_names] | |
| results = [future.result() for future in futures] | |
| return results[0], results[1] | |
| def generate_image_ie_museum_parallel(self, model_A, model_B): | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| model_1 = model_names[0].split('_')[1] | |
| model_2 = model_names[1].split('_')[1] | |
| result_list = draw2_from_imagen_museum("tie", model_1, model_2) | |
| image_links = result_list[0] | |
| prompt_list = result_list[1] | |
| # image_links = [src, model_A, model_B] | |
| # prompt_list = [source_caption, target_caption, instruction] | |
| return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2] | |
| def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B): | |
| # Using list comprehension to get the difference between two lists | |
| picking_list = [item for item in self.model_ie_list if item not in self.excluding_model_list] | |
| if model_A == "" and model_B == "": | |
| # Filter desired_model_list to only include models that exist in picking_list | |
| valid_desired_models = [m for m in self.desired_model_list if m in picking_list] | |
| # 50% (or DESIRED_APPEAR_MODEL_CHANCE) chance to include exactly one model from valid desired_model_list | |
| if valid_desired_models and random.random() < DESIRED_APPEAR_MODEL_CHANCE: | |
| # Pick one model from valid desired list | |
| desired_model = random.choice(valid_desired_models) | |
| # Pick one model from regular list, excluding desired models | |
| regular_model = random.choice([m for m in picking_list if m not in valid_desired_models]) | |
| # Randomly determine order | |
| model_names = [desired_model, regular_model] if random.random() < 0.5 else [regular_model, desired_model] | |
| else: | |
| # Pick two models from the regular picking list | |
| model_names = random.sample([model for model in picking_list], 2) | |
| else: | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names] | |
| results = [future.result() for future in futures] | |
| return results[0], results[1], model_names[0], model_names[1] | |
| def generate_image_ie_museum_parallel_anony(self, model_A, model_B): | |
| # Using list comprehension to get the difference between two lists | |
| picking_list = [item for item in self.model_ie_list if item not in self.excluding_model_list] | |
| if model_A == "" and model_B == "": | |
| # Filter desired_model_list to only include models that exist in picking_list | |
| valid_desired_models = [m for m in self.desired_model_list if m in picking_list] | |
| # 50% (or DESIRED_APPEAR_MODEL_CHANCE) chance to include exactly one model from valid desired_model_list | |
| if valid_desired_models and random.random() < DESIRED_APPEAR_MODEL_CHANCE: | |
| # Pick one model from valid desired list | |
| desired_model = random.choice(valid_desired_models) | |
| # Pick one model from regular list, excluding desired models | |
| regular_model = random.choice([m for m in picking_list if m not in valid_desired_models]) | |
| # Randomly determine order | |
| model_names = [desired_model, regular_model] if random.random() < 0.5 else [regular_model, desired_model] | |
| else: | |
| # Pick two models from the regular picking list | |
| model_names = random.sample([model for model in picking_list], 2) | |
| else: | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| model_1 = model_names[0].split('_')[1] | |
| model_2 = model_names[1].split('_')[1] | |
| result_list = draw2_from_imagen_museum("tie", model_1, model_2) | |
| image_links = result_list[0] | |
| prompt_list = result_list[1] | |
| # image_links = [src, model_A, model_B] | |
| # prompt_list = [source_caption, target_caption, instruction] | |
| return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2], model_names[0], model_names[1] | |
| def generate_video_vg(self, prompt, model_name): | |
| # if 'unsafe' not in self.NSFW_filter(prompt): | |
| pipe = self.load_model_pipe(model_name) | |
| result = pipe(prompt=prompt) | |
| # else: | |
| # result = '' | |
| return result | |
| def generate_video_vg_api(self, prompt, model_name): | |
| # if 'unsafe' not in self.NSFW_filter(prompt): | |
| pipe = self.load_model_pipe(model_name) | |
| result = pipe(prompt=prompt) | |
| # else: | |
| # result = '' | |
| return result | |
| def generate_video_vg_museum(self, model_name): | |
| model_name = model_name.split('_')[1] | |
| result_list = draw_from_videogen_museum("t2v", model_name) | |
| video_link = result_list[0] | |
| prompt = result_list[1] | |
| return video_link, prompt | |
| def generate_video_vg_parallel_anony(self, prompt, model_A, model_B): | |
| # Using list comprehension to get the difference between two lists | |
| picking_list = [item for item in self.model_vg_list if item not in self.excluding_model_list] | |
| if model_A == "" and model_B == "": | |
| # Filter desired_model_list to only include models that exist in picking_list | |
| valid_desired_models = [m for m in self.desired_model_list if m in picking_list] | |
| # 50% (or DESIRED_APPEAR_MODEL_CHANCE) chance to include exactly one model from valid desired_model_list | |
| if valid_desired_models and random.random() < DESIRED_APPEAR_MODEL_CHANCE: | |
| # Pick one model from valid desired list | |
| desired_model = random.choice(valid_desired_models) | |
| # Pick one model from regular list, excluding desired models | |
| regular_model = random.choice([m for m in picking_list if m not in valid_desired_models]) | |
| # Randomly determine order | |
| model_names = [desired_model, regular_model] if random.random() < 0.5 else [regular_model, desired_model] | |
| else: | |
| # Pick two models from the regular picking list | |
| model_names = random.sample([model for model in picking_list], 2) | |
| else: | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [executor.submit(self.generate_video_vg, prompt, model) if model.startswith("videogenhub") | |
| else executor.submit(self.generate_video_vg_api, prompt, model) for model in model_names] | |
| results = [future.result() for future in futures] | |
| return results[0], results[1], model_names[0], model_names[1] | |
| def generate_video_vg_museum_parallel_anony(self, model_A, model_B): | |
| # Using list comprehension to get the difference between two lists | |
| picking_list = [item for item in self.model_vg_list if item not in self.excluding_model_list] | |
| if model_A == "" and model_B == "": | |
| # Filter desired_model_list to only include models that exist in picking_list | |
| valid_desired_models = [m for m in self.desired_model_list if m in picking_list] | |
| # 50% (or DESIRED_APPEAR_MODEL_CHANCE) chance to include exactly one model from valid desired_model_list | |
| if valid_desired_models and random.random() < DESIRED_APPEAR_MODEL_CHANCE: | |
| # Pick one model from valid desired list | |
| desired_model = random.choice(valid_desired_models) | |
| # Pick one model from regular list, excluding desired models | |
| regular_model = random.choice([m for m in picking_list if m not in valid_desired_models]) | |
| # Randomly determine order | |
| model_names = [desired_model, regular_model] if random.random() < 0.5 else [regular_model, desired_model] | |
| else: | |
| # Pick two models from the regular picking list | |
| model_names = random.sample([model for model in picking_list], 2) | |
| else: | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| model_1 = model_names[0].split('_')[1] | |
| model_2 = model_names[1].split('_')[1] | |
| result_list = draw2_from_videogen_museum("t2v", model_1, model_2) | |
| video_links = result_list[0] | |
| prompt_list = result_list[1] | |
| return video_links[0], video_links[1], model_names[0], model_names[1], prompt_list[0] | |
| def generate_video_vg_parallel(self, prompt, model_A, model_B): | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [executor.submit(self.generate_video_vg, prompt, model) if model.startswith("videogenhub") | |
| else executor.submit(self.generate_video_vg_api, prompt, model) for model in model_names] | |
| results = [future.result() for future in futures] | |
| return results[0], results[1] | |
| def generate_video_vg_museum_parallel(self, model_A, model_B): | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| model_1 = model_A.split('_')[1] | |
| model_2 = model_B.split('_')[1] | |
| result_list = draw2_from_videogen_museum("t2v", model_1, model_2) | |
| video_links = result_list[0] | |
| prompt_list = result_list[1] | |
| return video_links[0], video_links[1], prompt_list[0] |