Spaces:
Sleeping
Sleeping
logs current steps + auto push checkpoint step to repo
Browse files- train_dreambooth_lora_sdxl.py +103 -7
train_dreambooth_lora_sdxl.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
| 13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
|
|
|
|
| 16 |
import argparse
|
| 17 |
import gc
|
| 18 |
import hashlib
|
|
@@ -62,6 +63,73 @@ check_min_version("0.21.0.dev0")
|
|
| 62 |
|
| 63 |
logger = get_logger(__name__)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def save_model_card(
|
| 67 |
repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
|
@@ -92,9 +160,9 @@ datasets:
|
|
| 92 |
These are LoRA adaption weights for {base_model}.
|
| 93 |
|
| 94 |
The weights were trained on the concept prompt:
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
Use this keyword to trigger your custom model in your prompts.
|
| 99 |
|
| 100 |
LoRA for the text encoder was enabled: {train_text_encoder}.
|
|
@@ -126,11 +194,11 @@ pipe = DiffusionPipeline.from_pretrained(
|
|
| 126 |
use_safetensors=True
|
| 127 |
)
|
| 128 |
|
|
|
|
|
|
|
| 129 |
# This is where you load your trained weights
|
| 130 |
pipe.load_lora_weights('{repo_id}')
|
| 131 |
|
| 132 |
-
pipe.to("cuda")
|
| 133 |
-
|
| 134 |
prompt = "A majestic {prompt} jumping from a big stone at night"
|
| 135 |
|
| 136 |
image = pipe(prompt=prompt, num_inference_steps=50).images[0]
|
|
@@ -1067,6 +1135,7 @@ def main(args):
|
|
| 1067 |
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
|
| 1068 |
|
| 1069 |
# Train!
|
|
|
|
| 1070 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 1071 |
|
| 1072 |
logger.info("***** Running training *****")
|
|
@@ -1110,6 +1179,9 @@ def main(args):
|
|
| 1110 |
progress_bar.set_description("Steps")
|
| 1111 |
|
| 1112 |
for epoch in range(first_epoch, args.num_train_epochs):
|
|
|
|
|
|
|
|
|
|
| 1113 |
unet.train()
|
| 1114 |
if args.train_text_encoder:
|
| 1115 |
text_encoder_one.train()
|
|
@@ -1118,7 +1190,7 @@ def main(args):
|
|
| 1118 |
# Skip steps until we reach the resumed step
|
| 1119 |
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
| 1120 |
if step % args.gradient_accumulation_steps == 0:
|
| 1121 |
-
progress_bar.update(1)
|
| 1122 |
continue
|
| 1123 |
|
| 1124 |
with accelerator.accumulate(unet):
|
|
@@ -1211,6 +1283,8 @@ def main(args):
|
|
| 1211 |
|
| 1212 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1213 |
if accelerator.sync_gradients:
|
|
|
|
|
|
|
| 1214 |
progress_bar.update(1)
|
| 1215 |
global_step += 1
|
| 1216 |
|
|
@@ -1240,10 +1314,32 @@ def main(args):
|
|
| 1240 |
accelerator.save_state(save_path)
|
| 1241 |
logger.info(f"Saved state to {save_path}")
|
| 1242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1243 |
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1244 |
progress_bar.set_postfix(**logs)
|
| 1245 |
accelerator.log(logs, step=global_step)
|
| 1246 |
-
|
|
|
|
|
|
|
| 1247 |
if global_step >= args.max_train_steps:
|
| 1248 |
break
|
| 1249 |
|
|
|
|
| 13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
|
| 16 |
+
import gradio as gr
|
| 17 |
import argparse
|
| 18 |
import gc
|
| 19 |
import hashlib
|
|
|
|
| 63 |
|
| 64 |
logger = get_logger(__name__)
|
| 65 |
|
| 66 |
+
def save_tempo_model_card(
|
| 67 |
+
repo_id: str, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None, last_checkpoint=str
|
| 68 |
+
):
|
| 69 |
+
|
| 70 |
+
yaml = f"""
|
| 71 |
+
---
|
| 72 |
+
base_model: {base_model}
|
| 73 |
+
instance_prompt: {prompt}
|
| 74 |
+
tags:
|
| 75 |
+
- stable-diffusion-xl
|
| 76 |
+
- stable-diffusion-xl-diffusers
|
| 77 |
+
- text-to-image
|
| 78 |
+
- diffusers
|
| 79 |
+
- lora
|
| 80 |
+
inference: false
|
| 81 |
+
datasets:
|
| 82 |
+
- {dataset_id}
|
| 83 |
+
---
|
| 84 |
+
"""
|
| 85 |
+
model_card = f"""
|
| 86 |
+
# LoRA DreamBooth - {repo_id}
|
| 87 |
+
|
| 88 |
+
## MODEL IS CURRENTLY TRAINING ...
|
| 89 |
+
Last checkpoint saved: {last_checkpoint}
|
| 90 |
+
|
| 91 |
+
These are LoRA adaption weights for {base_model}.
|
| 92 |
+
|
| 93 |
+
The weights were trained on the concept prompt:
|
| 94 |
+
```
|
| 95 |
+
{prompt}
|
| 96 |
+
```
|
| 97 |
+
Use this keyword to trigger your custom model in your prompts.
|
| 98 |
+
|
| 99 |
+
LoRA for the text encoder was enabled: {train_text_encoder}.
|
| 100 |
+
|
| 101 |
+
Special VAE used for training: {vae_path}.
|
| 102 |
+
|
| 103 |
+
## Usage
|
| 104 |
+
|
| 105 |
+
Make sure to upgrade diffusers to >= 0.19.0:
|
| 106 |
+
```
|
| 107 |
+
pip install diffusers --upgrade
|
| 108 |
+
```
|
| 109 |
+
In addition make sure to install transformers, safetensors, accelerate as well as the invisible watermark:
|
| 110 |
+
```
|
| 111 |
+
pip install invisible_watermark transformers accelerate safetensors
|
| 112 |
+
```
|
| 113 |
+
To just use the base model, you can run:
|
| 114 |
+
```python
|
| 115 |
+
import torch
|
| 116 |
+
from diffusers import DiffusionPipeline, AutoencoderKL
|
| 117 |
+
vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
|
| 118 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 119 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 120 |
+
vae=vae, torch_dtype=torch.float16, variant="fp16",
|
| 121 |
+
use_safetensors=True
|
| 122 |
+
)
|
| 123 |
+
pipe.to("cuda")
|
| 124 |
+
# This is where you load your trained weights
|
| 125 |
+
pipe.load_lora_weights('{repo_id}')
|
| 126 |
+
|
| 127 |
+
prompt = "A majestic {prompt} jumping from a big stone at night"
|
| 128 |
+
image = pipe(prompt=prompt, num_inference_steps=50).images[0]
|
| 129 |
+
```
|
| 130 |
+
"""
|
| 131 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
| 132 |
+
f.write(yaml + model_card)
|
| 133 |
|
| 134 |
def save_model_card(
|
| 135 |
repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
|
|
|
| 160 |
These are LoRA adaption weights for {base_model}.
|
| 161 |
|
| 162 |
The weights were trained on the concept prompt:
|
| 163 |
+
```
|
| 164 |
+
{prompt}
|
| 165 |
+
```
|
| 166 |
Use this keyword to trigger your custom model in your prompts.
|
| 167 |
|
| 168 |
LoRA for the text encoder was enabled: {train_text_encoder}.
|
|
|
|
| 194 |
use_safetensors=True
|
| 195 |
)
|
| 196 |
|
| 197 |
+
pipe.to("cuda")
|
| 198 |
+
|
| 199 |
# This is where you load your trained weights
|
| 200 |
pipe.load_lora_weights('{repo_id}')
|
| 201 |
|
|
|
|
|
|
|
| 202 |
prompt = "A majestic {prompt} jumping from a big stone at night"
|
| 203 |
|
| 204 |
image = pipe(prompt=prompt, num_inference_steps=50).images[0]
|
|
|
|
| 1135 |
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
|
| 1136 |
|
| 1137 |
# Train!
|
| 1138 |
+
gr.Info("Training Starts now")
|
| 1139 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 1140 |
|
| 1141 |
logger.info("***** Running training *****")
|
|
|
|
| 1179 |
progress_bar.set_description("Steps")
|
| 1180 |
|
| 1181 |
for epoch in range(first_epoch, args.num_train_epochs):
|
| 1182 |
+
# Print a message for each epoch
|
| 1183 |
+
print(f"Epoch {epoch}: Training in progress...")
|
| 1184 |
+
|
| 1185 |
unet.train()
|
| 1186 |
if args.train_text_encoder:
|
| 1187 |
text_encoder_one.train()
|
|
|
|
| 1190 |
# Skip steps until we reach the resumed step
|
| 1191 |
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
| 1192 |
if step % args.gradient_accumulation_steps == 0:
|
| 1193 |
+
progress_bar.update(1)
|
| 1194 |
continue
|
| 1195 |
|
| 1196 |
with accelerator.accumulate(unet):
|
|
|
|
| 1283 |
|
| 1284 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1285 |
if accelerator.sync_gradients:
|
| 1286 |
+
# Print a message for each step
|
| 1287 |
+
print(f"Step {global_step}/{args.max_train_steps}: Done")
|
| 1288 |
progress_bar.update(1)
|
| 1289 |
global_step += 1
|
| 1290 |
|
|
|
|
| 1314 |
accelerator.save_state(save_path)
|
| 1315 |
logger.info(f"Saved state to {save_path}")
|
| 1316 |
|
| 1317 |
+
gr.Info(f"Saving checkpoint-{global_step} to {repo_id}")
|
| 1318 |
+
save_tempo_model_card(
|
| 1319 |
+
repo_id,
|
| 1320 |
+
dataset_id=args.dataset_id,
|
| 1321 |
+
base_model=args.pretrained_model_name_or_path,
|
| 1322 |
+
train_text_encoder=args.train_text_encoder,
|
| 1323 |
+
prompt=args.instance_prompt,
|
| 1324 |
+
repo_folder=args.output_dir,
|
| 1325 |
+
vae_path=args.pretrained_vae_model_name_or_path,
|
| 1326 |
+
last_checkpoint = f"checkpoint-{global_step}"
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
upload_folder(
|
| 1330 |
+
repo_id=repo_id,
|
| 1331 |
+
folder_path=args.output_dir,
|
| 1332 |
+
commit_message=f"saving checkpoint-{global_step}",
|
| 1333 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1334 |
+
token=args.hub_token
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1338 |
progress_bar.set_postfix(**logs)
|
| 1339 |
accelerator.log(logs, step=global_step)
|
| 1340 |
+
|
| 1341 |
+
|
| 1342 |
+
|
| 1343 |
if global_step >= args.max_train_steps:
|
| 1344 |
break
|
| 1345 |
|