Spaces:
Running
on
Zero
Running
on
Zero
| # Prediction interface for Cog ⚙️ | |
| # https://github.com/replicate/cog/blob/main/docs/python.md | |
| from cog import BasePredictor, Input, Path | |
| import os | |
| import time | |
| import subprocess | |
| from typing import List | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.utils.checkpoint | |
| from pytorch_lightning import seed_everything | |
| from diffusers import AutoencoderKL, DDPMScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor | |
| from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline | |
| from utils.wavelet_color_fix import wavelet_color_fix | |
| from ram.models.ram_lora import ram | |
| from ram import inference_ram as inference | |
| from torchvision import transforms | |
| from models.controlnet import ControlNetModel | |
| from models.unet_2d_condition import UNet2DConditionModel | |
| MODEL_URL = "https://weights.replicate.delivery/default/stabilityai/sd-2-1-base.tar" | |
| tensor_transforms = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| ram_transforms = transforms.Compose([ | |
| transforms.Resize((384, 384)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| device = "cuda" | |
| def download_weights(url, dest): | |
| start = time.time() | |
| print("downloading url: ", url) | |
| print("downloading to: ", dest) | |
| subprocess.check_call(["pget", "-x", url, dest], close_fds=False) | |
| print("downloading took: ", time.time() - start) | |
| class Predictor(BasePredictor): | |
| def setup(self) -> None: | |
| """Load the model into memory to make running multiple predictions efficient""" | |
| # Load scheduler, tokenizer and models. | |
| pretrained_model_path = 'preset/models/stable-diffusion-2-1-base' | |
| seesr_model_path = 'preset/models/seesr' | |
| # Download SD-2-1 weights | |
| if not os.path.exists(pretrained_model_path): | |
| download_weights(MODEL_URL, pretrained_model_path) | |
| scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
| feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") | |
| unet = UNet2DConditionModel.from_pretrained(seesr_model_path, subfolder="unet") | |
| controlnet = ControlNetModel.from_pretrained(seesr_model_path, subfolder="controlnet") | |
| # Freeze vae and text_encoder | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| unet.requires_grad_(False) | |
| controlnet.requires_grad_(False) | |
| if is_xformers_available(): | |
| unet.enable_xformers_memory_efficient_attention() | |
| controlnet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| # Get the validation pipeline | |
| validation_pipeline = StableDiffusionControlNetPipeline( | |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, | |
| unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, | |
| ) | |
| validation_pipeline._init_tiled_vae(encoder_tile_size=1024,decoder_tile_size=224) | |
| self.validation_pipeline = validation_pipeline | |
| weight_dtype = torch.float16 | |
| # Move text_encode and vae to gpu and cast to weight_dtype | |
| text_encoder.to(device, dtype=weight_dtype) | |
| vae.to(device, dtype=weight_dtype) | |
| unet.to(device, dtype=weight_dtype) | |
| controlnet.to(device, dtype=weight_dtype) | |
| tag_model = ram(pretrained='preset/models/ram_swin_large_14m.pth', | |
| pretrained_condition='preset/models/DAPE.pth', | |
| image_size=384, | |
| vit='swin_l') | |
| tag_model.eval() | |
| self.tag_model = tag_model.to(device, dtype=weight_dtype) | |
| # @torch.no_grad() | |
| def process( | |
| self, | |
| input_image: Image.Image, | |
| user_prompt: str, | |
| positive_prompt: str, | |
| negative_prompt: str, | |
| num_inference_steps: int, | |
| scale_factor: int, | |
| cfg_scale: float, | |
| seed: int, | |
| latent_tiled_size: int, | |
| latent_tiled_overlap: int, | |
| sample_times: int | |
| ) -> List[np.ndarray]: | |
| process_size = 512 | |
| resize_preproc = transforms.Compose([ | |
| transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), | |
| ]) | |
| seed_everything(seed) | |
| generator = torch.Generator(device=device) | |
| validation_prompt = "" | |
| lq = tensor_transforms(input_image).unsqueeze(0).to(device).half() | |
| lq = ram_transforms(lq) | |
| res = inference(lq, self.tag_model) | |
| ram_encoder_hidden_states = self.tag_model.generate_image_embeds(lq) | |
| validation_prompt = f"{res[0]}, {positive_prompt}," | |
| validation_prompt = validation_prompt if user_prompt=='' else f"{user_prompt}, {validation_prompt}" | |
| ori_width, ori_height = input_image.size | |
| resize_flag = False | |
| rscale = scale_factor | |
| input_image = input_image.resize((int(input_image.size[0] * rscale), int(input_image.size[1] * rscale))) | |
| if min(input_image.size) < process_size: | |
| input_image = resize_preproc(input_image) | |
| input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)) | |
| width, height = input_image.size | |
| resize_flag = True | |
| images = [] | |
| for _ in range(sample_times): | |
| try: | |
| with torch.autocast("cuda"): | |
| image = self.validation_pipeline( | |
| validation_prompt, input_image, negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, generator=generator, | |
| height=height, width=width, | |
| guidance_scale=cfg_scale, conditioning_scale=1, | |
| start_point='lr', start_steps=999,ram_encoder_hidden_states=ram_encoder_hidden_states, | |
| latent_tiled_size=latent_tiled_size, latent_tiled_overlap=latent_tiled_overlap | |
| ).images[0] | |
| if True: # alpha<1.0: | |
| image = wavelet_color_fix(image, input_image) | |
| if resize_flag: | |
| image = image.resize((ori_width * rscale, ori_height * rscale)) | |
| except Exception as e: | |
| print(e) | |
| image = Image.new(mode="RGB", size=(512, 512)) | |
| images.append(np.array(image)) | |
| return images | |
| def predict( | |
| self, | |
| image: Path = Input(description="Input image"), | |
| user_prompt: str = Input(description="Prompt to condition on", default=""), | |
| positive_prompt: str = Input(description="Prompt to add", default="clean, high-resolution, 8k"), | |
| negative_prompt: str = Input(description="Prompt to remove", default="dotted, noise, blur, lowres, smooth"), | |
| cfg_scale: float = Input(description="Guidance scale, set value to >1 to use", default=5.5, ge=0.1, le=10.0), | |
| num_inference_steps: int = Input(description="Number of inference steps", default=50, ge=10, le=100), | |
| sample_times: int = Input(description="Number of samples to generate", default=1, ge=1, le=10), | |
| latent_tiled_size: int = Input(description="Size of latent tiles", default=320, ge=128, le=480), | |
| latent_tiled_overlap: int = Input(description="Overlap of latent tiles", default=4, ge=4, le=16), | |
| scale_factor: int = Input(description="Scale factor", default=4), | |
| seed: int = Input(description="Seed", default=231, ge=0, le=2147483647), | |
| ) -> List[Path]: | |
| """Run a single prediction on the model""" | |
| pil_image = Image.open(image).convert("RGB") | |
| imgs = self.process( | |
| pil_image, user_prompt, positive_prompt, negative_prompt, num_inference_steps, | |
| scale_factor, cfg_scale, seed, latent_tiled_size, latent_tiled_overlap, sample_times) | |
| # Clear output folder | |
| os.system("rm -rf /tmp/output") | |
| # Create output folder | |
| os.system("mkdir /tmp/output") | |
| # Save images to output folder | |
| output_paths = [] | |
| for i, img in enumerate(imgs): | |
| img = Image.fromarray(img) | |
| output_path = f"/tmp/output/{i}.png" | |
| img.save(output_path) | |
| output_paths.append(Path(output_path)) | |
| return output_paths |