Spaces:
Running
on
Zero
Running
on
Zero
| import os.path | |
| import random | |
| import re | |
| import unicodedata | |
| import torch | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from typing import List, Union | |
| def clean_filename(s): | |
| # 去除首尾空格和点号 | |
| s = s.strip().strip('.') | |
| # 转换 Unicode 字符为 ASCII 形式 | |
| s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII') | |
| illegal_chars = r'[/]' | |
| reserved_names = set() | |
| # 替换非法字符为下划线 | |
| s = re.sub(illegal_chars, '_', s) | |
| # 合并连续的下划线 | |
| s = re.sub(r'_{2,}', '_', s) | |
| # 转换为小写 | |
| s = s.lower() | |
| # 检查是否为保留文件名 | |
| if s.upper() in reserved_names: | |
| s = s + '_' | |
| # 限制文件名长度 | |
| max_length = 200 | |
| s = s[:max_length] | |
| if not s: | |
| return 'untitled' | |
| return s | |
| def save_fn(image, metadata, root_path): | |
| image_path = os.path.join(root_path, str(metadata['filename'])+".png") | |
| Image.fromarray(image).save(image_path) | |
| class RandomNDataset(Dataset): | |
| def __init__(self, latent_shape=(4, 64, 64), conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1): | |
| if isinstance(conditions, int): | |
| conditions = list(range(conditions)) # class labels | |
| elif isinstance(conditions, str): | |
| if os.path.exists(conditions): | |
| conditions = open(conditions, "r").read().splitlines() | |
| else: | |
| raise FileNotFoundError(conditions) | |
| elif isinstance(conditions, list): | |
| conditions = conditions | |
| self.conditions = conditions | |
| self.num_conditons = len(conditions) | |
| self.seeds = seeds | |
| if num_samples_per_instance > 0: | |
| max_num_instances = num_samples_per_instance*self.num_conditons | |
| else: | |
| max_num_instances = max_num_instances | |
| if seeds is not None: | |
| self.max_num_instances = len(seeds)*self.num_conditons | |
| self.num_seeds = len(seeds) | |
| else: | |
| self.num_seeds = (max_num_instances + self.num_conditons - 1) // self.num_conditons | |
| self.max_num_instances = self.num_seeds*self.num_conditons | |
| self.latent_shape = latent_shape | |
| def __getitem__(self, idx): | |
| condition = self.conditions[idx//self.num_seeds] | |
| seed = random.randint(0, 1<<31) #idx % self.num_seeds | |
| if self.seeds is not None: | |
| seed = self.seeds[idx % self.num_seeds] | |
| filename = f"{clean_filename(str(condition))}_{seed}" | |
| generator = torch.Generator().manual_seed(seed) | |
| latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) | |
| metadata = dict( | |
| filename=filename, | |
| seed=seed, | |
| condition=condition, | |
| save_fn=save_fn, | |
| ) | |
| return latent, condition, metadata | |
| def __len__(self): | |
| return self.max_num_instances | |
| class ClassLabelRandomNDataset(RandomNDataset): | |
| def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1): | |
| if conditions is None: | |
| conditions = list(range(num_classes)) | |
| super().__init__(latent_shape, conditions, seeds, max_num_instances, num_samples_per_instance) | |