Spaces:
Build error
Build error
| import os | |
| import pickle | |
| import random | |
| from pathlib import Path | |
| from typing import Dict | |
| from typing import List | |
| import torch | |
| from loguru import logger | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from configs.mode import FaceSwapMode | |
| from configs.train_config import TrainConfig | |
| class ManyToManyTrainDataset(Dataset): | |
| def __init__(self, dataset_root: str, dataset_index: str, same_rate=0.5): | |
| """ | |
| Many-to-many 训练数据集构建 | |
| Parameters: | |
| ----------- | |
| dataset_root: str, 数据集根目录 | |
| dataset_index: str, 数据集index文件路径 | |
| same_rate: float, 每个batch里面相同人脸所占的比例 | |
| """ | |
| super(ManyToManyTrainDataset, self).__init__() | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop((256, 256)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| self.data_root = Path(dataset_root) | |
| with open(dataset_index, "rb") as f: | |
| self.file_index = pickle.load(f, encoding="bytes") | |
| self.same_rate = same_rate | |
| self.id_list: List[str] = list(self.file_index.keys()) | |
| # 所有id都遍历一遍,视为一个epoch | |
| self.length = len(self.id_list) | |
| self.image_num = sum([len(v) for v in self.file_index.values()]) | |
| self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth" | |
| logger.info(f"dataset contains {self.length} ids and {self.image_num} images") | |
| logger.info(f"will use mask mode: {self.mask_dir}") | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index): | |
| source_id_index = index | |
| source_file = random.choice(self.file_index[self.id_list[source_id_index]]) | |
| if random.random() < self.same_rate: | |
| # 在相同id的文件列表中选择 | |
| target_file = random.choice(self.file_index[self.id_list[source_id_index]]) | |
| same = torch.ones(1) | |
| else: | |
| # 在不同id的文件列表中选择 | |
| target_id_index = random.choice(list(set(range(self.length)) - set([source_id_index]))) | |
| target_file = random.choice(self.file_index[self.id_list[target_id_index]]) | |
| same = torch.zeros(1) | |
| source_file = self.data_root / Path(source_file) | |
| target_file = self.data_root / Path(target_file) | |
| target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name | |
| target_img = Image.open(target_file.as_posix()).convert("RGB") | |
| source_img = Image.open(source_file.as_posix()).convert("RGB") | |
| target_mask = Image.open(target_mask_file.as_posix()).convert("RGB") | |
| source_img = self.transform(source_img) | |
| target_img = self.transform(target_img) | |
| target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0) | |
| return { | |
| "source_image": source_img, | |
| "target_image": target_img, | |
| "target_mask": target_mask, | |
| "same": same, | |
| # "source_img_name": source_file.as_posix(), | |
| # "target_img_name": target_file.as_posix(), | |
| # "target_mask_name": target_mask_file.as_posix(), | |
| } | |
| class OneToManyTrainDataset(Dataset): | |
| def __init__(self, dataset_root: str, dataset_index: str, source_name: str, same_rate=0.5): | |
| """ | |
| One-to-many 训练数据集构建 | |
| Parameters: | |
| ----------- | |
| dataset_root: str, 数据集根目录 | |
| dataset_index: str, 数据集index文件路径 | |
| source_name: str, source face id的名称, one-to-many里面的one | |
| same_rate: float, 每个batch里面相同人脸所占的比例 | |
| """ | |
| super(OneToManyTrainDataset, self).__init__() | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop((256, 256)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| self.data_root = Path(dataset_root) | |
| with open(dataset_index, "rb") as f: | |
| self.file_index = pickle.load(f, encoding="bytes") | |
| self.same_rate = same_rate | |
| self.source_name = source_name | |
| self.id_list: List[str] = list(self.file_index.keys()) | |
| try: | |
| self.source_id_index: int = self.id_list.index(self.source_name) | |
| except Exception: | |
| raise Exception(f"{self.source_name} not in dataset dir") | |
| # 所有id都遍历一遍,视为一个epoch | |
| self.length = len(self.id_list) | |
| self.image_num = sum([len(v) for v in self.file_index.values()]) | |
| self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth" | |
| logger.info(f"dataset contains {self.length} ids and {self.image_num} images") | |
| logger.info(f"will use mask mode: {self.mask_dir}") | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index): | |
| target_id_index = index | |
| target_file = random.choice(self.file_index[self.id_list[target_id_index]]) | |
| if random.random() < self.same_rate: | |
| # 在相同id的文件列表中选择 | |
| source_file = random.choice(self.file_index[self.id_list[target_id_index]]) | |
| same = torch.ones(1) | |
| else: | |
| # 直接选择source name中的图片 | |
| source_file = random.choice(self.file_index[self.source_name]) | |
| # 如果和target同个id | |
| if self.source_id_index == target_id_index: | |
| same = torch.ones(1) | |
| else: | |
| same = torch.zeros(1) | |
| source_file = self.data_root / Path(source_file) | |
| target_file = self.data_root / Path(target_file) | |
| target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name | |
| target_img = Image.open(target_file.as_posix()).convert("RGB") | |
| source_img = Image.open(source_file.as_posix()).convert("RGB") | |
| target_mask = Image.open(target_mask_file.as_posix()).convert("RGB") | |
| source_img = self.transform(source_img) | |
| target_img = self.transform(target_img) | |
| target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0) | |
| return { | |
| "source_image": source_img, | |
| "target_image": target_img, | |
| "target_mask": target_mask, | |
| "same": same, | |
| # "source_img_name": source_file.as_posix(), | |
| # "target_img_name": target_file.as_posix(), | |
| # "target_mask_name": target_mask_file.as_posix(), | |
| } | |
| class TrainDatasetDataLoader: | |
| """Wrapper class of Dataset class that performs multi-threaded data loading""" | |
| def __init__(self): | |
| """Initialize this class""" | |
| opt = TrainConfig() | |
| if opt.mode is FaceSwapMode.MANY_TO_MANY: | |
| self.dataset = ManyToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.same_rate) | |
| elif opt.mode is FaceSwapMode.ONE_TO_MANY: | |
| logger.info(f"In one-to-many mode, source face is {opt.source_name}") | |
| self.dataset = OneToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.source_name, opt.same_rate) | |
| else: | |
| raise NotImplementedError | |
| logger.info(f"dataset {type(self.dataset).__name__} created") | |
| if opt.use_ddp: | |
| self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset, shuffle=True) | |
| self.dataloader = torch.utils.data.DataLoader( | |
| self.dataset, | |
| batch_size=opt.batch_size, | |
| num_workers=int(opt.num_threads), | |
| drop_last=True, | |
| sampler=self.train_sampler, | |
| pin_memory=True, | |
| ) | |
| else: | |
| self.dataloader = torch.utils.data.DataLoader( | |
| self.dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=True, | |
| num_workers=int(opt.num_threads), | |
| drop_last=True, | |
| pin_memory=True, | |
| ) | |
| def load_data(self): | |
| return self | |
| def __len__(self): | |
| """Return the number of data in the dataset""" | |
| return len(self.dataset) | |
| def __iter__(self): | |
| """Return a batch of data""" | |
| for data in self.dataloader: | |
| yield data | |
| if __name__ == "__main__": | |
| dataloader = TrainDatasetDataLoader() | |
| for idx, data in enumerate(dataloader): | |
| # print(data["source_img_name"]) | |
| # print(data["target_img_name"]) | |
| # print(data["target_mask_name"]) | |
| print(data["same"]) | |