Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Dict, Optional | |
| import numpy as np | |
| import torch | |
| from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from PIL import Image, ImageChops, ImageEnhance | |
| from rembg import new_session, remove | |
| from transformers import DPTForDepthEstimation, DPTImageProcessor | |
| from ip_adapter_instantstyle import IPAdapterXL | |
| from ip_adapter_instantstyle.utils import register_cross_attention_hook | |
| from parametric_control_mlp import control_mlp | |
| file_dir = os.path.dirname(os.path.abspath(__file__)) | |
| base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
| image_encoder_path = "models/image_encoder" | |
| ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" | |
| controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" | |
| # Cache for rembg sessions | |
| _session_cache = None | |
| CONTROL_MLPS = ["metallic", "roughness", "transparency", "glow"] | |
| def get_session(): | |
| global _session_cache | |
| if _session_cache is None: | |
| _session_cache = new_session() | |
| return _session_cache | |
| def get_device(): | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def setup_control_mlps( | |
| features: int = 1024, | |
| device: Optional[str] = None, | |
| dtype: torch.dtype = torch.float16, | |
| ) -> Dict[str, torch.nn.Module]: | |
| ret = {} | |
| if device is None: | |
| device = get_device() | |
| print(f"Setting up control MLPs on {device}") | |
| for mlp in CONTROL_MLPS: | |
| ret[mlp] = setup_control_mlp(mlp, features, device, dtype) | |
| return ret | |
| def setup_control_mlp( | |
| material_parameter: str, | |
| features: int = 1024, | |
| device: Optional[str] = None, | |
| dtype: torch.dtype = torch.float16, | |
| ): | |
| if device is None: | |
| device = get_device() | |
| net = control_mlp(features) | |
| net.load_state_dict( | |
| torch.load( | |
| os.path.join(file_dir, f"model_weights/{material_parameter}.pt"), | |
| map_location=device | |
| ) | |
| ) | |
| net.to(device, dtype=dtype) | |
| net.eval() | |
| return net | |
| def download_ip_adapter(): | |
| repo_id = "h94/IP-Adapter" | |
| target_folders = ["models/", "sdxl_models/"] | |
| local_dir = file_dir | |
| # Check if folders exist and contain files | |
| folders_exist = all( | |
| os.path.exists(os.path.join(local_dir, folder)) for folder in target_folders | |
| ) | |
| if folders_exist: | |
| # Check if any of the target folders are empty | |
| folders_empty = any( | |
| len(os.listdir(os.path.join(local_dir, folder))) == 0 | |
| for folder in target_folders | |
| ) | |
| if not folders_empty: | |
| print("IP-Adapter files already downloaded. Skipping download.") | |
| return | |
| # List all files in the repo | |
| all_files = list_repo_files(repo_id) | |
| # Filter for files in the desired folders | |
| filtered_files = [ | |
| f for f in all_files if any(f.startswith(folder) for folder in target_folders) | |
| ] | |
| # Download each file | |
| for file_path in filtered_files: | |
| local_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=file_path, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=False, | |
| ) | |
| print(f"Downloaded: {file_path} to {local_path}") | |
| def setup_pipeline( | |
| device: Optional[str] = None, | |
| dtype: torch.dtype = torch.float16, | |
| ): | |
| if device is None: | |
| device = get_device() | |
| print(f"Setting up pipeline on {device}") | |
| download_ip_adapter() | |
| cur_block = ("up", 0, 1) | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype | |
| ).to(device) | |
| pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( | |
| base_model_path, | |
| controlnet=controlnet, | |
| use_safetensors=True, | |
| torch_dtype=dtype, | |
| add_watermarker=False, | |
| ).to(device) | |
| pipe.unet = register_cross_attention_hook(pipe.unet) | |
| block_name = ( | |
| cur_block[0] | |
| + "_blocks." | |
| + str(cur_block[1]) | |
| + ".attentions." | |
| + str(cur_block[2]) | |
| ) | |
| print("Testing block {}".format(block_name)) | |
| return IPAdapterXL( | |
| pipe, | |
| os.path.join(file_dir, image_encoder_path), | |
| os.path.join(file_dir, ip_ckpt), | |
| device, | |
| target_blocks=[block_name], | |
| ) | |
| def get_dpt_model(device: Optional[str] = None, dtype: torch.dtype = torch.float16): | |
| if device is None: | |
| device = get_device() | |
| image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") | |
| model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas") | |
| model.to(device, dtype=dtype) | |
| model.eval() | |
| return model, image_processor | |
| def run_dpt_depth( | |
| image: Image.Image, model, processor, device: Optional[str] = None | |
| ) -> Image.Image: | |
| """Run DPT depth estimation on an image.""" | |
| if device is None: | |
| device = get_device() | |
| # Prepare image | |
| inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype) | |
| # Get depth prediction | |
| with torch.no_grad(): | |
| depth_map = model(**inputs).predicted_depth | |
| # Now normalize to 0-1 range | |
| depth_map = (depth_map - depth_map.min()) / ( | |
| depth_map.max() - depth_map.min() + 1e-7 | |
| ) | |
| depth_map = depth_map.clip(0, 1) * 255 | |
| # Convert to PIL Image | |
| depth_map = depth_map.squeeze().cpu().numpy().astype(np.uint8) | |
| return Image.fromarray(depth_map).resize((1024, 1024)) | |
| def prepare_mask(image: Image.Image) -> Image.Image: | |
| """Prepare mask from image using rembg.""" | |
| rm_bg = remove(image, session=get_session()) | |
| target_mask = ( | |
| rm_bg.convert("RGB") | |
| .point(lambda x: 0 if x < 1 else 255) | |
| .convert("L") | |
| .convert("RGB") | |
| ) | |
| return target_mask.resize((1024, 1024)) | |
| def prepare_init_image(image: Image.Image, mask: Image.Image) -> Image.Image: | |
| """Prepare initial image for inpainting.""" | |
| # Create grayscale version | |
| gray_image = image.convert("L").convert("RGB") | |
| gray_image = ImageEnhance.Brightness(gray_image).enhance(1.0) | |
| # Create mask inversions | |
| invert_mask = ImageChops.invert(mask) | |
| # Combine images | |
| grayscale_img = ImageChops.darker(gray_image, mask) | |
| img_black_mask = ImageChops.darker(image, invert_mask) | |
| init_img = ImageChops.lighter(img_black_mask, grayscale_img) | |
| return init_img.resize((1024, 1024)) | |
| def run_parametric_control( | |
| ip_model, | |
| target_image: Image.Image, | |
| edit_mlps: dict[torch.nn.Module, float], | |
| texture_image: Image.Image = None, | |
| num_inference_steps: int = 30, | |
| seed: int = 42, | |
| depth_map: Optional[Image.Image] = None, | |
| mask: Optional[Image.Image] = None, | |
| ) -> Image.Image: | |
| """Run parametric control with metallic and roughness adjustments.""" | |
| # Get depth map | |
| if depth_map is None: | |
| print("No depth map provided, running DPT depth estimation") | |
| model, processor = get_dpt_model() | |
| depth_map = run_dpt_depth(target_image, model, processor) | |
| else: | |
| depth_map = depth_map.resize((1024, 1024)) | |
| # Prepare mask and init image | |
| if mask is None: | |
| print("No mask provided, preparing mask") | |
| mask = prepare_mask(target_image) | |
| else: | |
| mask = mask.resize((1024, 1024)) | |
| print("Preparing initial image") | |
| if texture_image is None: | |
| texture_image = target_image | |
| init_img = prepare_init_image(target_image, mask) | |
| # Generate edit | |
| print("Generating parametric edit") | |
| images = ip_model.generate_parametric_edits( | |
| texture_image, | |
| image=init_img, | |
| control_image=depth_map, | |
| mask_image=mask, | |
| controlnet_conditioning_scale=1.0, | |
| num_samples=1, | |
| num_inference_steps=num_inference_steps, | |
| seed=seed, | |
| edit_mlps=edit_mlps, | |
| strength=1.0, | |
| ) | |
| return images[0] | |
| def run_blend( | |
| ip_model, | |
| target_image: Image.Image, | |
| texture_image1: Image.Image, | |
| texture_image2: Image.Image, | |
| edit_strength: float = 0.0, | |
| num_inference_steps: int = 20, | |
| seed: int = 1, | |
| depth_map: Optional[Image.Image] = None, | |
| mask: Optional[Image.Image] = None, | |
| ) -> Image.Image: | |
| """Run blending between two texture images.""" | |
| # Get depth map | |
| if depth_map is None: | |
| print("No depth map provided, running DPT depth estimation") | |
| model, processor = get_dpt_model() | |
| depth_map = run_dpt_depth(target_image, model, processor) | |
| else: | |
| depth_map = depth_map.resize((1024, 1024)) | |
| # Prepare mask and init image | |
| if mask is None: | |
| print("No mask provided, preparing mask") | |
| mask = prepare_mask(target_image) | |
| else: | |
| mask = mask.resize((1024, 1024)) | |
| print("Preparing initial image") | |
| init_img = prepare_init_image(target_image, mask) | |
| # Generate edit | |
| print("Generating edit") | |
| images = ip_model.generate_edit( | |
| start_image=texture_image1, | |
| pil_image=texture_image1, | |
| pil_image2=texture_image2, | |
| image=init_img, | |
| control_image=depth_map, | |
| mask_image=mask, | |
| controlnet_conditioning_scale=1.0, | |
| num_samples=1, | |
| num_inference_steps=num_inference_steps, | |
| seed=seed, | |
| edit_strength=edit_strength, | |
| clip_strength=1.0, | |
| strength=1.0, | |
| ) | |
| return images[0] | |