Spaces:
Running
on
Zero
Running
on
Zero
| from PIL import Image, ImageFilter, ImageDraw | |
| import cv2 | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as T | |
| import random | |
| class Subject200KDataset(Dataset): | |
| def __init__( | |
| self, | |
| base_dataset, | |
| condition_size: int = 512, | |
| target_size: int = 512, | |
| image_size: int = 512, | |
| padding: int = 0, | |
| condition_type: str = "subject", | |
| drop_text_prob: float = 0.1, | |
| drop_image_prob: float = 0.1, | |
| return_pil_image: bool = False, | |
| ): | |
| self.base_dataset = base_dataset | |
| self.condition_size = condition_size | |
| self.target_size = target_size | |
| self.image_size = image_size | |
| self.padding = padding | |
| self.condition_type = condition_type | |
| self.drop_text_prob = drop_text_prob | |
| self.drop_image_prob = drop_image_prob | |
| self.return_pil_image = return_pil_image | |
| self.to_tensor = T.ToTensor() | |
| def __len__(self): | |
| return len(self.base_dataset) * 2 | |
| def __getitem__(self, idx): | |
| # If target is 0, left image is target, right image is condition | |
| target = idx % 2 | |
| item = self.base_dataset[idx // 2] | |
| # Crop the image to target and condition | |
| image = item["image"] | |
| left_img = image.crop( | |
| ( | |
| self.padding, | |
| self.padding, | |
| self.image_size + self.padding, | |
| self.image_size + self.padding, | |
| ) | |
| ) | |
| right_img = image.crop( | |
| ( | |
| self.image_size + self.padding * 2, | |
| self.padding, | |
| self.image_size * 2 + self.padding * 2, | |
| self.image_size + self.padding, | |
| ) | |
| ) | |
| # Get the target and condition image | |
| target_image, condition_img = ( | |
| (left_img, right_img) if target == 0 else (right_img, left_img) | |
| ) | |
| # Resize the image | |
| condition_img = condition_img.resize( | |
| (self.condition_size, self.condition_size) | |
| ).convert("RGB") | |
| target_image = target_image.resize( | |
| (self.target_size, self.target_size) | |
| ).convert("RGB") | |
| # Get the description | |
| description = item["description"][ | |
| "description_0" if target == 0 else "description_1" | |
| ] | |
| # Randomly drop text or image | |
| drop_text = random.random() < self.drop_text_prob | |
| drop_image = random.random() < self.drop_image_prob | |
| if drop_text: | |
| description = "" | |
| if drop_image: | |
| condition_img = Image.new( | |
| "RGB", (self.condition_size, self.condition_size), (0, 0, 0) | |
| ) | |
| return { | |
| "image": self.to_tensor(target_image), | |
| "condition": self.to_tensor(condition_img), | |
| "condition_type": self.condition_type, | |
| "description": description, | |
| # 16 is the downscale factor of the image | |
| "position_delta": np.array([0, -self.condition_size // 16]), | |
| **({"pil_image": image} if self.return_pil_image else {}), | |
| } | |
| class ImageConditionDataset(Dataset): | |
| def __init__( | |
| self, | |
| base_dataset, | |
| condition_size: int = 512, | |
| target_size: int = 512, | |
| condition_type: str = "canny", | |
| drop_text_prob: float = 0.1, | |
| drop_image_prob: float = 0.1, | |
| return_pil_image: bool = False, | |
| position_scale=1.0, | |
| ): | |
| self.base_dataset = base_dataset | |
| self.condition_size = condition_size | |
| self.target_size = target_size | |
| self.condition_type = condition_type | |
| self.drop_text_prob = drop_text_prob | |
| self.drop_image_prob = drop_image_prob | |
| self.return_pil_image = return_pil_image | |
| self.position_scale = position_scale | |
| self.to_tensor = T.ToTensor() | |
| def __len__(self): | |
| return len(self.base_dataset) | |
| def depth_pipe(self): | |
| if not hasattr(self, "_depth_pipe"): | |
| from transformers import pipeline | |
| self._depth_pipe = pipeline( | |
| task="depth-estimation", | |
| model="LiheYoung/depth-anything-small-hf", | |
| device="cpu", | |
| ) | |
| return self._depth_pipe | |
| def _get_canny_edge(self, img): | |
| resize_ratio = self.condition_size / max(img.size) | |
| img = img.resize( | |
| (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio)) | |
| ) | |
| img_np = np.array(img) | |
| img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) | |
| edges = cv2.Canny(img_gray, 100, 200) | |
| return Image.fromarray(edges).convert("RGB") | |
| def __getitem__(self, idx): | |
| image = self.base_dataset[idx]["jpg"] | |
| image = image.resize((self.target_size, self.target_size)).convert("RGB") | |
| description = self.base_dataset[idx]["json"]["prompt"] | |
| enable_scale = random.random() < 1 | |
| if not enable_scale: | |
| condition_size = int(self.condition_size * self.position_scale) | |
| position_scale = 1.0 | |
| else: | |
| condition_size = self.condition_size | |
| position_scale = self.position_scale | |
| # Get the condition image | |
| position_delta = np.array([0, 0]) | |
| if self.condition_type == "canny": | |
| condition_img = self._get_canny_edge(image) | |
| elif self.condition_type == "coloring": | |
| condition_img = ( | |
| image.resize((condition_size, condition_size)) | |
| .convert("L") | |
| .convert("RGB") | |
| ) | |
| elif self.condition_type == "deblurring": | |
| blur_radius = random.randint(1, 10) | |
| condition_img = ( | |
| image.convert("RGB") | |
| .filter(ImageFilter.GaussianBlur(blur_radius)) | |
| .resize((condition_size, condition_size)) | |
| .convert("RGB") | |
| ) | |
| elif self.condition_type == "depth": | |
| condition_img = self.depth_pipe(image)["depth"].convert("RGB") | |
| condition_img = condition_img.resize((condition_size, condition_size)) | |
| elif self.condition_type == "depth_pred": | |
| condition_img = image | |
| image = self.depth_pipe(condition_img)["depth"].convert("RGB") | |
| description = f"[depth] {description}" | |
| elif self.condition_type == "fill": | |
| condition_img = image.resize((condition_size, condition_size)).convert( | |
| "RGB" | |
| ) | |
| w, h = image.size | |
| x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) | |
| y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) | |
| mask = Image.new("L", image.size, 0) | |
| draw = ImageDraw.Draw(mask) | |
| draw.rectangle([x1, y1, x2, y2], fill=255) | |
| if random.random() > 0.5: | |
| mask = Image.eval(mask, lambda a: 255 - a) | |
| condition_img = Image.composite( | |
| image, Image.new("RGB", image.size, (0, 0, 0)), mask | |
| ) | |
| elif self.condition_type == "sr": | |
| condition_img = image.resize((condition_size, condition_size)).convert( | |
| "RGB" | |
| ) | |
| position_delta = np.array([0, -condition_size // 16]) | |
| else: | |
| raise ValueError(f"Condition type {self.condition_type} not implemented") | |
| # Randomly drop text or image | |
| drop_text = random.random() < self.drop_text_prob | |
| drop_image = random.random() < self.drop_image_prob | |
| if drop_text: | |
| description = "" | |
| if drop_image: | |
| condition_img = Image.new( | |
| "RGB", (condition_size, condition_size), (0, 0, 0) | |
| ) | |
| return { | |
| "image": self.to_tensor(image), | |
| "condition": self.to_tensor(condition_img), | |
| "condition_type": self.condition_type, | |
| "description": description, | |
| "position_delta": position_delta, | |
| **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), | |
| **({"position_scale": position_scale} if position_scale != 1.0 else {}), | |
| } | |
| class CartoonDataset(Dataset): | |
| def __init__( | |
| self, | |
| base_dataset, | |
| condition_size: int = 1024, | |
| target_size: int = 1024, | |
| image_size: int = 1024, | |
| padding: int = 0, | |
| condition_type: str = "cartoon", | |
| drop_text_prob: float = 0.1, | |
| drop_image_prob: float = 0.1, | |
| return_pil_image: bool = False, | |
| ): | |
| self.base_dataset = base_dataset | |
| self.condition_size = condition_size | |
| self.target_size = target_size | |
| self.image_size = image_size | |
| self.padding = padding | |
| self.condition_type = condition_type | |
| self.drop_text_prob = drop_text_prob | |
| self.drop_image_prob = drop_image_prob | |
| self.return_pil_image = return_pil_image | |
| self.to_tensor = T.ToTensor() | |
| def __len__(self): | |
| return len(self.base_dataset) | |
| def __getitem__(self, idx): | |
| data = self.base_dataset[idx] | |
| condition_img = data["condition"] | |
| target_image = data["target"] | |
| # Tag | |
| tag = data["tags"][0] | |
| target_description = data["target_description"] | |
| description = { | |
| "lion": "lion like animal", | |
| "bear": "bear like animal", | |
| "gorilla": "gorilla like animal", | |
| "dog": "dog like animal", | |
| "elephant": "elephant like animal", | |
| "eagle": "eagle like bird", | |
| "tiger": "tiger like animal", | |
| "owl": "owl like bird", | |
| "woman": "woman", | |
| "parrot": "parrot like bird", | |
| "mouse": "mouse like animal", | |
| "man": "man", | |
| "pigeon": "pigeon like bird", | |
| "girl": "girl", | |
| "panda": "panda like animal", | |
| "crocodile": "crocodile like animal", | |
| "rabbit": "rabbit like animal", | |
| "boy": "boy", | |
| "monkey": "monkey like animal", | |
| "cat": "cat like animal", | |
| } | |
| # Resize the image | |
| condition_img = condition_img.resize( | |
| (self.condition_size, self.condition_size) | |
| ).convert("RGB") | |
| target_image = target_image.resize( | |
| (self.target_size, self.target_size) | |
| ).convert("RGB") | |
| # Process datum to create description | |
| description = data.get( | |
| "description", | |
| f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.", | |
| ) | |
| # Randomly drop text or image | |
| drop_text = random.random() < self.drop_text_prob | |
| drop_image = random.random() < self.drop_image_prob | |
| if drop_text: | |
| description = "" | |
| if drop_image: | |
| condition_img = Image.new( | |
| "RGB", (self.condition_size, self.condition_size), (0, 0, 0) | |
| ) | |
| return { | |
| "image": self.to_tensor(target_image), | |
| "condition": self.to_tensor(condition_img), | |
| "condition_type": self.condition_type, | |
| "description": description, | |
| # 16 is the downscale factor of the image | |
| "position_delta": np.array([0, -16]), | |
| } | |