Spaces:
Runtime error
Runtime error
| import torch.nn.init as init | |
| import argparse | |
| from cgitb import text | |
| import copy | |
| import gc | |
| import itertools | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import shutil | |
| from tkinter import NO | |
| import warnings | |
| from contextlib import nullcontext | |
| from pathlib import Path | |
| import PIL.Image | |
| import PIL.ImageOps | |
| import numpy as np | |
| from sympy import N | |
| import torch | |
| import torch.utils.checkpoint | |
| import transformers | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed | |
| from huggingface_hub import create_repo, upload_folder | |
| from huggingface_hub.utils import insecure_hashlib | |
| from PIL import Image | |
| from PIL.ImageOps import exif_transpose | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import crop | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast | |
| # from transformer_sd3 import SD3Transformer2DModel | |
| import diffusers | |
| from diffusers import ( | |
| AutoencoderKL, | |
| FlowMatchEulerDiscreteScheduler, | |
| StableDiffusion3Pipeline, | |
| SD3Transformer2DModel, | |
| StableDiffusion3InstructPix2PixPipeline | |
| ) | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.utils import ( | |
| check_min_version, | |
| is_wandb_available, | |
| ) | |
| from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | |
| from diffusers.utils.torch_utils import is_compiled_module | |
| import accelerate | |
| import datasets | |
| import PIL | |
| import requests | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from os.path import join | |
| from datasets import load_dataset | |
| from packaging import version | |
| def load_text_encoders(class_one, class_two, class_three): | |
| text_encoder_one = class_one.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant | |
| ) | |
| text_encoder_two = class_two.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant | |
| ) | |
| text_encoder_three = class_three.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant | |
| ) | |
| return text_encoder_one, text_encoder_two, text_encoder_three | |
| def tokenize_prompt(tokenizer, prompt, max_sequence_length=77): | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| return text_input_ids | |
| def _encode_prompt_with_t5( | |
| text_encoder, | |
| tokenizer, | |
| max_sequence_length, | |
| text_encoder_dtype, | |
| prompt=None, | |
| num_images_per_prompt=1, | |
| device=None, | |
| text_input_ids=None | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if text_input_ids is None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device))[0] | |
| prompt_embeds = prompt_embeds.to(dtype=text_encoder_dtype, device=device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds | |
| def _encode_prompt_with_clip( | |
| text_encoder, | |
| tokenizer, | |
| prompt: str, | |
| text_encoder_dtype, | |
| device=None, | |
| num_images_per_prompt: int = 1, | |
| text_input_ids=None | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if text_input_ids is None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds.hidden_states[-2] | |
| prompt_embeds = prompt_embeds.to(dtype=text_encoder_dtype, device=device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds, pooled_prompt_embeds | |
| def encode_prompt( | |
| text_encoders, | |
| tokenizers, | |
| prompt: str, | |
| max_sequence_length=None, | |
| text_encoders_dtypes=[torch.float32,torch.float32,torch.float32], | |
| device=None, | |
| num_images_per_prompt: int = 1, | |
| text_input_ids_list=None | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| clip_prompt_embeds_list = [] | |
| clip_pooled_prompt_embeds_list = [] | |
| clip_tokenizers = tokenizers[:2] | |
| clip_text_encoders = text_encoders[:2] | |
| clip_text_encoders_dtypes = text_encoders_dtypes[:2] | |
| if text_input_ids_list is not None: | |
| clip_text_input_ids_list = text_input_ids_list[:2] | |
| else: | |
| clip_text_input_ids_list = [None, None] | |
| zipped_text_encoders = zip(clip_tokenizers, clip_text_encoders, clip_text_encoders_dtypes, clip_text_input_ids_list) | |
| for tokenizer, text_encoder, clip_text_encoder_dtype, text_input_ids in zipped_text_encoders: | |
| prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| text_encoder_dtype=clip_text_encoder_dtype, | |
| device=device if device is not None else text_encoder.device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| text_input_ids=text_input_ids, | |
| ) | |
| clip_prompt_embeds_list.append(prompt_embeds) | |
| clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) | |
| clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) | |
| pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) | |
| if text_input_ids_list is not None: | |
| t5_text_input_ids = text_input_ids_list[-1] | |
| else: | |
| t5_text_input_ids = None | |
| t5_prompt_embed = _encode_prompt_with_t5( | |
| text_encoders[-1], | |
| tokenizers[-1], | |
| max_sequence_length, | |
| clip_text_encoders_dtypes[-1], | |
| prompt=prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device if device is not None else text_encoders[-1].device, | |
| text_input_ids=t5_text_input_ids | |
| ) | |
| clip_prompt_embeds = torch.nn.functional.pad( | |
| clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) | |
| ) | |
| prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) | |
| return prompt_embeds, pooled_prompt_embeds | |
| logger = get_logger(__name__, log_level="INFO") | |
| DATASET_NAME_MAPPING = { | |
| "BleachNick/UltraEdit_500k": ("source_image", "edited_image", "edit_prompt"), | |
| } | |
| WANDB_TABLE_COL_NAMES = ["source_image", "edited_image", "edit_prompt"] | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--revision", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help="Revision of pretrained model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--ori_model_name_or_path", | |
| type=str, | |
| default=None, | |
| help="Path to ori_model_name_or_path.", | |
| ) | |
| parser.add_argument( | |
| "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"] | |
| ) | |
| parser.add_argument( | |
| "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." | |
| ) | |
| parser.add_argument( | |
| "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." | |
| ) | |
| parser.add_argument( | |
| "--mode_scale", | |
| type=float, | |
| default=1.29, | |
| help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", | |
| ) | |
| parser.add_argument( | |
| "--optimizer", | |
| type=str, | |
| default="AdamW", | |
| help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), | |
| ) | |
| parser.add_argument( | |
| "--use_8bit_adam", | |
| action="store_true", | |
| help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", | |
| ) | |
| parser.add_argument( | |
| "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" | |
| ) | |
| parser.add_argument( | |
| "--prodigy_beta3", | |
| type=float, | |
| default=None, | |
| help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " | |
| "uses the value of square root of beta2. Ignored if optimizer is adamW", | |
| ) | |
| parser.add_argument( | |
| "--prodigy_use_bias_correction", | |
| type=bool, | |
| default=True, | |
| help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", | |
| ) | |
| parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") | |
| parser.add_argument( | |
| "--prodigy_safeguard_warmup", | |
| type=bool, | |
| default=True, | |
| help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " | |
| "Ignored if optimizer is adamW", | |
| ) | |
| parser.add_argument( | |
| "--variant", | |
| type=str, | |
| default=None, | |
| help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", | |
| ) | |
| parser.add_argument( | |
| "--dataset_name", | |
| type=str, | |
| default=None, | |
| help=( | |
| "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," | |
| " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," | |
| " or to a folder containing files that π€ Datasets can understand." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--dataset_config_name", | |
| type=str, | |
| default=None, | |
| help="The config of the Dataset, leave as None if there's only one config.", | |
| ) | |
| parser.add_argument( | |
| "--train_data_jsonl", | |
| type=str, | |
| default=None, | |
| help=( | |
| "A folder containing the training data. Folder contents must follow the structure described in" | |
| " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" | |
| " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--original_image_column", | |
| type=str, | |
| default="source_image", | |
| help="The column of the dataset containing the original image on which edits where made.", | |
| ) | |
| parser.add_argument( | |
| "--config_file", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--edited_image_column", | |
| type=str, | |
| default="edited_image", | |
| help="The column of the dataset containing the edited image.", | |
| ) | |
| parser.add_argument( | |
| "--edit_prompt_column", | |
| type=str, | |
| default="edit_prompt", | |
| help="The column of the dataset containing the edit instruction.", | |
| ) | |
| parser.add_argument( | |
| "--val_image_url", | |
| type=str, | |
| default=None, | |
| help="URL to the original image that you would like to edit (used during inference for debugging purposes).", | |
| ) | |
| parser.add_argument( | |
| '--val_mask_url', | |
| type=str, | |
| default=None, | |
| help="URL to the mask image that you would like to edit (used during inference for debugging purposes).", | |
| ) | |
| parser.add_argument( | |
| "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." | |
| ) | |
| parser.add_argument( | |
| "--num_validation_images", | |
| type=int, | |
| default=4, | |
| help="Number of images that should be generated during validation with `validation_prompt`.", | |
| ) | |
| parser.add_argument( | |
| "--validation_epochs", | |
| type=int, | |
| default=1, | |
| help=( | |
| "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" | |
| " `args.validation_prompt` multiple times: `args.num_validation_images`." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--validation_step", | |
| type=int, | |
| default=5000, | |
| help=( | |
| "Run fine-tuning validation every X steps. The validation process consists of running the prompt" | |
| " `args.validation_prompt` multiple times: `args.num_validation_images`." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--top_training_data_sample", | |
| type=int, | |
| default=None, | |
| help="Number of top samples to use for training, ranked by clip-sim-dit. If None, use the full dataset.", | |
| ) | |
| parser.add_argument( | |
| "--max_train_samples", | |
| type=int, | |
| default=None, | |
| help=( | |
| "For debugging purposes or quicker training, truncate the number of training examples to this " | |
| "value if set." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="sd3_edit", | |
| help="The output directory where the model predictions and checkpoints will be written.", | |
| ) | |
| parser.add_argument( | |
| "--cache_dir", | |
| type=str, | |
| default=None, | |
| help="The directory where the downloaded models and datasets will be stored.", | |
| ) | |
| parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") | |
| parser.add_argument( | |
| "--resolution", | |
| type=int, | |
| default=256, | |
| help=( | |
| "The resolution for input images, all the images in the train/validation dataset will be resized to this" | |
| " resolution" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--eval_resolution", | |
| type=int, | |
| default=512, | |
| help=( | |
| "The resolution for input images, all the images in the train/validation dataset will be resized to this" | |
| " resolution" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--max_sequence_length", | |
| type=int, | |
| default=77, | |
| help="Maximum sequence length to use with with the T5 text encoder", | |
| ) | |
| parser.add_argument( | |
| "--center_crop", | |
| default=False, | |
| action="store_true", | |
| help=( | |
| "Whether to center crop the input images to the resolution. If not set, the images will be randomly" | |
| " cropped. The images will be resized to the resolution first before cropping." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--random_flip", | |
| action="store_true", | |
| help="whether to randomly flip images horizontally", | |
| ) | |
| parser.add_argument( | |
| "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." | |
| ) | |
| parser.add_argument("--num_train_epochs", type=int, default=100) | |
| parser.add_argument( | |
| "--max_train_steps", | |
| type=int, | |
| default=None, | |
| help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | |
| ) | |
| parser.add_argument( | |
| "--gradient_accumulation_steps", | |
| type=int, | |
| default=1, | |
| help="Number of updates steps to accumulate before performing a backward/update pass.", | |
| ) | |
| parser.add_argument( | |
| "--gradient_checkpointing", | |
| action="store_true", | |
| help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", | |
| type=float, | |
| default=1e-4, | |
| help="Initial learning rate (after the potential warmup period) to use.", | |
| ) | |
| parser.add_argument( | |
| "--scale_lr", | |
| action="store_true", | |
| default=False, | |
| help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | |
| ) | |
| parser.add_argument( | |
| "--lr_scheduler", | |
| type=str, | |
| default="cosine", | |
| help=( | |
| 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | |
| ' "constant", "constant_with_warmup"]' | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | |
| ) | |
| parser.add_argument( | |
| "--conditioning_dropout_prob", | |
| type=float, | |
| default=None, | |
| help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", | |
| ) | |
| parser.add_argument( | |
| "--allow_tf32", | |
| action="store_true", | |
| help=( | |
| "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" | |
| " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--text_encoder_lr", | |
| type=float, | |
| default=5e-6, | |
| help="Text encoder learning rate to use.", | |
| ) | |
| parser.add_argument( | |
| "--dataloader_num_workers", | |
| type=int, | |
| default=0, | |
| help=( | |
| "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." | |
| ), | |
| ) | |
| parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") | |
| parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") | |
| parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") | |
| parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") | |
| parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") | |
| parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") | |
| parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") | |
| parser.add_argument( | |
| "--hub_model_id", | |
| type=str, | |
| default=None, | |
| help="The name of the repository to keep in sync with the local `output_dir`.", | |
| ) | |
| parser.add_argument( | |
| "--logging_dir", | |
| type=str, | |
| default="logs", | |
| help=( | |
| "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" | |
| " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str, | |
| default=None, | |
| choices=["no", "fp16", "bf16"], | |
| help=( | |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--report_to", | |
| type=str, | |
| default="tensorboard", | |
| help=( | |
| 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | |
| ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--lr_num_cycles", | |
| type=int, | |
| default=1, | |
| help="Number of hard resets of the lr in cosine_with_restarts scheduler.", | |
| ) | |
| parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") | |
| parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") | |
| parser.add_argument( | |
| "--checkpointing_steps", | |
| type=int, | |
| default=500, | |
| help=( | |
| "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" | |
| " training using `--resume_from_checkpoint`." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--checkpoints_total_limit", | |
| type=int, | |
| default=None, | |
| help=("Max number of checkpoints to store."), | |
| ) | |
| parser.add_argument( | |
| "--train_text_encoder", | |
| action="store_true" | |
| ) | |
| parser.add_argument( | |
| "--resume_from_checkpoint", | |
| type=str, | |
| default=None, | |
| help=( | |
| "Whether training should be resumed from a previous checkpoint. Use a path saved by" | |
| ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--do_mask", action="store_true", help="Whether or not to use xformers." | |
| ) | |
| parser.add_argument( | |
| "--mask_column", | |
| type=str, | |
| default="mask_image", | |
| help="The column of the dataset containing the original image`s mask.", | |
| ) | |
| args = parser.parse_args() | |
| env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
| if env_local_rank != -1 and env_local_rank != args.local_rank: | |
| args.local_rank = env_local_rank | |
| # Sanity checks | |
| if args.dataset_name is None and args.train_data_jsonl is None: | |
| raise ValueError("Need either a dataset name or a training folder.") | |
| # default to using the same revision for the non-ema model if not specified | |
| return args | |
| def combine_rgb_and_mask_to_rgba(rgb_image, mask_image): | |
| # Ensure the input images are the same size | |
| if rgb_image.size != mask_image.size: | |
| raise ValueError("The RGB image and the mask image must have the same dimensions") | |
| # Convert the mask image to 'L' mode (grayscale) if it is not | |
| if mask_image.mode != 'L': | |
| mask_image = mask_image.convert('L') | |
| # Split the RGB image into its three channels | |
| r, g, b = rgb_image.split() | |
| # Combine the RGB channels with the mask to form an RGBA image | |
| rgba_image = Image.merge("RGBA", (r, g, b, mask_image)) | |
| return rgba_image | |
| def convert_to_np(image, resolution): | |
| try: | |
| if isinstance(image, str): | |
| if image == "NONE": | |
| image = PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255)) | |
| else: | |
| image = PIL.Image.open(image) | |
| image = image.convert("RGB").resize((resolution, resolution)) | |
| return np.array(image).transpose(2, 0, 1) | |
| except Exception as e: | |
| print("Load error", image) | |
| print(e) | |
| # New blank image | |
| image = PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255)) | |
| return np.array(image).transpose(2, 0, 1) | |
| def import_model_class_from_model_name_or_path( | |
| pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" | |
| ): | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path, subfolder=subfolder, revision=revision | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModelWithProjection": | |
| from transformers import CLIPTextModelWithProjection | |
| return CLIPTextModelWithProjection | |
| elif model_class == "T5EncoderModel": | |
| from transformers import T5EncoderModel | |
| return T5EncoderModel | |
| else: | |
| raise ValueError(f"{model_class} is not supported.") | |
| def main(): | |
| if args.report_to == "wandb" and args.hub_token is not None: | |
| raise ValueError( | |
| "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." | |
| " Please use `huggingface-cli login` to authenticate with the Hub." | |
| ) | |
| logging_dir = os.path.join(args.output_dir, args.logging_dir) | |
| accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) | |
| from accelerate import DistributedDataParallelKwargs as DDPK | |
| kwargs = DDPK(find_unused_parameters=True) | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| mixed_precision=args.mixed_precision, | |
| log_with=args.report_to, | |
| project_config=accelerator_project_config, | |
| kwargs_handlers=[kwargs], | |
| ) | |
| if torch.backends.mps.is_available(): | |
| accelerator.native_amp = False | |
| def download_image(path_or_url,resolution=512): | |
| # Check if path_or_url is a local file path | |
| if path_or_url is None: | |
| # return a white RBG image image | |
| return PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255)) | |
| if os.path.exists(path_or_url): | |
| image = Image.open(path_or_url).convert("RGB").resize((resolution, resolution)) | |
| else: | |
| image = Image.open(requests.get(path_or_url, stream=True).raw).convert("RGB") | |
| image = PIL.ImageOps.exif_transpose(image) | |
| return image | |
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | |
| if args.report_to == "wandb": | |
| if not is_wandb_available(): | |
| raise ImportError("Make sure to install wandb if you want to use it for logging during training.") | |
| import wandb | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| if accelerator.is_local_main_process: | |
| datasets.utils.logging.set_verbosity_warning() | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| datasets.utils.logging.set_verbosity_error() | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| # If passed along, set the training seed now. | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| # Handle the repository creation | |
| if accelerator.is_main_process: | |
| if args.output_dir is not None: | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| if args.push_to_hub: | |
| repo_id = create_repo( | |
| repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token | |
| ).repo_id | |
| # Load scheduler, tokenizer and models. | |
| # Load the tokenizers | |
| tokenizer_one = CLIPTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="tokenizer", | |
| revision=args.revision, | |
| ) | |
| tokenizer_two = CLIPTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="tokenizer_2", | |
| revision=args.revision, | |
| ) | |
| tokenizer_three = T5TokenizerFast.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="tokenizer_3", | |
| revision=args.revision, | |
| ) | |
| # import correct text encoder classes | |
| text_encoder_cls_one = import_model_class_from_model_name_or_path( | |
| args.pretrained_model_name_or_path, args.revision | |
| ) | |
| text_encoder_cls_two = import_model_class_from_model_name_or_path( | |
| args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" | |
| ) | |
| text_encoder_cls_three = import_model_class_from_model_name_or_path( | |
| args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" | |
| ) | |
| # Load scheduler and models | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="scheduler" | |
| ) | |
| noise_scheduler_copy = copy.deepcopy(noise_scheduler) | |
| text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( | |
| text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three | |
| ) | |
| vae = AutoencoderKL.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="vae", | |
| revision=args.revision, | |
| variant=args.variant, | |
| ) | |
| transformer = SD3Transformer2DModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant | |
| ) | |
| # TODO | |
| logger.info("Initializing the new channel of DIT from the pretrained DIT.") | |
| in_channels = int(1.5 * transformer.config.in_channels) if args.do_mask else 2 * transformer.config.in_channels # 48 for mask | |
| out_channels = transformer.pos_embed.proj.out_channels | |
| load_num_channel = transformer.config.in_channels | |
| print("Do mask",args.do_mask) | |
| print("new in_channels",in_channels) | |
| print("load_num_channel",load_num_channel) | |
| transformer.register_to_config(in_channels=in_channels) | |
| print("transformer.pos_embed.proj.weight.shape", transformer.pos_embed.proj.weight.shape) | |
| print("load_num_channel", load_num_channel) | |
| with torch.no_grad(): | |
| new_proj = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=(transformer.config.patch_size, transformer.config.patch_size), | |
| stride=transformer.config.patch_size, bias=True | |
| ) | |
| print("new_proj", new_proj) | |
| new_proj.weight.zero_() | |
| # init.kaiming_normal_(new_proj.weight, mode='fan_out', nonlinearity='relu') | |
| # if new_proj.bias is not None and transformer.pos_embed.proj.bias is not None: | |
| # new_proj.bias.copy_(transformer.pos_embed.proj.bias) | |
| # else: | |
| # if new_proj.bias is not None: | |
| # new_proj.bias.zero_() | |
| new_proj = new_proj.to(transformer.pos_embed.proj.weight.dtype) | |
| new_proj.weight[:, :load_num_channel, :, :].copy_(transformer.pos_embed.proj.weight) | |
| new_proj.bias.copy_(transformer.pos_embed.proj.bias) | |
| print("new_proj", new_proj.weight.shape) | |
| print("transformer.pos_embed.proj", transformer.pos_embed.proj.weight.shape) | |
| transformer.pos_embed.proj = new_proj | |
| for param in transformer.parameters(): | |
| param.requires_grad = True | |
| transformer.requires_grad_(True) | |
| vae.requires_grad_(False) | |
| if args.train_text_encoder: | |
| text_encoder_one.requires_grad_(True) | |
| text_encoder_two.requires_grad_(True) | |
| text_encoder_three.requires_grad_(True) | |
| else: | |
| text_encoder_one.requires_grad_(False) | |
| text_encoder_two.requires_grad_(False) | |
| text_encoder_three.requires_grad_(False) | |
| # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision | |
| # as these weights are only used for inference, keeping weights in full precision is not required. | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: | |
| # due to pytorch#99272, MPS does not yet support bfloat16. | |
| raise ValueError( | |
| "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." | |
| ) | |
| vae.to(accelerator.device, dtype=torch.float32) | |
| if not args.train_text_encoder: | |
| text_encoder_one.to(accelerator.device, dtype=weight_dtype) | |
| text_encoder_two.to(accelerator.device, dtype=weight_dtype) | |
| text_encoder_three.to(accelerator.device, dtype=weight_dtype) | |
| if args.gradient_checkpointing: | |
| transformer.enable_gradient_checkpointing() | |
| if args.train_text_encoder: | |
| text_encoder_one.gradient_checkpointing_enable() | |
| text_encoder_two.gradient_checkpointing_enable() | |
| text_encoder_three.gradient_checkpointing_enable() | |
| def unwrap_model(model): | |
| model = accelerator.unwrap_model(model) | |
| model = model._orig_mod if is_compiled_module(model) else model | |
| return model | |
| # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
| def save_model_hook(models, weights, output_dir): | |
| if accelerator.is_main_process: | |
| for i, model in enumerate(models): | |
| if isinstance(unwrap_model(model), SD3Transformer2DModel): | |
| unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) | |
| elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): | |
| if isinstance(unwrap_model(model), CLIPTextModelWithProjection): | |
| hidden_size = unwrap_model(model).config.hidden_size | |
| if hidden_size == 768: | |
| unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) | |
| elif hidden_size == 1280: | |
| unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) | |
| else: | |
| unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3")) | |
| else: | |
| raise ValueError(f"Wrong model supplied: {type(model)=}.") | |
| # make sure to pop weight so that corresponding model is not saved again | |
| weights.pop() | |
| def load_model_hook(models, input_dir): | |
| for _ in range(len(models)): | |
| # pop models so that they are not loaded again | |
| model = models.pop() | |
| # load diffusers style into model | |
| if isinstance(unwrap_model(model), SD3Transformer2DModel): | |
| load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") | |
| model.register_to_config(**load_model.config) | |
| model.load_state_dict(load_model.state_dict()) | |
| elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): | |
| try: | |
| load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") | |
| model(**load_model.config) | |
| model.load_state_dict(load_model.state_dict()) | |
| except Exception: | |
| try: | |
| load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2") | |
| model(**load_model.config) | |
| model.load_state_dict(load_model.state_dict()) | |
| except Exception: | |
| try: | |
| load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3") | |
| model(**load_model.config) | |
| model.load_state_dict(load_model.state_dict()) | |
| except Exception: | |
| raise ValueError(f"Couldn't load the model of type: ({type(model)}).") | |
| else: | |
| raise ValueError(f"Unsupported model found: {type(model)=}") | |
| del load_model | |
| accelerator.register_save_state_pre_hook(save_model_hook) | |
| accelerator.register_load_state_pre_hook(load_model_hook) | |
| # Enable TF32 for faster training on Ampere GPUs, | |
| # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
| if args.allow_tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| if args.scale_lr: | |
| args.learning_rate = ( | |
| args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | |
| ) | |
| transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} | |
| if args.train_text_encoder: | |
| # different learning rate for text encoder and unet | |
| text_parameters_one_with_lr = { | |
| "params": text_encoder_one.parameters(), | |
| "weight_decay": args.adam_weight_decay_text_encoder, | |
| "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, | |
| } | |
| text_parameters_two_with_lr = { | |
| "params": text_encoder_two.parameters(), | |
| "weight_decay": args.adam_weight_decay_text_encoder, | |
| "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, | |
| } | |
| text_parameters_three_with_lr = { | |
| "params": text_encoder_three.parameters(), | |
| "weight_decay": args.adam_weight_decay_text_encoder, | |
| "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, | |
| } | |
| params_to_optimize = [ | |
| transformer_parameters_with_lr, | |
| text_parameters_one_with_lr, | |
| text_parameters_two_with_lr, | |
| text_parameters_three_with_lr, | |
| ] | |
| else: | |
| params_to_optimize = [transformer_parameters_with_lr] | |
| if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): | |
| logger.warning( | |
| f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." | |
| "Defaulting to adamW" | |
| ) | |
| args.optimizer = "adamw" | |
| # Initialize the optimizer | |
| if args.use_8bit_adam and not args.optimizer.lower() == "adamw": | |
| logger.warning( | |
| f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " | |
| f"set to {args.optimizer.lower()}" | |
| ) | |
| if args.optimizer.lower() == "adamw": | |
| if args.use_8bit_adam: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| raise ImportError( | |
| "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | |
| ) | |
| optimizer_class = bnb.optim.AdamW8bit | |
| else: | |
| optimizer_class = torch.optim.AdamW | |
| optimizer = optimizer_class( | |
| params_to_optimize, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| weight_decay=args.adam_weight_decay, | |
| eps=args.adam_epsilon, | |
| ) | |
| if args.optimizer.lower() == "prodigy": | |
| try: | |
| import prodigyopt | |
| except ImportError: | |
| raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") | |
| optimizer_class = prodigyopt.Prodigy | |
| if args.learning_rate <= 0.1: | |
| logger.warning( | |
| "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" | |
| ) | |
| if args.train_text_encoder and args.text_encoder_lr: | |
| logger.warning( | |
| f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" | |
| f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " | |
| f"When using prodigy only learning_rate is used as the initial learning rate." | |
| ) | |
| # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be | |
| # --learning_rate | |
| params_to_optimize[1]["lr"] = args.learning_rate | |
| params_to_optimize[2]["lr"] = args.learning_rate | |
| params_to_optimize[3]["lr"] = args.learning_rate | |
| optimizer = optimizer_class( | |
| params_to_optimize, | |
| lr=args.learning_rate, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| beta3=args.prodigy_beta3, | |
| weight_decay=args.adam_weight_decay, | |
| eps=args.adam_epsilon, | |
| decouple=args.prodigy_decouple, | |
| use_bias_correction=args.prodigy_use_bias_correction, | |
| safeguard_warmup=args.prodigy_safeguard_warmup, | |
| ) | |
| text_encoders_dtypes = [text_encoder_one.dtype, text_encoder_two.dtype, text_encoder_three.dtype] | |
| if not args.train_text_encoder: | |
| tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three] | |
| text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three] | |
| def compute_text_embeddings(prompt, text_encoders, tokenizers,text_encoders_dtypes): | |
| with torch.no_grad(): | |
| prompt_embeds, pooled_prompt_embeds = encode_prompt( | |
| text_encoders, tokenizers, prompt, args.max_sequence_length, text_encoders_dtypes | |
| ) | |
| prompt_embeds = prompt_embeds.to(accelerator.device) | |
| pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) | |
| return prompt_embeds, pooled_prompt_embeds | |
| # Get the datasets: you can either provide your own training and evaluation files (see below) | |
| # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). | |
| # In distributed training, the load_dataset function guarantees that only one local process can concurrently | |
| # download the dataset. | |
| if args.dataset_name is not None: | |
| # Downloading and loading a dataset from the hub. | |
| dataset = load_dataset( | |
| args.dataset_name, | |
| args.dataset_config_name, | |
| cache_dir=args.cache_dir, | |
| ) | |
| else: | |
| if args.train_data_jsonl is not None: | |
| dataset = load_dataset( | |
| "json", | |
| data_files=args.train_data_jsonl, | |
| cache_dir=args.cache_dir, | |
| # split="train" | |
| ) | |
| # See more about loading custom images at | |
| # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder | |
| # Preprocessing the datasets. | |
| # We need to tokenize inputs and targets. | |
| column_names = dataset["train"].column_names | |
| # 6. Get the column names for input/target. | |
| dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) | |
| if args.original_image_column is None: | |
| original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] | |
| else: | |
| original_image_column = args.original_image_column | |
| if original_image_column not in column_names: | |
| raise ValueError( | |
| f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}" | |
| ) | |
| if args.edit_prompt_column is None: | |
| edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1] | |
| else: | |
| edit_prompt_column = args.edit_prompt_column | |
| if edit_prompt_column not in column_names: | |
| raise ValueError( | |
| f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}" | |
| ) | |
| if args.edited_image_column is None: | |
| edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2] | |
| else: | |
| edited_image_column = args.edited_image_column | |
| if edited_image_column not in column_names: | |
| raise ValueError( | |
| f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" | |
| ) | |
| # Preprocessing the datasets. | |
| # We need to tokenize input captions and transform the images. | |
| # def tokenize_captions(captions): | |
| # inputs = tokenizer( | |
| # captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| # ) | |
| # return inputs.input_ids | |
| # Preprocessing the datasets. | |
| train_transforms = transforms.Compose( | |
| [ | |
| transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), | |
| transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), | |
| ] | |
| ) | |
| def preprocess_images(examples): | |
| original_images = np.concatenate( | |
| [convert_to_np(image, args.resolution) for image in examples[original_image_column]] | |
| ) | |
| edited_images = np.concatenate( | |
| [convert_to_np(image, args.resolution) for image in examples[edited_image_column]] | |
| ) | |
| if args.do_mask: | |
| mask_images = np.concatenate( | |
| [convert_to_np(image, args.resolution) for image in examples[args.mask_column]] | |
| ) | |
| # We need to ensure that the original and the edited images undergo the same | |
| # augmentation transforms. | |
| images = np.concatenate([original_images, edited_images, mask_images]) | |
| images = torch.tensor(images) | |
| images = 2 * (images / 255) - 1 | |
| # mask_index = torch.tensor([image == "NONE" for image in examples[args.mask_column]],dtype=torch.bool) | |
| # return train_transforms(images),mask_index | |
| return train_transforms(images) | |
| # We need to ensure that the original and the edited images undergo the same | |
| # augmentation transforms. | |
| images = np.concatenate([original_images, edited_images]) | |
| images = torch.tensor(images) | |
| images = 2 * (images / 255) - 1 | |
| return train_transforms(images) | |
| def preprocess_train(examples): | |
| # Preprocess images. | |
| # Since the original and edited images were concatenated before | |
| # applying the transformations, we need to separate them and reshape | |
| # them accordingly. | |
| preprocessed_images = preprocess_images(examples) | |
| if not args.do_mask: | |
| # preprocessed_images = preprocess_images(examples) | |
| original_images, edited_images = preprocessed_images.chunk(2) | |
| else: | |
| # preprocessed_images = preprocess_images(examples) | |
| # preprocessed_images,mask_index = preprocess_images(examples) | |
| original_images, edited_images, mask_images = preprocessed_images.chunk(3) | |
| mask_images = mask_images.reshape(-1, 3, args.resolution, args.resolution) | |
| # examples["mask_index"] = mask_index | |
| examples["mask_pixel_values"] = mask_images | |
| original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) | |
| edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) | |
| examples["original_pixel_values"] = original_images | |
| examples["edited_pixel_values"] = edited_images | |
| # Preprocess the captions. | |
| # captions = list(examples[edit_prompt_column]) | |
| # examples[edit_prompt_column] = captions | |
| return examples | |
| with accelerator.main_process_first(): | |
| if args.top_training_data_sample is not None: | |
| dataset["train"] = dataset["train"].select(range(args.top_training_data_sample)).shuffle(seed=args.seed) | |
| if args.max_train_samples is not None: | |
| dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) | |
| # Set the training transforms | |
| train_dataset = dataset["train"].with_transform(preprocess_train) | |
| def collate_fn(examples): | |
| original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples]) | |
| original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) | |
| edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| prompts = [example[edit_prompt_column] for example in examples] | |
| if args.do_mask: | |
| mask_pixel_values = torch.stack([example["mask_pixel_values"] for example in examples]) | |
| mask_pixel_values = mask_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| return { | |
| "original_pixel_values": original_pixel_values, | |
| "edited_pixel_values": edited_pixel_values, | |
| edit_prompt_column: prompts, | |
| "mask_pixel_values": mask_pixel_values, | |
| } | |
| else: | |
| return { | |
| "original_pixel_values": original_pixel_values, | |
| "edited_pixel_values": edited_pixel_values, | |
| edit_prompt_column: prompts, | |
| } | |
| # DataLoaders creation: | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| batch_size=args.train_batch_size, | |
| num_workers=args.dataloader_num_workers, | |
| ) | |
| # Scheduler and math around the number of training steps. | |
| overrode_max_train_steps = False | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
| if args.max_train_steps is None: | |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
| overrode_max_train_steps = True | |
| lr_scheduler = get_scheduler( | |
| args.lr_scheduler, | |
| optimizer=optimizer, | |
| num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, | |
| num_training_steps=args.max_train_steps * accelerator.num_processes, | |
| num_cycles=args.lr_num_cycles, | |
| power=args.lr_power, | |
| ) | |
| # Prepare everything with our `accelerator`. | |
| if args.train_text_encoder: | |
| ( | |
| transformer, | |
| text_encoder_one, | |
| text_encoder_two, | |
| text_encoder_three, | |
| optimizer, | |
| train_dataloader, | |
| lr_scheduler, | |
| ) = accelerator.prepare( | |
| transformer, | |
| text_encoder_one, | |
| text_encoder_two, | |
| text_encoder_three, | |
| optimizer, | |
| train_dataloader, | |
| lr_scheduler, | |
| ) | |
| else: | |
| transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| transformer, optimizer, train_dataloader, lr_scheduler | |
| ) | |
| # For mixed precision training we cast the text_encoder and vae weights to half-precision | |
| # as these models are only used for inference, keeping weights in full precision is not required. | |
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
| if accelerator.is_main_process: | |
| pretrained_path = args.pretrained_model_name_or_path | |
| pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( | |
| pretrained_path, | |
| vae=vae, | |
| text_encoder=accelerator.unwrap_model(text_encoder_one), | |
| text_encoder_2=accelerator.unwrap_model(text_encoder_two), | |
| text_encoder_3=accelerator.unwrap_model(text_encoder_three), | |
| transformer=accelerator.unwrap_model(transformer), | |
| revision=args.revision, | |
| variant=args.variant, | |
| torch_dtype=weight_dtype, | |
| ) | |
| pipeline = pipeline.to(accelerator.device) | |
| pipeline.set_progress_bar_config(disable=True) | |
| generator = torch.Generator(device=accelerator.device).manual_seed( | |
| args.seed) if args.seed else None | |
| if args.do_mask: | |
| original_image = download_image(args.val_image_url, args.eval_resolution) | |
| mask_image = download_image(args.val_mask_url, args.eval_resolution) | |
| else: | |
| original_image = download_image(args.val_image_url, args.eval_resolution) | |
| mask_image = None | |
| edited_images = [] | |
| with torch.autocast( | |
| str(accelerator.device).replace(":0", ""), | |
| enabled=(accelerator.mixed_precision == "fp16") | ( | |
| accelerator.mixed_precision == "bf16") | |
| ): | |
| for i in range(args.num_validation_images): | |
| edited_images.append( | |
| pipeline( | |
| args.validation_prompt, | |
| image=original_image, | |
| mask_img=mask_image, | |
| num_inference_steps=50, | |
| image_guidance_scale=1.5, | |
| guidance_scale=7.5, | |
| generator=generator, | |
| ).images[0] | |
| ) | |
| path = join(args.output_dir, f"start_test") | |
| os.makedirs(path, exist_ok=True) | |
| original_image.save(join(path, f"original.jpg")) | |
| for idx, edited_image in enumerate(edited_images): | |
| edited_image.save(join(path, f"sample_{idx}.jpg")) | |
| del pipeline | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
| if overrode_max_train_steps: | |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
| # Afterwards we recalculate our number of training epochs | |
| print('=========num_update_steps_per_epoch==========', num_update_steps_per_epoch) | |
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
| # We need to initialize the trackers we use, and also store our configuration. | |
| # The trackers initializes automatically on the main process. | |
| if accelerator.is_main_process: | |
| accelerator.init_trackers("instruct-pix2pix_sd3", config=vars(args)) | |
| # Train! | |
| total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataset)}") | |
| logger.info(f" Num Epochs = {args.num_train_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {args.max_train_steps}") | |
| global_step = 0 | |
| first_epoch = 0 | |
| # Potentially load in the weights and states from a previous save | |
| if args.resume_from_checkpoint: | |
| if args.resume_from_checkpoint != "latest": | |
| path = os.path.basename(args.resume_from_checkpoint) | |
| else: | |
| # Get the most recent checkpoint | |
| dirs = os.listdir(args.output_dir) | |
| dirs = [d for d in dirs if d.startswith("checkpoint")] | |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | |
| path = dirs[-1] if len(dirs) > 0 else None | |
| if path is None: | |
| accelerator.print( | |
| f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." | |
| ) | |
| args.resume_from_checkpoint = None | |
| initial_global_step = 0 | |
| else: | |
| accelerator.print(f"Resuming from checkpoint {path}") | |
| accelerator.load_state(os.path.join(args.output_dir, path)) | |
| global_step = int(path.split("-")[1]) | |
| initial_global_step = global_step | |
| first_epoch = global_step // num_update_steps_per_epoch | |
| resume_global_step = global_step * args.gradient_accumulation_steps | |
| resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) | |
| else: | |
| initial_global_step = 0 | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm(range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", | |
| disable=not accelerator.is_local_main_process) | |
| def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |
| sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) | |
| schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) | |
| timesteps = timesteps.to(accelerator.device) | |
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | |
| sigma = sigmas[step_indices].flatten() | |
| while len(sigma.shape) < n_dim: | |
| sigma = sigma.unsqueeze(-1) | |
| return sigma | |
| # with torch.autograd.set_detect_anomaly(True): | |
| for epoch in range(first_epoch, args.num_train_epochs): | |
| transformer.train() | |
| if args.train_text_encoder: | |
| text_encoder_one.train() | |
| text_encoder_two.train() | |
| text_encoder_three.train() | |
| train_loss = 0.0 | |
| for step, batch in enumerate(train_dataloader): | |
| # Skip steps until we reach the resumed step | |
| # if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: | |
| # if step % args.gradient_accumulation_steps == 0: | |
| # progress_bar.update(1) | |
| # continue | |
| models_to_accumulate = [transformer] | |
| if args.train_text_encoder: | |
| models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three]) | |
| with accelerator.accumulate(models_to_accumulate): | |
| # We want to learn the denoising process w.r.t the edited images which | |
| # are conditioned on the original image (which was edited) and the edit instruction. | |
| # So, first, convert images to latent space.] | |
| pixel_values = batch["edited_pixel_values"].to(dtype=vae.dtype) | |
| prompt = batch[edit_prompt_column] | |
| if not args.train_text_encoder: | |
| prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( | |
| prompt, text_encoders, tokenizers,text_encoders_dtypes | |
| ) | |
| else: | |
| tokens_one = tokenize_prompt(tokenizer_one, prompt) | |
| tokens_two = tokenize_prompt(tokenizer_two, prompt) | |
| tokens_three = tokenize_prompt(tokenizer_three, prompt, args.max_sequence_length) | |
| latents = vae.encode(pixel_values).latent_dist.sample() | |
| latents = latents * vae.config.scaling_factor | |
| latents = latents.to(dtype=weight_dtype) | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| # Sample a random timestep for each image | |
| # for weighting schemes where we sample timesteps non-uniformly | |
| if args.weighting_scheme == "logit_normal": | |
| # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). | |
| u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") | |
| u = torch.nn.functional.sigmoid(u) | |
| elif args.weighting_scheme == "mode": | |
| u = torch.rand(size=(bsz,), device="cpu") | |
| u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) | |
| else: | |
| u = torch.rand(size=(bsz,), device="cpu") | |
| indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() | |
| timesteps = noise_scheduler_copy.timesteps[indices].to(device=latents.device) | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) | |
| noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents | |
| # Get the additional image embedding for conditioning. | |
| # Instead of getting a diagonal Gaussian here, we simply take the mode. | |
| original_image_embeds = vae.encode(batch["original_pixel_values"].to(vae.dtype)).latent_dist.mode() | |
| concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1) | |
| if args.do_mask: | |
| mask_embeds = vae.encode(batch["mask_pixel_values"].to(vae.dtype)).latent_dist.mode() | |
| concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, mask_embeds], dim=1) | |
| # Predict the noise residual | |
| if not args.train_text_encoder: | |
| model_pred = transformer( | |
| hidden_states=concatenated_noisy_latents, | |
| timestep=timesteps, | |
| encoder_hidden_states=prompt_embeds, | |
| pooled_projections=pooled_prompt_embeds, | |
| return_dict=False, | |
| # mask_index = mask_index | |
| )[0] | |
| else: | |
| prompt_embeds, pooled_prompt_embeds = encode_prompt( | |
| text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], | |
| tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three], | |
| prompt=prompt, | |
| text_input_ids_list=[tokens_one, tokens_two, tokens_three], | |
| max_sequence_length=args.max_sequence_length, | |
| text_encoders_dtypes = text_encoders_dtypes | |
| ) | |
| model_pred = transformer( | |
| hidden_states=concatenated_noisy_latents, | |
| timestep=timesteps, | |
| encoder_hidden_states=prompt_embeds, | |
| pooled_projections=pooled_prompt_embeds, | |
| return_dict=False, | |
| mask_index=mask_index | |
| )[0] | |
| model_pred = model_pred * (-sigmas) + noisy_model_input | |
| # these weighting schemes use a uniform timestep sampling | |
| # and instead post-weight the loss | |
| if args.weighting_scheme == "sigma_sqrt": | |
| weighting = (sigmas ** -2.0).float() | |
| elif args.weighting_scheme == "cosmap": | |
| bot = 1 - 2 * sigmas + 2 * sigmas ** 2 | |
| weighting = 2 / (math.pi * bot) | |
| else: | |
| weighting = torch.ones_like(sigmas) | |
| target = latents | |
| # Conditioning dropout to support classifier-free guidance during inference. For more details | |
| # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. | |
| # Concatenate the `original_image_embeds` with the `noisy_latents`. | |
| # Get the target for loss depending on the prediction type | |
| loss = torch.mean( | |
| (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), | |
| 1, | |
| ) | |
| loss = loss.mean() | |
| avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() | |
| train_loss += avg_loss.item() / args.gradient_accumulation_steps | |
| # Backpropagate | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| params_to_clip = ( | |
| itertools.chain( | |
| transformer.parameters(), | |
| text_encoder_one.parameters(), | |
| text_encoder_two.parameters(), | |
| text_encoder_three.parameters(), | |
| ) | |
| if args.train_text_encoder | |
| else transformer.parameters() | |
| ) | |
| accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| global_step += 1 | |
| accelerator.log({"train_loss": train_loss}, step=global_step) | |
| train_loss = 0.0 | |
| if accelerator.is_main_process: | |
| if global_step % args.checkpointing_steps == 0: | |
| # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | |
| if args.checkpoints_total_limit is not None: | |
| checkpoints = os.listdir(args.output_dir) | |
| checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | |
| # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints | |
| if len(checkpoints) >= args.checkpoints_total_limit: | |
| num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 | |
| removing_checkpoints = checkpoints[0:num_to_remove] | |
| logger.info( | |
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | |
| ) | |
| logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | |
| for removing_checkpoint in removing_checkpoints: | |
| removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) | |
| shutil.rmtree(removing_checkpoint) | |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | |
| accelerator.save_state(save_path) | |
| logger.info(f"Saved state to {save_path}") | |
| if accelerator.is_main_process: | |
| if ( | |
| (args.val_image_url is not None) | |
| and (args.validation_prompt is not None) | |
| and (global_step % args.validation_step == 0) | |
| ): | |
| logger.info( | |
| f"Running validation... \n Generating {args.num_validation_images} images with prompt:" | |
| f" {args.validation_prompt}." | |
| ) | |
| # create pipeline | |
| # if not args.train_text_encoder: | |
| # text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( | |
| # text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three | |
| # ) | |
| if args.do_mask: | |
| pretrained_path = args.ori_model_name_or_path | |
| else: | |
| pretrained_path = args.pretrained_model_name_or_path | |
| pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( | |
| pretrained_path, | |
| vae=vae, | |
| text_encoder=accelerator.unwrap_model(text_encoder_one), | |
| text_encoder_2=accelerator.unwrap_model(text_encoder_two), | |
| text_encoder_3=accelerator.unwrap_model(text_encoder_three), | |
| transformer=accelerator.unwrap_model(transformer), | |
| revision=args.revision, | |
| variant=args.variant, | |
| torch_dtype=weight_dtype, | |
| ) | |
| pipeline = pipeline.to(accelerator.device) | |
| pipeline.set_progress_bar_config(disable=True) | |
| generator = torch.Generator(device=accelerator.device).manual_seed( | |
| args.seed) if args.seed else None | |
| # run inference | |
| if args.do_mask: | |
| original_image = download_image(args.val_image_url,args.eval_resolution) | |
| mask_image = download_image(args.val_mask_url,args.eval_resolution) | |
| else: | |
| original_image = download_image(args.val_image_url,args.eval_resolution) | |
| mask_image = None | |
| edited_images = [] | |
| with torch.autocast( | |
| str(accelerator.device).replace(":0", ""), | |
| enabled=(accelerator.mixed_precision == "fp16") | ( | |
| accelerator.mixed_precision == "bf16") | |
| ): | |
| for i in range(args.num_validation_images): | |
| edited_images.append( | |
| pipeline( | |
| args.validation_prompt, | |
| image=original_image, | |
| mask_img=mask_image, | |
| num_inference_steps=50, | |
| image_guidance_scale=1.5, | |
| guidance_scale=7.5, | |
| generator=generator, | |
| ).images[0] | |
| ) | |
| for tracker in accelerator.trackers: | |
| path = join(args.output_dir, f"eval_{global_step}") | |
| os.makedirs(path, exist_ok=True) | |
| original_image.save(join(path, f"original.jpg")) | |
| if tracker.name == "wandb": | |
| wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) | |
| for idx, edited_image in enumerate(edited_images): | |
| wandb_table.add_data( | |
| wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt | |
| ) | |
| # save in the dir as well | |
| tracker.log({"validation": wandb_table}) | |
| for idx, edited_image in enumerate(edited_images): | |
| edited_image.save(join(path, f"sample_{idx}.jpg")) | |
| del pipeline | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
| progress_bar.set_postfix(**logs) | |
| accelerator.log(logs, step=global_step) | |
| if global_step >= args.max_train_steps: | |
| break | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| transformer = unwrap_model(transformer) | |
| if args.train_text_encoder: | |
| text_encoder_one = unwrap_model(text_encoder_one) | |
| text_encoder_two = unwrap_model(text_encoder_two) | |
| text_encoder_three = unwrap_model(text_encoder_three) | |
| pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| transformer=transformer, | |
| text_encoder=text_encoder_one, | |
| text_encoder_2=text_encoder_two, | |
| text_encoder_3=text_encoder_three, | |
| ) | |
| else: | |
| pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, transformer=transformer | |
| ) | |
| pipeline.save_pretrained(args.output_dir) | |
| accelerator.end_training() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main() | |