Spaces:
Sleeping
Sleeping
add dataset_id arg
Browse files
train_dreambooth_lora_sdxl.py
CHANGED
|
@@ -64,7 +64,7 @@ logger = get_logger(__name__)
|
|
| 64 |
|
| 65 |
|
| 66 |
def save_model_card(
|
| 67 |
-
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
| 68 |
):
|
| 69 |
img_str = ""
|
| 70 |
for i, image in enumerate(images):
|
|
@@ -82,6 +82,8 @@ tags:
|
|
| 82 |
- diffusers
|
| 83 |
- lora
|
| 84 |
inference: false
|
|
|
|
|
|
|
| 85 |
---
|
| 86 |
"""
|
| 87 |
model_card = f"""
|
|
@@ -180,6 +182,13 @@ def parse_args(input_args=None):
|
|
| 180 |
required=False,
|
| 181 |
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 182 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
parser.add_argument(
|
| 184 |
"--instance_data_dir",
|
| 185 |
type=str,
|
|
@@ -1386,6 +1395,7 @@ def main(args):
|
|
| 1386 |
save_model_card(
|
| 1387 |
repo_id,
|
| 1388 |
images=images,
|
|
|
|
| 1389 |
base_model=args.pretrained_model_name_or_path,
|
| 1390 |
train_text_encoder=args.train_text_encoder,
|
| 1391 |
prompt=args.instance_prompt,
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
def save_model_card(
|
| 67 |
+
repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
| 68 |
):
|
| 69 |
img_str = ""
|
| 70 |
for i, image in enumerate(images):
|
|
|
|
| 82 |
- diffusers
|
| 83 |
- lora
|
| 84 |
inference: false
|
| 85 |
+
datasets:
|
| 86 |
+
- {dataset_id}
|
| 87 |
---
|
| 88 |
"""
|
| 89 |
model_card = f"""
|
|
|
|
| 182 |
required=False,
|
| 183 |
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 184 |
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--dataset_id",
|
| 187 |
+
type=str,
|
| 188 |
+
default=None,
|
| 189 |
+
required=True,
|
| 190 |
+
help="The dataset ID you want to train images from",
|
| 191 |
+
)
|
| 192 |
parser.add_argument(
|
| 193 |
"--instance_data_dir",
|
| 194 |
type=str,
|
|
|
|
| 1395 |
save_model_card(
|
| 1396 |
repo_id,
|
| 1397 |
images=images,
|
| 1398 |
+
dataset_id=args.dataset_id,
|
| 1399 |
base_model=args.pretrained_model_name_or_path,
|
| 1400 |
train_text_encoder=args.train_text_encoder,
|
| 1401 |
prompt=args.instance_prompt,
|