Spaces:
Configuration error
Configuration error
| import os | |
| import json | |
| import random | |
| from tqdm import tqdm | |
| import numpy as np | |
| from PIL import Image, ImageStat | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, IterableDataset, get_worker_info | |
| from torchvision import transforms as T | |
| ### >>>>>>>> >>>>>>>> text related >>>>>>>> >>>>>>>> ### | |
| class TokenizerWrapper(): | |
| def __init__(self, tokenizer, is_train, proportion_empty_prompts, use_generic_prompts=False): | |
| self.tokenizer = tokenizer | |
| self.is_train = is_train | |
| self.proportion_empty_prompts = proportion_empty_prompts | |
| self.use_generic_prompts = use_generic_prompts | |
| def __call__(self, prompts): | |
| if isinstance(prompts, str): | |
| prompts = [prompts] | |
| captions = [] | |
| for caption in prompts: | |
| if random.random() < self.proportion_empty_prompts: | |
| captions.append("") | |
| else: | |
| if self.use_generic_prompts: | |
| captions.append("best quality, high quality") | |
| elif isinstance(caption, str): | |
| captions.append(caption) | |
| elif isinstance(caption, (list, np.ndarray)): | |
| # take a random caption if there are multiple | |
| captions.append(random.choice(caption) if self.is_train else caption[0]) | |
| else: | |
| raise ValueError( | |
| f"Caption column should contain either strings or lists of strings." | |
| ) | |
| inputs = self.tokenizer( | |
| captions, max_length=self.tokenizer.model_max_length, padding="max_length", | |
| truncation=True, return_tensors="pt" | |
| ) | |
| return inputs.input_ids | |
| ### >>>>>>>> >>>>>>>> image related >>>>>>>> >>>>>>>> ### | |
| MONOCHROMATIC_MAX_VARIANCE = 0.3 | |
| def is_monochromatic_image(pil_img): | |
| v = ImageStat.Stat(pil_img.convert('RGB')).var | |
| return sum(v)<MONOCHROMATIC_MAX_VARIANCE | |
| def isnumeric(text): | |
| return (''.join(filter(str.isalnum, text))).isnumeric() | |
| class TextPromptDataset(IterableDataset): | |
| ''' | |
| The dataset for (text embedding, noise, generated latent) triplets. | |
| ''' | |
| def __init__(self, | |
| data_root, | |
| tokenizer = None, | |
| transform = None, | |
| rank = 0, | |
| world_size = 1, | |
| shuffle = True, | |
| ): | |
| self.tokenizer = tokenizer | |
| self.transform = transform | |
| self.img_root = os.path.join(data_root, 'JPEGImages') | |
| self.data_list = [] | |
| print("#### Loading filename list...") | |
| json_root = os.path.join(data_root, 'list') | |
| json_list = [p for p in os.listdir(json_root) if p.startswith("shard") and p.endswith('.json')] | |
| # duplicate several shards to make sure each process has the same number of shards | |
| assert len(json_list) > world_size | |
| duplicate = world_size - len(json_list)%world_size if len(json_list)%world_size>0 else 0 | |
| json_list = json_list + json_list[:duplicate] | |
| json_list = json_list[rank::world_size] | |
| for json_file in tqdm(json_list): | |
| shard_name = os.path.basename(json_file).split('.')[0] | |
| with open(os.path.join(json_root, json_file)) as f: | |
| key_text_pairs = json.load(f) | |
| for pair in key_text_pairs: | |
| self.data_list.append( [shard_name] + pair ) | |
| print("#### All filename loaded...") | |
| self.shuffle = shuffle | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __iter__(self): | |
| worker_info = get_worker_info() | |
| if worker_info is None: # single-process data loading, return the full iterator | |
| data_list = self.data_list | |
| else: | |
| len_data = len(self.data_list) - len(self.data_list) % worker_info.num_workers | |
| data_list = self.data_list[:len_data][worker_info.id :: worker_info.num_workers] | |
| # print(worker_info.num_workers, worker_info.id, len(data_list)/len(self.data_list)) | |
| if self.shuffle: | |
| random.shuffle(data_list) | |
| while True: | |
| for idx in range(len(data_list)): | |
| # try: | |
| shard_name = data_list[idx][0] | |
| data = {} | |
| img_file = data_list[idx][1] | |
| img = Image.open(os.path.join(self.img_root, shard_name, img_file+'.jpg')).convert("RGB") | |
| if is_monochromatic_image(img): | |
| continue | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| data['pixel_values'] = img | |
| text = data_list[idx][2] | |
| if self.tokenizer is not None: | |
| if isinstance(self.tokenizer, list): | |
| assert len(self.tokenizer)==2 | |
| data['input_ids'] = self.tokenizer[0](text)[0] | |
| data['input_ids_2'] = self.tokenizer[1](text)[0] | |
| else: | |
| data['input_ids'] = self.tokenizer(text)[0] | |
| else: | |
| data['input_ids'] = text | |
| yield data | |
| # except Exception as e: | |
| # raise(e) | |
| def collate_fn(self, examples): | |
| pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | |
| if self.tokenizer is not None: | |
| if isinstance(self.tokenizer, list): | |
| assert len(self.tokenizer)==2 | |
| input_ids = torch.stack([example["input_ids"] for example in examples]) | |
| input_ids_2 = torch.stack([example["input_ids_2"] for example in examples]) | |
| return {"pixel_values": pixel_values, "input_ids": input_ids, "input_ids_2": input_ids_2,} | |
| else: | |
| input_ids = torch.stack([example["input_ids"] for example in examples]) | |
| return {"pixel_values": pixel_values, "input_ids": input_ids,} | |
| else: | |
| input_ids = [example["input_ids"] for example in examples] | |
| return {"pixel_values": pixel_values, "input_ids": input_ids,} | |
| def make_train_dataset( | |
| train_data_path, | |
| size = 512, | |
| tokenizer=None, | |
| cfg_drop_ratio=0, | |
| rank=0, | |
| world_size=1, | |
| shuffle=True, | |
| ): | |
| _image_transform = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
| T.Resize(size), | |
| T.CenterCrop((size,size)), | |
| T.ToTensor(), | |
| T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| if tokenizer is not None: | |
| if isinstance(tokenizer, list): | |
| assert len(tokenizer)==2 | |
| tokenizer_1 = TokenizerWrapper( | |
| tokenizer[0], | |
| is_train=True, | |
| proportion_empty_prompts=cfg_drop_ratio, | |
| use_generic_prompts=False, | |
| ) | |
| tokenizer_2 = TokenizerWrapper( | |
| tokenizer[1], | |
| is_train=True, | |
| proportion_empty_prompts=cfg_drop_ratio, | |
| use_generic_prompts=False, | |
| ) | |
| tokenizer = [tokenizer_1, tokenizer_2] | |
| else: | |
| tokenizer = TokenizerWrapper( | |
| tokenizer, | |
| is_train=True, | |
| proportion_empty_prompts=cfg_drop_ratio, | |
| use_generic_prompts=False, | |
| ) | |
| train_dataset = TextPromptDataset( | |
| data_root=train_data_path, | |
| transform=_image_transform, | |
| rank=rank, | |
| world_size=world_size, | |
| tokenizer=tokenizer, | |
| shuffle=shuffle, | |
| ) | |
| return train_dataset | |
| ### >>>>>>>> >>>>>>>> Test >>>>>>>> >>>>>>>> ### | |
| if __name__ == "__main__": | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| "/mnt/bn/ic-research-aigc-editing/fast-diffusion-models/assets/public_models/StableDiffusion/stable-diffusion-v1-5", | |
| subfolder="tokenizer" | |
| ) | |
| train_dataset = make_train_dataset(tokenizer=tokenizer, rank=0, world_size=10) | |
| loader = torch.utils.data.DataLoader( | |
| train_dataset, batch_size=64, num_workers=0, | |
| collate_fn=train_dataset.collect_fn if hasattr(train_dataset, 'collect_fn') else None, | |
| ) | |
| for batch in loader: | |
| pixel_values = batch["pixel_values"] | |
| prompt_ids = batch['input_ids'] | |
| from einops import rearrange | |
| pixel_values = rearrange(pixel_values, 'b c h w -> b h w c') | |
| for i in range(pixel_values.shape[0]): | |
| import pdb; pdb.set_trace() | |
| Image.fromarray(((pixel_values[i] + 1 )/2 * 255 ).numpy().astype(np.uint8)).save('tmp.png') | |
| input_id = prompt_ids[i] | |
| text = tokenizer.decode(input_id).split('<|startoftext|>')[-1].split('<|endoftext|>')[0] | |
| print(text) | |
| pass |