Spaces:
Runtime error
Runtime error
| import os | |
| os.system("pip uninstall -y gradio") | |
| os.system("pip install gradio==3.47") | |
| os.system("pip install diffusers -U") | |
| import torch | |
| import random | |
| import numpy as np | |
| import gradio as gr | |
| from glob import glob | |
| from datetime import datetime | |
| from diffusers import StableDiffusionPipeline | |
| from diffusers import DDIMScheduler, LCMScheduler | |
| import torch.nn.functional as F | |
| from PIL import Image,ImageDraw | |
| from utils.masactrl_utils import (AttentionBase, | |
| regiter_attention_editor_diffusers) | |
| from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d | |
| from utils.style_attn_control import MaskPromptedStyleAttentionControl | |
| from utils.pipeline import MasaCtrlPipeline | |
| from torchvision.utils import save_image | |
| from segment_anything import sam_model_registry, SamPredictor | |
| css = """ | |
| .toolbutton { | |
| margin-buttom: 0em 0em 0em 0em; | |
| max-width: 2.5em; | |
| min-width: 2.5em !important; | |
| height: 2.5em; | |
| } | |
| """ | |
| class GlobalText: | |
| def __init__(self): | |
| # config dirs | |
| self.basedir = os.getcwd() | |
| self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") | |
| self.personalized_model_dir = './models/Stable-diffusion' | |
| self.lora_model_dir = './models/Lora' | |
| self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) | |
| self.savedir_sample = os.path.join(self.savedir, "sample") | |
| self.savedir_mask = os.path.join(self.savedir, "mask") | |
| self.stable_diffusion_list = ["runwayml/stable-diffusion-v1-5", | |
| "latent-consistency/lcm-lora-sdv1-5"] | |
| self.personalized_model_list = [] | |
| self.lora_model_list = [] | |
| # config models | |
| self.tokenizer = None | |
| self.text_encoder = None | |
| self.vae = None | |
| self.unet = None | |
| self.pipeline = None | |
| self.lora_loaded = None | |
| self.lcm_lora_loaded = False | |
| self.personal_model_loaded = None | |
| self.sam_predictor = None | |
| self.lora_model_state_dict = {} | |
| self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| # self.refresh_stable_diffusion() | |
| self.refresh_personalized_model() | |
| self.reset_start_code() | |
| def load_base_pipeline(self, model_path): | |
| print(f'loading {model_path} model') | |
| scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler") | |
| self.pipeline = MasaCtrlPipeline.from_pretrained(model_path, | |
| scheduler=scheduler).to(self.device) | |
| def refresh_stable_diffusion(self): | |
| self.load_base_pipeline(self.stable_diffusion_list[0]) | |
| self.lora_loaded = None | |
| self.personal_model_loaded = None | |
| self.lcm_lora_loaded = False | |
| return self.stable_diffusion_list[0] | |
| def refresh_personalized_model(self): | |
| personalized_model_list = glob(os.path.join(self.personalized_model_dir, "**/*.safetensors"), recursive=True) | |
| self.personalized_model_list = {os.path.basename(file): file for file in personalized_model_list} | |
| lora_model_list = glob(os.path.join(self.lora_model_dir, "**/*.safetensors"), recursive=True) | |
| self.lora_model_list = {os.path.basename(file): file for file in lora_model_list} | |
| def update_stable_diffusion(self, stable_diffusion_dropdown): | |
| if stable_diffusion_dropdown == 'latent-consistency/lcm-lora-sdv1-5': | |
| self.load_lcm_lora() | |
| else: | |
| self.load_base_pipeline(stable_diffusion_dropdown) | |
| self.lora_loaded = None | |
| self.personal_model_loaded = None | |
| return gr.Dropdown() | |
| def update_base_model(self, base_model_dropdown): | |
| if self.pipeline is None: | |
| gr.Info(f"Please select a pretrained model path.") | |
| return None | |
| else: | |
| base_model = self.personalized_model_list[base_model_dropdown] | |
| mid_model = StableDiffusionPipeline.from_single_file(base_model) | |
| self.pipeline.vae = mid_model.vae | |
| self.pipeline.unet = mid_model.unet | |
| self.pipeline.text_encoder = mid_model.text_encoder | |
| self.pipeline.to(self.device) | |
| self.personal_model_loaded = base_model_dropdown.split('.')[0] | |
| print(f'load {base_model_dropdown} model success!') | |
| return gr.Dropdown() | |
| def update_lora_model(self, lora_model_dropdown,lora_alpha_slider): | |
| if self.pipeline is None: | |
| gr.Info(f"Please select a pretrained model path.") | |
| return None | |
| else: | |
| if lora_model_dropdown == "none": | |
| self.pipeline.unfuse_lora() | |
| self.pipeline.unload_lora_weights() | |
| self.lora_loaded = None | |
| print("Restore lora.") | |
| else: | |
| lora_model_path = self.lora_model_list[lora_model_dropdown] | |
| self.pipeline.load_lora_weights(lora_model_path) | |
| self.pipeline.fuse_lora(lora_alpha_slider) | |
| self.lora_loaded = lora_model_dropdown.split('.')[0] | |
| print(f'load {lora_model_dropdown} LoRA Model Success!') | |
| return gr.Dropdown() | |
| def load_lcm_lora(self, lora_alpha_slider=1.0): | |
| # set scheduler | |
| self.pipeline = MasaCtrlPipeline.from_pretrained(self.stable_diffusion_list[0]).to(self.device) | |
| self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config) | |
| # load LCM-LoRA | |
| self.pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") | |
| self.pipeline.fuse_lora(lora_alpha_slider) | |
| self.lcm_lora_loaded = True | |
| print(f'load LCM-LoRA model success!') | |
| def generate(self, source, style, source_mask, style_mask, | |
| start_step, start_layer, Style_attn_step, | |
| Method, Style_Guidance, ddim_steps, scale, seed, de_bug, | |
| target_prompt, negative_prompt_textbox, | |
| inter_latents, | |
| freeu, b1, b2, s1, s2, | |
| width_slider,height_slider, | |
| ): | |
| os.makedirs(self.savedir, exist_ok=True) | |
| os.makedirs(self.savedir_sample, exist_ok=True) | |
| os.makedirs(self.savedir_mask, exist_ok=True) | |
| model = self.pipeline | |
| if seed != -1 and seed != "": torch.manual_seed(int(seed)) | |
| else: torch.seed() | |
| seed = torch.initial_seed() | |
| sample_count = len(os.listdir(self.savedir_sample)) | |
| os.makedirs(os.path.join(self.savedir_mask, f"results_{sample_count}"), exist_ok=True) | |
| # ref_prompt = [source_prompt, target_prompt] | |
| # prompts = ref_prompt+[''] | |
| ref_prompt = [target_prompt, target_prompt] | |
| prompts = ref_prompt+[target_prompt] | |
| source_image,style_image,source_mask,style_mask = load_mask_images(source,style,source_mask,style_mask,self.device,width_slider,height_slider,out_dir=os.path.join(self.savedir_mask, f"results_{sample_count}")) | |
| # global START_CODE, LATENTS_LIST | |
| with torch.no_grad(): | |
| #import pdb;pdb.set_trace() | |
| #prev_source | |
| if self.start_code is None and self.latents_list is None: | |
| content_style = torch.cat([style_image, source_image], dim=0) | |
| editor = AttentionBase() | |
| regiter_attention_editor_diffusers(model, editor) | |
| st_code, latents_list = model.invert(content_style, | |
| ref_prompt, | |
| guidance_scale=scale, | |
| num_inference_steps=ddim_steps, | |
| return_intermediates=True) | |
| start_code = torch.cat([st_code, st_code[1:]], dim=0) | |
| self.start_code = start_code | |
| self.latents_list = latents_list | |
| else: | |
| start_code = self.start_code | |
| latents_list = self.latents_list | |
| print('------------------------------------------ Use previous latents ------------------------------------------ ') | |
| #["Without mask", "Only masked region", "Seperate Background Foreground"] | |
| if Method == "Without mask": | |
| style_mask = None | |
| source_mask = None | |
| only_masked_region = False | |
| elif Method == "Only masked region": | |
| assert style_mask is not None and source_mask is not None | |
| only_masked_region = True | |
| else: | |
| assert style_mask is not None and source_mask is not None | |
| only_masked_region = False | |
| controller = MaskPromptedStyleAttentionControl(start_step, start_layer, | |
| style_attn_step=Style_attn_step, | |
| style_guidance=Style_Guidance, | |
| style_mask=style_mask, | |
| source_mask=source_mask, | |
| only_masked_region=only_masked_region, | |
| guidance=scale, | |
| de_bug=de_bug, | |
| ) | |
| if freeu: | |
| # model.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) | |
| print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++') | |
| if Method != "Without mask": | |
| register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask) | |
| register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask) | |
| else: | |
| register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None) | |
| register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None) | |
| else: | |
| print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++') | |
| # model.disable_freeu() | |
| register_upblock2d(model) | |
| register_crossattn_upblock2d(model) | |
| regiter_attention_editor_diffusers(model, controller) | |
| # inference the synthesized image | |
| generate_image= model(prompts, | |
| width=width_slider, | |
| height=height_slider, | |
| latents=start_code, | |
| guidance_scale=scale, | |
| num_inference_steps=ddim_steps, | |
| ref_intermediate_latents=latents_list if inter_latents else None, | |
| neg_prompt=negative_prompt_textbox, | |
| return_intermediates=False, | |
| lcm_lora=self.lcm_lora_loaded, | |
| de_bug=de_bug,) | |
| # os.makedirs(os.path.join(output_dir, f"results_{sample_count}")) | |
| save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg" | |
| if self.lora_loaded != None: | |
| save_file_name = f"lora_{self.lora_loaded}_" + save_file_name | |
| if self.personal_model_loaded != None: | |
| save_file_name = f"personal_{self.personal_model_loaded}_" + save_file_name | |
| #f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}_lora_{self.lora_loaded}.jpg" | |
| save_file_path = os.path.join(self.savedir_sample, save_file_name) | |
| #save_file_name = os.path.join(output_dir, f"results_style_{style_name}", f"{content_name}.jpg") | |
| save_image(torch.cat([source_image/2 + 0.5, style_image/2 + 0.5, generate_image[2:]], dim=0), save_file_path, nrow=3, padding=0) | |
| # global OUTPUT_RESULT | |
| # OUTPUT_RESULT = save_file_name | |
| generate_image = generate_image.cpu().permute(0, 2, 3, 1).numpy() | |
| #save_gif(latents_list, os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif')) | |
| # import pdb;pdb.set_trace() | |
| #gif_dir = os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif') | |
| return [ | |
| generate_image[0], | |
| generate_image[1], | |
| generate_image[2], | |
| ] | |
| def reset_start_code(self,): | |
| self.start_code = None | |
| self.latents_list = None | |
| def lora_sam_predictor(self, sam_path): | |
| sam_checkpoint = sam_path | |
| model_type = "vit_h" | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
| sam.to(device=self.device) | |
| self.sam_predictor = SamPredictor(sam) | |
| self.sam_point = [] | |
| self.sam_point_label = [] | |
| def get_points_with_draw(self, image, image_with_points, label, evt: gr.SelectData): | |
| x, y = evt.index[0], evt.index[1] | |
| point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255) | |
| self.sam_point.append([x, y]) | |
| self.sam_point_label.append(1 if label == 'Add Mask' else 0) | |
| print(x, y, label == 'Add Mask') | |
| if image_with_points is None: | |
| draw = ImageDraw.Draw(image) | |
| draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color) | |
| return image | |
| else: | |
| draw = ImageDraw.Draw(image_with_points) | |
| draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color) | |
| return image_with_points | |
| def reset_sam_points(self,): | |
| self.sam_point = [] | |
| self.sam_point_label = [] | |
| print('reset all points') | |
| return None | |
| def obtain_mask(self, image,sam_path): | |
| if self.sam_predictor is None: | |
| self.lora_sam_predictor(sam_path) | |
| print("+++++++++++++++++++ Obtain Mask by SAM ++++++++++++++++++++++") | |
| input_point = np.array(self.sam_point) | |
| input_label = np.array(self.sam_point_label) | |
| predictor = self.sam_predictor | |
| image = np.array(image) | |
| predictor.set_image(image) | |
| # input_point = np.array([[500, 375]]) | |
| # input_label = np.array([1]) | |
| masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=False) | |
| # import pdb; pdb.set_trace() | |
| masks = masks.astype(np.uint8) | |
| masks = masks * 255 | |
| masks = masks.transpose(1,2,0) | |
| masks = masks.repeat(3, axis=2) | |
| return masks | |
| global_text = GlobalText() | |
| def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None): | |
| # invert the image into noise map | |
| if isinstance(source['image'], np.ndarray): | |
| source_image = torch.from_numpy(source['image']).to(device) / 127.5 - 1. | |
| else: | |
| source_image = torch.from_numpy(np.array(source['image'])).to(device) / 127.5 - 1. | |
| source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2) | |
| source_image = F.interpolate(source_image, (height,width )) | |
| if out_dir is not None and source_mask is None: | |
| source['mask'].save(os.path.join(out_dir,'source_mask.jpg')) | |
| else: | |
| Image.fromarray(source_mask).save(os.path.join(out_dir,'source_mask.jpg')) | |
| if out_dir is not None and style_mask is None: | |
| style['mask'].save(os.path.join(out_dir,'style_mask.jpg')) | |
| else: | |
| Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg')) | |
| source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255. | |
| source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1] | |
| source_mask = F.interpolate(source_mask, (height//8,width//8)) | |
| if isinstance(source['image'], np.ndarray): | |
| style_image = torch.from_numpy(style['image']).to(device) / 127.5 - 1. | |
| else: | |
| style_image = torch.from_numpy(np.array(style['image'])).to(device) / 127.5 - 1. | |
| style_image = style_image.unsqueeze(0).permute(0, 3, 1, 2) | |
| style_image = F.interpolate(style_image, (height,width)) | |
| style_mask = torch.from_numpy(np.array(style['mask']) if style_mask is None else style_mask ).to(device) / 255. | |
| style_mask = style_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1] | |
| style_mask = F.interpolate(style_mask, (height//8,width//8)) | |
| return source_image,style_image,source_mask,style_mask | |
| def ui(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/2312.02212) | |
| Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)<br> | |
| [Arxiv Report](https://arxiv.org/abs/2312.02212) | [Github](https://github.com/liujin112/PortraitDiffusion) | |
| """ | |
| ) | |
| with gr.Column(variant="panel"): | |
| gr.Markdown( | |
| """ | |
| ### 1. Select a pretrained model. | |
| """ | |
| ) | |
| with gr.Row(): | |
| stable_diffusion_dropdown = gr.Dropdown( | |
| label="Pretrained Model Path", | |
| choices=global_text.stable_diffusion_list, | |
| interactive=True, | |
| allow_custom_value=True | |
| ) | |
| stable_diffusion_dropdown.change(fn=global_text.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) | |
| stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
| def update_stable_diffusion(): | |
| global_text.refresh_stable_diffusion() | |
| stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[]) | |
| base_model_dropdown = gr.Dropdown( | |
| label="Select a ckpt model (optional)", | |
| choices=sorted(list(global_text.personalized_model_list.keys())), | |
| interactive=True, | |
| allow_custom_value=True, | |
| ) | |
| base_model_dropdown.change(fn=global_text.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown]) | |
| lora_model_dropdown = gr.Dropdown( | |
| label="Select a LoRA model (optional)", | |
| choices=["none"] + sorted(list(global_text.lora_model_list.keys())), | |
| value="none", | |
| interactive=True, | |
| allow_custom_value=True, | |
| ) | |
| lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) | |
| lora_model_dropdown.change(fn=global_text.update_lora_model, inputs=[lora_model_dropdown,lora_alpha_slider], outputs=[lora_model_dropdown]) | |
| personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
| def update_personalized_model(): | |
| global_text.refresh_personalized_model() | |
| return [ | |
| gr.Dropdown(choices=sorted(list(global_text.personalized_model_list.keys()))), | |
| gr.Dropdown(choices=["none"] + sorted(list(global_text.lora_model_list.keys()))) | |
| ] | |
| personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) | |
| with gr.Column(variant="panel"): | |
| gr.Markdown( | |
| """ | |
| ### 2. Configs for PortraitDiff. | |
| """ | |
| ) | |
| with gr.Tab("Configs"): | |
| with gr.Row(): | |
| source_image = gr.Image(label="Source Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512) | |
| style_image = gr.Image(label="Style Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512) | |
| with gr.Row(): | |
| prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1) | |
| negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1) | |
| # output_dir = gr.Textbox(label="output_dir", value='./results/') | |
| with gr.Row().style(equal_height=False): | |
| with gr.Column(): | |
| width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) | |
| height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) | |
| Method = gr.Dropdown( | |
| ["Without mask", "Only masked region", "Seperate Background Foreground"], | |
| value="Without mask", | |
| label="Mask", info="Select how to use masks") | |
| with gr.Tab('Base Configs'): | |
| with gr.Row(): | |
| # sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) | |
| ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=0, maximum=100, step=1) | |
| Style_attn_step = gr.Slider(label="Step of Style Attention Control", | |
| minimum=0, | |
| maximum=50, | |
| value=35, | |
| step=1) | |
| start_step = gr.Slider(label="Step of Attention Control", | |
| minimum=0, | |
| maximum=150, | |
| value=0, | |
| step=1) | |
| start_layer = gr.Slider(label="Layer of Style Attention Control", | |
| minimum=0, | |
| maximum=16, | |
| value=10, | |
| step=1) | |
| Style_Guidance = gr.Slider(label="Style Guidance Scale", | |
| minimum=0, | |
| maximum=4, | |
| value=1.2, | |
| step=0.05) | |
| cfg_scale_slider = gr.Slider(label="CFG Scale", value=0, minimum=0, maximum=20) | |
| with gr.Tab('FreeU'): | |
| with gr.Row(): | |
| freeu = gr.Checkbox(label="Free Upblock", value=False) | |
| de_bug = gr.Checkbox(value=False,label='DeBug') | |
| inter_latents = gr.Checkbox(value=True,label='Use intermediate latents') | |
| with gr.Row(): | |
| b1 = gr.Slider(label='b1:', | |
| minimum=-1, | |
| maximum=2, | |
| step=0.01, | |
| value=1.3) | |
| b2 = gr.Slider(label='b2:', | |
| minimum=-1, | |
| maximum=2, | |
| step=0.01, | |
| value=1.5) | |
| with gr.Row(): | |
| s1 = gr.Slider(label='s1: ', | |
| minimum=0, | |
| maximum=2, | |
| step=0.1, | |
| value=1.0) | |
| s2 = gr.Slider(label='s2:', | |
| minimum=0, | |
| maximum=2, | |
| step=0.1, | |
| value=1.0) | |
| with gr.Row(): | |
| seed_textbox = gr.Textbox(label="Seed", value=-1) | |
| seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") | |
| seed_button.click(fn=lambda: random.randint(1, 1e8), inputs=[], outputs=[seed_textbox]) | |
| with gr.Column(): | |
| generate_button = gr.Button(value="Generate", variant='primary') | |
| generate_image = gr.Image(label="Image with PortraitDiff", interactive=False, type='numpy', height=512,) | |
| with gr.Row(): | |
| recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256) | |
| recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256) | |
| with gr.Tab("SAM"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)") | |
| sam_path = gr.Textbox(label="Sam Model path", value='') | |
| load_sam_btn = gr.Button(value="Lora SAM form path") | |
| with gr.Row(): | |
| send_source_btn = gr.Button(value="Send Source Image from PD Tab") | |
| sam_source_btn = gr.Button(value="Segment Source") | |
| send_style_btn = gr.Button(value="Send Style Image from PD Tab") | |
| sam_style_btn = gr.Button(value="Segment Style") | |
| with gr.Row(): | |
| source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512) | |
| style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512) | |
| with gr.Row(): | |
| source_image_with_points = gr.Image(label="source Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256) | |
| source_mask = gr.Image(label="Source Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256) | |
| style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256) | |
| style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256) | |
| load_sam_btn.click(global_text.lora_sam_predictor,inputs=[sam_path],outputs=[]) | |
| source_image_sam.select(global_text.get_points_with_draw, [source_image_sam, source_image_with_points, add_or_remove], source_image_with_points) | |
| style_image_sam.select(global_text.get_points_with_draw, [style_image_sam, style_image_with_points, add_or_remove], style_image_with_points) | |
| send_source_btn.click(lambda x: (x['image'], None), inputs=[source_image], outputs=[source_image_sam, source_image_with_points]) | |
| send_style_btn.click(lambda x: (x['image'], None), inputs=[style_image], outputs=[style_image_sam, style_image_with_points]) | |
| style_image_sam.change(global_text.reset_sam_points, inputs=[], outputs=[style_image_with_points]) | |
| source_image_sam.change(global_text.reset_sam_points, inputs=[], outputs=[source_image_with_points]) | |
| sam_source_btn.click(global_text.obtain_mask,[source_image_sam, sam_path],[source_mask]) | |
| sam_style_btn.click(global_text.obtain_mask,[style_image_sam, sam_path],[style_mask]) | |
| gr.Examples( | |
| [[os.path.join(os.path.dirname(__file__), "images/content/1.jpg"), | |
| os.path.join(os.path.dirname(__file__), "images/style/1.jpg")], | |
| ], | |
| [source_image, style_image] | |
| ) | |
| inputs = [ | |
| source_image, style_image, source_mask, style_mask, | |
| start_step, start_layer, Style_attn_step, | |
| Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug, | |
| prompt_textbox, negative_prompt_textbox, inter_latents, | |
| freeu, b1, b2, s1, s2, | |
| width_slider,height_slider | |
| ] | |
| generate_button.click( | |
| fn=global_text.generate, | |
| inputs=inputs, | |
| outputs=[recons_style,recons_content,generate_image] | |
| ) | |
| source_image.upload(global_text.reset_start_code, inputs=[], outputs=[]) | |
| style_image.upload(global_text.reset_start_code, inputs=[], outputs=[]) | |
| ddim_steps.change(fn=global_text.reset_start_code, inputs=[], outputs=[]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = ui() | |
| demo.launch() |