PUYONE commited on
Commit
e69a9f5
·
verified ·
1 Parent(s): a708532

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .github/workflows/typos.yaml +21 -0
  3. .gitignore +11 -0
  4. .gradio/certificate.pem +31 -0
  5. LICENSE.md +201 -0
  6. README.md +17 -8
  7. XTI_hijack.py +209 -0
  8. _typos.toml +15 -0
  9. cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 +3 -0
  10. config_README-ja.md +279 -0
  11. config_files/accelerate/default_config.yaml +22 -0
  12. dreambooth_gui.py +944 -0
  13. fine_tune.py +430 -0
  14. fine_tune_README.md +465 -0
  15. fine_tune_README_ja.md +140 -0
  16. finetune/blip/blip.py +240 -0
  17. finetune/blip/med.py +955 -0
  18. finetune/blip/med_config.json +22 -0
  19. finetune/blip/vit.py +305 -0
  20. finetune/clean_captions_and_tags.py +190 -0
  21. finetune/hypernetwork_nai.py +96 -0
  22. finetune/make_captions.py +168 -0
  23. finetune/make_captions_by_git.py +151 -0
  24. finetune/merge_captions_to_metadata.py +76 -0
  25. finetune/merge_dd_tags_to_metadata.py +71 -0
  26. finetune/prepare_buckets_latents.py +267 -0
  27. finetune/tag_images_by_wd14_tagger.py +206 -0
  28. finetune_gui.py +888 -0
  29. gen_img_diffusers.py +0 -0
  30. gui.sh +9 -0
  31. kohya_gui.py +110 -0
  32. kohya_ss_colab.ipynb +448 -0
  33. library/__init__.py +0 -0
  34. library/basic_caption_gui.py +140 -0
  35. library/blip_caption_gui.py +149 -0
  36. library/common_gui.py +978 -0
  37. library/config_util.py +536 -0
  38. library/convert_model_gui.py +247 -0
  39. library/custom_train_functions.py +18 -0
  40. library/dataset_balancing_gui.py +146 -0
  41. library/dreambooth_folder_creation_gui.py +210 -0
  42. library/extract_lora_gui.py +178 -0
  43. library/extract_lycoris_locon_gui.py +309 -0
  44. library/git_caption_gui.py +136 -0
  45. library/lpw_stable_diffusion.py +1179 -0
  46. library/merge_lora_gui.py +156 -0
  47. library/model_util.py +1165 -0
  48. library/resize_lora_gui.py +173 -0
  49. library/sampler_gui.py +102 -0
  50. library/svd_merge_lora_gui.py +190 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 filter=lfs diff=lfs merge=lfs -text
.github/workflows/typos.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # yamllint disable rule:line-length
3
+ name: Typos
4
+
5
+ on: # yamllint disable-line rule:truthy
6
+ push:
7
+ pull_request:
8
+ types:
9
+ - opened
10
+ - synchronize
11
+ - reopened
12
+
13
+ jobs:
14
+ build:
15
+ runs-on: ubuntu-latest
16
+
17
+ steps:
18
+ - uses: actions/checkout@v3
19
+
20
+ - name: typos-action
21
+ uses: crate-ci/typos@v1.13.10
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ venv
2
+ __pycache__
3
+ cudnn_windows
4
+ .vscode
5
+ *.egg-info
6
+ build
7
+ wd14_tagger_model
8
+ .DS_Store
9
+ locon
10
+ gui-user.bat
11
+ gui-user.ps1
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2022] [kohya-ss]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,21 @@
1
  ---
2
- title: Kohya Ss Colab
3
- emoji: 📈
4
- colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.49.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: kohya_ss_colab
3
+ app_file: dreambooth_gui.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.47.2
 
 
6
  ---
7
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb)
8
 
9
+ # Kohya SS WebUI Colab Setup
10
+
11
+ This Colab workbook sets up a Kohya SS instance on Colab and provides a link to access the Kohya WebUI on Gradio Live. Kohya SS is a Python library that provides Stable Diffusion-based models for image, text, and audio generation tasks. This Colab workbook provides a convenient way for users to run Kohya SS without needing to install anything on their local machine.
12
+
13
+ This workbook was inspired by the work of [Spaceginner](https://github.com/Spaceginner)'s original Colab workbook and the [Kohya SS project](https://github.com/bmaltais/kohya_ss) by [bmaltais](https://github.com/bmaltais). The Colab workbook was coded by [panguin6010](https://github.com/panguin6010)
14
+
15
+
16
+ ## Tutorials
17
+
18
+ Before running this code, make sure you are familiar with using Colab workbooks and have a basic understanding of Kohya SS and its usage. You can find tutorials for these online. If you encounter any issues or have suggestions for improvement, feel free to contribute to the project.
19
+
20
+ ## Link
21
+ ```https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb```
XTI_hijack.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, List, Optional, Dict, Any, Tuple
3
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
4
+
5
+ def unet_forward_XTI(self,
6
+ sample: torch.FloatTensor,
7
+ timestep: Union[torch.Tensor, float, int],
8
+ encoder_hidden_states: torch.Tensor,
9
+ class_labels: Optional[torch.Tensor] = None,
10
+ return_dict: bool = True,
11
+ ) -> Union[UNet2DConditionOutput, Tuple]:
12
+ r"""
13
+ Args:
14
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
15
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
16
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
17
+ return_dict (`bool`, *optional*, defaults to `True`):
18
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
19
+
20
+ Returns:
21
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
22
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
23
+ returning a tuple, the first element is the sample tensor.
24
+ """
25
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
26
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
27
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
28
+ # on the fly if necessary.
29
+ default_overall_up_factor = 2**self.num_upsamplers
30
+
31
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
32
+ forward_upsample_size = False
33
+ upsample_size = None
34
+
35
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
36
+ logger.info("Forward upsample size to force interpolation output size.")
37
+ forward_upsample_size = True
38
+
39
+ # 0. center input if necessary
40
+ if self.config.center_input_sample:
41
+ sample = 2 * sample - 1.0
42
+
43
+ # 1. time
44
+ timesteps = timestep
45
+ if not torch.is_tensor(timesteps):
46
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
47
+ # This would be a good case for the `match` statement (Python 3.10+)
48
+ is_mps = sample.device.type == "mps"
49
+ if isinstance(timestep, float):
50
+ dtype = torch.float32 if is_mps else torch.float64
51
+ else:
52
+ dtype = torch.int32 if is_mps else torch.int64
53
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
54
+ elif len(timesteps.shape) == 0:
55
+ timesteps = timesteps[None].to(sample.device)
56
+
57
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
58
+ timesteps = timesteps.expand(sample.shape[0])
59
+
60
+ t_emb = self.time_proj(timesteps)
61
+
62
+ # timesteps does not contain any weights and will always return f32 tensors
63
+ # but time_embedding might actually be running in fp16. so we need to cast here.
64
+ # there might be better ways to encapsulate this.
65
+ t_emb = t_emb.to(dtype=self.dtype)
66
+ emb = self.time_embedding(t_emb)
67
+
68
+ if self.config.num_class_embeds is not None:
69
+ if class_labels is None:
70
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
71
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
72
+ emb = emb + class_emb
73
+
74
+ # 2. pre-process
75
+ sample = self.conv_in(sample)
76
+
77
+ # 3. down
78
+ down_block_res_samples = (sample,)
79
+ down_i = 0
80
+ for downsample_block in self.down_blocks:
81
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
82
+ sample, res_samples = downsample_block(
83
+ hidden_states=sample,
84
+ temb=emb,
85
+ encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
86
+ )
87
+ down_i += 2
88
+ else:
89
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
90
+
91
+ down_block_res_samples += res_samples
92
+
93
+ # 4. mid
94
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
95
+
96
+ # 5. up
97
+ up_i = 7
98
+ for i, upsample_block in enumerate(self.up_blocks):
99
+ is_final_block = i == len(self.up_blocks) - 1
100
+
101
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
102
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
103
+
104
+ # if we have not reached the final block and need to forward the
105
+ # upsample size, we do it here
106
+ if not is_final_block and forward_upsample_size:
107
+ upsample_size = down_block_res_samples[-1].shape[2:]
108
+
109
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
110
+ sample = upsample_block(
111
+ hidden_states=sample,
112
+ temb=emb,
113
+ res_hidden_states_tuple=res_samples,
114
+ encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
115
+ upsample_size=upsample_size,
116
+ )
117
+ up_i += 3
118
+ else:
119
+ sample = upsample_block(
120
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
121
+ )
122
+ # 6. post-process
123
+ sample = self.conv_norm_out(sample)
124
+ sample = self.conv_act(sample)
125
+ sample = self.conv_out(sample)
126
+
127
+ if not return_dict:
128
+ return (sample,)
129
+
130
+ return UNet2DConditionOutput(sample=sample)
131
+
132
+ def downblock_forward_XTI(
133
+ self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
134
+ ):
135
+ output_states = ()
136
+ i = 0
137
+
138
+ for resnet, attn in zip(self.resnets, self.attentions):
139
+ if self.training and self.gradient_checkpointing:
140
+
141
+ def create_custom_forward(module, return_dict=None):
142
+ def custom_forward(*inputs):
143
+ if return_dict is not None:
144
+ return module(*inputs, return_dict=return_dict)
145
+ else:
146
+ return module(*inputs)
147
+
148
+ return custom_forward
149
+
150
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
151
+ hidden_states = torch.utils.checkpoint.checkpoint(
152
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
153
+ )[0]
154
+ else:
155
+ hidden_states = resnet(hidden_states, temb)
156
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
157
+
158
+ output_states += (hidden_states,)
159
+ i += 1
160
+
161
+ if self.downsamplers is not None:
162
+ for downsampler in self.downsamplers:
163
+ hidden_states = downsampler(hidden_states)
164
+
165
+ output_states += (hidden_states,)
166
+
167
+ return hidden_states, output_states
168
+
169
+ def upblock_forward_XTI(
170
+ self,
171
+ hidden_states,
172
+ res_hidden_states_tuple,
173
+ temb=None,
174
+ encoder_hidden_states=None,
175
+ upsample_size=None,
176
+ ):
177
+ i = 0
178
+ for resnet, attn in zip(self.resnets, self.attentions):
179
+ # pop res hidden states
180
+ res_hidden_states = res_hidden_states_tuple[-1]
181
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
182
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
183
+
184
+ if self.training and self.gradient_checkpointing:
185
+
186
+ def create_custom_forward(module, return_dict=None):
187
+ def custom_forward(*inputs):
188
+ if return_dict is not None:
189
+ return module(*inputs, return_dict=return_dict)
190
+ else:
191
+ return module(*inputs)
192
+
193
+ return custom_forward
194
+
195
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
196
+ hidden_states = torch.utils.checkpoint.checkpoint(
197
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
198
+ )[0]
199
+ else:
200
+ hidden_states = resnet(hidden_states, temb)
201
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
202
+
203
+ i += 1
204
+
205
+ if self.upsamplers is not None:
206
+ for upsampler in self.upsamplers:
207
+ hidden_states = upsampler(hidden_states, upsample_size)
208
+
209
+ return hidden_states
_typos.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Files for typos
2
+ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started
3
+
4
+ [default.extend-identifiers]
5
+
6
+ [default.extend-words]
7
+ NIN="NIN"
8
+ parms="parms"
9
+ nin="nin"
10
+ extention="extention" # Intentionally left
11
+ nd="nd"
12
+
13
+
14
+ [files]
15
+ extend-exclude = ["_typos.toml"]
cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c791d1f047b41ff5885772fc4bf20b797c6059bbd82abb9e31de15e55d6a57c4
3
+ size 11907224
config_README-ja.md ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
2
+
3
+ `--dataset_config` で渡すことができる設定ファイルに関する説明です。
4
+
5
+ ## 概要
6
+
7
+ 設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。
8
+
9
+ * 複数のデータセットが設定可能になります
10
+ * 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。
11
+ * DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。
12
+ * サブセットごとに設定を変更することが可能になります
13
+ * データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。
14
+ * `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。
15
+
16
+ 設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。
17
+
18
+ TOML で記述した設定ファイルの例です。
19
+
20
+ ```toml
21
+ [general]
22
+ shuffle_caption = true
23
+ caption_extension = '.txt'
24
+ keep_tokens = 1
25
+
26
+ # これは DreamBooth 方式のデータセット
27
+ [[datasets]]
28
+ resolution = 512
29
+ batch_size = 4
30
+ keep_tokens = 2
31
+
32
+ [[datasets.subsets]]
33
+ image_dir = 'C:\hoge'
34
+ class_tokens = 'hoge girl'
35
+ # このサブセットは keep_tokens = 2 (所属する datasets の値が使われる)
36
+
37
+ [[datasets.subsets]]
38
+ image_dir = 'C:\fuga'
39
+ class_tokens = 'fuga boy'
40
+ keep_tokens = 3
41
+
42
+ [[datasets.subsets]]
43
+ is_reg = true
44
+ image_dir = 'C:\reg'
45
+ class_tokens = 'human'
46
+ keep_tokens = 1
47
+
48
+ # これは fine tuning 方式のデータセット
49
+ [[datasets]]
50
+ resolution = [768, 768]
51
+ batch_size = 2
52
+
53
+ [[datasets.subsets]]
54
+ image_dir = 'C:\piyo'
55
+ metadata_file = 'C:\piyo\piyo_md.json'
56
+ # このサブセットは keep_tokens = 1 (general の値が使われる)
57
+ ```
58
+
59
+ この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。
60
+
61
+ ## データセット・サブセットに関する設定
62
+
63
+ データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。
64
+
65
+ * `[general]`
66
+ * 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。
67
+ * データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。
68
+ * `[[datasets]]`
69
+ * `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。
70
+ * サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。
71
+ * `[[datasets.subsets]]`
72
+ * `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。
73
+
74
+ 先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。
75
+
76
+ ```
77
+ C:\
78
+ ├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
79
+ ├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
80
+ ├─ reg -> [[datasets.subsets]] No.3 ┘ |
81
+ └─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
82
+ ```
83
+
84
+ 画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。
85
+
86
+ 登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。
87
+
88
+ 加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。
89
+
90
+ * DreamBooth 方式専用のオプション
91
+ * fine tuning 方式専用のオプション
92
+ * caption dropout の手法が使える場合のオプション
93
+
94
+ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。
95
+ 併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。
96
+ つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。
97
+
98
+ プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。
99
+ そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。
100
+
101
+ 以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。
102
+
103
+ ### 全学習方法で共通のオプション
104
+
105
+ 学習方法によらずに指定可能なオプションです。
106
+
107
+ #### データセット向けオプション
108
+
109
+ データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。
110
+
111
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` |
112
+ | ---- | ---- | ---- | ---- |
113
+ | `batch_size` | `1` | o | o |
114
+ | `bucket_no_upscale` | `true` | o | o |
115
+ | `bucket_reso_steps` | `64` | o | o |
116
+ | `enable_bucket` | `true` | o | o |
117
+ | `max_bucket_reso` | `1024` | o | o |
118
+ | `min_bucket_reso` | `128` | o | o |
119
+ | `resolution` | `256`, `[512, 512]` | o | o |
120
+
121
+ * `batch_size`
122
+ * コマンドライン引数の `--train_batch_size` と同等です。
123
+
124
+ これらの設定はデータセットごとに固定です。
125
+ つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
126
+ 例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。
127
+
128
+ #### サブセット向けオプション
129
+
130
+ サブセットの設定に関わるオプションです。
131
+
132
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
133
+ | ---- | ---- | ---- | ---- | ---- |
134
+ | `color_aug` | `false` | o | o | o |
135
+ | `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
136
+ | `flip_aug` | `true` | o | o | o |
137
+ | `keep_tokens` | `2` | o | o | o |
138
+ | `num_repeats` | `10` | o | o | o |
139
+ | `random_crop` | `false` | o | o | o |
140
+ | `shuffle_caption` | `true` | o | o | o |
141
+
142
+ * `num_repeats`
143
+ * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
144
+
145
+ ### DreamBooth 方式専用のオプション
146
+
147
+ DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
148
+
149
+ #### サブセット向けオプション
150
+
151
+ DreamBooth 方式のサブセットの設定に関わるオプションです。
152
+
153
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
154
+ | ---- | ---- | ---- | ---- | ---- |
155
+ | `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
156
+ | `caption_extension` | `".txt"` | o | o | o |
157
+ | `class_tokens` | `“sks girl”` | - | - | o |
158
+ | `is_reg` | `false` | - | - | o |
159
+
160
+ まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
161
+
162
+ * `image_dir`
163
+ * 画像ディレクトリのパスを指定します。指定必須オプションです。
164
+ * 画像はディレクトリ直下に置かれている必要があります。
165
+ * `class_tokens`
166
+ * クラストークンを設定します。
167
+ * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイル���見つからなかった場合にはエラーになります。
168
+ * `is_reg`
169
+ * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
170
+
171
+ ### fine tuning 方式専用のオプション
172
+
173
+ fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。
174
+
175
+ #### サブセット向けオプション
176
+
177
+ fine tuning 方式のサブセットの設定に関わるオプションです。
178
+
179
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
180
+ | ---- | ---- | ---- | ---- | ---- |
181
+ | `image_dir` | `‘C:\hoge’` | - | - | o |
182
+ | `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) |
183
+
184
+ * `image_dir`
185
+ * 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。
186
+ * 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。
187
+ * 画像はディレクトリ直下に置かれている必要があります。
188
+ * `metadata_file`
189
+ * サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。
190
+ * コマンドライン引数の `--in_json` と同等です。
191
+ * サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。
192
+
193
+ ### caption dropout の手法が使える場合に指定可能なオプション
194
+
195
+ caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。
196
+ DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。
197
+
198
+ #### サブセット向けオプション
199
+
200
+ caption dropout が使えるサブセットの設定に関わるオプションです。
201
+
202
+ | オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
203
+ | ---- | ---- | ---- | ---- |
204
+ | `caption_dropout_every_n_epochs` | o | o | o |
205
+ | `caption_dropout_rate` | o | o | o |
206
+ | `caption_tag_dropout_rate` | o | o | o |
207
+
208
+ ## 重複したサブセットが存在する時の挙動
209
+
210
+ DreamBooth 方式のデータセットの場合、その中にある `image_dir` が同一のサブセットは重複していると見なされます。
211
+ fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。
212
+ データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。
213
+
214
+ 一方、異なるデータセットに所属している場合は、重複しているとは見なされません。
215
+ 例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。
216
+ これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。
217
+
218
+ ```toml
219
+ # 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる
220
+
221
+ [[datasets]]
222
+ resolution = 512
223
+
224
+ [[datasets.subsets]]
225
+ image_dir = 'C:\hoge'
226
+
227
+ [[datasets]]
228
+ resolution = 768
229
+
230
+ [[datasets.subsets]]
231
+ image_dir = 'C:\hoge'
232
+ ```
233
+
234
+ ## コマンドライン引数との併用
235
+
236
+ 設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
237
+
238
+ 以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。
239
+
240
+ * `--train_data_dir`
241
+ * `--reg_data_dir`
242
+ * `--in_json`
243
+
244
+ 以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。
245
+
246
+ | コマンドライン引数のオプション | 優先される設定ファイルのオプション |
247
+ | ---------------------------------- | ---------------------------------- |
248
+ | `--bucket_no_upscale` | |
249
+ | `--bucket_reso_steps` | |
250
+ | `--caption_dropout_every_n_epochs` | |
251
+ | `--caption_dropout_rate` | |
252
+ | `--caption_extension` | |
253
+ | `--caption_tag_dropout_rate` | |
254
+ | `--color_aug` | |
255
+ | `--dataset_repeats` | `num_repeats` |
256
+ | `--enable_bucket` | |
257
+ | `--face_crop_aug_range` | |
258
+ | `--flip_aug` | |
259
+ | `--keep_tokens` | |
260
+ | `--min_bucket_reso` | |
261
+ | `--random_crop` | |
262
+ | `--resolution` | |
263
+ | `--shuffle_caption` | |
264
+ | `--train_batch_size` | `batch_size` |
265
+
266
+ ## エラーの手引き
267
+
268
+ 現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。
269
+ 将来的にはこの問題の改善に取り組む予定です。
270
+
271
+ 次善策として、頻出のエラーとその対処法について載せておきます。
272
+ 正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。
273
+
274
+ * `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。
275
+ * `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。
276
+ * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
277
+ * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
278
+
279
+
config_files/accelerate/default_config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ command_file: null
2
+ commands: null
3
+ compute_environment: LOCAL_MACHINE
4
+ deepspeed_config: {}
5
+ distributed_type: 'NO'
6
+ downcast_bf16: 'no'
7
+ dynamo_backend: 'NO'
8
+ fsdp_config: {}
9
+ gpu_ids: all
10
+ machine_rank: 0
11
+ main_process_ip: null
12
+ main_process_port: null
13
+ main_training_function: main
14
+ megatron_lm_config: {}
15
+ mixed_precision: 'no'
16
+ num_machines: 1
17
+ num_processes: 1
18
+ rdzv_backend: static
19
+ same_network: true
20
+ tpu_name: null
21
+ tpu_zone: null
22
+ use_cpu: false
dreambooth_gui.py ADDED
@@ -0,0 +1,944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: initial release
2
+ # v2: add open and save folder icons
3
+ # v3: Add new Utilities tab for Dreambooth folder preparation
4
+ # v3.1: Adding captionning of images to utilities
5
+
6
+ import gradio as gr
7
+ import json
8
+ import math
9
+ import os
10
+ import subprocess
11
+ import pathlib
12
+ import argparse
13
+ from library.common_gui import (
14
+ get_folder_path,
15
+ remove_doublequote,
16
+ get_file_path,
17
+ get_any_file_path,
18
+ get_saveasfile_path,
19
+ color_aug_changed,
20
+ save_inference_file,
21
+ gradio_advanced_training,
22
+ run_cmd_advanced_training,
23
+ run_cmd_training,
24
+ gradio_training,
25
+ gradio_config,
26
+ gradio_source_model,
27
+ # set_legacy_8bitadam,
28
+ update_my_data,
29
+ check_if_model_exist,
30
+ )
31
+ from library.tensorboard_gui import (
32
+ gradio_tensorboard,
33
+ start_tensorboard,
34
+ stop_tensorboard,
35
+ )
36
+ from library.dreambooth_folder_creation_gui import (
37
+ gradio_dreambooth_folder_creation_tab,
38
+ )
39
+ from library.utilities import utilities_tab
40
+ from library.sampler_gui import sample_gradio_config, run_cmd_sample
41
+ from easygui import msgbox
42
+
43
+ folder_symbol = '\U0001f4c2' # 📂
44
+ refresh_symbol = '\U0001f504' # 🔄
45
+ save_style_symbol = '\U0001f4be' # 💾
46
+ document_symbol = '\U0001F4C4' # 📄
47
+
48
+
49
+ def save_configuration(
50
+ save_as,
51
+ file_path,
52
+ pretrained_model_name_or_path,
53
+ v2,
54
+ v_parameterization,
55
+ logging_dir,
56
+ train_data_dir,
57
+ reg_data_dir,
58
+ output_dir,
59
+ max_resolution,
60
+ learning_rate,
61
+ lr_scheduler,
62
+ lr_warmup,
63
+ train_batch_size,
64
+ epoch,
65
+ save_every_n_epochs,
66
+ mixed_precision,
67
+ save_precision,
68
+ seed,
69
+ num_cpu_threads_per_process,
70
+ cache_latents,
71
+ caption_extension,
72
+ enable_bucket,
73
+ gradient_checkpointing,
74
+ full_fp16,
75
+ no_token_padding,
76
+ stop_text_encoder_training,
77
+ # use_8bit_adam,
78
+ xformers,
79
+ save_model_as,
80
+ shuffle_caption,
81
+ save_state,
82
+ resume,
83
+ prior_loss_weight,
84
+ color_aug,
85
+ flip_aug,
86
+ clip_skip,
87
+ vae,
88
+ output_name,
89
+ max_token_length,
90
+ max_train_epochs,
91
+ max_data_loader_n_workers,
92
+ mem_eff_attn,
93
+ gradient_accumulation_steps,
94
+ model_list,
95
+ keep_tokens,
96
+ persistent_data_loader_workers,
97
+ bucket_no_upscale,
98
+ random_crop,
99
+ bucket_reso_steps,
100
+ caption_dropout_every_n_epochs,
101
+ caption_dropout_rate,
102
+ optimizer,
103
+ optimizer_args,
104
+ noise_offset,
105
+ sample_every_n_steps,
106
+ sample_every_n_epochs,
107
+ sample_sampler,
108
+ sample_prompts,
109
+ additional_parameters,
110
+ vae_batch_size,
111
+ min_snr_gamma,
112
+ ):
113
+ # Get list of function parameters and values
114
+ parameters = list(locals().items())
115
+
116
+ original_file_path = file_path
117
+
118
+ save_as_bool = True if save_as.get('label') == 'True' else False
119
+
120
+ if save_as_bool:
121
+ print('Save as...')
122
+ file_path = get_saveasfile_path(file_path)
123
+ else:
124
+ print('Save...')
125
+ if file_path == None or file_path == '':
126
+ file_path = get_saveasfile_path(file_path)
127
+
128
+ # print(file_path)
129
+
130
+ if file_path == None or file_path == '':
131
+ return original_file_path # In case a file_path was provided and the user decide to cancel the open action
132
+
133
+ # Return the values of the variables as a dictionary
134
+ variables = {
135
+ name: value
136
+ for name, value in parameters # locals().items()
137
+ if name
138
+ not in [
139
+ 'file_path',
140
+ 'save_as',
141
+ ]
142
+ }
143
+
144
+ # Extract the destination directory from the file path
145
+ destination_directory = os.path.dirname(file_path)
146
+
147
+ # Create the destination directory if it doesn't exist
148
+ if not os.path.exists(destination_directory):
149
+ os.makedirs(destination_directory)
150
+
151
+ # Save the data to the selected file
152
+ with open(file_path, 'w') as file:
153
+ json.dump(variables, file, indent=2)
154
+
155
+ return file_path
156
+
157
+
158
+ def open_configuration(
159
+ ask_for_file,
160
+ file_path,
161
+ pretrained_model_name_or_path,
162
+ v2,
163
+ v_parameterization,
164
+ logging_dir,
165
+ train_data_dir,
166
+ reg_data_dir,
167
+ output_dir,
168
+ max_resolution,
169
+ learning_rate,
170
+ lr_scheduler,
171
+ lr_warmup,
172
+ train_batch_size,
173
+ epoch,
174
+ save_every_n_epochs,
175
+ mixed_precision,
176
+ save_precision,
177
+ seed,
178
+ num_cpu_threads_per_process,
179
+ cache_latents,
180
+ caption_extension,
181
+ enable_bucket,
182
+ gradient_checkpointing,
183
+ full_fp16,
184
+ no_token_padding,
185
+ stop_text_encoder_training,
186
+ # use_8bit_adam,
187
+ xformers,
188
+ save_model_as,
189
+ shuffle_caption,
190
+ save_state,
191
+ resume,
192
+ prior_loss_weight,
193
+ color_aug,
194
+ flip_aug,
195
+ clip_skip,
196
+ vae,
197
+ output_name,
198
+ max_token_length,
199
+ max_train_epochs,
200
+ max_data_loader_n_workers,
201
+ mem_eff_attn,
202
+ gradient_accumulation_steps,
203
+ model_list,
204
+ keep_tokens,
205
+ persistent_data_loader_workers,
206
+ bucket_no_upscale,
207
+ random_crop,
208
+ bucket_reso_steps,
209
+ caption_dropout_every_n_epochs,
210
+ caption_dropout_rate,
211
+ optimizer,
212
+ optimizer_args,
213
+ noise_offset,
214
+ sample_every_n_steps,
215
+ sample_every_n_epochs,
216
+ sample_sampler,
217
+ sample_prompts,
218
+ additional_parameters,
219
+ vae_batch_size,
220
+ min_snr_gamma,
221
+ ):
222
+ # Get list of function parameters and values
223
+ parameters = list(locals().items())
224
+
225
+ ask_for_file = True if ask_for_file.get('label') == 'True' else False
226
+
227
+ original_file_path = file_path
228
+
229
+ if ask_for_file:
230
+ file_path = get_file_path(file_path)
231
+
232
+ if not file_path == '' and not file_path == None:
233
+ # load variables from JSON file
234
+ with open(file_path, 'r') as f:
235
+ my_data = json.load(f)
236
+ print('Loading config...')
237
+ # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
238
+ my_data = update_my_data(my_data)
239
+ else:
240
+ file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
241
+ my_data = {}
242
+
243
+ values = [file_path]
244
+ for key, value in parameters:
245
+ # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
246
+ if not key in ['ask_for_file', 'file_path']:
247
+ values.append(my_data.get(key, value))
248
+ return tuple(values)
249
+
250
+
251
+ def train_model(
252
+ pretrained_model_name_or_path,
253
+ v2,
254
+ v_parameterization,
255
+ logging_dir,
256
+ train_data_dir,
257
+ reg_data_dir,
258
+ output_dir,
259
+ max_resolution,
260
+ learning_rate,
261
+ lr_scheduler,
262
+ lr_warmup,
263
+ train_batch_size,
264
+ epoch,
265
+ save_every_n_epochs,
266
+ mixed_precision,
267
+ save_precision,
268
+ seed,
269
+ num_cpu_threads_per_process,
270
+ cache_latents,
271
+ caption_extension,
272
+ enable_bucket,
273
+ gradient_checkpointing,
274
+ full_fp16,
275
+ no_token_padding,
276
+ stop_text_encoder_training_pct,
277
+ # use_8bit_adam,
278
+ xformers,
279
+ save_model_as,
280
+ shuffle_caption,
281
+ save_state,
282
+ resume,
283
+ prior_loss_weight,
284
+ color_aug,
285
+ flip_aug,
286
+ clip_skip,
287
+ vae,
288
+ output_name,
289
+ max_token_length,
290
+ max_train_epochs,
291
+ max_data_loader_n_workers,
292
+ mem_eff_attn,
293
+ gradient_accumulation_steps,
294
+ model_list, # Keep this. Yes, it is unused here but required given the common list used
295
+ keep_tokens,
296
+ persistent_data_loader_workers,
297
+ bucket_no_upscale,
298
+ random_crop,
299
+ bucket_reso_steps,
300
+ caption_dropout_every_n_epochs,
301
+ caption_dropout_rate,
302
+ optimizer,
303
+ optimizer_args,
304
+ noise_offset,
305
+ sample_every_n_steps,
306
+ sample_every_n_epochs,
307
+ sample_sampler,
308
+ sample_prompts,
309
+ additional_parameters,
310
+ vae_batch_size,
311
+ min_snr_gamma,
312
+ ):
313
+ if pretrained_model_name_or_path == '':
314
+ msgbox('Source model information is missing')
315
+ return
316
+
317
+ if train_data_dir == '':
318
+ msgbox('Image folder path is missing')
319
+ return
320
+
321
+ if not os.path.exists(train_data_dir):
322
+ msgbox('Image folder does not exist')
323
+ return
324
+
325
+ if reg_data_dir != '':
326
+ if not os.path.exists(reg_data_dir):
327
+ msgbox('Regularisation folder does not exist')
328
+ return
329
+
330
+ if output_dir == '':
331
+ msgbox('Output folder path is missing')
332
+ return
333
+
334
+ if check_if_model_exist(output_name, output_dir, save_model_as):
335
+ return
336
+
337
+ # Get a list of all subfolders in train_data_dir, excluding hidden folders
338
+ subfolders = [
339
+ f
340
+ for f in os.listdir(train_data_dir)
341
+ if os.path.isdir(os.path.join(train_data_dir, f))
342
+ and not f.startswith('.')
343
+ ]
344
+
345
+ # Check if subfolders are present. If not let the user know and return
346
+ if not subfolders:
347
+ print(
348
+ '\033[33mNo subfolders were found in',
349
+ train_data_dir,
350
+ " can't train\...033[0m",
351
+ )
352
+ return
353
+
354
+ total_steps = 0
355
+
356
+ # Loop through each subfolder and extract the number of repeats
357
+ for folder in subfolders:
358
+ # Extract the number of repeats from the folder name
359
+ try:
360
+ repeats = int(folder.split('_')[0])
361
+ except ValueError:
362
+ print(
363
+ '\033[33mSubfolder',
364
+ folder,
365
+ "does not have a proper repeat value, please correct the name or remove it... can't train...\033[0m",
366
+ )
367
+ continue
368
+
369
+ # Count the number of images in the folder
370
+ num_images = len(
371
+ [
372
+ f
373
+ for f, lower_f in (
374
+ (file, file.lower())
375
+ for file in os.listdir(
376
+ os.path.join(train_data_dir, folder)
377
+ )
378
+ )
379
+ if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
380
+ ]
381
+ )
382
+
383
+ if num_images == 0:
384
+ print(f'{folder} folder contain no images, skipping...')
385
+ else:
386
+ # Calculate the total number of steps for this folder
387
+ steps = repeats * num_images
388
+ total_steps += steps
389
+
390
+ # Print the result
391
+ print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
392
+
393
+ if total_steps == 0:
394
+ print(
395
+ '\033[33mNo images were found in folder',
396
+ train_data_dir,
397
+ '... please rectify!\033[0m',
398
+ )
399
+ return
400
+
401
+ # Print the result
402
+ # print(f"{total_steps} total steps")
403
+
404
+ if reg_data_dir == '':
405
+ reg_factor = 1
406
+ else:
407
+ print(
408
+ '\033[94mRegularisation images are used... Will double the number of steps required...\033[0m'
409
+ )
410
+ reg_factor = 2
411
+
412
+ # calculate max_train_steps
413
+ max_train_steps = int(
414
+ math.ceil(
415
+ float(total_steps)
416
+ / int(train_batch_size)
417
+ * int(epoch)
418
+ * int(reg_factor)
419
+ )
420
+ )
421
+ print(f'max_train_steps = {max_train_steps}')
422
+
423
+ # calculate stop encoder training
424
+ if int(stop_text_encoder_training_pct) == -1:
425
+ stop_text_encoder_training = -1
426
+ elif stop_text_encoder_training_pct == None:
427
+ stop_text_encoder_training = 0
428
+ else:
429
+ stop_text_encoder_training = math.ceil(
430
+ float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
431
+ )
432
+ print(f'stop_text_encoder_training = {stop_text_encoder_training}')
433
+
434
+ lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
435
+ print(f'lr_warmup_steps = {lr_warmup_steps}')
436
+
437
+ run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
438
+ if v2:
439
+ run_cmd += ' --v2'
440
+ if v_parameterization:
441
+ run_cmd += ' --v_parameterization'
442
+ if enable_bucket:
443
+ run_cmd += ' --enable_bucket'
444
+ if no_token_padding:
445
+ run_cmd += ' --no_token_padding'
446
+ run_cmd += (
447
+ f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
448
+ )
449
+ run_cmd += f' --train_data_dir="{train_data_dir}"'
450
+ if len(reg_data_dir):
451
+ run_cmd += f' --reg_data_dir="{reg_data_dir}"'
452
+ run_cmd += f' --resolution={max_resolution}'
453
+ run_cmd += f' --output_dir="{output_dir}"'
454
+ run_cmd += f' --logging_dir="{logging_dir}"'
455
+ if not stop_text_encoder_training == 0:
456
+ run_cmd += (
457
+ f' --stop_text_encoder_training={stop_text_encoder_training}'
458
+ )
459
+ if not save_model_as == 'same as source model':
460
+ run_cmd += f' --save_model_as={save_model_as}'
461
+ # if not resume == '':
462
+ # run_cmd += f' --resume={resume}'
463
+ if not float(prior_loss_weight) == 1.0:
464
+ run_cmd += f' --prior_loss_weight={prior_loss_weight}'
465
+ if not vae == '':
466
+ run_cmd += f' --vae="{vae}"'
467
+ if not output_name == '':
468
+ run_cmd += f' --output_name="{output_name}"'
469
+ if int(max_token_length) > 75:
470
+ run_cmd += f' --max_token_length={max_token_length}'
471
+ if not max_train_epochs == '':
472
+ run_cmd += f' --max_train_epochs="{max_train_epochs}"'
473
+ if not max_data_loader_n_workers == '':
474
+ run_cmd += (
475
+ f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
476
+ )
477
+ if int(gradient_accumulation_steps) > 1:
478
+ run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
479
+
480
+ run_cmd += run_cmd_training(
481
+ learning_rate=learning_rate,
482
+ lr_scheduler=lr_scheduler,
483
+ lr_warmup_steps=lr_warmup_steps,
484
+ train_batch_size=train_batch_size,
485
+ max_train_steps=max_train_steps,
486
+ save_every_n_epochs=save_every_n_epochs,
487
+ mixed_precision=mixed_precision,
488
+ save_precision=save_precision,
489
+ seed=seed,
490
+ caption_extension=caption_extension,
491
+ cache_latents=cache_latents,
492
+ optimizer=optimizer,
493
+ optimizer_args=optimizer_args,
494
+ )
495
+
496
+ run_cmd += run_cmd_advanced_training(
497
+ max_train_epochs=max_train_epochs,
498
+ max_data_loader_n_workers=max_data_loader_n_workers,
499
+ max_token_length=max_token_length,
500
+ resume=resume,
501
+ save_state=save_state,
502
+ mem_eff_attn=mem_eff_attn,
503
+ clip_skip=clip_skip,
504
+ flip_aug=flip_aug,
505
+ color_aug=color_aug,
506
+ shuffle_caption=shuffle_caption,
507
+ gradient_checkpointing=gradient_checkpointing,
508
+ full_fp16=full_fp16,
509
+ xformers=xformers,
510
+ # use_8bit_adam=use_8bit_adam,
511
+ keep_tokens=keep_tokens,
512
+ persistent_data_loader_workers=persistent_data_loader_workers,
513
+ bucket_no_upscale=bucket_no_upscale,
514
+ random_crop=random_crop,
515
+ bucket_reso_steps=bucket_reso_steps,
516
+ caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
517
+ caption_dropout_rate=caption_dropout_rate,
518
+ noise_offset=noise_offset,
519
+ additional_parameters=additional_parameters,
520
+ vae_batch_size=vae_batch_size,
521
+ min_snr_gamma=min_snr_gamma,
522
+ )
523
+
524
+ run_cmd += run_cmd_sample(
525
+ sample_every_n_steps,
526
+ sample_every_n_epochs,
527
+ sample_sampler,
528
+ sample_prompts,
529
+ output_dir,
530
+ )
531
+
532
+ print(run_cmd)
533
+
534
+ # Run the command
535
+ if os.name == 'posix':
536
+ os.system(run_cmd)
537
+ else:
538
+ subprocess.run(run_cmd)
539
+
540
+ # check if output_dir/last is a folder... therefore it is a diffuser model
541
+ last_dir = pathlib.Path(f'{output_dir}/{output_name}')
542
+
543
+ if not last_dir.is_dir():
544
+ # Copy inference model for v2 if required
545
+ save_inference_file(output_dir, v2, v_parameterization, output_name)
546
+
547
+
548
+ def dreambooth_tab(
549
+ train_data_dir=gr.Textbox(),
550
+ reg_data_dir=gr.Textbox(),
551
+ output_dir=gr.Textbox(),
552
+ logging_dir=gr.Textbox(),
553
+ ):
554
+ dummy_db_true = gr.Label(value=True, visible=False)
555
+ dummy_db_false = gr.Label(value=False, visible=False)
556
+ gr.Markdown('Train a custom model using kohya dreambooth python code...')
557
+ (
558
+ button_open_config,
559
+ button_save_config,
560
+ button_save_as_config,
561
+ config_file_name,
562
+ button_load_config,
563
+ ) = gradio_config()
564
+
565
+ (
566
+ pretrained_model_name_or_path,
567
+ v2,
568
+ v_parameterization,
569
+ save_model_as,
570
+ model_list,
571
+ ) = gradio_source_model()
572
+
573
+ with gr.Tab('Folders'):
574
+ with gr.Row():
575
+ train_data_dir = gr.Textbox(
576
+ label='Image folder',
577
+ placeholder='Folder where the training folders containing the images are located',
578
+ )
579
+ train_data_dir_input_folder = gr.Button(
580
+ '📂', elem_id='open_folder_small'
581
+ )
582
+ train_data_dir_input_folder.click(
583
+ get_folder_path,
584
+ outputs=train_data_dir,
585
+ show_progress=False,
586
+ )
587
+ reg_data_dir = gr.Textbox(
588
+ label='Regularisation folder',
589
+ placeholder='(Optional) Folder where where the regularization folders containing the images are located',
590
+ )
591
+ reg_data_dir_input_folder = gr.Button(
592
+ '📂', elem_id='open_folder_small'
593
+ )
594
+ reg_data_dir_input_folder.click(
595
+ get_folder_path,
596
+ outputs=reg_data_dir,
597
+ show_progress=False,
598
+ )
599
+ with gr.Row():
600
+ output_dir = gr.Textbox(
601
+ label='Model output folder',
602
+ placeholder='Folder to output trained model',
603
+ )
604
+ output_dir_input_folder = gr.Button(
605
+ '📂', elem_id='open_folder_small'
606
+ )
607
+ output_dir_input_folder.click(get_folder_path, outputs=output_dir)
608
+ logging_dir = gr.Textbox(
609
+ label='Logging folder',
610
+ placeholder='Optional: enable logging and output TensorBoard log to this folder',
611
+ )
612
+ logging_dir_input_folder = gr.Button(
613
+ '📂', elem_id='open_folder_small'
614
+ )
615
+ logging_dir_input_folder.click(
616
+ get_folder_path,
617
+ outputs=logging_dir,
618
+ show_progress=False,
619
+ )
620
+ with gr.Row():
621
+ output_name = gr.Textbox(
622
+ label='Model output name',
623
+ placeholder='Name of the model to output',
624
+ value='last',
625
+ interactive=True,
626
+ )
627
+ train_data_dir.change(
628
+ remove_doublequote,
629
+ inputs=[train_data_dir],
630
+ outputs=[train_data_dir],
631
+ )
632
+ reg_data_dir.change(
633
+ remove_doublequote,
634
+ inputs=[reg_data_dir],
635
+ outputs=[reg_data_dir],
636
+ )
637
+ output_dir.change(
638
+ remove_doublequote,
639
+ inputs=[output_dir],
640
+ outputs=[output_dir],
641
+ )
642
+ logging_dir.change(
643
+ remove_doublequote,
644
+ inputs=[logging_dir],
645
+ outputs=[logging_dir],
646
+ )
647
+ with gr.Tab('Training parameters'):
648
+ (
649
+ learning_rate,
650
+ lr_scheduler,
651
+ lr_warmup,
652
+ train_batch_size,
653
+ epoch,
654
+ save_every_n_epochs,
655
+ mixed_precision,
656
+ save_precision,
657
+ num_cpu_threads_per_process,
658
+ seed,
659
+ caption_extension,
660
+ cache_latents,
661
+ optimizer,
662
+ optimizer_args,
663
+ ) = gradio_training(
664
+ learning_rate_value='1e-5',
665
+ lr_scheduler_value='cosine',
666
+ lr_warmup_value='10',
667
+ )
668
+ with gr.Row():
669
+ max_resolution = gr.Textbox(
670
+ label='Max resolution',
671
+ value='512,512',
672
+ placeholder='512,512',
673
+ )
674
+ stop_text_encoder_training = gr.Slider(
675
+ minimum=-1,
676
+ maximum=100,
677
+ value=0,
678
+ step=1,
679
+ label='Stop text encoder training',
680
+ )
681
+ enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
682
+ with gr.Accordion('Advanced Configuration', open=False):
683
+ with gr.Row():
684
+ no_token_padding = gr.Checkbox(
685
+ label='No token padding', value=False
686
+ )
687
+ gradient_accumulation_steps = gr.Number(
688
+ label='Gradient accumulate steps', value='1'
689
+ )
690
+ with gr.Row():
691
+ prior_loss_weight = gr.Number(
692
+ label='Prior loss weight', value=1.0
693
+ )
694
+ vae = gr.Textbox(
695
+ label='VAE',
696
+ placeholder='(Optiona) path to checkpoint of vae to replace for training',
697
+ )
698
+ vae_button = gr.Button('📂', elem_id='open_folder_small')
699
+ vae_button.click(
700
+ get_any_file_path,
701
+ outputs=vae,
702
+ show_progress=False,
703
+ )
704
+ (
705
+ # use_8bit_adam,
706
+ xformers,
707
+ full_fp16,
708
+ gradient_checkpointing,
709
+ shuffle_caption,
710
+ color_aug,
711
+ flip_aug,
712
+ clip_skip,
713
+ mem_eff_attn,
714
+ save_state,
715
+ resume,
716
+ max_token_length,
717
+ max_train_epochs,
718
+ max_data_loader_n_workers,
719
+ keep_tokens,
720
+ persistent_data_loader_workers,
721
+ bucket_no_upscale,
722
+ random_crop,
723
+ bucket_reso_steps,
724
+ caption_dropout_every_n_epochs,
725
+ caption_dropout_rate,
726
+ noise_offset,
727
+ additional_parameters,
728
+ vae_batch_size,
729
+ min_snr_gamma,
730
+ ) = gradio_advanced_training()
731
+ color_aug.change(
732
+ color_aug_changed,
733
+ inputs=[color_aug],
734
+ outputs=[cache_latents],
735
+ )
736
+
737
+ (
738
+ sample_every_n_steps,
739
+ sample_every_n_epochs,
740
+ sample_sampler,
741
+ sample_prompts,
742
+ ) = sample_gradio_config()
743
+
744
+ with gr.Tab('Tools'):
745
+ gr.Markdown(
746
+ 'This section provide Dreambooth tools to help setup your dataset...'
747
+ )
748
+ gradio_dreambooth_folder_creation_tab(
749
+ train_data_dir_input=train_data_dir,
750
+ reg_data_dir_input=reg_data_dir,
751
+ output_dir_input=output_dir,
752
+ logging_dir_input=logging_dir,
753
+ )
754
+
755
+ button_run = gr.Button('Train model', variant='primary')
756
+
757
+ # Setup gradio tensorboard buttons
758
+ button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
759
+
760
+ button_start_tensorboard.click(
761
+ start_tensorboard,
762
+ inputs=logging_dir,
763
+ show_progress=False,
764
+ )
765
+
766
+ button_stop_tensorboard.click(
767
+ stop_tensorboard,
768
+ show_progress=False,
769
+ )
770
+
771
+ settings_list = [
772
+ pretrained_model_name_or_path,
773
+ v2,
774
+ v_parameterization,
775
+ logging_dir,
776
+ train_data_dir,
777
+ reg_data_dir,
778
+ output_dir,
779
+ max_resolution,
780
+ learning_rate,
781
+ lr_scheduler,
782
+ lr_warmup,
783
+ train_batch_size,
784
+ epoch,
785
+ save_every_n_epochs,
786
+ mixed_precision,
787
+ save_precision,
788
+ seed,
789
+ num_cpu_threads_per_process,
790
+ cache_latents,
791
+ caption_extension,
792
+ enable_bucket,
793
+ gradient_checkpointing,
794
+ full_fp16,
795
+ no_token_padding,
796
+ stop_text_encoder_training,
797
+ # use_8bit_adam,
798
+ xformers,
799
+ save_model_as,
800
+ shuffle_caption,
801
+ save_state,
802
+ resume,
803
+ prior_loss_weight,
804
+ color_aug,
805
+ flip_aug,
806
+ clip_skip,
807
+ vae,
808
+ output_name,
809
+ max_token_length,
810
+ max_train_epochs,
811
+ max_data_loader_n_workers,
812
+ mem_eff_attn,
813
+ gradient_accumulation_steps,
814
+ model_list,
815
+ keep_tokens,
816
+ persistent_data_loader_workers,
817
+ bucket_no_upscale,
818
+ random_crop,
819
+ bucket_reso_steps,
820
+ caption_dropout_every_n_epochs,
821
+ caption_dropout_rate,
822
+ optimizer,
823
+ optimizer_args,
824
+ noise_offset,
825
+ sample_every_n_steps,
826
+ sample_every_n_epochs,
827
+ sample_sampler,
828
+ sample_prompts,
829
+ additional_parameters,
830
+ vae_batch_size,
831
+ min_snr_gamma,
832
+ ]
833
+
834
+ button_open_config.click(
835
+ open_configuration,
836
+ inputs=[dummy_db_true, config_file_name] + settings_list,
837
+ outputs=[config_file_name] + settings_list,
838
+ show_progress=False,
839
+ )
840
+
841
+ button_load_config.click(
842
+ open_configuration,
843
+ inputs=[dummy_db_false, config_file_name] + settings_list,
844
+ outputs=[config_file_name] + settings_list,
845
+ show_progress=False,
846
+ )
847
+
848
+ button_save_config.click(
849
+ save_configuration,
850
+ inputs=[dummy_db_false, config_file_name] + settings_list,
851
+ outputs=[config_file_name],
852
+ show_progress=False,
853
+ )
854
+
855
+ button_save_as_config.click(
856
+ save_configuration,
857
+ inputs=[dummy_db_true, config_file_name] + settings_list,
858
+ outputs=[config_file_name],
859
+ show_progress=False,
860
+ )
861
+
862
+ button_run.click(
863
+ train_model,
864
+ inputs=settings_list,
865
+ show_progress=False,
866
+ )
867
+
868
+ return (
869
+ train_data_dir,
870
+ reg_data_dir,
871
+ output_dir,
872
+ logging_dir,
873
+ )
874
+
875
+
876
+ def UI(**kwargs):
877
+ css = ''
878
+
879
+ if os.path.exists('./style.css'):
880
+ with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
881
+ print('Load CSS...')
882
+ css += file.read() + '\n'
883
+
884
+ interface = gr.Blocks(css=css)
885
+
886
+ with interface:
887
+ with gr.Tab('Dreambooth'):
888
+ (
889
+ train_data_dir_input,
890
+ reg_data_dir_input,
891
+ output_dir_input,
892
+ logging_dir_input,
893
+ ) = dreambooth_tab()
894
+ with gr.Tab('Utilities'):
895
+ utilities_tab(
896
+ train_data_dir_input=train_data_dir_input,
897
+ reg_data_dir_input=reg_data_dir_input,
898
+ output_dir_input=output_dir_input,
899
+ logging_dir_input=logging_dir_input,
900
+ enable_copy_info_button=True,
901
+ )
902
+
903
+ # Show the interface
904
+ launch_kwargs = {}
905
+ if not kwargs.get('username', None) == '':
906
+ launch_kwargs['auth'] = (
907
+ kwargs.get('username', None),
908
+ kwargs.get('password', None),
909
+ )
910
+ if kwargs.get('server_port', 0) > 0:
911
+ launch_kwargs['server_port'] = kwargs.get('server_port', 0)
912
+ if kwargs.get('inbrowser', False):
913
+ launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
914
+ print(launch_kwargs)
915
+ interface.launch(**launch_kwargs)
916
+
917
+
918
+ if __name__ == '__main__':
919
+ # torch.cuda.set_per_process_memory_fraction(0.48)
920
+ parser = argparse.ArgumentParser()
921
+ parser.add_argument(
922
+ '--username', type=str, default='', help='Username for authentication'
923
+ )
924
+ parser.add_argument(
925
+ '--password', type=str, default='', help='Password for authentication'
926
+ )
927
+ parser.add_argument(
928
+ '--server_port',
929
+ type=int,
930
+ default=0,
931
+ help='Port to run the server listener on',
932
+ )
933
+ parser.add_argument(
934
+ '--inbrowser', action='store_true', help='Open in browser'
935
+ )
936
+
937
+ args = parser.parse_args()
938
+
939
+ UI(
940
+ username=args.username,
941
+ password=args.password,
942
+ inbrowser=args.inbrowser,
943
+ server_port=args.server_port,
944
+ )
fine_tune.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training with captions
2
+ # XXX dropped option: hypernetwork training
3
+
4
+ import argparse
5
+ import gc
6
+ import math
7
+ import os
8
+ import toml
9
+ from multiprocessing import Value
10
+
11
+ from tqdm import tqdm
12
+ import torch
13
+ from accelerate.utils import set_seed
14
+ import diffusers
15
+ from diffusers import DDPMScheduler
16
+
17
+ import library.train_util as train_util
18
+ import library.config_util as config_util
19
+ from library.config_util import (
20
+ ConfigSanitizer,
21
+ BlueprintGenerator,
22
+ )
23
+ import library.custom_train_functions as custom_train_functions
24
+ from library.custom_train_functions import apply_snr_weight
25
+
26
+
27
+ def train(args):
28
+ train_util.verify_training_args(args)
29
+ train_util.prepare_dataset_args(args, True)
30
+
31
+ cache_latents = args.cache_latents
32
+
33
+ if args.seed is not None:
34
+ set_seed(args.seed) # 乱数系列を初期化する
35
+
36
+ tokenizer = train_util.load_tokenizer(args)
37
+
38
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
39
+ if args.dataset_config is not None:
40
+ print(f"Load dataset config from {args.dataset_config}")
41
+ user_config = config_util.load_user_config(args.dataset_config)
42
+ ignored = ["train_data_dir", "in_json"]
43
+ if any(getattr(args, attr) is not None for attr in ignored):
44
+ print(
45
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
46
+ ", ".join(ignored)
47
+ )
48
+ )
49
+ else:
50
+ user_config = {
51
+ "datasets": [
52
+ {
53
+ "subsets": [
54
+ {
55
+ "image_dir": args.train_data_dir,
56
+ "metadata_file": args.in_json,
57
+ }
58
+ ]
59
+ }
60
+ ]
61
+ }
62
+
63
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
64
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
65
+
66
+ current_epoch = Value("i", 0)
67
+ current_step = Value("i", 0)
68
+ ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
69
+ collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
70
+
71
+ if args.debug_dataset:
72
+ train_util.debug_dataset(train_dataset_group)
73
+ return
74
+ if len(train_dataset_group) == 0:
75
+ print(
76
+ "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
77
+ )
78
+ return
79
+
80
+ if cache_latents:
81
+ assert (
82
+ train_dataset_group.is_latent_cacheable()
83
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
84
+
85
+ # acceleratorを準備する
86
+ print("prepare accelerator")
87
+ accelerator, unwrap_model = train_util.prepare_accelerator(args)
88
+
89
+ # mixed precisionに対応した型を用意しておき適宜castする
90
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
91
+
92
+ # モデルを読み込む
93
+ text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
94
+
95
+ # verify load/save model formats
96
+ if load_stable_diffusion_format:
97
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
98
+ src_diffusers_model_path = None
99
+ else:
100
+ src_stable_diffusion_ckpt = None
101
+ src_diffusers_model_path = args.pretrained_model_name_or_path
102
+
103
+ if args.save_model_as is None:
104
+ save_stable_diffusion_format = load_stable_diffusion_format
105
+ use_safetensors = args.use_safetensors
106
+ else:
107
+ save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
108
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
109
+
110
+ # Diffusers版のxformers使用フラグを設定する関数
111
+ def set_diffusers_xformers_flag(model, valid):
112
+ # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
113
+ # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
114
+ # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
115
+ # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
116
+
117
+ # Recursively walk through all the children.
118
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
119
+ # gets the message
120
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
121
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
122
+ module.set_use_memory_efficient_attention_xformers(valid)
123
+
124
+ for child in module.children():
125
+ fn_recursive_set_mem_eff(child)
126
+
127
+ fn_recursive_set_mem_eff(model)
128
+
129
+ # モデルに xformers とか memory efficient attention を組み込む
130
+ if args.diffusers_xformers:
131
+ print("Use xformers by Diffusers")
132
+ set_diffusers_xformers_flag(unet, True)
133
+ else:
134
+ # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
135
+ print("Disable Diffusers' xformers")
136
+ set_diffusers_xformers_flag(unet, False)
137
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
138
+
139
+ # 学習を準備する
140
+ if cache_latents:
141
+ vae.to(accelerator.device, dtype=weight_dtype)
142
+ vae.requires_grad_(False)
143
+ vae.eval()
144
+ with torch.no_grad():
145
+ train_dataset_group.cache_latents(vae, args.vae_batch_size)
146
+ vae.to("cpu")
147
+ if torch.cuda.is_available():
148
+ torch.cuda.empty_cache()
149
+ gc.collect()
150
+
151
+ # 学習を準備する:モデルを適切な状態にする
152
+ training_models = []
153
+ if args.gradient_checkpointing:
154
+ unet.enable_gradient_checkpointing()
155
+ training_models.append(unet)
156
+
157
+ if args.train_text_encoder:
158
+ print("enable text encoder training")
159
+ if args.gradient_checkpointing:
160
+ text_encoder.gradient_checkpointing_enable()
161
+ training_models.append(text_encoder)
162
+ else:
163
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
164
+ text_encoder.requires_grad_(False) # text encoderは学習しない
165
+ if args.gradient_checkpointing:
166
+ text_encoder.gradient_checkpointing_enable()
167
+ text_encoder.train() # required for gradient_checkpointing
168
+ else:
169
+ text_encoder.eval()
170
+
171
+ if not cache_latents:
172
+ vae.requires_grad_(False)
173
+ vae.eval()
174
+ vae.to(accelerator.device, dtype=weight_dtype)
175
+
176
+ for m in training_models:
177
+ m.requires_grad_(True)
178
+ params = []
179
+ for m in training_models:
180
+ params.extend(m.parameters())
181
+ params_to_optimize = params
182
+
183
+ # 学習に必要なクラスを準備する
184
+ print("prepare optimizer, data loader etc.")
185
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
186
+
187
+ # dataloaderを準備する
188
+ # DataLoaderのプロセス数:0はメインプロセスになる
189
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
190
+ train_dataloader = torch.utils.data.DataLoader(
191
+ train_dataset_group,
192
+ batch_size=1,
193
+ shuffle=True,
194
+ collate_fn=collater,
195
+ num_workers=n_workers,
196
+ persistent_workers=args.persistent_data_loader_workers,
197
+ )
198
+
199
+ # 学習ステップ数を計算する
200
+ if args.max_train_epochs is not None:
201
+ args.max_train_steps = args.max_train_epochs * math.ceil(
202
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
203
+ )
204
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
205
+
206
+ # データセット側にも学習ステップを送信
207
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
208
+
209
+ # lr schedulerを用意する
210
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
211
+
212
+ # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
213
+ if args.full_fp16:
214
+ assert (
215
+ args.mixed_precision == "fp16"
216
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
217
+ print("enable full fp16 training.")
218
+ unet.to(weight_dtype)
219
+ text_encoder.to(weight_dtype)
220
+
221
+ # acceleratorがなんかよろしくやってくれるらしい
222
+ if args.train_text_encoder:
223
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
224
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
225
+ )
226
+ else:
227
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
228
+
229
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
230
+ if args.full_fp16:
231
+ train_util.patch_accelerator_for_fp16_training(accelerator)
232
+
233
+ # resumeする
234
+ if args.resume is not None:
235
+ print(f"resume training from state: {args.resume}")
236
+ accelerator.load_state(args.resume)
237
+
238
+ # epoch数を計算する
239
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
240
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
241
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
242
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
243
+
244
+ # 学習する
245
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
246
+ print("running training / 学習開始")
247
+ print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
248
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
249
+ print(f" num epochs / epoch数: {num_train_epochs}")
250
+ print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
251
+ print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
252
+ print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
253
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
254
+
255
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
256
+ global_step = 0
257
+
258
+ noise_scheduler = DDPMScheduler(
259
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
260
+ )
261
+
262
+ if accelerator.is_main_process:
263
+ accelerator.init_trackers("finetuning")
264
+
265
+ for epoch in range(num_train_epochs):
266
+ print(f"epoch {epoch+1}/{num_train_epochs}")
267
+ current_epoch.value = epoch + 1
268
+
269
+ for m in training_models:
270
+ m.train()
271
+
272
+ loss_total = 0
273
+ for step, batch in enumerate(train_dataloader):
274
+ current_step.value = global_step
275
+ with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
276
+ with torch.no_grad():
277
+ if "latents" in batch and batch["latents"] is not None:
278
+ latents = batch["latents"].to(accelerator.device)
279
+ else:
280
+ # latentに変換
281
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
282
+ latents = latents * 0.18215
283
+ b_size = latents.shape[0]
284
+
285
+ with torch.set_grad_enabled(args.train_text_encoder):
286
+ # Get the text embedding for conditioning
287
+ input_ids = batch["input_ids"].to(accelerator.device)
288
+ encoder_hidden_states = train_util.get_hidden_states(
289
+ args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
290
+ )
291
+
292
+ # Sample noise that we'll add to the latents
293
+ noise = torch.randn_like(latents, device=latents.device)
294
+ if args.noise_offset:
295
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
296
+ noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
297
+
298
+ # Sample a random timestep for each image
299
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
300
+ timesteps = timesteps.long()
301
+
302
+ # Add noise to the latents according to the noise magnitude at each timestep
303
+ # (this is the forward diffusion process)
304
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
305
+
306
+ # Predict the noise residual
307
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
308
+
309
+ if args.v_parameterization:
310
+ # v-parameterization training
311
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
312
+ else:
313
+ target = noise
314
+
315
+ if args.min_snr_gamma:
316
+ # do not mean over batch dimension for snr weight
317
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
318
+ loss = loss.mean([1, 2, 3])
319
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
320
+ loss = loss.mean() # mean over batch dimension
321
+ else:
322
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
323
+
324
+ accelerator.backward(loss)
325
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
326
+ params_to_clip = []
327
+ for m in training_models:
328
+ params_to_clip.extend(m.parameters())
329
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
330
+
331
+ optimizer.step()
332
+ lr_scheduler.step()
333
+ optimizer.zero_grad(set_to_none=True)
334
+
335
+ # Checks if the accelerator has performed an optimization step behind the scenes
336
+ if accelerator.sync_gradients:
337
+ progress_bar.update(1)
338
+ global_step += 1
339
+
340
+ train_util.sample_images(
341
+ accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
342
+ )
343
+
344
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
345
+ if args.logging_dir is not None:
346
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
347
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
348
+ logs["lr/d*lr"] = (
349
+ lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
350
+ )
351
+ accelerator.log(logs, step=global_step)
352
+
353
+ # TODO moving averageにする
354
+ loss_total += current_loss
355
+ avr_loss = loss_total / (step + 1)
356
+ logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
357
+ progress_bar.set_postfix(**logs)
358
+
359
+ if global_step >= args.max_train_steps:
360
+ break
361
+
362
+ if args.logging_dir is not None:
363
+ logs = {"loss/epoch": loss_total / len(train_dataloader)}
364
+ accelerator.log(logs, step=epoch + 1)
365
+
366
+ accelerator.wait_for_everyone()
367
+
368
+ if args.save_every_n_epochs is not None:
369
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
370
+ train_util.save_sd_model_on_epoch_end(
371
+ args,
372
+ accelerator,
373
+ src_path,
374
+ save_stable_diffusion_format,
375
+ use_safetensors,
376
+ save_dtype,
377
+ epoch,
378
+ num_train_epochs,
379
+ global_step,
380
+ unwrap_model(text_encoder),
381
+ unwrap_model(unet),
382
+ vae,
383
+ )
384
+
385
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
386
+
387
+ is_main_process = accelerator.is_main_process
388
+ if is_main_process:
389
+ unet = unwrap_model(unet)
390
+ text_encoder = unwrap_model(text_encoder)
391
+
392
+ accelerator.end_training()
393
+
394
+ if args.save_state:
395
+ train_util.save_state_on_train_end(args, accelerator)
396
+
397
+ del accelerator # この後メモリを使うのでこれは消す
398
+
399
+ if is_main_process:
400
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
401
+ train_util.save_sd_model_on_train_end(
402
+ args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
403
+ )
404
+ print("model saved.")
405
+
406
+
407
+ def setup_parser() -> argparse.ArgumentParser:
408
+ parser = argparse.ArgumentParser()
409
+
410
+ train_util.add_sd_models_arguments(parser)
411
+ train_util.add_dataset_arguments(parser, False, True, True)
412
+ train_util.add_training_arguments(parser, False)
413
+ train_util.add_sd_saving_arguments(parser)
414
+ train_util.add_optimizer_arguments(parser)
415
+ config_util.add_config_arguments(parser)
416
+ custom_train_functions.add_custom_train_arguments(parser)
417
+
418
+ parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
419
+ parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
420
+
421
+ return parser
422
+
423
+
424
+ if __name__ == "__main__":
425
+ parser = setup_parser()
426
+
427
+ args = parser.parse_args()
428
+ args = train_util.read_config_from_file(args, parser)
429
+
430
+ train(args)
fine_tune_README.md ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ It is a fine tuning that corresponds to NovelAI's proposed learning method, automatic captioning, tagging, Windows + VRAM 12GB (for v1.4/1.5) environment, etc.
2
+
3
+ ## overview
4
+ Fine tuning of U-Net of Stable Diffusion using Diffusers. It corresponds to the following improvements in NovelAI's article (For Aspect Ratio Bucketing, I referred to NovelAI's code, but the final code is all original).
5
+
6
+ * Use the output of the penultimate layer instead of the last layer of CLIP (Text Encoder).
7
+ * Learning at non-square resolutions (Aspect Ratio Bucketing).
8
+ * Extend token length from 75 to 225.
9
+ * Captioning with BLIP (automatic creation of captions), automatic tagging with DeepDanbooru or WD14Tagger.
10
+ * Also supports Hypernetwork learning.
11
+ * Supports Stable Diffusion v2.0 (base and 768/v).
12
+ * By acquiring the output of VAE in advance and saving it to disk, we aim to save memory and speed up learning.
13
+
14
+ Text Encoder is not trained by default. For fine tuning of the whole model, it seems common to learn only U-Net (NovelAI seems to be the same). Text Encoder can also be learned as an option.
15
+
16
+ ## Additional features
17
+ ### Change CLIP output
18
+ CLIP (Text Encoder) converts the text into features in order to reflect the prompt in the image. Stable diffusion uses the output of the last layer of CLIP, but you can change it to use the output of the penultimate layer. According to NovelAI, this will reflect prompts more accurately.
19
+ It is also possible to use the output of the last layer as is.
20
+ *Stable Diffusion 2.0 uses the penultimate layer by default. Do not specify the clip_skip option.
21
+
22
+ ### Training in non-square resolutions
23
+ Stable Diffusion is trained at 512\*512, but also at resolutions such as 256\*1024 and 384\*640. It is expected that this will reduce the cropped portion and learn the relationship between prompts and images more correctly.
24
+ The learning resolution is adjusted vertically and horizontally in units of 64 pixels within a range that does not exceed the resolution area (= memory usage) given as a parameter.
25
+
26
+ In machine learning, it is common to unify all input sizes, but there are no particular restrictions, and in fact it is okay as long as they are unified within the same batch. NovelAI's bucketing seems to refer to classifying training data in advance for each learning resolution according to the aspect ratio. And by creating a batch with the images in each bucket, the image size of the batch is unified.
27
+
28
+ ### Extending token length from 75 to 225
29
+ Stable diffusion has a maximum of 75 tokens (77 tokens including the start and end), but we will extend it to 225 tokens.
30
+ However, the maximum length that CLIP accepts is 75 tokens, so in the case of 225 tokens, we simply divide it into thirds, call CLIP, and then concatenate the results.
31
+
32
+ *I'm not sure if this is the preferred implementation. It seems to be working for now. Especially in 2.0, there is no implementation that can be used as a reference, so I have implemented it independently.
33
+
34
+ *Automatic1111's Web UI seems to divide the text with commas in mind, but in my case, it's a simple division.
35
+
36
+ ## Environmental arrangement
37
+
38
+ See the [README](./README-en.md) in this repository.
39
+
40
+ ## Preparing teacher data
41
+
42
+ Prepare the image data you want to learn and put it in any folder. No prior preparation such as resizing is required.
43
+ However, for images that are smaller than the training resolution, it is recommended to enlarge them while maintaining the quality using super-resolution.
44
+
45
+ It also supports multiple teacher data folders. Preprocessing will be executed for each folder.
46
+
47
+ For example, store an image like this:
48
+
49
+ ![Teacher data folder screenshot](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png)
50
+
51
+ ## Automatic captioning
52
+ Skip if you just want to learn tags without captions.
53
+
54
+ Also, when preparing captions manually, prepare them in the same directory as the teacher data image, with the same file name, extension .caption, etc. Each file should be a text file with only one line.
55
+
56
+ ### Captioning with BLIP
57
+
58
+ The latest version no longer requires BLIP downloads, weight downloads, and additional virtual environments. Works as-is.
59
+
60
+ Run make_captions.py in the finetune folder.
61
+
62
+ ```
63
+ python finetune\make_captions.py --batch_size <batch size> <teacher data folder>
64
+ ```
65
+
66
+ If the batch size is 8 and the training data is placed in the parent folder train_data, it will be as follows.
67
+
68
+ ```
69
+ python finetune\make_captions.py --batch_size 8 ..\train_data
70
+ ```
71
+
72
+ A caption file is created in the same directory as the teacher data image with the same file name and extension .caption.
73
+
74
+ Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more).
75
+ You can specify the maximum length of the caption with the max_length option. Default is 75. It may be longer if the model is trained with a token length of 225.
76
+ You can change the caption extension with the caption_extension option. Default is .caption (.txt conflicts with DeepDanbooru described later).
77
+
78
+ If there are multiple teacher data folders, execute for each folder.
79
+
80
+ Note that the inference is random, so the results will change each time you run it. If you want to fix it, specify a random number seed like "--seed 42" with the --seed option.
81
+
82
+ For other options, please refer to the help with --help (there seems to be no documentation for the meaning of the parameters, so you have to look at the source).
83
+
84
+ A caption file is generated with the extension .caption by default.
85
+
86
+ ![Folder where caption is generated](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png)
87
+
88
+ For example, with captions like:
89
+
90
+ ![captions and images](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png)
91
+
92
+ ## Tagged by DeepDanbooru
93
+ If you do not want to tag the danbooru tag itself, please proceed to "Preprocessing of caption and tag information".
94
+
95
+ Tagging is done with DeepDanbooru or WD14Tagger. WD14Tagger seems to be more accurate. If you want to tag with WD14Tagger, skip to the next chapter.
96
+
97
+ ### Environmental arrangement
98
+ Clone DeepDanbooru https://github.com/KichangKim/DeepDanbooru into your working folder, or download the zip and extract it. I unzipped it.
99
+ Also, download deepdanbooru-v3-20211112-sgd-e28.zip from Assets of "DeepDanbooru Pretrained Model v3-20211112-sgd-e28" on the DeepDanbooru Releases page https://github.com/KichangKim/DeepDanbooru/releases and extract it to the DeepDanbooru folder.
100
+
101
+ Download from below. Click to open Assets and download from there.
102
+
103
+ ![DeepDanbooru download page](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png)
104
+
105
+ Make a directory structure like this
106
+
107
+ ![DeepDanbooru directory structure](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png)
108
+
109
+ Install the necessary libraries for the Diffusers environment. Go to the DeepDanbooru folder and install it (I think it's actually just adding tensorflow-io).
110
+
111
+ ```
112
+ pip install -r requirements.txt
113
+ ```
114
+
115
+ Next, install DeepDanbooru itself.
116
+
117
+ ```
118
+ pip install .
119
+ ```
120
+
121
+ This completes the preparation of the environment for tagging.
122
+
123
+ ### Implementing tagging
124
+ Go to DeepDanbooru's folder and run deepdanbooru to tag.
125
+
126
+ ```
127
+ deepdanbooru evaluate <teacher data folder> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
128
+ ```
129
+
130
+ If you put the training data in the parent folder train_data, it will be as follows.
131
+
132
+ ```
133
+ deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
134
+ ```
135
+
136
+ A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. It is slow because it is processed one by one.
137
+
138
+ If there are multiple teacher data folders, execute for each folder.
139
+
140
+ It is generated as follows.
141
+
142
+ ![DeepDanbooru generated files](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png)
143
+
144
+ A tag is attached like this (great amount of information...).
145
+
146
+ ![Deep Danbooru tag and image](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png)
147
+
148
+ ## Tagging with WD14Tagger
149
+ This procedure uses WD14Tagger instead of DeepDanbooru.
150
+
151
+ Use the tagger used in Mr. Automatic1111's WebUI. I referred to the information on this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
152
+
153
+ The modules required for the initial environment maintenance have already been installed. Weights are automatically downloaded from Hugging Face.
154
+
155
+ ### Implementing tagging
156
+ Run the script to do the tagging.
157
+ ```
158
+ python tag_images_by_wd14_tagger.py --batch_size <batch size> <teacher data folder>
159
+ ```
160
+
161
+ If you put the training data in the parent folder train_data, it will be as follows.
162
+ ```
163
+ python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
164
+ ```
165
+
166
+ The model file will be automatically downloaded to the wd14_tagger_model folder on first launch (folder can be changed in options). It will be as follows.
167
+
168
+ ![downloaded file](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png)
169
+
170
+ A tag file is created in the same directory as the teacher data image with the same file name and extension .txt.
171
+
172
+ ![generated tag file](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
173
+
174
+ ![tags and images](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
175
+
176
+ With the thresh option, you can specify the number of confidences of the determined tag to attach the tag. The default is 0.35, same as the WD14Tagger sample. Lower values give more tags, but less accuracy.
177
+ Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). You can change the tag file extension with the caption_extension option. Default is .txt.
178
+ You can specify the folder where the model is saved with the model_dir option.
179
+ Also, if you specify the force_download option, the model will be re-downloaded even if there is a save destination folder.
180
+
181
+ If there are multiple teacher data folders, execute for each folder.
182
+
183
+ ## Preprocessing caption and tag information
184
+
185
+ Combine captions and tags into a single file as metadata for easy processing from scripts.
186
+
187
+ ### Caption preprocessing
188
+
189
+ To put captions into the metadata, run the following in your working folder (if you don't use captions for learning, you don't need to run this) (it's actually a single line, and so on).
190
+
191
+ ```
192
+ python merge_captions_to_metadata.py <teacher data folder>
193
+ --in_json <metadata file name to read>
194
+ <metadata file name>
195
+ ```
196
+
197
+ The metadata file name is an arbitrary name.
198
+ If the training data is train_data, there is no metadata file to read, and the metadata file is meta_cap.json, it will be as follows.
199
+
200
+ ```
201
+ python merge_captions_to_metadata.py train_data meta_cap.json
202
+ ```
203
+
204
+ You can specify the caption extension with the caption_extension option.
205
+
206
+ If there are multiple teacher data folders, please specify the full_path argument (metadata will have full path information). Then run it for each folder.
207
+
208
+ ```
209
+ python merge_captions_to_metadata.py --full_path
210
+ train_data1 meta_cap1.json
211
+ python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
212
+ train_data2 meta_cap2.json
213
+ ```
214
+
215
+ If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there.
216
+
217
+ __*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __
218
+
219
+ ### Tag preprocessing
220
+
221
+ Similarly, tags are also collected in metadata (no need to do this if tags are not used for learning).
222
+ ```
223
+ python merge_dd_tags_to_metadata.py <teacher data folder>
224
+ --in_json <metadata file name to load>
225
+ <metadata file name to write>
226
+ ```
227
+
228
+ With the same directory structure as above, when reading meta_cap.json and writing to meta_cap_dd.json, it will be as follows.
229
+ ```
230
+ python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json
231
+ ```
232
+
233
+ If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder.
234
+
235
+ ```
236
+ python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
237
+ train_data1 meta_cap_dd1.json
238
+ python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
239
+ train_data2 meta_cap_dd2.json
240
+ ```
241
+
242
+ If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there.
243
+
244
+ __*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __
245
+
246
+ ### Cleaning captions and tags
247
+ Up to this point, captions and DeepDanbooru tags have been put together in the metadata file. However, captions with automatic captioning are subtle due to spelling variations (*), and tags include underscores and ratings (in the case of DeepDanbooru), so the editor's replacement function etc. You should use it to clean your captions and tags.
248
+
249
+ *For example, when learning a girl in an anime picture, there are variations in captions such as girl/girls/woman/women. Also, it may be more appropriate to simply use "girl" for things like "anime girl".
250
+
251
+ A script for cleaning is provided, so please edit the contents of the script according to the situation and use it.
252
+
253
+ (It is no longer necessary to specify the teacher data folder. All data in the metadata will be cleaned.)
254
+
255
+ ```
256
+ python clean_captions_and_tags.py <metadata file name to read> <metadata file name to write>
257
+ ```
258
+
259
+ Please note that --in_json is not included. For example:
260
+
261
+ ```
262
+ python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
263
+ ```
264
+
265
+ Preprocessing of captions and tags is now complete.
266
+
267
+ ## Get latents in advance
268
+
269
+ In order to speed up the learning, we acquire the latent representation of the image in advance and save it to disk. At the same time, bucketing (classifying the training data according to the aspect ratio) is performed.
270
+
271
+ In your working folder, type:
272
+ ```
273
+ python prepare_buckets_latents.py <teacher data folder>
274
+ <metadata file name to read> <metadata file name to write>
275
+ <model name or checkpoint for fine tuning>
276
+ --batch_size <batch size>
277
+ --max_resolution <resolution width, height>
278
+ --mixed_precision <precision>
279
+ ```
280
+
281
+ If the model is model.ckpt, batch size 4, training resolution is 512\*512, precision is no (float32), read metadata from meta_clean.json and write to meta_lat.json:
282
+
283
+ ```
284
+ python prepare_buckets_latents.py
285
+ train_data meta_clean.json meta_lat.json model.ckpt
286
+ --batch_size 4 --max_resolution 512,512 --mixed_precision no
287
+ ```
288
+
289
+ Latents are saved in numpy npz format in the teacher data folder.
290
+
291
+ Specify the --v2 option when loading a Stable Diffusion 2.0 model (--v_parameterization is not required).
292
+
293
+ You can specify the minimum resolution size with the --min_bucket_reso option and the maximum size with the --max_bucket_reso option. The defaults are 256 and 1024 respectively. For example, specifying a minimum size of 384 will not use resolutions such as 256\*1024 or 320\*768.
294
+ If you increase the resolution to something like 768\*768, you should specify something like 1280 for the maximum size.
295
+
296
+ If you specify the --flip_aug option, it will perform horizontal flip augmentation (data augmentation). You can artificially double the amount of data, but if you specify it when the data is not left-right symmetrical (for example, character appearance, hairstyle, etc.), learning will not go well.
297
+ (This is a simple implementation that acquires the latents for the flipped image and saves the \*\_flip.npz file. No options are required for fline_tune.py. If there is a file with \_flip, Randomly load a file without
298
+
299
+ The batch size may be increased a little more even with 12GB of VRAM.
300
+ The resolution is a number divisible by 64, and is specified by "width, height". The resolution is directly linked to the memory size during fine tuning. 512,512 seems to be the limit with VRAM 12GB (*). 16GB may be raised to 512,704 or 512,768. Even with 256, 256, etc., it seems to be difficult with 8GB of VRAM (because parameters and optimizers require a certain amount of memory regardless of resolution).
301
+
302
+ *There was also a report that learning batch size 1 worked with 12GB VRAM and 640,640.
303
+
304
+ The result of bucketing is displayed as follows.
305
+
306
+ ![bucketing result](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png)
307
+
308
+ If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder.
309
+ ```
310
+ python prepare_buckets_latents.py --full_path
311
+ train_data1 meta_clean.json meta_lat1.json model.ckpt
312
+ --batch_size 4 --max_resolution 512,512 --mixed_precision no
313
+
314
+ python prepare_buckets_latents.py --full_path
315
+ train_data2 meta_lat1.json meta_lat2.json model.ckpt
316
+ --batch_size 4 --max_resolution 512,512 --mixed_precision no
317
+
318
+ ```
319
+ It is possible to make the read source and write destination the same, but separate is safer.
320
+
321
+ __*It is safe to rewrite the argument each time and write it to a separate metadata file. __
322
+
323
+
324
+ ## Run training
325
+ For example: Below are the settings for saving memory.
326
+ ```
327
+ accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
328
+ --pretrained_model_name_or_path=model.ckpt
329
+ --in_json meta_lat.json
330
+ --train_data_dir=train_data
331
+ --output_dir=fine_tuned
332
+ --shuffle_caption
333
+ --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
334
+ --use_8bit_adam --xformers --gradient_checkpointing
335
+ --mixed_precision=bf16
336
+ --save_every_n_epochs=4
337
+ ```
338
+
339
+ It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process of accelerate.
340
+
341
+ Specify the model to be trained in pretrained_model_name_or_path (Stable Diffusion checkpoint or Diffusers model). Stable Diffusion checkpoint supports .ckpt and .safetensors (automatically determined by extension).
342
+
343
+ Specifies the metadata file when caching latent to in_json.
344
+
345
+ Specify the training data folder for train_data_dir and the output destination folder for the trained model for output_dir.
346
+
347
+ If shuffle_caption is specified, captions and tags are shuffled and learned in units separated by commas (this is the method used in Waifu Diffusion v1.3).
348
+ (You can keep some of the leading tokens fixed without shuffling. See keep_tokens for other options.)
349
+
350
+ Specify the batch size in train_batch_size. Specify 1 or 2 for VRAM 12GB. The number that can be specified also changes depending on the resolution.
351
+ The actual amount of data used for training is "batch size x number of steps". When increasing the batch size, the number of steps can be decreased accordingly.
352
+
353
+ Specify the learning rate in learning_rate. For example Waifu Diffusion v1.3 seems to be 5e-6.
354
+ Specify the number of steps in max_train_steps.
355
+
356
+ Specify use_8bit_adam to use the 8-bit Adam Optimizer. It saves memory and speeds up, but accuracy may decrease.
357
+
358
+ Specifying xformers replaces CrossAttention to save memory and speed up.
359
+ * As of 11/9, xformers will cause an error in float32 learning, so please use bf16/fp16 or use memory-saving CrossAttention with mem_eff_attn instead (speed is inferior to xformers).
360
+
361
+ Enable intermediate saving of gradients in gradient_checkpointing. It's slower, but uses less memory.
362
+
363
+ Specifies whether to use mixed precision with mixed_precision. Specifying "fp16" or "bf16" saves memory, but accuracy is inferior.
364
+ "fp16" and "bf16" use almost the same amount of memory, and it is said that bf16 has better learning results (I didn't feel much difference in the range I tried).
365
+ If "no" is specified, it will not be used (it will be float32).
366
+
367
+ * It seems that an error will occur when reading checkpoints learned with bf16 with Mr. AUTOMATIC1111's Web UI. This seems to be because the data type bfloat16 causes an error in the Web UI model safety checker. Save in fp16 or float32 format with the save_precision option. Or it seems to be good to store it in safetytensors format.
368
+
369
+ Specifying save_every_n_epochs will save the model being trained every time that many epochs have passed.
370
+
371
+ ### Supports Stable Diffusion 2.0
372
+ Specify the --v2 option when using Hugging Face's stable-diffusion-2-base, and specify both --v2 and --v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt please.
373
+
374
+ ### Increase accuracy and speed when memory is available
375
+ First, removing gradient_checkpointing will speed it up. However, the batch size that can be set is reduced, so please set while looking at the balance between accuracy and speed.
376
+
377
+ Increasing the batch size increases speed and accuracy. Increase the speed while checking the speed per data within the range where the memory is sufficient (the speed may actually decrease when the memory is at the limit).
378
+
379
+ ### Change CLIP output used
380
+ Specifying 2 for the clip_skip option uses the output of the next-to-last layer. If 1 or option is omitted, the last layer is used.
381
+ The learned model should be able to be inferred by Automatic1111's web UI.
382
+
383
+ *SD2.0 uses the second layer from the back by default, so please do not specify it when learning SD2.0.
384
+
385
+ If the model being trained was originally trained to use the second layer, 2 is a good value.
386
+
387
+ If you were using the last layer instead, the entire model would have been trained on that assumption. Therefore, if you train again using the second layer, you may need a certain number of teacher data and longer learning to obtain the desired learning result.
388
+
389
+ ### Extending Token Length
390
+ You can learn by extending the token length by specifying 150 or 225 for max_token_length.
391
+ The learned model should be able to be inferred by Automatic1111's web UI.
392
+
393
+ As with clip_skip, learning with a length different from the learning state of the model may require a certain amount of teacher data and a longer learning time.
394
+
395
+ ### Save learning log
396
+ Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved.
397
+
398
+ For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder.
399
+ Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=fine_tune_style1" for identification.
400
+
401
+ To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard).
402
+ ```
403
+ tensorboard --logdir=logs
404
+ ```
405
+
406
+ ### Learning Hypernetworks
407
+ It will be explained in another article.
408
+
409
+ ### Learning with fp16 gradient (experimental feature)
410
+ The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). As a result, it seems that the SD1.x 512*512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512*512 size can be learned with a VRAM usage of less than 12GB.
411
+
412
+ Specify fp16 in advance in accelerate config and optionally set mixed_precision="fp16" (does not work with bf16).
413
+
414
+ To minimize memory usage, use the xformers, use_8bit_adam, gradient_checkpointing options and set train_batch_size to 1.
415
+ (If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.)
416
+
417
+ It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). The accuracy will drop considerably, and the probability of learning failure on the way will also increase. The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk.
418
+
419
+ ### Other Options
420
+
421
+ #### keep_tokens
422
+ If a number is specified, the specified number of tokens (comma-separated strings) from the beginning of the caption are fixed without being shuffled.
423
+
424
+ If there are both captions and tags, the prompts during learning will be concatenated like "caption, tag 1, tag 2...", so if you set "--keep_tokens=1", the caption will always be at the beginning during learning. will come.
425
+
426
+ #### dataset_repeats
427
+ If the number of data sets is extremely small, the epoch will end soon (it will take some time at the epoch break), so please specify a numerical value and multiply the data by some to make the epoch longer.
428
+
429
+ #### train_text_encoder
430
+ Text Encoder is also a learning target. Slightly increased memory usage.
431
+
432
+ In normal fine tuning, the Text Encoder is not targeted for training (probably because U-Net is trained to follow the output of the Text Encoder), but if the number of training data is small, the Text Encoder is trained like DreamBooth. also seems to be valid.
433
+
434
+ #### save_precision
435
+ The data format when saving checkpoints can be specified from float, fp16, and bf16 (if not specified, it is the same as the data format during learning). It saves disk space, but the model produces different results. Also, if you specify float or fp16, you should be able to read it on Mr. 1111's Web UI.
436
+
437
+ *For VAE, the data format of the original checkpoint will remain, so the model size may not be reduced to a little over 2GB even with fp16.
438
+
439
+ #### save_model_as
440
+ Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors.
441
+
442
+ When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face.
443
+
444
+ #### use_safetensors
445
+ This option saves checkpoints in safetyensors format. The save format will be the default (same format as loaded).
446
+
447
+ #### save_state and resume
448
+ The save_state option saves the learning state of the optimizer, etc. in addition to the checkpoint in the folder when saving midway and at the final save. This avoids a decrease in accuracy when learning is resumed after being interrupted (since the optimizer optimizes while having a state, if the state is reset, the optimization must be performed again from the initial state. not). Note that the number of steps is not saved due to Accelerate specifications.
449
+
450
+ When starting the script, you can resume by specifying the folder where the state is saved with the resume option.
451
+
452
+ Please note that the learning state will be about 5 GB per save, so please be careful of the disk capacity.
453
+
454
+ #### gradient_accumulation_steps
455
+ Updates the gradient in batches for the specified number of steps. Has a similar effect to increasing the batch size, but consumes slightly more memory.
456
+
457
+ *The Accelerate specification does not support multiple learning models, so if you set Text Encoder as the learning target and specify a value of 2 or more for this option, an error may occur.
458
+
459
+ #### lr_scheduler / lr_warmup_steps
460
+ You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant.
461
+
462
+ With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details.
463
+
464
+ #### diffusers_xformers
465
+ Uses Diffusers' xformers feature rather than the script's own xformers replacement feature. Hypernetwork learning is no longer possible.
fine_tune_README_ja.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません)
2
+
3
+ [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
4
+
5
+ # 概要
6
+
7
+ Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。
8
+
9
+ * CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。
10
+ * 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。
11
+ * トークン長を75から225に拡張する。
12
+ * BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。
13
+ * Hypernetworkの学習にも対応する。
14
+ * Stable Diffusion v2.0(baseおよび768/v)に対応。
15
+ * VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。
16
+
17
+ デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。
18
+
19
+ # 追加機能について
20
+
21
+ ## CLIPの出力の変更
22
+
23
+ プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。
24
+ 元のまま、最後の層の出力を用いることも可能です。
25
+
26
+ ※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。
27
+
28
+ ## 正方形以外の解像度での学習
29
+
30
+ Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。
31
+ 学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。
32
+
33
+ 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
34
+
35
+ ## トークン長の75から225への拡張
36
+
37
+ Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。
38
+ ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。
39
+
40
+ ※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。
41
+
42
+ ※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。
43
+
44
+ # 学習の手順
45
+
46
+ あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
47
+
48
+ ## データの準備
49
+
50
+ [学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。
51
+
52
+ ## 学習の実行
53
+ たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。
54
+
55
+ ```
56
+ accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
57
+ --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
58
+ --output_dir=<学習したモデルの出力先フォルダ>
59
+ --output_name=<学習したモデル出力時のファイル名>
60
+ --dataset_config=<データ準備で作成した.tomlファイル>
61
+ --save_model_as=safetensors
62
+ --learning_rate=5e-6 --max_train_steps=10000
63
+ --use_8bit_adam --xformers --gradient_checkpointing
64
+ --mixed_precision=fp16
65
+ ```
66
+
67
+ `num_cpu_threads_per_process` には通常は1を指定するとよいようです。
68
+
69
+ `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
70
+
71
+ `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
72
+
73
+ `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
74
+
75
+ 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
76
+
77
+ 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
78
+
79
+ オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
80
+
81
+ `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
82
+
83
+ ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。
84
+
85
+ ### よく使われるオプションについて
86
+
87
+ 以下の場合にはオプションに関するドキュメントを参照してください。
88
+
89
+ - Stable Diffusion 2.xまたはそこからの派生モデルを学習する
90
+ - clip skipを2以上を前提としたモデルを学習する
91
+ - 75トークンを超えたキャプションで学習する
92
+
93
+ ### バッチサイズについて
94
+
95
+ モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。
96
+
97
+ ### 学習率について
98
+
99
+ 1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。
100
+
101
+ ### 以前の形式のデータセット指定をした場合のコマンドライン
102
+
103
+ 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
104
+
105
+ ```
106
+ accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
107
+ --pretrained_model_name_or_path=model.ckpt
108
+ --in_json meta_lat.json
109
+ --train_data_dir=train_data
110
+ --output_dir=fine_tuned
111
+ --shuffle_caption
112
+ --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
113
+ --use_8bit_adam --xformers --gradient_checkpointing
114
+ --mixed_precision=bf16
115
+ --save_every_n_epochs=4
116
+ ```
117
+
118
+ <!--
119
+ ### 勾配をfp16とした学習(実験的機能)
120
+ full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。
121
+
122
+ あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。
123
+
124
+ メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
125
+ (余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
126
+
127
+ PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
128
+ -->
129
+
130
+ # fine tuning特有のその他の主なオプション
131
+
132
+ すべてのオプションについては別文書を参照してください。
133
+
134
+ ## `train_text_encoder`
135
+ Text Encoderも学習対象とします。メモリ使用量が若干増加します。
136
+
137
+ 通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。
138
+
139
+ ## `diffusers_xformers`
140
+ スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。
finetune/blip/blip.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # from models.vit import VisionTransformer, interpolate_pos_embed
12
+ # from models.med import BertConfig, BertModel, BertLMHeadModel
13
+ from blip.vit import VisionTransformer, interpolate_pos_embed
14
+ from blip.med import BertConfig, BertModel, BertLMHeadModel
15
+ from transformers import BertTokenizer
16
+
17
+ import torch
18
+ from torch import nn
19
+ import torch.nn.functional as F
20
+
21
+ import os
22
+ from urllib.parse import urlparse
23
+ from timm.models.hub import download_cached_file
24
+
25
+ class BLIP_Base(nn.Module):
26
+ def __init__(self,
27
+ med_config = 'configs/med_config.json',
28
+ image_size = 224,
29
+ vit = 'base',
30
+ vit_grad_ckpt = False,
31
+ vit_ckpt_layer = 0,
32
+ ):
33
+ """
34
+ Args:
35
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
36
+ image_size (int): input image size
37
+ vit (str): model size of vision transformer
38
+ """
39
+ super().__init__()
40
+
41
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
42
+ self.tokenizer = init_tokenizer()
43
+ med_config = BertConfig.from_json_file(med_config)
44
+ med_config.encoder_width = vision_width
45
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
46
+
47
+
48
+ def forward(self, image, caption, mode):
49
+
50
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
51
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
52
+
53
+ if mode=='image':
54
+ # return image features
55
+ image_embeds = self.visual_encoder(image)
56
+ return image_embeds
57
+
58
+ elif mode=='text':
59
+ # return text features
60
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
61
+ return_dict = True, mode = 'text')
62
+ return text_output.last_hidden_state
63
+
64
+ elif mode=='multimodal':
65
+ # return multimodel features
66
+ image_embeds = self.visual_encoder(image)
67
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
68
+
69
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
70
+ output = self.text_encoder(text.input_ids,
71
+ attention_mask = text.attention_mask,
72
+ encoder_hidden_states = image_embeds,
73
+ encoder_attention_mask = image_atts,
74
+ return_dict = True,
75
+ )
76
+ return output.last_hidden_state
77
+
78
+
79
+
80
+ class BLIP_Decoder(nn.Module):
81
+ def __init__(self,
82
+ med_config = 'configs/med_config.json',
83
+ image_size = 384,
84
+ vit = 'base',
85
+ vit_grad_ckpt = False,
86
+ vit_ckpt_layer = 0,
87
+ prompt = 'a picture of ',
88
+ ):
89
+ """
90
+ Args:
91
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
92
+ image_size (int): input image size
93
+ vit (str): model size of vision transformer
94
+ """
95
+ super().__init__()
96
+
97
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
98
+ self.tokenizer = init_tokenizer()
99
+ med_config = BertConfig.from_json_file(med_config)
100
+ med_config.encoder_width = vision_width
101
+ self.text_decoder = BertLMHeadModel(config=med_config)
102
+
103
+ self.prompt = prompt
104
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
105
+
106
+
107
+ def forward(self, image, caption):
108
+
109
+ image_embeds = self.visual_encoder(image)
110
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
111
+
112
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
113
+
114
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
115
+
116
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
117
+ decoder_targets[:,:self.prompt_length] = -100
118
+
119
+ decoder_output = self.text_decoder(text.input_ids,
120
+ attention_mask = text.attention_mask,
121
+ encoder_hidden_states = image_embeds,
122
+ encoder_attention_mask = image_atts,
123
+ labels = decoder_targets,
124
+ return_dict = True,
125
+ )
126
+ loss_lm = decoder_output.loss
127
+
128
+ return loss_lm
129
+
130
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
131
+ image_embeds = self.visual_encoder(image)
132
+
133
+ if not sample:
134
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
135
+
136
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
137
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
138
+
139
+ prompt = [self.prompt] * image.size(0)
140
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
141
+ input_ids[:,0] = self.tokenizer.bos_token_id
142
+ input_ids = input_ids[:, :-1]
143
+
144
+ if sample:
145
+ #nucleus sampling
146
+ outputs = self.text_decoder.generate(input_ids=input_ids,
147
+ max_length=max_length,
148
+ min_length=min_length,
149
+ do_sample=True,
150
+ top_p=top_p,
151
+ num_return_sequences=1,
152
+ eos_token_id=self.tokenizer.sep_token_id,
153
+ pad_token_id=self.tokenizer.pad_token_id,
154
+ repetition_penalty=1.1,
155
+ **model_kwargs)
156
+ else:
157
+ #beam search
158
+ outputs = self.text_decoder.generate(input_ids=input_ids,
159
+ max_length=max_length,
160
+ min_length=min_length,
161
+ num_beams=num_beams,
162
+ eos_token_id=self.tokenizer.sep_token_id,
163
+ pad_token_id=self.tokenizer.pad_token_id,
164
+ repetition_penalty=repetition_penalty,
165
+ **model_kwargs)
166
+
167
+ captions = []
168
+ for output in outputs:
169
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
170
+ captions.append(caption[len(self.prompt):])
171
+ return captions
172
+
173
+
174
+ def blip_decoder(pretrained='',**kwargs):
175
+ model = BLIP_Decoder(**kwargs)
176
+ if pretrained:
177
+ model,msg = load_checkpoint(model,pretrained)
178
+ assert(len(msg.missing_keys)==0)
179
+ return model
180
+
181
+ def blip_feature_extractor(pretrained='',**kwargs):
182
+ model = BLIP_Base(**kwargs)
183
+ if pretrained:
184
+ model,msg = load_checkpoint(model,pretrained)
185
+ assert(len(msg.missing_keys)==0)
186
+ return model
187
+
188
+ def init_tokenizer():
189
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
190
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
191
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
192
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
193
+ return tokenizer
194
+
195
+
196
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
197
+
198
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
199
+ if vit=='base':
200
+ vision_width = 768
201
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
202
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
203
+ drop_path_rate=0 or drop_path_rate
204
+ )
205
+ elif vit=='large':
206
+ vision_width = 1024
207
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
208
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
209
+ drop_path_rate=0.1 or drop_path_rate
210
+ )
211
+ return visual_encoder, vision_width
212
+
213
+ def is_url(url_or_filename):
214
+ parsed = urlparse(url_or_filename)
215
+ return parsed.scheme in ("http", "https")
216
+
217
+ def load_checkpoint(model,url_or_filename):
218
+ if is_url(url_or_filename):
219
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
220
+ checkpoint = torch.load(cached_file, map_location='cpu')
221
+ elif os.path.isfile(url_or_filename):
222
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
223
+ else:
224
+ raise RuntimeError('checkpoint url or path is invalid')
225
+
226
+ state_dict = checkpoint['model']
227
+
228
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
229
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
230
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
231
+ model.visual_encoder_m)
232
+ for key in model.state_dict().keys():
233
+ if key in state_dict.keys():
234
+ if state_dict[key].shape!=model.state_dict()[key].shape:
235
+ del state_dict[key]
236
+
237
+ msg = model.load_state_dict(state_dict,strict=False)
238
+ print('load checkpoint from %s'%url_or_filename)
239
+ return model,msg
240
+
finetune/blip/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
finetune/blip/med_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
22
+
finetune/blip/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
finetune/clean_captions_and_tags.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、Apache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import json
8
+ import re
9
+
10
+ from tqdm import tqdm
11
+
12
+ PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
13
+ PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
14
+ PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
15
+ PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
16
+
17
+ # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
18
+ PATTERNS_REMOVE_IN_MULTI = [
19
+ PATTERN_HAIR_LENGTH,
20
+ PATTERN_HAIR_CUT,
21
+ re.compile(r', [\w\-]+ eyes, '),
22
+ re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
23
+ # 複数の髪型定義がある場合は削除する
24
+ re.compile(
25
+ r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
26
+ ]
27
+
28
+
29
+ def clean_tags(image_key, tags):
30
+ # replace '_' to ' '
31
+ tags = tags.replace('^_^', '^@@@^')
32
+ tags = tags.replace('_', ' ')
33
+ tags = tags.replace('^@@@^', '^_^')
34
+
35
+ # remove rating: deepdanbooruのみ
36
+ tokens = tags.split(", rating")
37
+ if len(tokens) == 1:
38
+ # WD14 taggerのときはこちらになるのでメッセージは出さない
39
+ # print("no rating:")
40
+ # print(f"{image_key} {tags}")
41
+ pass
42
+ else:
43
+ if len(tokens) > 2:
44
+ print("multiple ratings:")
45
+ print(f"{image_key} {tags}")
46
+ tags = tokens[0]
47
+
48
+ tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
49
+
50
+ # 複数の人物がいる場合は髪色等のタグを削除する
51
+ if 'girls' in tags or 'boys' in tags:
52
+ for pat in PATTERNS_REMOVE_IN_MULTI:
53
+ found = pat.findall(tags)
54
+ if len(found) > 1: # 二つ以上、タグがある
55
+ tags = pat.sub("", tags)
56
+
57
+ # 髪の特殊対応
58
+ srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
59
+ if srch_hair_len:
60
+ org = srch_hair_len.group()
61
+ tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
62
+
63
+ found = PATTERN_HAIR.findall(tags)
64
+ if len(found) > 1:
65
+ tags = PATTERN_HAIR.sub("", tags)
66
+
67
+ if srch_hair_len:
68
+ tags = tags.replace(", @@@, ", org) # 戻す
69
+
70
+ # white shirtとshirtみたいな重複タグの削除
71
+ found = PATTERN_WORD.findall(tags)
72
+ for word in found:
73
+ if re.search(f", ((\w+) )+{word}, ", tags):
74
+ tags = tags.replace(f", {word}, ", "")
75
+
76
+ tags = tags.replace(", , ", ", ")
77
+ assert tags.startswith(", ") and tags.endswith(", ")
78
+ tags = tags[2:-2]
79
+ return tags
80
+
81
+
82
+ # 上から順に検索、置換される
83
+ # ('置換元文字列', '置換後文字列')
84
+ CAPTION_REPLACEMENTS = [
85
+ ('anime anime', 'anime'),
86
+ ('young ', ''),
87
+ ('anime girl', 'girl'),
88
+ ('cartoon female', 'girl'),
89
+ ('cartoon lady', 'girl'),
90
+ ('cartoon character', 'girl'), # a or ~s
91
+ ('cartoon woman', 'girl'),
92
+ ('cartoon women', 'girls'),
93
+ ('cartoon girl', 'girl'),
94
+ ('anime female', 'girl'),
95
+ ('anime lady', 'girl'),
96
+ ('anime character', 'girl'), # a or ~s
97
+ ('anime woman', 'girl'),
98
+ ('anime women', 'girls'),
99
+ ('lady', 'girl'),
100
+ ('female', 'girl'),
101
+ ('woman', 'girl'),
102
+ ('women', 'girls'),
103
+ ('people', 'girls'),
104
+ ('person', 'girl'),
105
+ ('a cartoon figure', 'a figure'),
106
+ ('a cartoon image', 'an image'),
107
+ ('a cartoon picture', 'a picture'),
108
+ ('an anime cartoon image', 'an image'),
109
+ ('a cartoon anime drawing', 'a drawing'),
110
+ ('a cartoon drawing', 'a drawing'),
111
+ ('girl girl', 'girl'),
112
+ ]
113
+
114
+
115
+ def clean_caption(caption):
116
+ for rf, rt in CAPTION_REPLACEMENTS:
117
+ replaced = True
118
+ while replaced:
119
+ bef = caption
120
+ caption = caption.replace(rf, rt)
121
+ replaced = bef != caption
122
+ return caption
123
+
124
+
125
+ def main(args):
126
+ if os.path.exists(args.in_json):
127
+ print(f"loading existing metadata: {args.in_json}")
128
+ with open(args.in_json, "rt", encoding='utf-8') as f:
129
+ metadata = json.load(f)
130
+ else:
131
+ print("no metadata / メタデータファイルがありません")
132
+ return
133
+
134
+ print("cleaning captions and tags.")
135
+ image_keys = list(metadata.keys())
136
+ for image_key in tqdm(image_keys):
137
+ tags = metadata[image_key].get('tags')
138
+ if tags is None:
139
+ print(f"image does not have tags / メタデータにタグがありません: {image_key}")
140
+ else:
141
+ org = tags
142
+ tags = clean_tags(image_key, tags)
143
+ metadata[image_key]['tags'] = tags
144
+ if args.debug and org != tags:
145
+ print("FROM: " + org)
146
+ print("TO: " + tags)
147
+
148
+ caption = metadata[image_key].get('caption')
149
+ if caption is None:
150
+ print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
151
+ else:
152
+ org = caption
153
+ caption = clean_caption(caption)
154
+ metadata[image_key]['caption'] = caption
155
+ if args.debug and org != caption:
156
+ print("FROM: " + org)
157
+ print("TO: " + caption)
158
+
159
+ # metadataを書き出して終わり
160
+ print(f"writing metadata: {args.out_json}")
161
+ with open(args.out_json, "wt", encoding='utf-8') as f:
162
+ json.dump(metadata, f, indent=2)
163
+ print("done!")
164
+
165
+
166
+ def setup_parser() -> argparse.ArgumentParser:
167
+ parser = argparse.ArgumentParser()
168
+ # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
169
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
170
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
171
+ parser.add_argument("--debug", action="store_true", help="debug mode")
172
+
173
+ return parser
174
+
175
+
176
+ if __name__ == '__main__':
177
+ parser = setup_parser()
178
+
179
+ args, unknown = parser.parse_known_args()
180
+ if len(unknown) == 1:
181
+ print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
182
+ print("All captions and tags in the metadata are processed.")
183
+ print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
184
+ print("メタデータ内のすべてのキャプションとタグが処理されます。")
185
+ args.in_json = args.out_json
186
+ args.out_json = unknown[0]
187
+ elif len(unknown) > 0:
188
+ raise ValueError(f"error: unrecognized arguments: {unknown}")
189
+
190
+ main(args)
finetune/hypernetwork_nai.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NAI compatible
2
+
3
+ import torch
4
+
5
+
6
+ class HypernetworkModule(torch.nn.Module):
7
+ def __init__(self, dim, multiplier=1.0):
8
+ super().__init__()
9
+
10
+ linear1 = torch.nn.Linear(dim, dim * 2)
11
+ linear2 = torch.nn.Linear(dim * 2, dim)
12
+ linear1.weight.data.normal_(mean=0.0, std=0.01)
13
+ linear1.bias.data.zero_()
14
+ linear2.weight.data.normal_(mean=0.0, std=0.01)
15
+ linear2.bias.data.zero_()
16
+ linears = [linear1, linear2]
17
+
18
+ self.linear = torch.nn.Sequential(*linears)
19
+ self.multiplier = multiplier
20
+
21
+ def forward(self, x):
22
+ return x + self.linear(x) * self.multiplier
23
+
24
+
25
+ class Hypernetwork(torch.nn.Module):
26
+ enable_sizes = [320, 640, 768, 1280]
27
+ # return self.modules[Hypernetwork.enable_sizes.index(size)]
28
+
29
+ def __init__(self, multiplier=1.0) -> None:
30
+ super().__init__()
31
+ self.modules = []
32
+ for size in Hypernetwork.enable_sizes:
33
+ self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
34
+ self.register_module(f"{size}_0", self.modules[-1][0])
35
+ self.register_module(f"{size}_1", self.modules[-1][1])
36
+
37
+ def apply_to_stable_diffusion(self, text_encoder, vae, unet):
38
+ blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
39
+ for block in blocks:
40
+ for subblk in block:
41
+ if 'SpatialTransformer' in str(type(subblk)):
42
+ for tf_block in subblk.transformer_blocks:
43
+ for attn in [tf_block.attn1, tf_block.attn2]:
44
+ size = attn.context_dim
45
+ if size in Hypernetwork.enable_sizes:
46
+ attn.hypernetwork = self
47
+ else:
48
+ attn.hypernetwork = None
49
+
50
+ def apply_to_diffusers(self, text_encoder, vae, unet):
51
+ blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
52
+ for block in blocks:
53
+ if hasattr(block, 'attentions'):
54
+ for subblk in block.attentions:
55
+ if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
56
+ for tf_block in subblk.transformer_blocks:
57
+ for attn in [tf_block.attn1, tf_block.attn2]:
58
+ size = attn.to_k.in_features
59
+ if size in Hypernetwork.enable_sizes:
60
+ attn.hypernetwork = self
61
+ else:
62
+ attn.hypernetwork = None
63
+ return True # TODO error checking
64
+
65
+ def forward(self, x, context):
66
+ size = context.shape[-1]
67
+ assert size in Hypernetwork.enable_sizes
68
+ module = self.modules[Hypernetwork.enable_sizes.index(size)]
69
+ return module[0].forward(context), module[1].forward(context)
70
+
71
+ def load_from_state_dict(self, state_dict):
72
+ # old ver to new ver
73
+ changes = {
74
+ 'linear1.bias': 'linear.0.bias',
75
+ 'linear1.weight': 'linear.0.weight',
76
+ 'linear2.bias': 'linear.1.bias',
77
+ 'linear2.weight': 'linear.1.weight',
78
+ }
79
+ for key_from, key_to in changes.items():
80
+ if key_from in state_dict:
81
+ state_dict[key_to] = state_dict[key_from]
82
+ del state_dict[key_from]
83
+
84
+ for size, sd in state_dict.items():
85
+ if type(size) == int:
86
+ self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
87
+ self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
88
+ return True
89
+
90
+ def get_state_dict(self):
91
+ state_dict = {}
92
+ for i, size in enumerate(Hypernetwork.enable_sizes):
93
+ sd0 = self.modules[i][0].state_dict()
94
+ sd1 = self.modules[i][1].state_dict()
95
+ state_dict[size] = [sd0, sd1]
96
+ return state_dict
finetune/make_captions.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import json
5
+ import random
6
+
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from blip.blip import blip_decoder
14
+ import library.train_util as train_util
15
+
16
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+
19
+ IMAGE_SIZE = 384
20
+
21
+ # 正方形でいいのか? という気がするがソースがそうなので
22
+ IMAGE_TRANSFORM = transforms.Compose([
23
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
26
+ ])
27
+
28
+ # 共通化したいが微妙に処理が異なる……
29
+ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
30
+ def __init__(self, image_paths):
31
+ self.images = image_paths
32
+
33
+ def __len__(self):
34
+ return len(self.images)
35
+
36
+ def __getitem__(self, idx):
37
+ img_path = self.images[idx]
38
+
39
+ try:
40
+ image = Image.open(img_path).convert("RGB")
41
+ # convert to tensor temporarily so dataloader will accept it
42
+ tensor = IMAGE_TRANSFORM(image)
43
+ except Exception as e:
44
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
45
+ return None
46
+
47
+ return (tensor, img_path)
48
+
49
+
50
+ def collate_fn_remove_corrupted(batch):
51
+ """Collate function that allows to remove corrupted examples in the
52
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
53
+ The 'None's in the batch are removed.
54
+ """
55
+ # Filter out all the Nones (corrupted examples)
56
+ batch = list(filter(lambda x: x is not None, batch))
57
+ return batch
58
+
59
+
60
+ def main(args):
61
+ # fix the seed for reproducibility
62
+ seed = args.seed # + utils.get_rank()
63
+ torch.manual_seed(seed)
64
+ np.random.seed(seed)
65
+ random.seed(seed)
66
+
67
+ if not os.path.exists("blip"):
68
+ args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
69
+
70
+ cwd = os.getcwd()
71
+ print('Current Working Directory is: ', cwd)
72
+ os.chdir('finetune')
73
+
74
+ print(f"load images from {args.train_data_dir}")
75
+ image_paths = train_util.glob_images(args.train_data_dir)
76
+ print(f"found {len(image_paths)} images.")
77
+
78
+ print(f"loading BLIP caption: {args.caption_weights}")
79
+ model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
80
+ model.eval()
81
+ model = model.to(DEVICE)
82
+ print("BLIP loaded")
83
+
84
+ # captioningする
85
+ def run_batch(path_imgs):
86
+ imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
87
+
88
+ with torch.no_grad():
89
+ if args.beam_search:
90
+ captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
91
+ max_length=args.max_length, min_length=args.min_length)
92
+ else:
93
+ captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
94
+
95
+ for (image_path, _), caption in zip(path_imgs, captions):
96
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
97
+ f.write(caption + "\n")
98
+ if args.debug:
99
+ print(image_path, caption)
100
+
101
+ # 読み込みの高速化のためにDataLoaderを使うオプション
102
+ if args.max_data_loader_n_workers is not None:
103
+ dataset = ImageLoadingTransformDataset(image_paths)
104
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
105
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
106
+ else:
107
+ data = [[(None, ip)] for ip in image_paths]
108
+
109
+ b_imgs = []
110
+ for data_entry in tqdm(data, smoothing=0.0):
111
+ for data in data_entry:
112
+ if data is None:
113
+ continue
114
+
115
+ img_tensor, image_path = data
116
+ if img_tensor is None:
117
+ try:
118
+ raw_image = Image.open(image_path)
119
+ if raw_image.mode != 'RGB':
120
+ raw_image = raw_image.convert("RGB")
121
+ img_tensor = IMAGE_TRANSFORM(raw_image)
122
+ except Exception as e:
123
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
124
+ continue
125
+
126
+ b_imgs.append((image_path, img_tensor))
127
+ if len(b_imgs) >= args.batch_size:
128
+ run_batch(b_imgs)
129
+ b_imgs.clear()
130
+ if len(b_imgs) > 0:
131
+ run_batch(b_imgs)
132
+
133
+ print("done!")
134
+
135
+
136
+ def setup_parser() -> argparse.ArgumentParser:
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
139
+ parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
140
+ help="BLIP caption weights (model_large_caption.pth) / BLIP captionの��みファイル(model_large_caption.pth)")
141
+ parser.add_argument("--caption_extention", type=str, default=None,
142
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
143
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
144
+ parser.add_argument("--beam_search", action="store_true",
145
+ help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
146
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
147
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
148
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
149
+ parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
150
+ parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
151
+ parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
152
+ parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
153
+ parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
154
+ parser.add_argument("--debug", action="store_true", help="debug mode")
155
+
156
+ return parser
157
+
158
+
159
+ if __name__ == '__main__':
160
+ parser = setup_parser()
161
+
162
+ args = parser.parse_args()
163
+
164
+ # スペルミスしていたオプションを復元する
165
+ if args.caption_extention is not None:
166
+ args.caption_extension = args.caption_extention
167
+
168
+ main(args)
finetune/make_captions_by_git.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import torch
8
+ from transformers import AutoProcessor, AutoModelForCausalLM
9
+ from transformers.generation.utils import GenerationMixin
10
+
11
+ import library.train_util as train_util
12
+
13
+
14
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ PATTERN_REPLACE = [
17
+ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
18
+ re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
19
+ re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
20
+ re.compile(r'with the number \d+ on (it|\w+ \w+)'),
21
+ re.compile(r'with the words "'),
22
+ re.compile(r'word \w+ on it'),
23
+ re.compile(r'that says the word \w+ on it'),
24
+ re.compile('that says\'the word "( on it)?'),
25
+ ]
26
+
27
+ # 誤検知しまくりの with the word xxxx を消す
28
+
29
+
30
+ def remove_words(captions, debug):
31
+ removed_caps = []
32
+ for caption in captions:
33
+ cap = caption
34
+ for pat in PATTERN_REPLACE:
35
+ cap = pat.sub("", cap)
36
+ if debug and cap != caption:
37
+ print(caption)
38
+ print(cap)
39
+ removed_caps.append(cap)
40
+ return removed_caps
41
+
42
+
43
+ def collate_fn_remove_corrupted(batch):
44
+ """Collate function that allows to remove corrupted examples in the
45
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
46
+ The 'None's in the batch are removed.
47
+ """
48
+ # Filter out all the Nones (corrupted examples)
49
+ batch = list(filter(lambda x: x is not None, batch))
50
+ return batch
51
+
52
+
53
+ def main(args):
54
+ # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
55
+ org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
56
+ curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
57
+
58
+ # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
59
+ # ここより上で置き換えようとするとすごく大変
60
+ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
61
+ input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
62
+ if input_ids.size()[0] != curr_batch_size[0]:
63
+ input_ids = input_ids.repeat(curr_batch_size[0], 1)
64
+ return input_ids
65
+ GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
66
+
67
+ print(f"load images from {args.train_data_dir}")
68
+ image_paths = train_util.glob_images(args.train_data_dir)
69
+ print(f"found {len(image_paths)} images.")
70
+
71
+ # できればcacheに依存せず明示的にダウンロードしたい
72
+ print(f"loading GIT: {args.model_id}")
73
+ git_processor = AutoProcessor.from_pretrained(args.model_id)
74
+ git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
75
+ print("GIT loaded")
76
+
77
+ # captioningする
78
+ def run_batch(path_imgs):
79
+ imgs = [im for _, im in path_imgs]
80
+
81
+ curr_batch_size[0] = len(path_imgs)
82
+ inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
83
+ generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
84
+ captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
85
+
86
+ if args.remove_words:
87
+ captions = remove_words(captions, args.debug)
88
+
89
+ for (image_path, _), caption in zip(path_imgs, captions):
90
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
91
+ f.write(caption + "\n")
92
+ if args.debug:
93
+ print(image_path, caption)
94
+
95
+ # 読み込みの高速化のためにDataLoaderを使うオプション
96
+ if args.max_data_loader_n_workers is not None:
97
+ dataset = train_util.ImageLoadingDataset(image_paths)
98
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
99
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
100
+ else:
101
+ data = [[(None, ip)] for ip in image_paths]
102
+
103
+ b_imgs = []
104
+ for data_entry in tqdm(data, smoothing=0.0):
105
+ for data in data_entry:
106
+ if data is None:
107
+ continue
108
+
109
+ image, image_path = data
110
+ if image is None:
111
+ try:
112
+ image = Image.open(image_path)
113
+ if image.mode != 'RGB':
114
+ image = image.convert("RGB")
115
+ except Exception as e:
116
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
117
+ continue
118
+
119
+ b_imgs.append((image_path, image))
120
+ if len(b_imgs) >= args.batch_size:
121
+ run_batch(b_imgs)
122
+ b_imgs.clear()
123
+
124
+ if len(b_imgs) > 0:
125
+ run_batch(b_imgs)
126
+
127
+ print("done!")
128
+
129
+
130
+ def setup_parser() -> argparse.ArgumentParser:
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
133
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
134
+ parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
135
+ help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
136
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
137
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
138
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
139
+ parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
140
+ parser.add_argument("--remove_words", action="store_true",
141
+ help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
142
+ parser.add_argument("--debug", action="store_true", help="debug mode")
143
+
144
+ return parser
145
+
146
+
147
+ if __name__ == '__main__':
148
+ parser = setup_parser()
149
+
150
+ args = parser.parse_args()
151
+ main(args)
finetune/merge_captions_to_metadata.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List
5
+ from tqdm import tqdm
6
+ import library.train_util as train_util
7
+ import os
8
+
9
+ def main(args):
10
+ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
11
+
12
+ train_data_dir_path = Path(args.train_data_dir)
13
+ image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
14
+ print(f"found {len(image_paths)} images.")
15
+
16
+ if args.in_json is None and Path(args.out_json).is_file():
17
+ args.in_json = args.out_json
18
+
19
+ if args.in_json is not None:
20
+ print(f"loading existing metadata: {args.in_json}")
21
+ metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
22
+ print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
23
+ else:
24
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
25
+ metadata = {}
26
+
27
+ print("merge caption texts to metadata json.")
28
+ for image_path in tqdm(image_paths):
29
+ caption_path = image_path.with_suffix(args.caption_extension)
30
+ caption = caption_path.read_text(encoding='utf-8').strip()
31
+
32
+ if not os.path.exists(caption_path):
33
+ caption_path = os.path.join(image_path, args.caption_extension)
34
+
35
+ image_key = str(image_path) if args.full_path else image_path.stem
36
+ if image_key not in metadata:
37
+ metadata[image_key] = {}
38
+
39
+ metadata[image_key]['caption'] = caption
40
+ if args.debug:
41
+ print(image_key, caption)
42
+
43
+ # metadataを書き出して終わり
44
+ print(f"writing metadata: {args.out_json}")
45
+ Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
46
+ print("done!")
47
+
48
+
49
+ def setup_parser() -> argparse.ArgumentParser:
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
52
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
53
+ parser.add_argument("--in_json", type=str,
54
+ help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
55
+ parser.add_argument("--caption_extention", type=str, default=None,
56
+ help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
57
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
58
+ parser.add_argument("--full_path", action="store_true",
59
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
60
+ parser.add_argument("--recursive", action="store_true",
61
+ help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
62
+ parser.add_argument("--debug", action="store_true", help="debug mode")
63
+
64
+ return parser
65
+
66
+
67
+ if __name__ == '__main__':
68
+ parser = setup_parser()
69
+
70
+ args = parser.parse_args()
71
+
72
+ # スペルミスしていたオプションを復元する
73
+ if args.caption_extention is not None:
74
+ args.caption_extension = args.caption_extention
75
+
76
+ main(args)
finetune/merge_dd_tags_to_metadata.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List
5
+ from tqdm import tqdm
6
+ import library.train_util as train_util
7
+ import os
8
+
9
+ def main(args):
10
+ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
11
+
12
+ train_data_dir_path = Path(args.train_data_dir)
13
+ image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
14
+ print(f"found {len(image_paths)} images.")
15
+
16
+ if args.in_json is None and Path(args.out_json).is_file():
17
+ args.in_json = args.out_json
18
+
19
+ if args.in_json is not None:
20
+ print(f"loading existing metadata: {args.in_json}")
21
+ metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
22
+ print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
23
+ else:
24
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
25
+ metadata = {}
26
+
27
+ print("merge tags to metadata json.")
28
+ for image_path in tqdm(image_paths):
29
+ tags_path = image_path.with_suffix(args.caption_extension)
30
+ tags = tags_path.read_text(encoding='utf-8').strip()
31
+
32
+ if not os.path.exists(tags_path):
33
+ tags_path = os.path.join(image_path, args.caption_extension)
34
+
35
+ image_key = str(image_path) if args.full_path else image_path.stem
36
+ if image_key not in metadata:
37
+ metadata[image_key] = {}
38
+
39
+ metadata[image_key]['tags'] = tags
40
+ if args.debug:
41
+ print(image_key, tags)
42
+
43
+ # metadataを書き出して終わり
44
+ print(f"writing metadata: {args.out_json}")
45
+ Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
46
+
47
+ print("done!")
48
+
49
+
50
+ def setup_parser() -> argparse.ArgumentParser:
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
53
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
54
+ parser.add_argument("--in_json", type=str,
55
+ help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
56
+ parser.add_argument("--full_path", action="store_true",
57
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
58
+ parser.add_argument("--recursive", action="store_true",
59
+ help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
60
+ parser.add_argument("--caption_extension", type=str, default=".txt",
61
+ help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
62
+ parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
63
+
64
+ return parser
65
+
66
+
67
+ if __name__ == '__main__':
68
+ parser = setup_parser()
69
+
70
+ args = parser.parse_args()
71
+ main(args)
finetune/prepare_buckets_latents.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import json
4
+
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ import torch
10
+ from torchvision import transforms
11
+
12
+ import library.model_util as model_util
13
+ import library.train_util as train_util
14
+
15
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ IMAGE_TRANSFORMS = transforms.Compose(
18
+ [
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5], [0.5]),
21
+ ]
22
+ )
23
+
24
+
25
+ def collate_fn_remove_corrupted(batch):
26
+ """Collate function that allows to remove corrupted examples in the
27
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
28
+ The 'None's in the batch are removed.
29
+ """
30
+ # Filter out all the Nones (corrupted examples)
31
+ batch = list(filter(lambda x: x is not None, batch))
32
+ return batch
33
+
34
+
35
+ def get_latents(vae, images, weight_dtype):
36
+ img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
37
+ img_tensors = torch.stack(img_tensors)
38
+ img_tensors = img_tensors.to(DEVICE, weight_dtype)
39
+ with torch.no_grad():
40
+ latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
41
+ return latents
42
+
43
+
44
+ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
45
+ if is_full_path:
46
+ base_name = os.path.splitext(os.path.basename(image_key))[0]
47
+ else:
48
+ base_name = image_key
49
+ if flip:
50
+ base_name += '_flip'
51
+ return os.path.join(data_dir, base_name)
52
+
53
+
54
+ def main(args):
55
+ # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
56
+ if args.bucket_reso_steps % 8 > 0:
57
+ print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
58
+
59
+ image_paths = train_util.glob_images(args.train_data_dir)
60
+ print(f"found {len(image_paths)} images.")
61
+
62
+ if os.path.exists(args.in_json):
63
+ print(f"loading existing metadata: {args.in_json}")
64
+ with open(args.in_json, "rt", encoding='utf-8') as f:
65
+ metadata = json.load(f)
66
+ else:
67
+ print(f"no metadata / メタデータファイルがありません: {args.in_json}")
68
+ return
69
+
70
+ weight_dtype = torch.float32
71
+ if args.mixed_precision == "fp16":
72
+ weight_dtype = torch.float16
73
+ elif args.mixed_precision == "bf16":
74
+ weight_dtype = torch.bfloat16
75
+
76
+ vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
77
+ vae.eval()
78
+ vae.to(DEVICE, dtype=weight_dtype)
79
+
80
+ # bucketのサイズを計算する
81
+ max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
82
+ assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
83
+
84
+ bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
85
+ args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
86
+ if not args.bucket_no_upscale:
87
+ bucket_manager.make_buckets()
88
+ else:
89
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
90
+
91
+ # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
92
+ img_ar_errors = []
93
+
94
+ def process_batch(is_last):
95
+ for bucket in bucket_manager.buckets:
96
+ if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
97
+ latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
98
+ assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
99
+ f"latent shape {latents.shape}, {bucket[0][1].shape}"
100
+
101
+ for (image_key, _), latent in zip(bucket, latents):
102
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
103
+ np.savez(npz_file_name, latent)
104
+
105
+ # flip
106
+ if args.flip_aug:
107
+ latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
108
+
109
+ for (image_key, _), latent in zip(bucket, latents):
110
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
111
+ np.savez(npz_file_name, latent)
112
+ else:
113
+ # remove existing flipped npz
114
+ for image_key, _ in bucket:
115
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
116
+ if os.path.isfile(npz_file_name):
117
+ print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
118
+ os.remove(npz_file_name)
119
+
120
+ bucket.clear()
121
+
122
+ # 読み込みの高速化のためにDataLoaderを使うオプション
123
+ if args.max_data_loader_n_workers is not None:
124
+ dataset = train_util.ImageLoadingDataset(image_paths)
125
+ data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
126
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
127
+ else:
128
+ data = [[(None, ip)] for ip in image_paths]
129
+
130
+ bucket_counts = {}
131
+ for data_entry in tqdm(data, smoothing=0.0):
132
+ if data_entry[0] is None:
133
+ continue
134
+
135
+ img_tensor, image_path = data_entry[0]
136
+ if img_tensor is not None:
137
+ image = transforms.functional.to_pil_image(img_tensor)
138
+ else:
139
+ try:
140
+ image = Image.open(image_path)
141
+ if image.mode != 'RGB':
142
+ image = image.convert("RGB")
143
+ except Exception as e:
144
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
145
+ continue
146
+
147
+ image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
148
+ if image_key not in metadata:
149
+ metadata[image_key] = {}
150
+
151
+ # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
152
+
153
+ reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
154
+ img_ar_errors.append(abs(ar_error))
155
+ bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
156
+
157
+ # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
158
+ metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
159
+
160
+ if not args.bucket_no_upscale:
161
+ # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
162
+ assert resized_size[0] == reso[0] or resized_size[1] == reso[
163
+ 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
164
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
165
+ 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
166
+
167
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
168
+ 1], f"internal error resized size is small: {resized_size}, {reso}"
169
+
170
+ # 既に存在するファイルがあればshapeを確認して同じならskipする
171
+ if args.skip_existing:
172
+ npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
173
+ if args.flip_aug:
174
+ npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
175
+
176
+ found = True
177
+ for npz_file in npz_files:
178
+ if not os.path.exists(npz_file):
179
+ found = False
180
+ break
181
+
182
+ dat = np.load(npz_file)['arr_0']
183
+ if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
184
+ found = False
185
+ break
186
+ if found:
187
+ continue
188
+
189
+ # 画像をリサイズしてトリミングする
190
+ # PILにinter_areaがないのでcv2で……
191
+ image = np.array(image)
192
+ if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
193
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
194
+
195
+ if resized_size[0] > reso[0]:
196
+ trim_size = resized_size[0] - reso[0]
197
+ image = image[:, trim_size//2:trim_size//2 + reso[0]]
198
+
199
+ if resized_size[1] > reso[1]:
200
+ trim_size = resized_size[1] - reso[1]
201
+ image = image[trim_size//2:trim_size//2 + reso[1]]
202
+
203
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
204
+
205
+ # # debug
206
+ # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
207
+
208
+ # バッチへ追加
209
+ bucket_manager.add_image(reso, (image_key, image))
210
+
211
+ # バッチを推論するか判定して推論する
212
+ process_batch(False)
213
+
214
+ # 残りを処理する
215
+ process_batch(True)
216
+
217
+ bucket_manager.sort()
218
+ for i, reso in enumerate(bucket_manager.resos):
219
+ count = bucket_counts.get(reso, 0)
220
+ if count > 0:
221
+ print(f"bucket {i} {reso}: {count}")
222
+ img_ar_errors = np.array(img_ar_errors)
223
+ print(f"mean ar error: {np.mean(img_ar_errors)}")
224
+
225
+ # metadataを書き出して終わり
226
+ print(f"writing metadata: {args.out_json}")
227
+ with open(args.out_json, "wt", encoding='utf-8') as f:
228
+ json.dump(metadata, f, indent=2)
229
+ print("done!")
230
+
231
+
232
+ def setup_parser() -> argparse.ArgumentParser:
233
+ parser = argparse.ArgumentParser()
234
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
235
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
236
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
237
+ parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
238
+ parser.add_argument("--v2", action='store_true',
239
+ help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
240
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
241
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
242
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
243
+ parser.add_argument("--max_resolution", type=str, default="512,512",
244
+ help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
245
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
246
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
247
+ parser.add_argument("--bucket_reso_steps", type=int, default=64,
248
+ help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
249
+ parser.add_argument("--bucket_no_upscale", action="store_true",
250
+ help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
251
+ parser.add_argument("--mixed_precision", type=str, default="no",
252
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
253
+ parser.add_argument("--full_path", action="store_true",
254
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
255
+ parser.add_argument("--flip_aug", action="store_true",
256
+ help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
257
+ parser.add_argument("--skip_existing", action="store_true",
258
+ help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
259
+
260
+ return parser
261
+
262
+
263
+ if __name__ == '__main__':
264
+ parser = setup_parser()
265
+
266
+ args = parser.parse_args()
267
+ main(args)
finetune/tag_images_by_wd14_tagger.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import glob
4
+ import os
5
+
6
+ from PIL import Image
7
+ import cv2
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ from tensorflow.keras.models import load_model
11
+ from huggingface_hub import hf_hub_download
12
+ import torch
13
+
14
+ import library.train_util as train_util
15
+
16
+ # from wd14 tagger
17
+ IMAGE_SIZE = 448
18
+
19
+ # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
20
+ DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
21
+ FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
22
+ SUB_DIR = "variables"
23
+ SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
24
+ CSV_FILE = FILES[-1]
25
+
26
+
27
+ def preprocess_image(image):
28
+ image = np.array(image)
29
+ image = image[:, :, ::-1] # RGB->BGR
30
+
31
+ # pad to square
32
+ size = max(image.shape[0:2])
33
+ pad_x = size - image.shape[1]
34
+ pad_y = size - image.shape[0]
35
+ pad_l = pad_x // 2
36
+ pad_t = pad_y // 2
37
+ image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
38
+
39
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
40
+ image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
41
+
42
+ image = image.astype(np.float32)
43
+ return image
44
+
45
+
46
+ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
47
+ def __init__(self, image_paths):
48
+ self.images = image_paths
49
+
50
+ def __len__(self):
51
+ return len(self.images)
52
+
53
+ def __getitem__(self, idx):
54
+ img_path = self.images[idx]
55
+
56
+ try:
57
+ image = Image.open(img_path).convert("RGB")
58
+ image = preprocess_image(image)
59
+ tensor = torch.tensor(image)
60
+ except Exception as e:
61
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
62
+ return None
63
+
64
+ return (tensor, img_path)
65
+
66
+
67
+ def collate_fn_remove_corrupted(batch):
68
+ """Collate function that allows to remove corrupted examples in the
69
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
70
+ The 'None's in the batch are removed.
71
+ """
72
+ # Filter out all the Nones (corrupted examples)
73
+ batch = list(filter(lambda x: x is not None, batch))
74
+ return batch
75
+
76
+
77
+ def main(args):
78
+ # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
79
+ # depreacatedの警告が出るけどなくなったらその時
80
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
81
+ if not os.path.exists(args.model_dir) or args.force_download:
82
+ print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
83
+ for file in FILES:
84
+ hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
85
+ for file in SUB_DIR_FILES:
86
+ hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
87
+ args.model_dir, SUB_DIR), force_download=True, force_filename=file)
88
+ else:
89
+ print("using existing wd14 tagger model")
90
+
91
+ # 画像を読み込む
92
+ image_paths = train_util.glob_images(args.train_data_dir)
93
+ print(f"found {len(image_paths)} images.")
94
+
95
+ print("loading model and labels")
96
+ model = load_model(args.model_dir)
97
+
98
+ # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
99
+ # 依存ライブラリを増やしたくないので自力で読むよ
100
+ with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
101
+ reader = csv.reader(f)
102
+ l = [row for row in reader]
103
+ header = l[0] # tag_id,name,category,count
104
+ rows = l[1:]
105
+ assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
106
+
107
+ tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
108
+
109
+ # 推論する
110
+ def run_batch(path_imgs):
111
+ imgs = np.array([im for _, im in path_imgs])
112
+
113
+ probs = model(imgs, training=False)
114
+ probs = probs.numpy()
115
+
116
+ for (image_path, _), prob in zip(path_imgs, probs):
117
+ # 最初の4つはratingなので無視する
118
+ # # First 4 labels are actually ratings: pick one with argmax
119
+ # ratings_names = label_names[:4]
120
+ # rating_index = ratings_names["probs"].argmax()
121
+ # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
122
+
123
+ # それ以降はタグなのでconfidenceがthresholdより高いものを追加する
124
+ # Everything else is tags: pick any where prediction confidence > threshold
125
+ tag_text = ""
126
+ for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
127
+ if p >= args.thresh and i < len(tags):
128
+ tag_text += ", " + tags[i]
129
+
130
+ if len(tag_text) > 0:
131
+ tag_text = tag_text[2:] # 最初の ", " を消す
132
+
133
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
134
+ f.write(tag_text + '\n')
135
+ if args.debug:
136
+ print(image_path, tag_text)
137
+
138
+ # 読み込みの高速化のためにDataLoaderを使うオプション
139
+ if args.max_data_loader_n_workers is not None:
140
+ dataset = ImageLoadingPrepDataset(image_paths)
141
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
142
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
143
+ else:
144
+ data = [[(None, ip)] for ip in image_paths]
145
+
146
+ b_imgs = []
147
+ for data_entry in tqdm(data, smoothing=0.0):
148
+ for data in data_entry:
149
+ if data is None:
150
+ continue
151
+
152
+ image, image_path = data
153
+ if image is not None:
154
+ image = image.detach().numpy()
155
+ else:
156
+ try:
157
+ image = Image.open(image_path)
158
+ if image.mode != 'RGB':
159
+ image = image.convert("RGB")
160
+ image = preprocess_image(image)
161
+ except Exception as e:
162
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
163
+ continue
164
+ b_imgs.append((image_path, image))
165
+
166
+ if len(b_imgs) >= args.batch_size:
167
+ run_batch(b_imgs)
168
+ b_imgs.clear()
169
+
170
+ if len(b_imgs) > 0:
171
+ run_batch(b_imgs)
172
+
173
+ print("done!")
174
+
175
+
176
+ def setup_parser() -> argparse.ArgumentParser:
177
+ parser = argparse.ArgumentParser()
178
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
179
+ parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
180
+ help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
181
+ parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
182
+ help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
183
+ parser.add_argument("--force_download", action='store_true',
184
+ help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
185
+ parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
186
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
187
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
188
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
189
+ parser.add_argument("--caption_extention", type=str, default=None,
190
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
191
+ parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
192
+ parser.add_argument("--debug", action="store_true", help="debug mode")
193
+
194
+ return parser
195
+
196
+
197
+ if __name__ == '__main__':
198
+ parser = setup_parser()
199
+
200
+ args = parser.parse_args()
201
+
202
+ # スペルミスしていたオプションを復元する
203
+ if args.caption_extention is not None:
204
+ args.caption_extension = args.caption_extention
205
+
206
+ main(args)
finetune_gui.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import math
4
+ import os
5
+ import subprocess
6
+ import pathlib
7
+ import argparse
8
+ from library.common_gui import (
9
+ get_folder_path,
10
+ get_file_path,
11
+ get_saveasfile_path,
12
+ save_inference_file,
13
+ gradio_advanced_training,
14
+ run_cmd_advanced_training,
15
+ gradio_training,
16
+ run_cmd_advanced_training,
17
+ gradio_config,
18
+ gradio_source_model,
19
+ color_aug_changed,
20
+ run_cmd_training,
21
+ # set_legacy_8bitadam,
22
+ update_my_data,
23
+ check_if_model_exist,
24
+ )
25
+ from library.tensorboard_gui import (
26
+ gradio_tensorboard,
27
+ start_tensorboard,
28
+ stop_tensorboard,
29
+ )
30
+ from library.utilities import utilities_tab
31
+ from library.sampler_gui import sample_gradio_config, run_cmd_sample
32
+
33
+ folder_symbol = '\U0001f4c2' # 📂
34
+ refresh_symbol = '\U0001f504' # 🔄
35
+ save_style_symbol = '\U0001f4be' # 💾
36
+ document_symbol = '\U0001F4C4' # 📄
37
+
38
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
39
+
40
+
41
+ def save_configuration(
42
+ save_as,
43
+ file_path,
44
+ pretrained_model_name_or_path,
45
+ v2,
46
+ v_parameterization,
47
+ train_dir,
48
+ image_folder,
49
+ output_dir,
50
+ logging_dir,
51
+ max_resolution,
52
+ min_bucket_reso,
53
+ max_bucket_reso,
54
+ batch_size,
55
+ flip_aug,
56
+ caption_metadata_filename,
57
+ latent_metadata_filename,
58
+ full_path,
59
+ learning_rate,
60
+ lr_scheduler,
61
+ lr_warmup,
62
+ dataset_repeats,
63
+ train_batch_size,
64
+ epoch,
65
+ save_every_n_epochs,
66
+ mixed_precision,
67
+ save_precision,
68
+ seed,
69
+ num_cpu_threads_per_process,
70
+ train_text_encoder,
71
+ create_caption,
72
+ create_buckets,
73
+ save_model_as,
74
+ caption_extension,
75
+ # use_8bit_adam,
76
+ xformers,
77
+ clip_skip,
78
+ save_state,
79
+ resume,
80
+ gradient_checkpointing,
81
+ gradient_accumulation_steps,
82
+ mem_eff_attn,
83
+ shuffle_caption,
84
+ output_name,
85
+ max_token_length,
86
+ max_train_epochs,
87
+ max_data_loader_n_workers,
88
+ full_fp16,
89
+ color_aug,
90
+ model_list,
91
+ cache_latents,
92
+ use_latent_files,
93
+ keep_tokens,
94
+ persistent_data_loader_workers,
95
+ bucket_no_upscale,
96
+ random_crop,
97
+ bucket_reso_steps,
98
+ caption_dropout_every_n_epochs,
99
+ caption_dropout_rate,
100
+ optimizer,
101
+ optimizer_args,
102
+ noise_offset,
103
+ sample_every_n_steps,
104
+ sample_every_n_epochs,
105
+ sample_sampler,
106
+ sample_prompts,
107
+ additional_parameters,
108
+ vae_batch_size,
109
+ min_snr_gamma,
110
+ ):
111
+ # Get list of function parameters and values
112
+ parameters = list(locals().items())
113
+
114
+ original_file_path = file_path
115
+
116
+ save_as_bool = True if save_as.get('label') == 'True' else False
117
+
118
+ if save_as_bool:
119
+ print('Save as...')
120
+ file_path = get_saveasfile_path(file_path)
121
+ else:
122
+ print('Save...')
123
+ if file_path == None or file_path == '':
124
+ file_path = get_saveasfile_path(file_path)
125
+
126
+ # print(file_path)
127
+
128
+ if file_path == None or file_path == '':
129
+ return original_file_path # In case a file_path was provided and the user decide to cancel the open action
130
+
131
+ # Return the values of the variables as a dictionary
132
+ variables = {
133
+ name: value
134
+ for name, value in parameters # locals().items()
135
+ if name
136
+ not in [
137
+ 'file_path',
138
+ 'save_as',
139
+ ]
140
+ }
141
+
142
+ # Extract the destination directory from the file path
143
+ destination_directory = os.path.dirname(file_path)
144
+
145
+ # Create the destination directory if it doesn't exist
146
+ if not os.path.exists(destination_directory):
147
+ os.makedirs(destination_directory)
148
+
149
+ # Save the data to the selected file
150
+ with open(file_path, 'w') as file:
151
+ json.dump(variables, file, indent=2)
152
+
153
+ return file_path
154
+
155
+
156
+ def open_configuration(
157
+ ask_for_file,
158
+ file_path,
159
+ pretrained_model_name_or_path,
160
+ v2,
161
+ v_parameterization,
162
+ train_dir,
163
+ image_folder,
164
+ output_dir,
165
+ logging_dir,
166
+ max_resolution,
167
+ min_bucket_reso,
168
+ max_bucket_reso,
169
+ batch_size,
170
+ flip_aug,
171
+ caption_metadata_filename,
172
+ latent_metadata_filename,
173
+ full_path,
174
+ learning_rate,
175
+ lr_scheduler,
176
+ lr_warmup,
177
+ dataset_repeats,
178
+ train_batch_size,
179
+ epoch,
180
+ save_every_n_epochs,
181
+ mixed_precision,
182
+ save_precision,
183
+ seed,
184
+ num_cpu_threads_per_process,
185
+ train_text_encoder,
186
+ create_caption,
187
+ create_buckets,
188
+ save_model_as,
189
+ caption_extension,
190
+ # use_8bit_adam,
191
+ xformers,
192
+ clip_skip,
193
+ save_state,
194
+ resume,
195
+ gradient_checkpointing,
196
+ gradient_accumulation_steps,
197
+ mem_eff_attn,
198
+ shuffle_caption,
199
+ output_name,
200
+ max_token_length,
201
+ max_train_epochs,
202
+ max_data_loader_n_workers,
203
+ full_fp16,
204
+ color_aug,
205
+ model_list,
206
+ cache_latents,
207
+ use_latent_files,
208
+ keep_tokens,
209
+ persistent_data_loader_workers,
210
+ bucket_no_upscale,
211
+ random_crop,
212
+ bucket_reso_steps,
213
+ caption_dropout_every_n_epochs,
214
+ caption_dropout_rate,
215
+ optimizer,
216
+ optimizer_args,
217
+ noise_offset,
218
+ sample_every_n_steps,
219
+ sample_every_n_epochs,
220
+ sample_sampler,
221
+ sample_prompts,
222
+ additional_parameters,
223
+ vae_batch_size,
224
+ min_snr_gamma,
225
+ ):
226
+ # Get list of function parameters and values
227
+ parameters = list(locals().items())
228
+
229
+ ask_for_file = True if ask_for_file.get('label') == 'True' else False
230
+
231
+ original_file_path = file_path
232
+
233
+ if ask_for_file:
234
+ file_path = get_file_path(file_path)
235
+
236
+ if not file_path == '' and not file_path == None:
237
+ # load variables from JSON file
238
+ with open(file_path, 'r') as f:
239
+ my_data = json.load(f)
240
+ print('Loading config...')
241
+ # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
242
+ my_data = update_my_data(my_data)
243
+ else:
244
+ file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
245
+ my_data = {}
246
+
247
+ values = [file_path]
248
+ for key, value in parameters:
249
+ # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
250
+ if not key in ['ask_for_file', 'file_path']:
251
+ values.append(my_data.get(key, value))
252
+ return tuple(values)
253
+
254
+
255
+ def train_model(
256
+ pretrained_model_name_or_path,
257
+ v2,
258
+ v_parameterization,
259
+ train_dir,
260
+ image_folder,
261
+ output_dir,
262
+ logging_dir,
263
+ max_resolution,
264
+ min_bucket_reso,
265
+ max_bucket_reso,
266
+ batch_size,
267
+ flip_aug,
268
+ caption_metadata_filename,
269
+ latent_metadata_filename,
270
+ full_path,
271
+ learning_rate,
272
+ lr_scheduler,
273
+ lr_warmup,
274
+ dataset_repeats,
275
+ train_batch_size,
276
+ epoch,
277
+ save_every_n_epochs,
278
+ mixed_precision,
279
+ save_precision,
280
+ seed,
281
+ num_cpu_threads_per_process,
282
+ train_text_encoder,
283
+ generate_caption_database,
284
+ generate_image_buckets,
285
+ save_model_as,
286
+ caption_extension,
287
+ # use_8bit_adam,
288
+ xformers,
289
+ clip_skip,
290
+ save_state,
291
+ resume,
292
+ gradient_checkpointing,
293
+ gradient_accumulation_steps,
294
+ mem_eff_attn,
295
+ shuffle_caption,
296
+ output_name,
297
+ max_token_length,
298
+ max_train_epochs,
299
+ max_data_loader_n_workers,
300
+ full_fp16,
301
+ color_aug,
302
+ model_list, # Keep this. Yes, it is unused here but required given the common list used
303
+ cache_latents,
304
+ use_latent_files,
305
+ keep_tokens,
306
+ persistent_data_loader_workers,
307
+ bucket_no_upscale,
308
+ random_crop,
309
+ bucket_reso_steps,
310
+ caption_dropout_every_n_epochs,
311
+ caption_dropout_rate,
312
+ optimizer,
313
+ optimizer_args,
314
+ noise_offset,
315
+ sample_every_n_steps,
316
+ sample_every_n_epochs,
317
+ sample_sampler,
318
+ sample_prompts,
319
+ additional_parameters,
320
+ vae_batch_size,
321
+ min_snr_gamma,
322
+ ):
323
+ if check_if_model_exist(output_name, output_dir, save_model_as):
324
+ return
325
+
326
+ # create caption json file
327
+ if generate_caption_database:
328
+ if not os.path.exists(train_dir):
329
+ os.mkdir(train_dir)
330
+
331
+ run_cmd = f'{PYTHON} finetune/merge_captions_to_metadata.py'
332
+ if caption_extension == '':
333
+ run_cmd += f' --caption_extension=".caption"'
334
+ else:
335
+ run_cmd += f' --caption_extension={caption_extension}'
336
+ run_cmd += f' "{image_folder}"'
337
+ run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
338
+ if full_path:
339
+ run_cmd += f' --full_path'
340
+
341
+ print(run_cmd)
342
+
343
+ # Run the command
344
+ if os.name == 'posix':
345
+ os.system(run_cmd)
346
+ else:
347
+ subprocess.run(run_cmd)
348
+
349
+ # create images buckets
350
+ if generate_image_buckets:
351
+ run_cmd = f'{PYTHON} finetune/prepare_buckets_latents.py'
352
+ run_cmd += f' "{image_folder}"'
353
+ run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
354
+ run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
355
+ run_cmd += f' "{pretrained_model_name_or_path}"'
356
+ run_cmd += f' --batch_size={batch_size}'
357
+ run_cmd += f' --max_resolution={max_resolution}'
358
+ run_cmd += f' --min_bucket_reso={min_bucket_reso}'
359
+ run_cmd += f' --max_bucket_reso={max_bucket_reso}'
360
+ run_cmd += f' --mixed_precision={mixed_precision}'
361
+ # if flip_aug:
362
+ # run_cmd += f' --flip_aug'
363
+ if full_path:
364
+ run_cmd += f' --full_path'
365
+
366
+ print(run_cmd)
367
+
368
+ # Run the command
369
+ if os.name == 'posix':
370
+ os.system(run_cmd)
371
+ else:
372
+ subprocess.run(run_cmd)
373
+
374
+ image_num = len(
375
+ [
376
+ f
377
+ for f, lower_f in (
378
+ (file, file.lower()) for file in os.listdir(image_folder)
379
+ )
380
+ if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
381
+ ]
382
+ )
383
+ print(f'image_num = {image_num}')
384
+
385
+ repeats = int(image_num) * int(dataset_repeats)
386
+ print(f'repeats = {str(repeats)}')
387
+
388
+ # calculate max_train_steps
389
+ max_train_steps = int(
390
+ math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
391
+ )
392
+
393
+ # Divide by two because flip augmentation create two copied of the source images
394
+ if flip_aug:
395
+ max_train_steps = int(math.ceil(float(max_train_steps) / 2))
396
+
397
+ print(f'max_train_steps = {max_train_steps}')
398
+
399
+ lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
400
+ print(f'lr_warmup_steps = {lr_warmup_steps}')
401
+
402
+ run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
403
+ if v2:
404
+ run_cmd += ' --v2'
405
+ if v_parameterization:
406
+ run_cmd += ' --v_parameterization'
407
+ if train_text_encoder:
408
+ run_cmd += ' --train_text_encoder'
409
+ run_cmd += (
410
+ f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
411
+ )
412
+ if use_latent_files == 'Yes':
413
+ run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
414
+ else:
415
+ run_cmd += f' --in_json="{train_dir}/{caption_metadata_filename}"'
416
+ run_cmd += f' --train_data_dir="{image_folder}"'
417
+ run_cmd += f' --output_dir="{output_dir}"'
418
+ if not logging_dir == '':
419
+ run_cmd += f' --logging_dir="{logging_dir}"'
420
+ run_cmd += f' --dataset_repeats={dataset_repeats}'
421
+ run_cmd += f' --learning_rate={learning_rate}'
422
+
423
+ run_cmd += ' --enable_bucket'
424
+ run_cmd += f' --resolution={max_resolution}'
425
+ run_cmd += f' --min_bucket_reso={min_bucket_reso}'
426
+ run_cmd += f' --max_bucket_reso={max_bucket_reso}'
427
+
428
+ if not save_model_as == 'same as source model':
429
+ run_cmd += f' --save_model_as={save_model_as}'
430
+ if int(gradient_accumulation_steps) > 1:
431
+ run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
432
+ # if save_state:
433
+ # run_cmd += ' --save_state'
434
+ # if not resume == '':
435
+ # run_cmd += f' --resume={resume}'
436
+ if not output_name == '':
437
+ run_cmd += f' --output_name="{output_name}"'
438
+ if int(max_token_length) > 75:
439
+ run_cmd += f' --max_token_length={max_token_length}'
440
+
441
+ run_cmd += run_cmd_training(
442
+ learning_rate=learning_rate,
443
+ lr_scheduler=lr_scheduler,
444
+ lr_warmup_steps=lr_warmup_steps,
445
+ train_batch_size=train_batch_size,
446
+ max_train_steps=max_train_steps,
447
+ save_every_n_epochs=save_every_n_epochs,
448
+ mixed_precision=mixed_precision,
449
+ save_precision=save_precision,
450
+ seed=seed,
451
+ caption_extension=caption_extension,
452
+ cache_latents=cache_latents,
453
+ optimizer=optimizer,
454
+ optimizer_args=optimizer_args,
455
+ )
456
+
457
+ run_cmd += run_cmd_advanced_training(
458
+ max_train_epochs=max_train_epochs,
459
+ max_data_loader_n_workers=max_data_loader_n_workers,
460
+ max_token_length=max_token_length,
461
+ resume=resume,
462
+ save_state=save_state,
463
+ mem_eff_attn=mem_eff_attn,
464
+ clip_skip=clip_skip,
465
+ flip_aug=flip_aug,
466
+ color_aug=color_aug,
467
+ shuffle_caption=shuffle_caption,
468
+ gradient_checkpointing=gradient_checkpointing,
469
+ full_fp16=full_fp16,
470
+ xformers=xformers,
471
+ # use_8bit_adam=use_8bit_adam,
472
+ keep_tokens=keep_tokens,
473
+ persistent_data_loader_workers=persistent_data_loader_workers,
474
+ bucket_no_upscale=bucket_no_upscale,
475
+ random_crop=random_crop,
476
+ bucket_reso_steps=bucket_reso_steps,
477
+ caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
478
+ caption_dropout_rate=caption_dropout_rate,
479
+ noise_offset=noise_offset,
480
+ additional_parameters=additional_parameters,
481
+ vae_batch_size=vae_batch_size,
482
+ min_snr_gamma=min_snr_gamma,
483
+ )
484
+
485
+ run_cmd += run_cmd_sample(
486
+ sample_every_n_steps,
487
+ sample_every_n_epochs,
488
+ sample_sampler,
489
+ sample_prompts,
490
+ output_dir,
491
+ )
492
+
493
+ print(run_cmd)
494
+
495
+ # Run the command
496
+ if os.name == 'posix':
497
+ os.system(run_cmd)
498
+ else:
499
+ subprocess.run(run_cmd)
500
+
501
+ # check if output_dir/last is a folder... therefore it is a diffuser model
502
+ last_dir = pathlib.Path(f'{output_dir}/{output_name}')
503
+
504
+ if not last_dir.is_dir():
505
+ # Copy inference model for v2 if required
506
+ save_inference_file(output_dir, v2, v_parameterization, output_name)
507
+
508
+
509
+ def remove_doublequote(file_path):
510
+ if file_path != None:
511
+ file_path = file_path.replace('"', '')
512
+
513
+ return file_path
514
+
515
+
516
+ def finetune_tab():
517
+ dummy_db_true = gr.Label(value=True, visible=False)
518
+ dummy_db_false = gr.Label(value=False, visible=False)
519
+ gr.Markdown('Train a custom model using kohya finetune python code...')
520
+
521
+ (
522
+ button_open_config,
523
+ button_save_config,
524
+ button_save_as_config,
525
+ config_file_name,
526
+ button_load_config,
527
+ ) = gradio_config()
528
+
529
+ (
530
+ pretrained_model_name_or_path,
531
+ v2,
532
+ v_parameterization,
533
+ save_model_as,
534
+ model_list,
535
+ ) = gradio_source_model()
536
+
537
+ with gr.Tab('Folders'):
538
+ with gr.Row():
539
+ train_dir = gr.Textbox(
540
+ label='Training config folder',
541
+ placeholder='folder where the training configuration files will be saved',
542
+ )
543
+ train_dir_folder = gr.Button(
544
+ folder_symbol, elem_id='open_folder_small'
545
+ )
546
+ train_dir_folder.click(
547
+ get_folder_path,
548
+ outputs=train_dir,
549
+ show_progress=False,
550
+ )
551
+
552
+ image_folder = gr.Textbox(
553
+ label='Training Image folder',
554
+ placeholder='folder where the training images are located',
555
+ )
556
+ image_folder_input_folder = gr.Button(
557
+ folder_symbol, elem_id='open_folder_small'
558
+ )
559
+ image_folder_input_folder.click(
560
+ get_folder_path,
561
+ outputs=image_folder,
562
+ show_progress=False,
563
+ )
564
+ with gr.Row():
565
+ output_dir = gr.Textbox(
566
+ label='Model output folder',
567
+ placeholder='folder where the model will be saved',
568
+ )
569
+ output_dir_input_folder = gr.Button(
570
+ folder_symbol, elem_id='open_folder_small'
571
+ )
572
+ output_dir_input_folder.click(
573
+ get_folder_path,
574
+ outputs=output_dir,
575
+ show_progress=False,
576
+ )
577
+
578
+ logging_dir = gr.Textbox(
579
+ label='Logging folder',
580
+ placeholder='Optional: enable logging and output TensorBoard log to this folder',
581
+ )
582
+ logging_dir_input_folder = gr.Button(
583
+ folder_symbol, elem_id='open_folder_small'
584
+ )
585
+ logging_dir_input_folder.click(
586
+ get_folder_path,
587
+ outputs=logging_dir,
588
+ show_progress=False,
589
+ )
590
+ with gr.Row():
591
+ output_name = gr.Textbox(
592
+ label='Model output name',
593
+ placeholder='Name of the model to output',
594
+ value='last',
595
+ interactive=True,
596
+ )
597
+ train_dir.change(
598
+ remove_doublequote,
599
+ inputs=[train_dir],
600
+ outputs=[train_dir],
601
+ )
602
+ image_folder.change(
603
+ remove_doublequote,
604
+ inputs=[image_folder],
605
+ outputs=[image_folder],
606
+ )
607
+ output_dir.change(
608
+ remove_doublequote,
609
+ inputs=[output_dir],
610
+ outputs=[output_dir],
611
+ )
612
+ with gr.Tab('Dataset preparation'):
613
+ with gr.Row():
614
+ max_resolution = gr.Textbox(
615
+ label='Resolution (width,height)', value='512,512'
616
+ )
617
+ min_bucket_reso = gr.Textbox(
618
+ label='Min bucket resolution', value='256'
619
+ )
620
+ max_bucket_reso = gr.Textbox(
621
+ label='Max bucket resolution', value='1024'
622
+ )
623
+ batch_size = gr.Textbox(label='Batch size', value='1')
624
+ with gr.Row():
625
+ create_caption = gr.Checkbox(
626
+ label='Generate caption metadata', value=True
627
+ )
628
+ create_buckets = gr.Checkbox(
629
+ label='Generate image buckets metadata', value=True
630
+ )
631
+ use_latent_files = gr.Dropdown(
632
+ label='Use latent files',
633
+ choices=[
634
+ 'No',
635
+ 'Yes',
636
+ ],
637
+ value='Yes',
638
+ )
639
+ with gr.Accordion('Advanced parameters', open=False):
640
+ with gr.Row():
641
+ caption_metadata_filename = gr.Textbox(
642
+ label='Caption metadata filename', value='meta_cap.json'
643
+ )
644
+ latent_metadata_filename = gr.Textbox(
645
+ label='Latent metadata filename', value='meta_lat.json'
646
+ )
647
+ full_path = gr.Checkbox(label='Use full path', value=True)
648
+ with gr.Tab('Training parameters'):
649
+ (
650
+ learning_rate,
651
+ lr_scheduler,
652
+ lr_warmup,
653
+ train_batch_size,
654
+ epoch,
655
+ save_every_n_epochs,
656
+ mixed_precision,
657
+ save_precision,
658
+ num_cpu_threads_per_process,
659
+ seed,
660
+ caption_extension,
661
+ cache_latents,
662
+ optimizer,
663
+ optimizer_args,
664
+ ) = gradio_training(learning_rate_value='1e-5')
665
+ with gr.Row():
666
+ dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
667
+ train_text_encoder = gr.Checkbox(
668
+ label='Train text encoder', value=True
669
+ )
670
+ with gr.Accordion('Advanced parameters', open=False):
671
+ with gr.Row():
672
+ gradient_accumulation_steps = gr.Number(
673
+ label='Gradient accumulate steps', value='1'
674
+ )
675
+ (
676
+ # use_8bit_adam,
677
+ xformers,
678
+ full_fp16,
679
+ gradient_checkpointing,
680
+ shuffle_caption,
681
+ color_aug,
682
+ flip_aug,
683
+ clip_skip,
684
+ mem_eff_attn,
685
+ save_state,
686
+ resume,
687
+ max_token_length,
688
+ max_train_epochs,
689
+ max_data_loader_n_workers,
690
+ keep_tokens,
691
+ persistent_data_loader_workers,
692
+ bucket_no_upscale,
693
+ random_crop,
694
+ bucket_reso_steps,
695
+ caption_dropout_every_n_epochs,
696
+ caption_dropout_rate,
697
+ noise_offset,
698
+ additional_parameters,
699
+ vae_batch_size,
700
+ min_snr_gamma,
701
+ ) = gradio_advanced_training()
702
+ color_aug.change(
703
+ color_aug_changed,
704
+ inputs=[color_aug],
705
+ outputs=[cache_latents], # Not applicable to fine_tune.py
706
+ )
707
+
708
+ (
709
+ sample_every_n_steps,
710
+ sample_every_n_epochs,
711
+ sample_sampler,
712
+ sample_prompts,
713
+ ) = sample_gradio_config()
714
+
715
+ button_run = gr.Button('Train model', variant='primary')
716
+
717
+ # Setup gradio tensorboard buttons
718
+ button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
719
+
720
+ button_start_tensorboard.click(
721
+ start_tensorboard,
722
+ inputs=logging_dir,
723
+ )
724
+
725
+ button_stop_tensorboard.click(
726
+ stop_tensorboard,
727
+ show_progress=False,
728
+ )
729
+
730
+ settings_list = [
731
+ pretrained_model_name_or_path,
732
+ v2,
733
+ v_parameterization,
734
+ train_dir,
735
+ image_folder,
736
+ output_dir,
737
+ logging_dir,
738
+ max_resolution,
739
+ min_bucket_reso,
740
+ max_bucket_reso,
741
+ batch_size,
742
+ flip_aug,
743
+ caption_metadata_filename,
744
+ latent_metadata_filename,
745
+ full_path,
746
+ learning_rate,
747
+ lr_scheduler,
748
+ lr_warmup,
749
+ dataset_repeats,
750
+ train_batch_size,
751
+ epoch,
752
+ save_every_n_epochs,
753
+ mixed_precision,
754
+ save_precision,
755
+ seed,
756
+ num_cpu_threads_per_process,
757
+ train_text_encoder,
758
+ create_caption,
759
+ create_buckets,
760
+ save_model_as,
761
+ caption_extension,
762
+ # use_8bit_adam,
763
+ xformers,
764
+ clip_skip,
765
+ save_state,
766
+ resume,
767
+ gradient_checkpointing,
768
+ gradient_accumulation_steps,
769
+ mem_eff_attn,
770
+ shuffle_caption,
771
+ output_name,
772
+ max_token_length,
773
+ max_train_epochs,
774
+ max_data_loader_n_workers,
775
+ full_fp16,
776
+ color_aug,
777
+ model_list,
778
+ cache_latents,
779
+ use_latent_files,
780
+ keep_tokens,
781
+ persistent_data_loader_workers,
782
+ bucket_no_upscale,
783
+ random_crop,
784
+ bucket_reso_steps,
785
+ caption_dropout_every_n_epochs,
786
+ caption_dropout_rate,
787
+ optimizer,
788
+ optimizer_args,
789
+ noise_offset,
790
+ sample_every_n_steps,
791
+ sample_every_n_epochs,
792
+ sample_sampler,
793
+ sample_prompts,
794
+ additional_parameters,
795
+ vae_batch_size,
796
+ min_snr_gamma,
797
+ ]
798
+
799
+ button_run.click(train_model, inputs=settings_list)
800
+
801
+ button_open_config.click(
802
+ open_configuration,
803
+ inputs=[dummy_db_true, config_file_name] + settings_list,
804
+ outputs=[config_file_name] + settings_list,
805
+ show_progress=False,
806
+ )
807
+
808
+ button_load_config.click(
809
+ open_configuration,
810
+ inputs=[dummy_db_false, config_file_name] + settings_list,
811
+ outputs=[config_file_name] + settings_list,
812
+ show_progress=False,
813
+ )
814
+
815
+ button_save_config.click(
816
+ save_configuration,
817
+ inputs=[dummy_db_false, config_file_name] + settings_list,
818
+ outputs=[config_file_name],
819
+ show_progress=False,
820
+ )
821
+
822
+ button_save_as_config.click(
823
+ save_configuration,
824
+ inputs=[dummy_db_true, config_file_name] + settings_list,
825
+ outputs=[config_file_name],
826
+ show_progress=False,
827
+ )
828
+
829
+
830
+ def UI(**kwargs):
831
+
832
+ css = ''
833
+
834
+ if os.path.exists('./style.css'):
835
+ with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
836
+ print('Load CSS...')
837
+ css += file.read() + '\n'
838
+
839
+ interface = gr.Blocks(css=css)
840
+
841
+ with interface:
842
+ with gr.Tab('Finetune'):
843
+ finetune_tab()
844
+ with gr.Tab('Utilities'):
845
+ utilities_tab(enable_dreambooth_tab=False)
846
+
847
+ # Show the interface
848
+ launch_kwargs = {}
849
+ if not kwargs.get('username', None) == '':
850
+ launch_kwargs['auth'] = (
851
+ kwargs.get('username', None),
852
+ kwargs.get('password', None),
853
+ )
854
+ if kwargs.get('server_port', 0) > 0:
855
+ launch_kwargs['server_port'] = kwargs.get('server_port', 0)
856
+ if kwargs.get('inbrowser', False):
857
+ launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
858
+ print(launch_kwargs)
859
+ interface.launch(**launch_kwargs)
860
+
861
+
862
+ if __name__ == '__main__':
863
+ # torch.cuda.set_per_process_memory_fraction(0.48)
864
+ parser = argparse.ArgumentParser()
865
+ parser.add_argument(
866
+ '--username', type=str, default='', help='Username for authentication'
867
+ )
868
+ parser.add_argument(
869
+ '--password', type=str, default='', help='Password for authentication'
870
+ )
871
+ parser.add_argument(
872
+ '--server_port',
873
+ type=int,
874
+ default=0,
875
+ help='Port to run the server listener on',
876
+ )
877
+ parser.add_argument(
878
+ '--inbrowser', action='store_true', help='Open in browser'
879
+ )
880
+
881
+ args = parser.parse_args()
882
+
883
+ UI(
884
+ username=args.username,
885
+ password=args.password,
886
+ inbrowser=args.inbrowser,
887
+ server_port=args.server_port,
888
+ )
gen_img_diffusers.py ADDED
The diff for this file is too large to render. See raw diff
 
gui.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Activate the virtual environment
4
+ source ./venv/bin/activate
5
+
6
+ # If the requirements are validated, run the kohya_gui.py script with the command-line arguments
7
+ if python tools/validate_requirements.py; then
8
+ python kohya_gui.py "$@"
9
+ fi
kohya_gui.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import argparse
4
+ from dreambooth_gui import dreambooth_tab
5
+ from finetune_gui import finetune_tab
6
+ from textual_inversion_gui import ti_tab
7
+ from library.utilities import utilities_tab
8
+ from library.extract_lora_gui import gradio_extract_lora_tab
9
+ from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
10
+ from library.merge_lora_gui import gradio_merge_lora_tab
11
+ from library.resize_lora_gui import gradio_resize_lora_tab
12
+ from lora_gui import lora_tab
13
+
14
+
15
+ def UI(**kwargs):
16
+ css = ''
17
+
18
+ if os.path.exists('./style.css'):
19
+ with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
20
+ print('Load CSS...')
21
+ css += file.read() + '\n'
22
+
23
+ interface = gr.Blocks(css=css, title='Kohya_ss GUI')
24
+
25
+ with interface:
26
+ with gr.Tab('Dreambooth'):
27
+ (
28
+ train_data_dir_input,
29
+ reg_data_dir_input,
30
+ output_dir_input,
31
+ logging_dir_input,
32
+ ) = dreambooth_tab()
33
+ with gr.Tab('Dreambooth LoRA'):
34
+ lora_tab()
35
+ with gr.Tab('Dreambooth TI'):
36
+ ti_tab()
37
+ with gr.Tab('Finetune'):
38
+ finetune_tab()
39
+ with gr.Tab('Utilities'):
40
+ utilities_tab(
41
+ train_data_dir_input=train_data_dir_input,
42
+ reg_data_dir_input=reg_data_dir_input,
43
+ output_dir_input=output_dir_input,
44
+ logging_dir_input=logging_dir_input,
45
+ enable_copy_info_button=True,
46
+ )
47
+ gradio_extract_lora_tab()
48
+ gradio_extract_lycoris_locon_tab()
49
+ gradio_merge_lora_tab()
50
+ gradio_resize_lora_tab()
51
+
52
+ # Show the interface
53
+ launch_kwargs = {}
54
+ username = kwargs.get('username')
55
+ password = kwargs.get('password')
56
+ server_port = kwargs.get('server_port', 0)
57
+ inbrowser = kwargs.get('inbrowser', False)
58
+ share = kwargs.get('share', False)
59
+ server_name = kwargs.get('listen')
60
+
61
+ launch_kwargs['server_name'] = server_name
62
+ if username and password:
63
+ launch_kwargs['auth'] = (username, password)
64
+ if server_port > 0:
65
+ launch_kwargs['server_port'] = server_port
66
+ if inbrowser:
67
+ launch_kwargs['inbrowser'] = inbrowser
68
+ if share:
69
+ launch_kwargs['share'] = share
70
+ interface.launch(**launch_kwargs)
71
+
72
+
73
+ if __name__ == '__main__':
74
+ # torch.cuda.set_per_process_memory_fraction(0.48)
75
+ parser = argparse.ArgumentParser()
76
+ parser.add_argument(
77
+ '--listen',
78
+ type=str,
79
+ default='127.0.0.1',
80
+ help='IP to listen on for connections to Gradio',
81
+ )
82
+ parser.add_argument(
83
+ '--username', type=str, default='', help='Username for authentication'
84
+ )
85
+ parser.add_argument(
86
+ '--password', type=str, default='', help='Password for authentication'
87
+ )
88
+ parser.add_argument(
89
+ '--server_port',
90
+ type=int,
91
+ default=0,
92
+ help='Port to run the server listener on',
93
+ )
94
+ parser.add_argument(
95
+ '--inbrowser', action='store_true', help='Open in browser'
96
+ )
97
+ parser.add_argument(
98
+ '--share', action='store_true', help='Share the gradio UI'
99
+ )
100
+
101
+ args = parser.parse_args()
102
+
103
+ UI(
104
+ username=args.username,
105
+ password=args.password,
106
+ inbrowser=args.inbrowser,
107
+ server_port=args.server_port,
108
+ share=args.share,
109
+ listen=args.listen,
110
+ )
kohya_ss_colab.ipynb ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "MvroZ9rJ1iqN"
17
+ },
18
+ "source": [
19
+ "# Kohya SS WebUI Colab Setup\n",
20
+ "\n",
21
+ "This Colab workbook sets up a Kohya SS instance on Colab and provides a link to access the Kohya WebUI on Gradio Live. Kohya SS is a Python library that provides Stable Diffusion-based models for image, text, and audio generation tasks. This Colab workbook provides a convenient way for users to run Kohya SS without needing to install anything on their local machine.\n",
22
+ "\n",
23
+ "This workbook was inspired by the work of [Spaceginner](https://github.com/Spaceginner)'s original Colab workbook and the [Kohya SS project](https://github.com/bmaltais/kohya_ss) by [bmaltais](https://github.com/bmaltais). The Colab workbook was coded by [panguin6010](https://github.com/panguin6010) \n",
24
+ "\n",
25
+ "\n",
26
+ "## Tutorials\n",
27
+ "\n",
28
+ "Before running this code, make sure you are familiar with using Colab workbooks and have a basic understanding of Kohya SS and its usage. You can find tutorials for these online. If you encounter any issues or have suggestions for improvement, feel free to contribute to the project.\n"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {
34
+ "id": "DrAnm1um5vjh"
35
+ },
36
+ "source": [
37
+ "\n",
38
+ "\n",
39
+ "\n",
40
+ "---\n",
41
+ "\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {
48
+ "colab": {
49
+ "base_uri": "https://localhost:8080/"
50
+ },
51
+ "id": "vmoRnFQEqOeY",
52
+ "outputId": "09876c9a-d043-4881-d92f-6ed54313c390"
53
+ },
54
+ "outputs": [],
55
+ "source": [
56
+ "#@markdown #Step 1: Mounting Google Drive\n",
57
+ "\n",
58
+ "#@markdown The first step in setting up Kohya SS on Colab is to mount your Google Drive to the Colab notebook. This allows you to save and access files from your Google Drive in the Colab notebook.\n",
59
+ "\n",
60
+ "#@markdown To mount your Google Drive, run the following code block, which mounts your Google Drive to the /content/gdrive directory in the Colab notebook.\n",
61
+ "\n",
62
+ "\n",
63
+ "\n",
64
+ "from google.colab import drive\n",
65
+ "drive.mount('/content/gdrive')"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {
71
+ "id": "mvQwnr4354BM"
72
+ },
73
+ "source": [
74
+ "\n",
75
+ "\n",
76
+ "---\n",
77
+ "\n"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {
84
+ "cellView": "form",
85
+ "colab": {
86
+ "base_uri": "https://localhost:8080/",
87
+ "height": 49,
88
+ "referenced_widgets": [
89
+ "7ca7f6f727da46ac9a1149e69c16c81f",
90
+ "77e5e07552b641cf9c368fb3939cb1d1",
91
+ "235e01b92646444387ebd31ab945358e"
92
+ ]
93
+ },
94
+ "id": "jnhm7ycMrLWb",
95
+ "outputId": "63ba39ed-90c6-4b2d-f03e-61775587b083"
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "#@markdown #Kohya SS WebUI Installation\n",
100
+ "\n",
101
+ "#@markdown Now that your Google Drive is linked, we need to install the Kohya SS WebUI.\n",
102
+ "\n",
103
+ "#@markdown The code clones the [Kohya SS Google Colab](\"https://github.com/panguin6010/kohya_ss_google_colab\") repository and creates the necessary directories for Kohya SS to run. It then resets the git repository and pulls the latest changes. Finally, it displays a success message.\n",
104
+ "\n",
105
+ "#@markdown Note: If Google Drive is not connected, the code will use Colab storage instead.\n",
106
+ "\n",
107
+ "#@title\n",
108
+ "# Import necessary libraries\n",
109
+ "from IPython.display import clear_output\n",
110
+ "from IPython.utils import capture\n",
111
+ "from subprocess import getoutput\n",
112
+ "import ipywidgets as widgets\n",
113
+ "import sys\n",
114
+ "import fileinput\n",
115
+ "import os\n",
116
+ "import time\n",
117
+ "\n",
118
+ "# WebUI Installation\n",
119
+ "\n",
120
+ "# Check if Google Drive is connected\n",
121
+ "if not os.path.exists(\"/content/gdrive/MyDrive/\"):\n",
122
+ " print(\"Gdrive not connected, using colab storage ...\")\n",
123
+ " time.sleep(4)\n",
124
+ " !mkdir -p /content/gdrive/MyDrive/\n",
125
+ "\n",
126
+ "# Clone the repository and create necessary directories\n",
127
+ "with capture.capture_output() as cap:\n",
128
+ " def inf(msg, style, wdth):\n",
129
+ " inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth))\n",
130
+ " display(inf)\n",
131
+ "\n",
132
+ " %mkdir -p /content/gdrive/MyDrive/sd\n",
133
+ " %cd /content/gdrive/MyDrive/sd\n",
134
+ " !git clone https://github.com/panguin6010/kohya_ss_google_colab kohya_ss_colab\n",
135
+ " !mkdir -p /content/gdrive/MyDrive/sd/kohya_ss_colab/cache/huggingface\n",
136
+ " !ln -s /content/gdrive/MyDrive/sd/kohya_ss_colab/cache/huggingface /root/.cache/\n",
137
+ "\n",
138
+ "# Reset the git repository and pull the latest changes\n",
139
+ "with capture.capture_output() as cap:\n",
140
+ " %cd /content/gdrive/MyDrive/sd/kohya_ss_colab/\n",
141
+ " !git reset --hard\n",
142
+ " time.sleep(1)\n",
143
+ "\n",
144
+ "print(\"Updating the repository...\")\n",
145
+ "!git pull\n",
146
+ "\n",
147
+ "# Clear the output and display the success message\n",
148
+ "clear_output()\n",
149
+ "inf(\"✓ Done\", \"success\", \"50px\")"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {
155
+ "id": "8SrMhmFz7Lt4"
156
+ },
157
+ "source": [
158
+ "---"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {
165
+ "cellView": "form",
166
+ "colab": {
167
+ "base_uri": "https://localhost:8080/",
168
+ "height": 49,
169
+ "referenced_widgets": [
170
+ "54e929bcb37e4997a696d0becdecfd84",
171
+ "43fbca3abb04401296967f819680f94f",
172
+ "6d87b2c916394932b1a53382fe3cdb4e"
173
+ ]
174
+ },
175
+ "id": "yjvkHRlDtDmV",
176
+ "outputId": "06e1e873-b1ed-4403-c9a4-19ac1caa961b"
177
+ },
178
+ "outputs": [],
179
+ "source": [
180
+ "#@markdown #Requirements Installation\n",
181
+ "\n",
182
+ "#@markdown Now that we have downloaded the Kohya SS WebUI, we need to install the necessary requirements.\n",
183
+ "\n",
184
+ "# Print the status message\n",
185
+ "print(\"Installing requirements...\")\n",
186
+ "\n",
187
+ "# Change the working directory to the project folder\n",
188
+ "%cd /content/gdrive/MyDrive/sd/kohya_ss_colab/\n",
189
+ "\n",
190
+ "# Install the requirements\n",
191
+ "with capture.capture_output() as cap:\n",
192
+ " # Uncomment the following line if you need to install specific versions of torch and torchvision\n",
193
+ " # !pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116\n",
194
+ " \n",
195
+ " # Install the requirements from the requirements.txt file\n",
196
+ " !pip install -r requirements.txt\n",
197
+ "\n",
198
+ "# Clear the output to keep the notebook clean\n",
199
+ "clear_output()\n",
200
+ "\n",
201
+ "# Print the success message\n",
202
+ "inf(\"✓ Done\", \"success\", \"50px\")"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "markdown",
207
+ "metadata": {
208
+ "id": "FLDvlHm1tYud"
209
+ },
210
+ "source": [
211
+ "\n",
212
+ "---\n",
213
+ "\n"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {
220
+ "colab": {
221
+ "base_uri": "https://localhost:8080/"
222
+ },
223
+ "id": "IzS3hvuTtTqW",
224
+ "outputId": "9e629e1f-c8eb-43a2-9639-2583937ba81a"
225
+ },
226
+ "outputs": [],
227
+ "source": [
228
+ "#@markdown # Start Kohya ss WebUI\n",
229
+ "\n",
230
+ "User = \"\" #@param {type:\"string\"}\n",
231
+ "Password = \"\" #@param {type:\"string\"}\n",
232
+ "\n",
233
+ "#@markdown - Adding a username and password is not necessary but it will improve the security of your Kohya instance.\n",
234
+ "#@markdown ______\n",
235
+ "#@markdown # Please click the link that concludes with ```gradio.live``` to access your instance\n",
236
+ "# Encourage users to contribute improvements\n",
237
+ "print(\"Please feel free to make any changes or improvements you think would enhance this setup. Your input and contributions are greatly appreciated!\")\n",
238
+ "# Check if the user has provided a username and password\n",
239
+ "if User and Password:\n",
240
+ " # Run the Kohya GUI with the provided credentials\n",
241
+ " !python /content/gdrive/MyDrive/sd/kohya_ss_colab/kohya_gui.py --username $User --password $Password --share \n",
242
+ "else:\n",
243
+ " # Run the Kohya GUI without credentials\n",
244
+ " !python /content/gdrive/MyDrive/sd/kohya_ss_colab/kohya_gui.py --share \n"
245
+ ]
246
+ }
247
+ ],
248
+ "metadata": {
249
+ "colab": {
250
+ "authorship_tag": "ABX9TyOZmOjfS55zOBmbTmRNOf3b",
251
+ "include_colab_link": true,
252
+ "provenance": []
253
+ },
254
+ "kernelspec": {
255
+ "display_name": "Python 3",
256
+ "name": "python3"
257
+ },
258
+ "language_info": {
259
+ "name": "python"
260
+ },
261
+ "widgets": {
262
+ "application/vnd.jupyter.widget-state+json": {
263
+ "235e01b92646444387ebd31ab945358e": {
264
+ "model_module": "@jupyter-widgets/controls",
265
+ "model_module_version": "1.5.0",
266
+ "model_name": "ButtonStyleModel",
267
+ "state": {
268
+ "_model_module": "@jupyter-widgets/controls",
269
+ "_model_module_version": "1.5.0",
270
+ "_model_name": "ButtonStyleModel",
271
+ "_view_count": null,
272
+ "_view_module": "@jupyter-widgets/base",
273
+ "_view_module_version": "1.2.0",
274
+ "_view_name": "StyleView",
275
+ "button_color": null,
276
+ "font_weight": ""
277
+ }
278
+ },
279
+ "43fbca3abb04401296967f819680f94f": {
280
+ "model_module": "@jupyter-widgets/base",
281
+ "model_module_version": "1.2.0",
282
+ "model_name": "LayoutModel",
283
+ "state": {
284
+ "_model_module": "@jupyter-widgets/base",
285
+ "_model_module_version": "1.2.0",
286
+ "_model_name": "LayoutModel",
287
+ "_view_count": null,
288
+ "_view_module": "@jupyter-widgets/base",
289
+ "_view_module_version": "1.2.0",
290
+ "_view_name": "LayoutView",
291
+ "align_content": null,
292
+ "align_items": null,
293
+ "align_self": null,
294
+ "border": null,
295
+ "bottom": null,
296
+ "display": null,
297
+ "flex": null,
298
+ "flex_flow": null,
299
+ "grid_area": null,
300
+ "grid_auto_columns": null,
301
+ "grid_auto_flow": null,
302
+ "grid_auto_rows": null,
303
+ "grid_column": null,
304
+ "grid_gap": null,
305
+ "grid_row": null,
306
+ "grid_template_areas": null,
307
+ "grid_template_columns": null,
308
+ "grid_template_rows": null,
309
+ "height": null,
310
+ "justify_content": null,
311
+ "justify_items": null,
312
+ "left": null,
313
+ "margin": null,
314
+ "max_height": null,
315
+ "max_width": null,
316
+ "min_height": null,
317
+ "min_width": "50px",
318
+ "object_fit": null,
319
+ "object_position": null,
320
+ "order": null,
321
+ "overflow": null,
322
+ "overflow_x": null,
323
+ "overflow_y": null,
324
+ "padding": null,
325
+ "right": null,
326
+ "top": null,
327
+ "visibility": null,
328
+ "width": null
329
+ }
330
+ },
331
+ "54e929bcb37e4997a696d0becdecfd84": {
332
+ "model_module": "@jupyter-widgets/controls",
333
+ "model_module_version": "1.5.0",
334
+ "model_name": "ButtonModel",
335
+ "state": {
336
+ "_dom_classes": [],
337
+ "_model_module": "@jupyter-widgets/controls",
338
+ "_model_module_version": "1.5.0",
339
+ "_model_name": "ButtonModel",
340
+ "_view_count": null,
341
+ "_view_module": "@jupyter-widgets/controls",
342
+ "_view_module_version": "1.5.0",
343
+ "_view_name": "ButtonView",
344
+ "button_style": "success",
345
+ "description": "✓ Done",
346
+ "disabled": true,
347
+ "icon": "",
348
+ "layout": "IPY_MODEL_43fbca3abb04401296967f819680f94f",
349
+ "style": "IPY_MODEL_6d87b2c916394932b1a53382fe3cdb4e",
350
+ "tooltip": ""
351
+ }
352
+ },
353
+ "6d87b2c916394932b1a53382fe3cdb4e": {
354
+ "model_module": "@jupyter-widgets/controls",
355
+ "model_module_version": "1.5.0",
356
+ "model_name": "ButtonStyleModel",
357
+ "state": {
358
+ "_model_module": "@jupyter-widgets/controls",
359
+ "_model_module_version": "1.5.0",
360
+ "_model_name": "ButtonStyleModel",
361
+ "_view_count": null,
362
+ "_view_module": "@jupyter-widgets/base",
363
+ "_view_module_version": "1.2.0",
364
+ "_view_name": "StyleView",
365
+ "button_color": null,
366
+ "font_weight": ""
367
+ }
368
+ },
369
+ "77e5e07552b641cf9c368fb3939cb1d1": {
370
+ "model_module": "@jupyter-widgets/base",
371
+ "model_module_version": "1.2.0",
372
+ "model_name": "LayoutModel",
373
+ "state": {
374
+ "_model_module": "@jupyter-widgets/base",
375
+ "_model_module_version": "1.2.0",
376
+ "_model_name": "LayoutModel",
377
+ "_view_count": null,
378
+ "_view_module": "@jupyter-widgets/base",
379
+ "_view_module_version": "1.2.0",
380
+ "_view_name": "LayoutView",
381
+ "align_content": null,
382
+ "align_items": null,
383
+ "align_self": null,
384
+ "border": null,
385
+ "bottom": null,
386
+ "display": null,
387
+ "flex": null,
388
+ "flex_flow": null,
389
+ "grid_area": null,
390
+ "grid_auto_columns": null,
391
+ "grid_auto_flow": null,
392
+ "grid_auto_rows": null,
393
+ "grid_column": null,
394
+ "grid_gap": null,
395
+ "grid_row": null,
396
+ "grid_template_areas": null,
397
+ "grid_template_columns": null,
398
+ "grid_template_rows": null,
399
+ "height": null,
400
+ "justify_content": null,
401
+ "justify_items": null,
402
+ "left": null,
403
+ "margin": null,
404
+ "max_height": null,
405
+ "max_width": null,
406
+ "min_height": null,
407
+ "min_width": "50px",
408
+ "object_fit": null,
409
+ "object_position": null,
410
+ "order": null,
411
+ "overflow": null,
412
+ "overflow_x": null,
413
+ "overflow_y": null,
414
+ "padding": null,
415
+ "right": null,
416
+ "top": null,
417
+ "visibility": null,
418
+ "width": null
419
+ }
420
+ },
421
+ "7ca7f6f727da46ac9a1149e69c16c81f": {
422
+ "model_module": "@jupyter-widgets/controls",
423
+ "model_module_version": "1.5.0",
424
+ "model_name": "ButtonModel",
425
+ "state": {
426
+ "_dom_classes": [],
427
+ "_model_module": "@jupyter-widgets/controls",
428
+ "_model_module_version": "1.5.0",
429
+ "_model_name": "ButtonModel",
430
+ "_view_count": null,
431
+ "_view_module": "@jupyter-widgets/controls",
432
+ "_view_module_version": "1.5.0",
433
+ "_view_name": "ButtonView",
434
+ "button_style": "success",
435
+ "description": "✓ Done",
436
+ "disabled": true,
437
+ "icon": "",
438
+ "layout": "IPY_MODEL_77e5e07552b641cf9c368fb3939cb1d1",
439
+ "style": "IPY_MODEL_235e01b92646444387ebd31ab945358e",
440
+ "tooltip": ""
441
+ }
442
+ }
443
+ }
444
+ }
445
+ },
446
+ "nbformat": 4,
447
+ "nbformat_minor": 0
448
+ }
library/__init__.py ADDED
File without changes
library/basic_caption_gui.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ from .common_gui import get_folder_path, add_pre_postfix, find_replace
5
+ import os
6
+
7
+
8
+ def caption_images(
9
+ caption_text,
10
+ images_dir,
11
+ overwrite,
12
+ caption_ext,
13
+ prefix,
14
+ postfix,
15
+ find_text,
16
+ replace_text,
17
+ ):
18
+ # Check for images_dir
19
+ if not images_dir:
20
+ msgbox('Image folder is missing...')
21
+ return
22
+
23
+ if not caption_ext:
24
+ msgbox('Please provide an extension for the caption files.')
25
+ return
26
+
27
+ if caption_text:
28
+ print(f'Captioning files in {images_dir} with {caption_text}...')
29
+ run_cmd = f'python "tools/caption.py"'
30
+ run_cmd += f' --caption_text="{caption_text}"'
31
+ if overwrite:
32
+ run_cmd += f' --overwrite'
33
+ if caption_ext:
34
+ run_cmd += f' --caption_file_ext="{caption_ext}"'
35
+ run_cmd += f' "{images_dir}"'
36
+
37
+ print(run_cmd)
38
+
39
+ # Run the command
40
+ if os.name == 'posix':
41
+ os.system(run_cmd)
42
+ else:
43
+ subprocess.run(run_cmd)
44
+
45
+ if overwrite:
46
+ if prefix or postfix:
47
+ # Add prefix and postfix
48
+ add_pre_postfix(
49
+ folder=images_dir,
50
+ caption_file_ext=caption_ext,
51
+ prefix=prefix,
52
+ postfix=postfix,
53
+ )
54
+ if find_text:
55
+ find_replace(
56
+ folder_path=images_dir,
57
+ caption_file_ext=caption_ext,
58
+ search_text=find_text,
59
+ replace_text=replace_text,
60
+ )
61
+ else:
62
+ if prefix or postfix:
63
+ msgbox(
64
+ 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
65
+ )
66
+
67
+ print('...captioning done')
68
+
69
+
70
+ # Gradio UI
71
+ def gradio_basic_caption_gui_tab():
72
+ with gr.Tab('Basic Captioning'):
73
+ gr.Markdown(
74
+ 'This utility will allow the creation of simple caption files for each image in a folder.'
75
+ )
76
+ with gr.Row():
77
+ images_dir = gr.Textbox(
78
+ label='Image folder to caption',
79
+ placeholder='Directory containing the images to caption',
80
+ interactive=True,
81
+ )
82
+ folder_button = gr.Button('📂', elem_id='open_folder_small')
83
+ folder_button.click(
84
+ get_folder_path,
85
+ outputs=images_dir,
86
+ show_progress=False,
87
+ )
88
+ caption_ext = gr.Textbox(
89
+ label='Caption file extension',
90
+ placeholder='Extension for caption file. eg: .caption, .txt',
91
+ value='.txt',
92
+ interactive=True,
93
+ )
94
+ overwrite = gr.Checkbox(
95
+ label='Overwrite existing captions in folder',
96
+ interactive=True,
97
+ value=False,
98
+ )
99
+ with gr.Row():
100
+ prefix = gr.Textbox(
101
+ label='Prefix to add to caption',
102
+ placeholder='(Optional)',
103
+ interactive=True,
104
+ )
105
+ caption_text = gr.Textbox(
106
+ label='Caption text',
107
+ placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
108
+ interactive=True,
109
+ )
110
+ postfix = gr.Textbox(
111
+ label='Postfix to add to caption',
112
+ placeholder='(Optional)',
113
+ interactive=True,
114
+ )
115
+ with gr.Row():
116
+ find_text = gr.Textbox(
117
+ label='Find text',
118
+ placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
119
+ interactive=True,
120
+ )
121
+ replace_text = gr.Textbox(
122
+ label='Replacement text',
123
+ placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing',
124
+ interactive=True,
125
+ )
126
+ caption_button = gr.Button('Caption images')
127
+ caption_button.click(
128
+ caption_images,
129
+ inputs=[
130
+ caption_text,
131
+ images_dir,
132
+ overwrite,
133
+ caption_ext,
134
+ prefix,
135
+ postfix,
136
+ find_text,
137
+ replace_text,
138
+ ],
139
+ show_progress=False,
140
+ )
library/blip_caption_gui.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ from .common_gui import get_folder_path, add_pre_postfix
6
+
7
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
8
+
9
+
10
+ def caption_images(
11
+ train_data_dir,
12
+ caption_file_ext,
13
+ batch_size,
14
+ num_beams,
15
+ top_p,
16
+ max_length,
17
+ min_length,
18
+ beam_search,
19
+ prefix,
20
+ postfix,
21
+ ):
22
+ # Check for caption_text_input
23
+ # if caption_text_input == "":
24
+ # msgbox("Caption text is missing...")
25
+ # return
26
+
27
+ # Check for images_dir_input
28
+ if train_data_dir == '':
29
+ msgbox('Image folder is missing...')
30
+ return
31
+
32
+ if caption_file_ext == '':
33
+ msgbox('Please provide an extension for the caption files.')
34
+ return
35
+
36
+ print(f'Captioning files in {train_data_dir}...')
37
+ run_cmd = f'{PYTHON} "finetune/make_captions.py"'
38
+ run_cmd += f' --batch_size="{int(batch_size)}"'
39
+ run_cmd += f' --num_beams="{int(num_beams)}"'
40
+ run_cmd += f' --top_p="{top_p}"'
41
+ run_cmd += f' --max_length="{int(max_length)}"'
42
+ run_cmd += f' --min_length="{int(min_length)}"'
43
+ if beam_search:
44
+ run_cmd += f' --beam_search'
45
+ if caption_file_ext != '':
46
+ run_cmd += f' --caption_extension="{caption_file_ext}"'
47
+ run_cmd += f' "{train_data_dir}"'
48
+ run_cmd += f' --caption_weights="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"'
49
+
50
+ print(run_cmd)
51
+
52
+ # Run the command
53
+ if os.name == 'posix':
54
+ os.system(run_cmd)
55
+ else:
56
+ subprocess.run(run_cmd)
57
+
58
+ # Add prefix and postfix
59
+ add_pre_postfix(
60
+ folder=train_data_dir,
61
+ caption_file_ext=caption_file_ext,
62
+ prefix=prefix,
63
+ postfix=postfix,
64
+ )
65
+
66
+ print('...captioning done')
67
+
68
+
69
+ ###
70
+ # Gradio UI
71
+ ###
72
+
73
+
74
+ def gradio_blip_caption_gui_tab():
75
+ with gr.Tab('BLIP Captioning'):
76
+ gr.Markdown(
77
+ 'This utility will use BLIP to caption files for each images in a folder.'
78
+ )
79
+ with gr.Row():
80
+ train_data_dir = gr.Textbox(
81
+ label='Image folder to caption',
82
+ placeholder='Directory containing the images to caption',
83
+ interactive=True,
84
+ )
85
+ button_train_data_dir_input = gr.Button(
86
+ '📂', elem_id='open_folder_small'
87
+ )
88
+ button_train_data_dir_input.click(
89
+ get_folder_path,
90
+ outputs=train_data_dir,
91
+ show_progress=False,
92
+ )
93
+ with gr.Row():
94
+ caption_file_ext = gr.Textbox(
95
+ label='Caption file extension',
96
+ placeholder='Extention for caption file. eg: .caption, .txt',
97
+ value='.txt',
98
+ interactive=True,
99
+ )
100
+
101
+ prefix = gr.Textbox(
102
+ label='Prefix to add to BLIP caption',
103
+ placeholder='(Optional)',
104
+ interactive=True,
105
+ )
106
+
107
+ postfix = gr.Textbox(
108
+ label='Postfix to add to BLIP caption',
109
+ placeholder='(Optional)',
110
+ interactive=True,
111
+ )
112
+
113
+ batch_size = gr.Number(
114
+ value=1, label='Batch size', interactive=True
115
+ )
116
+
117
+ with gr.Row():
118
+ beam_search = gr.Checkbox(
119
+ label='Use beam search', interactive=True, value=True
120
+ )
121
+ num_beams = gr.Number(
122
+ value=1, label='Number of beams', interactive=True
123
+ )
124
+ top_p = gr.Number(value=0.9, label='Top p', interactive=True)
125
+ max_length = gr.Number(
126
+ value=75, label='Max length', interactive=True
127
+ )
128
+ min_length = gr.Number(
129
+ value=5, label='Min length', interactive=True
130
+ )
131
+
132
+ caption_button = gr.Button('Caption images')
133
+
134
+ caption_button.click(
135
+ caption_images,
136
+ inputs=[
137
+ train_data_dir,
138
+ caption_file_ext,
139
+ batch_size,
140
+ num_beams,
141
+ top_p,
142
+ max_length,
143
+ min_length,
144
+ beam_search,
145
+ prefix,
146
+ postfix,
147
+ ],
148
+ show_progress=False,
149
+ )
library/common_gui.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tkinter import filedialog, Tk
2
+ from easygui import msgbox
3
+ import os
4
+ import gradio as gr
5
+ import easygui
6
+ import shutil
7
+
8
+ folder_symbol = '\U0001f4c2' # 📂
9
+ refresh_symbol = '\U0001f504' # 🔄
10
+ save_style_symbol = '\U0001f4be' # 💾
11
+ document_symbol = '\U0001F4C4' # 📄
12
+
13
+ # define a list of substrings to search for v2 base models
14
+ V2_BASE_MODELS = [
15
+ 'stabilityai/stable-diffusion-2-1-base',
16
+ 'stabilityai/stable-diffusion-2-base',
17
+ ]
18
+
19
+ # define a list of substrings to search for v_parameterization models
20
+ V_PARAMETERIZATION_MODELS = [
21
+ 'stabilityai/stable-diffusion-2-1',
22
+ 'stabilityai/stable-diffusion-2',
23
+ ]
24
+
25
+ # define a list of substrings to v1.x models
26
+ V1_MODELS = [
27
+ 'CompVis/stable-diffusion-v1-4',
28
+ 'runwayml/stable-diffusion-v1-5',
29
+ ]
30
+
31
+ # define a list of substrings to search for
32
+ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
33
+
34
+ FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
35
+
36
+
37
+ def check_if_model_exist(output_name, output_dir, save_model_as):
38
+ if save_model_as in ['diffusers', 'diffusers_safetendors']:
39
+ ckpt_folder = os.path.join(output_dir, output_name)
40
+ if os.path.isdir(ckpt_folder):
41
+ msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?'
42
+ if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
43
+ print(
44
+ 'Aborting training due to existing model with same name...'
45
+ )
46
+ return True
47
+ elif save_model_as in ['ckpt', 'safetensors']:
48
+ ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as)
49
+ if os.path.isfile(ckpt_file):
50
+ msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?'
51
+ if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
52
+ print(
53
+ 'Aborting training due to existing model with same name...'
54
+ )
55
+ return True
56
+ else:
57
+ print(
58
+ 'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...'
59
+ )
60
+ return False
61
+
62
+ return False
63
+
64
+
65
+ def update_my_data(my_data):
66
+ # Update the optimizer based on the use_8bit_adam flag
67
+ use_8bit_adam = my_data.get('use_8bit_adam', False)
68
+ my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
69
+
70
+ # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
71
+ model_list = my_data.get('model_list', [])
72
+ pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
73
+ if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
74
+ my_data['model_list'] = 'custom'
75
+
76
+ # Convert epoch and save_every_n_epochs values to int if they are strings
77
+ for key in ['epoch', 'save_every_n_epochs']:
78
+ value = my_data.get(key, -1)
79
+ if isinstance(value, str) and value.isdigit():
80
+ my_data[key] = int(value)
81
+ elif not value:
82
+ my_data[key] = -1
83
+
84
+ # Update LoRA_type if it is set to LoCon
85
+ if my_data.get('LoRA_type', 'Standard') == 'LoCon':
86
+ my_data['LoRA_type'] = 'LyCORIS/LoCon'
87
+
88
+ # Update model save choices due to changes for LoRA and TI training
89
+ if (
90
+ (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
91
+ and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
92
+ ):
93
+ message = (
94
+ 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
95
+ )
96
+ if my_data.get('LoRA_type'):
97
+ print(message.format('LoRA'))
98
+ if my_data.get('num_vectors_per_token'):
99
+ print(message.format('TI'))
100
+ my_data['save_model_as'] = 'safetensors'
101
+
102
+ return my_data
103
+
104
+
105
+ def get_dir_and_file(file_path):
106
+ dir_path, file_name = os.path.split(file_path)
107
+ return (dir_path, file_name)
108
+
109
+
110
+ # def has_ext_files(directory, extension):
111
+ # # Iterate through all the files in the directory
112
+ # for file in os.listdir(directory):
113
+ # # If the file name ends with extension, return True
114
+ # if file.endswith(extension):
115
+ # return True
116
+ # # If no extension files were found, return False
117
+ # return False
118
+
119
+
120
+ def get_file_path(
121
+ file_path='', default_extension='.json', extension_name='Config files'
122
+ ):
123
+ if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
124
+ current_file_path = file_path
125
+ # print(f'current file path: {current_file_path}')
126
+
127
+ initial_dir, initial_file = get_dir_and_file(file_path)
128
+
129
+ # Create a hidden Tkinter root window
130
+ root = Tk()
131
+ root.wm_attributes('-topmost', 1)
132
+ root.withdraw()
133
+
134
+ # Show the open file dialog and get the selected file path
135
+ file_path = filedialog.askopenfilename(
136
+ filetypes=(
137
+ (extension_name, f'*{default_extension}'),
138
+ ('All files', '*.*'),
139
+ ),
140
+ defaultextension=default_extension,
141
+ initialfile=initial_file,
142
+ initialdir=initial_dir,
143
+ )
144
+
145
+ # Destroy the hidden root window
146
+ root.destroy()
147
+
148
+ # If no file is selected, use the current file path
149
+ if not file_path:
150
+ file_path = current_file_path
151
+ current_file_path = file_path
152
+ # print(f'current file path: {current_file_path}')
153
+
154
+ return file_path
155
+
156
+
157
+ def get_any_file_path(file_path=''):
158
+ if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
159
+ current_file_path = file_path
160
+ # print(f'current file path: {current_file_path}')
161
+
162
+ initial_dir, initial_file = get_dir_and_file(file_path)
163
+
164
+ root = Tk()
165
+ root.wm_attributes('-topmost', 1)
166
+ root.withdraw()
167
+ file_path = filedialog.askopenfilename(
168
+ initialdir=initial_dir,
169
+ initialfile=initial_file,
170
+ )
171
+ root.destroy()
172
+
173
+ if file_path == '':
174
+ file_path = current_file_path
175
+
176
+ return file_path
177
+
178
+
179
+ def remove_doublequote(file_path):
180
+ if file_path != None:
181
+ file_path = file_path.replace('"', '')
182
+
183
+ return file_path
184
+
185
+
186
+ # def set_legacy_8bitadam(optimizer, use_8bit_adam):
187
+ # if optimizer == 'AdamW8bit':
188
+ # # use_8bit_adam = True
189
+ # return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
190
+ # value=True, interactive=False, visible=True
191
+ # )
192
+ # else:
193
+ # # use_8bit_adam = False
194
+ # return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
195
+ # value=False, interactive=False, visible=True
196
+ # )
197
+
198
+
199
+ def get_folder_path(folder_path=''):
200
+ if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
201
+ current_folder_path = folder_path
202
+
203
+ initial_dir, initial_file = get_dir_and_file(folder_path)
204
+
205
+ root = Tk()
206
+ root.wm_attributes('-topmost', 1)
207
+ root.withdraw()
208
+ folder_path = filedialog.askdirectory(initialdir=initial_dir)
209
+ root.destroy()
210
+
211
+ if folder_path == '':
212
+ folder_path = current_folder_path
213
+
214
+ return folder_path
215
+
216
+
217
+ def get_saveasfile_path(
218
+ file_path='', defaultextension='.json', extension_name='Config files'
219
+ ):
220
+ if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
221
+ current_file_path = file_path
222
+ # print(f'current file path: {current_file_path}')
223
+
224
+ initial_dir, initial_file = get_dir_and_file(file_path)
225
+
226
+ root = Tk()
227
+ root.wm_attributes('-topmost', 1)
228
+ root.withdraw()
229
+ save_file_path = filedialog.asksaveasfile(
230
+ filetypes=(
231
+ (f'{extension_name}', f'{defaultextension}'),
232
+ ('All files', '*'),
233
+ ),
234
+ defaultextension=defaultextension,
235
+ initialdir=initial_dir,
236
+ initialfile=initial_file,
237
+ )
238
+ root.destroy()
239
+
240
+ # print(save_file_path)
241
+
242
+ if save_file_path == None:
243
+ file_path = current_file_path
244
+ else:
245
+ print(save_file_path.name)
246
+ file_path = save_file_path.name
247
+
248
+ # print(file_path)
249
+
250
+ return file_path
251
+
252
+
253
+ def get_saveasfilename_path(
254
+ file_path='', extensions='*', extension_name='Config files'
255
+ ):
256
+ if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
257
+ current_file_path = file_path
258
+ # print(f'current file path: {current_file_path}')
259
+
260
+ initial_dir, initial_file = get_dir_and_file(file_path)
261
+
262
+ root = Tk()
263
+ root.wm_attributes('-topmost', 1)
264
+ root.withdraw()
265
+ save_file_path = filedialog.asksaveasfilename(
266
+ filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
267
+ defaultextension=extensions,
268
+ initialdir=initial_dir,
269
+ initialfile=initial_file,
270
+ )
271
+ root.destroy()
272
+
273
+ if save_file_path == '':
274
+ file_path = current_file_path
275
+ else:
276
+ # print(save_file_path)
277
+ file_path = save_file_path
278
+
279
+ return file_path
280
+
281
+
282
+ def add_pre_postfix(
283
+ folder: str = '',
284
+ prefix: str = '',
285
+ postfix: str = '',
286
+ caption_file_ext: str = '.caption',
287
+ ) -> None:
288
+ """
289
+ Add prefix and/or postfix to the content of caption files within a folder.
290
+ If no caption files are found, create one with the requested prefix and/or postfix.
291
+
292
+ Args:
293
+ folder (str): Path to the folder containing caption files.
294
+ prefix (str, optional): Prefix to add to the content of the caption files.
295
+ postfix (str, optional): Postfix to add to the content of the caption files.
296
+ caption_file_ext (str, optional): Extension of the caption files.
297
+ """
298
+
299
+ if prefix == '' and postfix == '':
300
+ return
301
+
302
+ image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
303
+ image_files = [
304
+ f for f in os.listdir(folder) if f.lower().endswith(image_extensions)
305
+ ]
306
+
307
+ for image_file in image_files:
308
+ caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
309
+ caption_file_path = os.path.join(folder, caption_file_name)
310
+
311
+ if not os.path.exists(caption_file_path):
312
+ with open(caption_file_path, 'w') as f:
313
+ separator = ' ' if prefix and postfix else ''
314
+ f.write(f'{prefix}{separator}{postfix}')
315
+ else:
316
+ with open(caption_file_path, 'r+') as f:
317
+ content = f.read()
318
+ content = content.rstrip()
319
+ f.seek(0, 0)
320
+
321
+ prefix_separator = ' ' if prefix else ''
322
+ postfix_separator = ' ' if postfix else ''
323
+ f.write(
324
+ f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}'
325
+ )
326
+
327
+
328
+ def has_ext_files(folder_path: str, file_extension: str) -> bool:
329
+ """
330
+ Check if there are any files with the specified extension in the given folder.
331
+
332
+ Args:
333
+ folder_path (str): Path to the folder containing files.
334
+ file_extension (str): Extension of the files to look for.
335
+
336
+ Returns:
337
+ bool: True if files with the specified extension are found, False otherwise.
338
+ """
339
+ for file in os.listdir(folder_path):
340
+ if file.endswith(file_extension):
341
+ return True
342
+ return False
343
+
344
+
345
+ def find_replace(
346
+ folder_path: str = '',
347
+ caption_file_ext: str = '.caption',
348
+ search_text: str = '',
349
+ replace_text: str = '',
350
+ ) -> None:
351
+ """
352
+ Find and replace text in caption files within a folder.
353
+
354
+ Args:
355
+ folder_path (str, optional): Path to the folder containing caption files.
356
+ caption_file_ext (str, optional): Extension of the caption files.
357
+ search_text (str, optional): Text to search for in the caption files.
358
+ replace_text (str, optional): Text to replace the search text with.
359
+ """
360
+ print('Running caption find/replace')
361
+
362
+ if not has_ext_files(folder_path, caption_file_ext):
363
+ msgbox(
364
+ f'No files with extension {caption_file_ext} were found in {folder_path}...'
365
+ )
366
+ return
367
+
368
+ if search_text == '':
369
+ return
370
+
371
+ caption_files = [
372
+ f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)
373
+ ]
374
+
375
+ for caption_file in caption_files:
376
+ with open(
377
+ os.path.join(folder_path, caption_file), 'r', errors='ignore'
378
+ ) as f:
379
+ content = f.read()
380
+
381
+ content = content.replace(search_text, replace_text)
382
+
383
+ with open(os.path.join(folder_path, caption_file), 'w') as f:
384
+ f.write(content)
385
+
386
+
387
+ def color_aug_changed(color_aug):
388
+ if color_aug:
389
+ msgbox(
390
+ 'Disabling "Cache latent" because "Color augmentation" has been selected...'
391
+ )
392
+ return gr.Checkbox.update(value=False, interactive=False)
393
+ else:
394
+ return gr.Checkbox.update(value=True, interactive=True)
395
+
396
+
397
+ def save_inference_file(output_dir, v2, v_parameterization, output_name):
398
+ # List all files in the directory
399
+ files = os.listdir(output_dir)
400
+
401
+ # Iterate over the list of files
402
+ for file in files:
403
+ # Check if the file starts with the value of output_name
404
+ if file.startswith(output_name):
405
+ # Check if it is a file or a directory
406
+ if os.path.isfile(os.path.join(output_dir, file)):
407
+ # Split the file name and extension
408
+ file_name, ext = os.path.splitext(file)
409
+
410
+ # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
411
+ if v2 and v_parameterization:
412
+ print(
413
+ f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml'
414
+ )
415
+ shutil.copy(
416
+ f'./v2_inference/v2-inference-v.yaml',
417
+ f'{output_dir}/{file_name}.yaml',
418
+ )
419
+ elif v2:
420
+ print(
421
+ f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml'
422
+ )
423
+ shutil.copy(
424
+ f'./v2_inference/v2-inference.yaml',
425
+ f'{output_dir}/{file_name}.yaml',
426
+ )
427
+
428
+
429
+ def set_pretrained_model_name_or_path_input(
430
+ model_list, pretrained_model_name_or_path, v2, v_parameterization
431
+ ):
432
+ # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
433
+ if str(model_list) in V2_BASE_MODELS:
434
+ print('SD v2 model detected. Setting --v2 parameter')
435
+ v2 = True
436
+ v_parameterization = False
437
+ pretrained_model_name_or_path = str(model_list)
438
+
439
+ # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
440
+ if str(model_list) in V_PARAMETERIZATION_MODELS:
441
+ print(
442
+ 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
443
+ )
444
+ v2 = True
445
+ v_parameterization = True
446
+ pretrained_model_name_or_path = str(model_list)
447
+
448
+ if str(model_list) in V1_MODELS:
449
+ v2 = False
450
+ v_parameterization = False
451
+ pretrained_model_name_or_path = str(model_list)
452
+
453
+ if model_list == 'custom':
454
+ if (
455
+ str(pretrained_model_name_or_path) in V1_MODELS
456
+ or str(pretrained_model_name_or_path) in V2_BASE_MODELS
457
+ or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS
458
+ ):
459
+ pretrained_model_name_or_path = ''
460
+ v2 = False
461
+ v_parameterization = False
462
+ return model_list, pretrained_model_name_or_path, v2, v_parameterization
463
+
464
+
465
+ def set_v2_checkbox(model_list, v2, v_parameterization):
466
+ # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
467
+ if str(model_list) in V2_BASE_MODELS:
468
+ v2 = True
469
+ v_parameterization = False
470
+
471
+ # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
472
+ if str(model_list) in V_PARAMETERIZATION_MODELS:
473
+ v2 = True
474
+ v_parameterization = True
475
+
476
+ if str(model_list) in V1_MODELS:
477
+ v2 = False
478
+ v_parameterization = False
479
+
480
+ return v2, v_parameterization
481
+
482
+
483
+ def set_model_list(
484
+ model_list,
485
+ pretrained_model_name_or_path,
486
+ v2,
487
+ v_parameterization,
488
+ ):
489
+
490
+ if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
491
+ model_list = 'custom'
492
+ else:
493
+ model_list = pretrained_model_name_or_path
494
+
495
+ return model_list, v2, v_parameterization
496
+
497
+
498
+ ###
499
+ ### Gradio common GUI section
500
+ ###
501
+
502
+
503
+ def gradio_config():
504
+ with gr.Accordion('Configuration file', open=False):
505
+ with gr.Row():
506
+ button_open_config = gr.Button('Open 📂', elem_id='open_folder')
507
+ button_save_config = gr.Button('Save 💾', elem_id='open_folder')
508
+ button_save_as_config = gr.Button(
509
+ 'Save as... 💾', elem_id='open_folder'
510
+ )
511
+ config_file_name = gr.Textbox(
512
+ label='',
513
+ placeholder="type the configuration file path or use the 'Open' button above to select it...",
514
+ interactive=True,
515
+ )
516
+ button_load_config = gr.Button('Load 💾', elem_id='open_folder')
517
+ config_file_name.change(
518
+ remove_doublequote,
519
+ inputs=[config_file_name],
520
+ outputs=[config_file_name],
521
+ )
522
+ return (
523
+ button_open_config,
524
+ button_save_config,
525
+ button_save_as_config,
526
+ config_file_name,
527
+ button_load_config,
528
+ )
529
+
530
+
531
+ def get_pretrained_model_name_or_path_file(
532
+ model_list, pretrained_model_name_or_path
533
+ ):
534
+ pretrained_model_name_or_path = get_any_file_path(
535
+ pretrained_model_name_or_path
536
+ )
537
+ set_model_list(model_list, pretrained_model_name_or_path)
538
+
539
+
540
+ def gradio_source_model(save_model_as_choices = [
541
+ 'same as source model',
542
+ 'ckpt',
543
+ 'diffusers',
544
+ 'diffusers_safetensors',
545
+ 'safetensors',
546
+ ]):
547
+ with gr.Tab('Source model'):
548
+ # Define the input elements
549
+ with gr.Row():
550
+ pretrained_model_name_or_path = gr.Textbox(
551
+ label='Pretrained model name or path',
552
+ placeholder='enter the path to custom model or name of pretrained model',
553
+ value='runwayml/stable-diffusion-v1-5',
554
+ )
555
+ pretrained_model_name_or_path_file = gr.Button(
556
+ document_symbol, elem_id='open_folder_small'
557
+ )
558
+ pretrained_model_name_or_path_file.click(
559
+ get_any_file_path,
560
+ inputs=pretrained_model_name_or_path,
561
+ outputs=pretrained_model_name_or_path,
562
+ show_progress=False,
563
+ )
564
+ pretrained_model_name_or_path_folder = gr.Button(
565
+ folder_symbol, elem_id='open_folder_small'
566
+ )
567
+ pretrained_model_name_or_path_folder.click(
568
+ get_folder_path,
569
+ inputs=pretrained_model_name_or_path,
570
+ outputs=pretrained_model_name_or_path,
571
+ show_progress=False,
572
+ )
573
+ model_list = gr.Dropdown(
574
+ label='Model Quick Pick',
575
+ choices=[
576
+ 'custom',
577
+ 'stabilityai/stable-diffusion-2-1-base',
578
+ 'stabilityai/stable-diffusion-2-base',
579
+ 'stabilityai/stable-diffusion-2-1',
580
+ 'stabilityai/stable-diffusion-2',
581
+ 'runwayml/stable-diffusion-v1-5',
582
+ 'CompVis/stable-diffusion-v1-4',
583
+ ],
584
+ value='runwayml/stable-diffusion-v1-5',
585
+ )
586
+ save_model_as = gr.Dropdown(
587
+ label='Save trained model as',
588
+ choices=save_model_as_choices,
589
+ value='safetensors',
590
+ )
591
+
592
+ with gr.Row():
593
+ v2 = gr.Checkbox(label='v2', value=False)
594
+ v_parameterization = gr.Checkbox(
595
+ label='v_parameterization', value=False
596
+ )
597
+ v2.change(
598
+ set_v2_checkbox,
599
+ inputs=[model_list, v2, v_parameterization],
600
+ outputs=[v2, v_parameterization],
601
+ show_progress=False,
602
+ )
603
+ v_parameterization.change(
604
+ set_v2_checkbox,
605
+ inputs=[model_list, v2, v_parameterization],
606
+ outputs=[v2, v_parameterization],
607
+ show_progress=False,
608
+ )
609
+ model_list.change(
610
+ set_pretrained_model_name_or_path_input,
611
+ inputs=[
612
+ model_list,
613
+ pretrained_model_name_or_path,
614
+ v2,
615
+ v_parameterization,
616
+ ],
617
+ outputs=[
618
+ model_list,
619
+ pretrained_model_name_or_path,
620
+ v2,
621
+ v_parameterization,
622
+ ],
623
+ show_progress=False,
624
+ )
625
+ # Update the model list and parameters when user click outside the button or field
626
+ pretrained_model_name_or_path.change(
627
+ set_model_list,
628
+ inputs=[
629
+ model_list,
630
+ pretrained_model_name_or_path,
631
+ v2,
632
+ v_parameterization,
633
+ ],
634
+ outputs=[
635
+ model_list,
636
+ v2,
637
+ v_parameterization,
638
+ ],
639
+ show_progress=False,
640
+ )
641
+ return (
642
+ pretrained_model_name_or_path,
643
+ v2,
644
+ v_parameterization,
645
+ save_model_as,
646
+ model_list,
647
+ )
648
+
649
+
650
+ def gradio_training(
651
+ learning_rate_value='1e-6',
652
+ lr_scheduler_value='constant',
653
+ lr_warmup_value='0',
654
+ ):
655
+ with gr.Row():
656
+ train_batch_size = gr.Slider(
657
+ minimum=1,
658
+ maximum=64,
659
+ label='Train batch size',
660
+ value=1,
661
+ step=1,
662
+ )
663
+ epoch = gr.Number(label='Epoch', value=1, precision=0)
664
+ save_every_n_epochs = gr.Number(
665
+ label='Save every N epochs', value=1, precision=0
666
+ )
667
+ caption_extension = gr.Textbox(
668
+ label='Caption Extension',
669
+ placeholder='(Optional) Extension for caption files. default: .caption',
670
+ )
671
+ with gr.Row():
672
+ mixed_precision = gr.Dropdown(
673
+ label='Mixed precision',
674
+ choices=[
675
+ 'no',
676
+ 'fp16',
677
+ 'bf16',
678
+ ],
679
+ value='fp16',
680
+ )
681
+ save_precision = gr.Dropdown(
682
+ label='Save precision',
683
+ choices=[
684
+ 'float',
685
+ 'fp16',
686
+ 'bf16',
687
+ ],
688
+ value='fp16',
689
+ )
690
+ num_cpu_threads_per_process = gr.Slider(
691
+ minimum=1,
692
+ maximum=os.cpu_count(),
693
+ step=1,
694
+ label='Number of CPU threads per core',
695
+ value=2,
696
+ )
697
+ seed = gr.Textbox(label='Seed', placeholder='(Optional) eg:1234')
698
+ cache_latents = gr.Checkbox(label='Cache latent', value=True)
699
+ with gr.Row():
700
+ learning_rate = gr.Textbox(
701
+ label='Learning rate', value=learning_rate_value
702
+ )
703
+ lr_scheduler = gr.Dropdown(
704
+ label='LR Scheduler',
705
+ choices=[
706
+ 'adafactor',
707
+ 'constant',
708
+ 'constant_with_warmup',
709
+ 'cosine',
710
+ 'cosine_with_restarts',
711
+ 'linear',
712
+ 'polynomial',
713
+ ],
714
+ value=lr_scheduler_value,
715
+ )
716
+ lr_warmup = gr.Textbox(
717
+ label='LR warmup (% of steps)', value=lr_warmup_value
718
+ )
719
+ optimizer = gr.Dropdown(
720
+ label='Optimizer',
721
+ choices=[
722
+ 'AdamW',
723
+ 'AdamW8bit',
724
+ 'Adafactor',
725
+ 'DAdaptation',
726
+ 'Lion',
727
+ 'SGDNesterov',
728
+ 'SGDNesterov8bit',
729
+ ],
730
+ value='AdamW8bit',
731
+ interactive=True,
732
+ )
733
+ with gr.Row():
734
+ optimizer_args = gr.Textbox(
735
+ label='Optimizer extra arguments',
736
+ placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True',
737
+ )
738
+ return (
739
+ learning_rate,
740
+ lr_scheduler,
741
+ lr_warmup,
742
+ train_batch_size,
743
+ epoch,
744
+ save_every_n_epochs,
745
+ mixed_precision,
746
+ save_precision,
747
+ num_cpu_threads_per_process,
748
+ seed,
749
+ caption_extension,
750
+ cache_latents,
751
+ optimizer,
752
+ optimizer_args,
753
+ )
754
+
755
+
756
+ def run_cmd_training(**kwargs):
757
+ options = [
758
+ f' --learning_rate="{kwargs.get("learning_rate", "")}"'
759
+ if kwargs.get('learning_rate')
760
+ else '',
761
+ f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
762
+ if kwargs.get('lr_scheduler')
763
+ else '',
764
+ f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
765
+ if kwargs.get('lr_warmup_steps')
766
+ else '',
767
+ f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
768
+ if kwargs.get('train_batch_size')
769
+ else '',
770
+ f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
771
+ if kwargs.get('max_train_steps')
772
+ else '',
773
+ f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"'
774
+ if int(kwargs.get('save_every_n_epochs'))
775
+ else '',
776
+ f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
777
+ if kwargs.get('mixed_precision')
778
+ else '',
779
+ f' --save_precision="{kwargs.get("save_precision", "")}"'
780
+ if kwargs.get('save_precision')
781
+ else '',
782
+ f' --seed="{kwargs.get("seed", "")}"'
783
+ if kwargs.get('seed') != ''
784
+ else '',
785
+ f' --caption_extension="{kwargs.get("caption_extension", "")}"'
786
+ if kwargs.get('caption_extension')
787
+ else '',
788
+ ' --cache_latents' if kwargs.get('cache_latents') else '',
789
+ # ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
790
+ f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"',
791
+ f' --optimizer_args {kwargs.get("optimizer_args", "")}'
792
+ if not kwargs.get('optimizer_args') == ''
793
+ else '',
794
+ ]
795
+ run_cmd = ''.join(options)
796
+ return run_cmd
797
+
798
+
799
+ def gradio_advanced_training():
800
+ with gr.Row():
801
+ additional_parameters = gr.Textbox(
802
+ label='Additional parameters',
803
+ placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"',
804
+ )
805
+ with gr.Row():
806
+ keep_tokens = gr.Slider(
807
+ label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
808
+ )
809
+ clip_skip = gr.Slider(
810
+ label='Clip skip', value='1', minimum=1, maximum=12, step=1
811
+ )
812
+ max_token_length = gr.Dropdown(
813
+ label='Max Token Length',
814
+ choices=[
815
+ '75',
816
+ '150',
817
+ '225',
818
+ ],
819
+ value='75',
820
+ )
821
+ full_fp16 = gr.Checkbox(
822
+ label='Full fp16 training (experimental)', value=False
823
+ )
824
+ with gr.Row():
825
+ gradient_checkpointing = gr.Checkbox(
826
+ label='Gradient checkpointing', value=False
827
+ )
828
+ shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False)
829
+ persistent_data_loader_workers = gr.Checkbox(
830
+ label='Persistent data loader', value=False
831
+ )
832
+ mem_eff_attn = gr.Checkbox(
833
+ label='Memory efficient attention', value=False
834
+ )
835
+ with gr.Row():
836
+ # This use_8bit_adam element should be removed in a future release as it is no longer used
837
+ # use_8bit_adam = gr.Checkbox(
838
+ # label='Use 8bit adam', value=False, visible=False
839
+ # )
840
+ xformers = gr.Checkbox(label='Use xformers', value=True)
841
+ color_aug = gr.Checkbox(label='Color augmentation', value=False)
842
+ flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
843
+ min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1)
844
+ with gr.Row():
845
+ bucket_no_upscale = gr.Checkbox(
846
+ label="Don't upscale bucket resolution", value=True
847
+ )
848
+ bucket_reso_steps = gr.Number(
849
+ label='Bucket resolution steps', value=64
850
+ )
851
+ random_crop = gr.Checkbox(
852
+ label='Random crop instead of center crop', value=False
853
+ )
854
+ noise_offset = gr.Textbox(
855
+ label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1'
856
+ )
857
+
858
+ with gr.Row():
859
+ caption_dropout_every_n_epochs = gr.Number(
860
+ label='Dropout caption every n epochs', value=0
861
+ )
862
+ caption_dropout_rate = gr.Slider(
863
+ label='Rate of caption dropout', value=0, minimum=0, maximum=1
864
+ )
865
+ vae_batch_size = gr.Slider(
866
+ label='VAE batch size',
867
+ minimum=0,
868
+ maximum=32,
869
+ value=0,
870
+ step=1
871
+ )
872
+ with gr.Row():
873
+ save_state = gr.Checkbox(label='Save training state', value=False)
874
+ resume = gr.Textbox(
875
+ label='Resume from saved training state',
876
+ placeholder='path to "last-state" state folder to resume from',
877
+ )
878
+ resume_button = gr.Button('📂', elem_id='open_folder_small')
879
+ resume_button.click(
880
+ get_folder_path,
881
+ outputs=resume,
882
+ show_progress=False,
883
+ )
884
+ max_train_epochs = gr.Textbox(
885
+ label='Max train epoch',
886
+ placeholder='(Optional) Override number of epoch',
887
+ )
888
+ max_data_loader_n_workers = gr.Textbox(
889
+ label='Max num workers for DataLoader',
890
+ placeholder='(Optional) Override number of epoch. Default: 8',
891
+ value="0",
892
+ )
893
+ return (
894
+ # use_8bit_adam,
895
+ xformers,
896
+ full_fp16,
897
+ gradient_checkpointing,
898
+ shuffle_caption,
899
+ color_aug,
900
+ flip_aug,
901
+ clip_skip,
902
+ mem_eff_attn,
903
+ save_state,
904
+ resume,
905
+ max_token_length,
906
+ max_train_epochs,
907
+ max_data_loader_n_workers,
908
+ keep_tokens,
909
+ persistent_data_loader_workers,
910
+ bucket_no_upscale,
911
+ random_crop,
912
+ bucket_reso_steps,
913
+ caption_dropout_every_n_epochs,
914
+ caption_dropout_rate,
915
+ noise_offset,
916
+ additional_parameters,
917
+ vae_batch_size,
918
+ min_snr_gamma,
919
+ )
920
+
921
+
922
+ def run_cmd_advanced_training(**kwargs):
923
+ options = [
924
+ f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
925
+ if kwargs.get('max_train_epochs')
926
+ else '',
927
+ f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
928
+ if kwargs.get('max_data_loader_n_workers')
929
+ else '',
930
+ f' --max_token_length={kwargs.get("max_token_length", "")}'
931
+ if int(kwargs.get('max_token_length', 75)) > 75
932
+ else '',
933
+ f' --clip_skip={kwargs.get("clip_skip", "")}'
934
+ if int(kwargs.get('clip_skip', 1)) > 1
935
+ else '',
936
+ f' --resume="{kwargs.get("resume", "")}"'
937
+ if kwargs.get('resume')
938
+ else '',
939
+ f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
940
+ if int(kwargs.get('keep_tokens', 0)) > 0
941
+ else '',
942
+ f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
943
+ if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
944
+ else '',
945
+ f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
946
+ if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
947
+ else '',
948
+ f' --vae_batch_size="{kwargs.get("vae_batch_size", 0)}"'
949
+ if int(kwargs.get('vae_batch_size', 0)) > 0
950
+ else '',
951
+ f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
952
+ if int(kwargs.get('bucket_reso_steps', 64)) >= 1
953
+ else '',
954
+ f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}'
955
+ if int(kwargs.get('min_snr_gamma', 0)) >= 1
956
+ else '',
957
+ ' --save_state' if kwargs.get('save_state') else '',
958
+ ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
959
+ ' --color_aug' if kwargs.get('color_aug') else '',
960
+ ' --flip_aug' if kwargs.get('flip_aug') else '',
961
+ ' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
962
+ ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing')
963
+ else '',
964
+ ' --full_fp16' if kwargs.get('full_fp16') else '',
965
+ ' --xformers' if kwargs.get('xformers') else '',
966
+ # ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
967
+ ' --persistent_data_loader_workers'
968
+ if kwargs.get('persistent_data_loader_workers')
969
+ else '',
970
+ ' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '',
971
+ ' --random_crop' if kwargs.get('random_crop') else '',
972
+ f' --noise_offset={float(kwargs.get("noise_offset", 0))}'
973
+ if not kwargs.get('noise_offset', '') == ''
974
+ else '',
975
+ f' {kwargs.get("additional_parameters", "")}',
976
+ ]
977
+ run_cmd = ''.join(options)
978
+ return run_cmd
library/config_util.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+ # from toolz import curry
12
+ from typing import (
13
+ List,
14
+ Optional,
15
+ Sequence,
16
+ Tuple,
17
+ Union,
18
+ )
19
+
20
+ import toml
21
+ import voluptuous
22
+ from voluptuous import (
23
+ Any,
24
+ ExactSequence,
25
+ MultipleInvalid,
26
+ Object,
27
+ Required,
28
+ Schema,
29
+ )
30
+ from transformers import CLIPTokenizer
31
+
32
+ from . import train_util
33
+ from .train_util import (
34
+ DreamBoothSubset,
35
+ FineTuningSubset,
36
+ DreamBoothDataset,
37
+ FineTuningDataset,
38
+ DatasetGroup,
39
+ )
40
+
41
+
42
+ def add_config_arguments(parser: argparse.ArgumentParser):
43
+ parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
44
+
45
+ # TODO: inherit Params class in Subset, Dataset
46
+
47
+ @dataclass
48
+ class BaseSubsetParams:
49
+ image_dir: Optional[str] = None
50
+ num_repeats: int = 1
51
+ shuffle_caption: bool = False
52
+ keep_tokens: int = 0
53
+ color_aug: bool = False
54
+ flip_aug: bool = False
55
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
56
+ random_crop: bool = False
57
+ caption_dropout_rate: float = 0.0
58
+ caption_dropout_every_n_epochs: int = 0
59
+ caption_tag_dropout_rate: float = 0.0
60
+ token_warmup_min: int = 1
61
+ token_warmup_step: float = 0
62
+
63
+ @dataclass
64
+ class DreamBoothSubsetParams(BaseSubsetParams):
65
+ is_reg: bool = False
66
+ class_tokens: Optional[str] = None
67
+ caption_extension: str = ".caption"
68
+
69
+ @dataclass
70
+ class FineTuningSubsetParams(BaseSubsetParams):
71
+ metadata_file: Optional[str] = None
72
+
73
+ @dataclass
74
+ class BaseDatasetParams:
75
+ tokenizer: CLIPTokenizer = None
76
+ max_token_length: int = None
77
+ resolution: Optional[Tuple[int, int]] = None
78
+ debug_dataset: bool = False
79
+
80
+ @dataclass
81
+ class DreamBoothDatasetParams(BaseDatasetParams):
82
+ batch_size: int = 1
83
+ enable_bucket: bool = False
84
+ min_bucket_reso: int = 256
85
+ max_bucket_reso: int = 1024
86
+ bucket_reso_steps: int = 64
87
+ bucket_no_upscale: bool = False
88
+ prior_loss_weight: float = 1.0
89
+
90
+ @dataclass
91
+ class FineTuningDatasetParams(BaseDatasetParams):
92
+ batch_size: int = 1
93
+ enable_bucket: bool = False
94
+ min_bucket_reso: int = 256
95
+ max_bucket_reso: int = 1024
96
+ bucket_reso_steps: int = 64
97
+ bucket_no_upscale: bool = False
98
+
99
+ @dataclass
100
+ class SubsetBlueprint:
101
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
102
+
103
+ @dataclass
104
+ class DatasetBlueprint:
105
+ is_dreambooth: bool
106
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
107
+ subsets: Sequence[SubsetBlueprint]
108
+
109
+ @dataclass
110
+ class DatasetGroupBlueprint:
111
+ datasets: Sequence[DatasetBlueprint]
112
+ @dataclass
113
+ class Blueprint:
114
+ dataset_group: DatasetGroupBlueprint
115
+
116
+
117
+ class ConfigSanitizer:
118
+ # @curry
119
+ @staticmethod
120
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
121
+ Schema(ExactSequence([klass, klass]))(value)
122
+ return tuple(value)
123
+
124
+ # @curry
125
+ @staticmethod
126
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
127
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
128
+ try:
129
+ Schema(klass)(value)
130
+ return (value, value)
131
+ except:
132
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
133
+
134
+ # subset schema
135
+ SUBSET_ASCENDABLE_SCHEMA = {
136
+ "color_aug": bool,
137
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
138
+ "flip_aug": bool,
139
+ "num_repeats": int,
140
+ "random_crop": bool,
141
+ "shuffle_caption": bool,
142
+ "keep_tokens": int,
143
+ "token_warmup_min": int,
144
+ "token_warmup_step": Any(float,int),
145
+ }
146
+ # DO means DropOut
147
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
148
+ "caption_dropout_every_n_epochs": int,
149
+ "caption_dropout_rate": Any(float, int),
150
+ "caption_tag_dropout_rate": Any(float, int),
151
+ }
152
+ # DB means DreamBooth
153
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
154
+ "caption_extension": str,
155
+ "class_tokens": str,
156
+ }
157
+ DB_SUBSET_DISTINCT_SCHEMA = {
158
+ Required("image_dir"): str,
159
+ "is_reg": bool,
160
+ }
161
+ # FT means FineTuning
162
+ FT_SUBSET_DISTINCT_SCHEMA = {
163
+ Required("metadata_file"): str,
164
+ "image_dir": str,
165
+ }
166
+
167
+ # datasets schema
168
+ DATASET_ASCENDABLE_SCHEMA = {
169
+ "batch_size": int,
170
+ "bucket_no_upscale": bool,
171
+ "bucket_reso_steps": int,
172
+ "enable_bucket": bool,
173
+ "max_bucket_reso": int,
174
+ "min_bucket_reso": int,
175
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
176
+ }
177
+
178
+ # options handled by argparse but not handled by user config
179
+ ARGPARSE_SPECIFIC_SCHEMA = {
180
+ "debug_dataset": bool,
181
+ "max_token_length": Any(None, int),
182
+ "prior_loss_weight": Any(float, int),
183
+ }
184
+ # for handling default None value of argparse
185
+ ARGPARSE_NULLABLE_OPTNAMES = [
186
+ "face_crop_aug_range",
187
+ "resolution",
188
+ ]
189
+ # prepare map because option name may differ among argparse and user config
190
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
191
+ "train_batch_size": "batch_size",
192
+ "dataset_repeats": "num_repeats",
193
+ }
194
+
195
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
196
+ assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
197
+
198
+ self.db_subset_schema = self.__merge_dict(
199
+ self.SUBSET_ASCENDABLE_SCHEMA,
200
+ self.DB_SUBSET_DISTINCT_SCHEMA,
201
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
202
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
203
+ )
204
+
205
+ self.ft_subset_schema = self.__merge_dict(
206
+ self.SUBSET_ASCENDABLE_SCHEMA,
207
+ self.FT_SUBSET_DISTINCT_SCHEMA,
208
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
209
+ )
210
+
211
+ self.db_dataset_schema = self.__merge_dict(
212
+ self.DATASET_ASCENDABLE_SCHEMA,
213
+ self.SUBSET_ASCENDABLE_SCHEMA,
214
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
215
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
216
+ {"subsets": [self.db_subset_schema]},
217
+ )
218
+
219
+ self.ft_dataset_schema = self.__merge_dict(
220
+ self.DATASET_ASCENDABLE_SCHEMA,
221
+ self.SUBSET_ASCENDABLE_SCHEMA,
222
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
223
+ {"subsets": [self.ft_subset_schema]},
224
+ )
225
+
226
+ if support_dreambooth and support_finetuning:
227
+ def validate_flex_dataset(dataset_config: dict):
228
+ subsets_config = dataset_config.get("subsets", [])
229
+
230
+ # check dataset meets FT style
231
+ # NOTE: all FT subsets should have "metadata_file"
232
+ if all(["metadata_file" in subset for subset in subsets_config]):
233
+ return Schema(self.ft_dataset_schema)(dataset_config)
234
+ # check dataset meets DB style
235
+ # NOTE: all DB subsets should have no "metadata_file"
236
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
237
+ return Schema(self.db_dataset_schema)(dataset_config)
238
+ else:
239
+ raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
240
+
241
+ self.dataset_schema = validate_flex_dataset
242
+ elif support_dreambooth:
243
+ self.dataset_schema = self.db_dataset_schema
244
+ else:
245
+ self.dataset_schema = self.ft_dataset_schema
246
+
247
+ self.general_schema = self.__merge_dict(
248
+ self.DATASET_ASCENDABLE_SCHEMA,
249
+ self.SUBSET_ASCENDABLE_SCHEMA,
250
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
251
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
252
+ )
253
+
254
+ self.user_config_validator = Schema({
255
+ "general": self.general_schema,
256
+ "datasets": [self.dataset_schema],
257
+ })
258
+
259
+ self.argparse_schema = self.__merge_dict(
260
+ self.general_schema,
261
+ self.ARGPARSE_SPECIFIC_SCHEMA,
262
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
263
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
264
+ )
265
+
266
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
267
+
268
+ def sanitize_user_config(self, user_config: dict) -> dict:
269
+ try:
270
+ return self.user_config_validator(user_config)
271
+ except MultipleInvalid:
272
+ # TODO: エラー発生時のメッセージをわかりやすくする
273
+ print("Invalid user config / ユーザ設定の形式が正しくないようです")
274
+ raise
275
+
276
+ # NOTE: In nature, argument parser result is not needed to be sanitize
277
+ # However this will help us to detect program bug
278
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
279
+ try:
280
+ return self.argparse_config_validator(argparse_namespace)
281
+ except MultipleInvalid:
282
+ # XXX: this should be a bug
283
+ print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
284
+ raise
285
+
286
+ # NOTE: value would be overwritten by latter dict if there is already the same key
287
+ @staticmethod
288
+ def __merge_dict(*dict_list: dict) -> dict:
289
+ merged = {}
290
+ for schema in dict_list:
291
+ # merged |= schema
292
+ for k, v in schema.items():
293
+ merged[k] = v
294
+ return merged
295
+
296
+
297
+ class BlueprintGenerator:
298
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
299
+ }
300
+
301
+ def __init__(self, sanitizer: ConfigSanitizer):
302
+ self.sanitizer = sanitizer
303
+
304
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
305
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
306
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
307
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
308
+
309
+ # convert argparse namespace to dict like config
310
+ # NOTE: it is ok to have extra entries in dict
311
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
312
+ argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
313
+
314
+ general_config = sanitized_user_config.get("general", {})
315
+
316
+ dataset_blueprints = []
317
+ for dataset_config in sanitized_user_config.get("datasets", []):
318
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
319
+ subsets = dataset_config.get("subsets", [])
320
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
321
+ if is_dreambooth:
322
+ subset_params_klass = DreamBoothSubsetParams
323
+ dataset_params_klass = DreamBoothDatasetParams
324
+ else:
325
+ subset_params_klass = FineTuningSubsetParams
326
+ dataset_params_klass = FineTuningDatasetParams
327
+
328
+ subset_blueprints = []
329
+ for subset_config in subsets:
330
+ params = self.generate_params_by_fallbacks(subset_params_klass,
331
+ [subset_config, dataset_config, general_config, argparse_config, runtime_params])
332
+ subset_blueprints.append(SubsetBlueprint(params))
333
+
334
+ params = self.generate_params_by_fallbacks(dataset_params_klass,
335
+ [dataset_config, general_config, argparse_config, runtime_params])
336
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
337
+
338
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
339
+
340
+ return Blueprint(dataset_group_blueprint)
341
+
342
+ @staticmethod
343
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
344
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
345
+ search_value = BlueprintGenerator.search_value
346
+ default_params = asdict(param_klass())
347
+ param_names = default_params.keys()
348
+
349
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
350
+
351
+ return param_klass(**params)
352
+
353
+ @staticmethod
354
+ def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
355
+ for cand in fallbacks:
356
+ value = cand.get(key)
357
+ if value is not None:
358
+ return value
359
+
360
+ return default_value
361
+
362
+
363
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
364
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
365
+
366
+ for dataset_blueprint in dataset_group_blueprint.datasets:
367
+ if dataset_blueprint.is_dreambooth:
368
+ subset_klass = DreamBoothSubset
369
+ dataset_klass = DreamBoothDataset
370
+ else:
371
+ subset_klass = FineTuningSubset
372
+ dataset_klass = FineTuningDataset
373
+
374
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
375
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
376
+ datasets.append(dataset)
377
+
378
+ # print info
379
+ info = ""
380
+ for i, dataset in enumerate(datasets):
381
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
382
+ info += dedent(f"""\
383
+ [Dataset {i}]
384
+ batch_size: {dataset.batch_size}
385
+ resolution: {(dataset.width, dataset.height)}
386
+ enable_bucket: {dataset.enable_bucket}
387
+ """)
388
+
389
+ if dataset.enable_bucket:
390
+ info += indent(dedent(f"""\
391
+ min_bucket_reso: {dataset.min_bucket_reso}
392
+ max_bucket_reso: {dataset.max_bucket_reso}
393
+ bucket_reso_steps: {dataset.bucket_reso_steps}
394
+ bucket_no_upscale: {dataset.bucket_no_upscale}
395
+ \n"""), " ")
396
+ else:
397
+ info += "\n"
398
+
399
+ for j, subset in enumerate(dataset.subsets):
400
+ info += indent(dedent(f"""\
401
+ [Subset {j} of Dataset {i}]
402
+ image_dir: "{subset.image_dir}"
403
+ image_count: {subset.img_count}
404
+ num_repeats: {subset.num_repeats}
405
+ shuffle_caption: {subset.shuffle_caption}
406
+ keep_tokens: {subset.keep_tokens}
407
+ caption_dropout_rate: {subset.caption_dropout_rate}
408
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
409
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
410
+ color_aug: {subset.color_aug}
411
+ flip_aug: {subset.flip_aug}
412
+ face_crop_aug_range: {subset.face_crop_aug_range}
413
+ random_crop: {subset.random_crop}
414
+ token_warmup_min: {subset.token_warmup_min},
415
+ token_warmup_step: {subset.token_warmup_step},
416
+ """), " ")
417
+
418
+ if is_dreambooth:
419
+ info += indent(dedent(f"""\
420
+ is_reg: {subset.is_reg}
421
+ class_tokens: {subset.class_tokens}
422
+ caption_extension: {subset.caption_extension}
423
+ \n"""), " ")
424
+ else:
425
+ info += indent(dedent(f"""\
426
+ metadata_file: {subset.metadata_file}
427
+ \n"""), " ")
428
+
429
+ print(info)
430
+
431
+ # make buckets first because it determines the length of dataset
432
+ # and set the same seed for all datasets
433
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
434
+ for i, dataset in enumerate(datasets):
435
+ print(f"[Dataset {i}]")
436
+ dataset.make_buckets()
437
+ dataset.set_seed(seed)
438
+
439
+ return DatasetGroup(datasets)
440
+
441
+
442
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
443
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
444
+ tokens = name.split('_')
445
+ try:
446
+ n_repeats = int(tokens[0])
447
+ except ValueError as e:
448
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
449
+ return 0, ""
450
+ caption_by_folder = '_'.join(tokens[1:])
451
+ return n_repeats, caption_by_folder
452
+
453
+ def generate(base_dir: Optional[str], is_reg: bool):
454
+ if base_dir is None:
455
+ return []
456
+
457
+ base_dir: Path = Path(base_dir)
458
+ if not base_dir.is_dir():
459
+ return []
460
+
461
+ subsets_config = []
462
+ for subdir in base_dir.iterdir():
463
+ if not subdir.is_dir():
464
+ continue
465
+
466
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
467
+ if num_repeats < 1:
468
+ continue
469
+
470
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
471
+ subsets_config.append(subset_config)
472
+
473
+ return subsets_config
474
+
475
+ subsets_config = []
476
+ subsets_config += generate(train_data_dir, False)
477
+ subsets_config += generate(reg_data_dir, True)
478
+
479
+ return subsets_config
480
+
481
+
482
+ def load_user_config(file: str) -> dict:
483
+ file: Path = Path(file)
484
+ if not file.is_file():
485
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
486
+
487
+ if file.name.lower().endswith('.json'):
488
+ try:
489
+ config = json.load(file)
490
+ except Exception:
491
+ print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
492
+ raise
493
+ elif file.name.lower().endswith('.toml'):
494
+ try:
495
+ config = toml.load(file)
496
+ except Exception:
497
+ print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
498
+ raise
499
+ else:
500
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
501
+
502
+ return config
503
+
504
+ # for config test
505
+ if __name__ == "__main__":
506
+ parser = argparse.ArgumentParser()
507
+ parser.add_argument("--support_dreambooth", action="store_true")
508
+ parser.add_argument("--support_finetuning", action="store_true")
509
+ parser.add_argument("--support_dropout", action="store_true")
510
+ parser.add_argument("dataset_config")
511
+ config_args, remain = parser.parse_known_args()
512
+
513
+ parser = argparse.ArgumentParser()
514
+ train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
515
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
516
+ argparse_namespace = parser.parse_args(remain)
517
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
518
+
519
+ print("[argparse_namespace]")
520
+ print(vars(argparse_namespace))
521
+
522
+ user_config = load_user_config(config_args.dataset_config)
523
+
524
+ print("\n[user_config]")
525
+ print(user_config)
526
+
527
+ sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
528
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
529
+
530
+ print("\n[sanitized_user_config]")
531
+ print(sanitized_user_config)
532
+
533
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
534
+
535
+ print("\n[blueprint]")
536
+ print(blueprint)
library/convert_model_gui.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ import shutil
6
+ from .common_gui import get_folder_path, get_file_path
7
+
8
+ folder_symbol = '\U0001f4c2' # 📂
9
+ refresh_symbol = '\U0001f504' # 🔄
10
+ save_style_symbol = '\U0001f4be' # 💾
11
+ document_symbol = '\U0001F4C4' # 📄
12
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
13
+
14
+
15
+ def convert_model(
16
+ source_model_input,
17
+ source_model_type,
18
+ target_model_folder_input,
19
+ target_model_name_input,
20
+ target_model_type,
21
+ target_save_precision_type,
22
+ ):
23
+ # Check for caption_text_input
24
+ if source_model_type == '':
25
+ msgbox('Invalid source model type')
26
+ return
27
+
28
+ # Check if source model exist
29
+ if os.path.isfile(source_model_input):
30
+ print('The provided source model is a file')
31
+ elif os.path.isdir(source_model_input):
32
+ print('The provided model is a folder')
33
+ else:
34
+ msgbox('The provided source model is neither a file nor a folder')
35
+ return
36
+
37
+ # Check if source model exist
38
+ if os.path.isdir(target_model_folder_input):
39
+ print('The provided model folder exist')
40
+ else:
41
+ msgbox('The provided target folder does not exist')
42
+ return
43
+
44
+ run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"'
45
+
46
+ v1_models = [
47
+ 'runwayml/stable-diffusion-v1-5',
48
+ 'CompVis/stable-diffusion-v1-4',
49
+ ]
50
+
51
+ # check if v1 models
52
+ if str(source_model_type) in v1_models:
53
+ print('SD v1 model specified. Setting --v1 parameter')
54
+ run_cmd += ' --v1'
55
+ else:
56
+ print('SD v2 model specified. Setting --v2 parameter')
57
+ run_cmd += ' --v2'
58
+
59
+ if not target_save_precision_type == 'unspecified':
60
+ run_cmd += f' --{target_save_precision_type}'
61
+
62
+ if (
63
+ target_model_type == 'diffuser'
64
+ or target_model_type == 'diffuser_safetensors'
65
+ ):
66
+ run_cmd += f' --reference_model="{source_model_type}"'
67
+
68
+ if target_model_type == 'diffuser_safetensors':
69
+ run_cmd += ' --use_safetensors'
70
+
71
+ run_cmd += f' "{source_model_input}"'
72
+
73
+ if (
74
+ target_model_type == 'diffuser'
75
+ or target_model_type == 'diffuser_safetensors'
76
+ ):
77
+ target_model_path = os.path.join(
78
+ target_model_folder_input, target_model_name_input
79
+ )
80
+ run_cmd += f' "{target_model_path}"'
81
+ else:
82
+ target_model_path = os.path.join(
83
+ target_model_folder_input,
84
+ f'{target_model_name_input}.{target_model_type}',
85
+ )
86
+ run_cmd += f' "{target_model_path}"'
87
+
88
+ print(run_cmd)
89
+
90
+ # Run the command
91
+ if os.name == 'posix':
92
+ os.system(run_cmd)
93
+ else:
94
+ subprocess.run(run_cmd)
95
+
96
+ if (
97
+ not target_model_type == 'diffuser'
98
+ or target_model_type == 'diffuser_safetensors'
99
+ ):
100
+
101
+ v2_models = [
102
+ 'stabilityai/stable-diffusion-2-1-base',
103
+ 'stabilityai/stable-diffusion-2-base',
104
+ ]
105
+ v_parameterization = [
106
+ 'stabilityai/stable-diffusion-2-1',
107
+ 'stabilityai/stable-diffusion-2',
108
+ ]
109
+
110
+ if str(source_model_type) in v2_models:
111
+ inference_file = os.path.join(
112
+ target_model_folder_input, f'{target_model_name_input}.yaml'
113
+ )
114
+ print(f'Saving v2-inference.yaml as {inference_file}')
115
+ shutil.copy(
116
+ f'./v2_inference/v2-inference.yaml',
117
+ f'{inference_file}',
118
+ )
119
+
120
+ if str(source_model_type) in v_parameterization:
121
+ inference_file = os.path.join(
122
+ target_model_folder_input, f'{target_model_name_input}.yaml'
123
+ )
124
+ print(f'Saving v2-inference-v.yaml as {inference_file}')
125
+ shutil.copy(
126
+ f'./v2_inference/v2-inference-v.yaml',
127
+ f'{inference_file}',
128
+ )
129
+
130
+
131
+ # parser = argparse.ArgumentParser()
132
+ # parser.add_argument("--v1", action='store_true',
133
+ # help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
134
+ # parser.add_argument("--v2", action='store_true',
135
+ # help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
136
+ # parser.add_argument("--fp16", action='store_true',
137
+ # help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
138
+ # parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
139
+ # parser.add_argument("--float", action='store_true',
140
+ # help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
141
+ # parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記���するepoch数の値')
142
+ # parser.add_argument("--global_step", type=int, default=0,
143
+ # help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
144
+ # parser.add_argument("--reference_model", type=str, default=None,
145
+ # help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
146
+
147
+ # parser.add_argument("model_to_load", type=str, default=None,
148
+ # help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
149
+ # parser.add_argument("model_to_save", type=str, default=None,
150
+ # help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
151
+
152
+
153
+ ###
154
+ # Gradio UI
155
+ ###
156
+
157
+
158
+ def gradio_convert_model_tab():
159
+ with gr.Tab('Convert model'):
160
+ gr.Markdown(
161
+ 'This utility can be used to convert from one stable diffusion model format to another.'
162
+ )
163
+ with gr.Row():
164
+ source_model_input = gr.Textbox(
165
+ label='Source model',
166
+ placeholder='path to source model folder of file to convert...',
167
+ interactive=True,
168
+ )
169
+ button_source_model_dir = gr.Button(
170
+ folder_symbol, elem_id='open_folder_small'
171
+ )
172
+ button_source_model_dir.click(
173
+ get_folder_path,
174
+ outputs=source_model_input,
175
+ show_progress=False,
176
+ )
177
+
178
+ button_source_model_file = gr.Button(
179
+ document_symbol, elem_id='open_folder_small'
180
+ )
181
+ button_source_model_file.click(
182
+ get_file_path,
183
+ inputs=[source_model_input],
184
+ outputs=source_model_input,
185
+ show_progress=False,
186
+ )
187
+
188
+ source_model_type = gr.Dropdown(
189
+ label='Source model type',
190
+ choices=[
191
+ 'stabilityai/stable-diffusion-2-1-base',
192
+ 'stabilityai/stable-diffusion-2-base',
193
+ 'stabilityai/stable-diffusion-2-1',
194
+ 'stabilityai/stable-diffusion-2',
195
+ 'runwayml/stable-diffusion-v1-5',
196
+ 'CompVis/stable-diffusion-v1-4',
197
+ ],
198
+ )
199
+ with gr.Row():
200
+ target_model_folder_input = gr.Textbox(
201
+ label='Target model folder',
202
+ placeholder='path to target model folder of file name to create...',
203
+ interactive=True,
204
+ )
205
+ button_target_model_folder = gr.Button(
206
+ folder_symbol, elem_id='open_folder_small'
207
+ )
208
+ button_target_model_folder.click(
209
+ get_folder_path,
210
+ outputs=target_model_folder_input,
211
+ show_progress=False,
212
+ )
213
+
214
+ target_model_name_input = gr.Textbox(
215
+ label='Target model name',
216
+ placeholder='target model name...',
217
+ interactive=True,
218
+ )
219
+ target_model_type = gr.Dropdown(
220
+ label='Target model type',
221
+ choices=[
222
+ 'diffuser',
223
+ 'diffuser_safetensors',
224
+ 'ckpt',
225
+ 'safetensors',
226
+ ],
227
+ )
228
+ target_save_precision_type = gr.Dropdown(
229
+ label='Target model precision',
230
+ choices=['unspecified', 'fp16', 'bf16', 'float'],
231
+ value='unspecified',
232
+ )
233
+
234
+ convert_button = gr.Button('Convert model')
235
+
236
+ convert_button.click(
237
+ convert_model,
238
+ inputs=[
239
+ source_model_input,
240
+ source_model_type,
241
+ target_model_folder_input,
242
+ target_model_name_input,
243
+ target_model_type,
244
+ target_save_precision_type,
245
+ ],
246
+ show_progress=False,
247
+ )
library/custom_train_functions.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
5
+ alphas_cumprod = noise_scheduler.alphas_cumprod
6
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
7
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
8
+ alpha = sqrt_alphas_cumprod
9
+ sigma = sqrt_one_minus_alphas_cumprod
10
+ all_snr = (alpha / sigma) ** 2
11
+ snr = torch.stack([all_snr[t] for t in timesteps])
12
+ gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
13
+ snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
14
+ loss = loss * snr_weight
15
+ return loss
16
+
17
+ def add_custom_train_arguments(parser: argparse.ArgumentParser):
18
+ parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")
library/dataset_balancing_gui.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from easygui import msgbox, boolbox
5
+ from .common_gui import get_folder_path
6
+
7
+ # def select_folder():
8
+ # # Open a file dialog to select a directory
9
+ # folder = filedialog.askdirectory()
10
+
11
+ # # Update the GUI to display the selected folder
12
+ # selected_folder_label.config(text=folder)
13
+
14
+
15
+ def dataset_balancing(concept_repeats, folder, insecure):
16
+
17
+ if not concept_repeats > 0:
18
+ # Display an error message if the total number of repeats is not a valid integer
19
+ msgbox('Please enter a valid integer for the total number of repeats.')
20
+ return
21
+
22
+ concept_repeats = int(concept_repeats)
23
+
24
+ # Check if folder exist
25
+ if folder == '' or not os.path.isdir(folder):
26
+ msgbox('Please enter a valid folder for balancing.')
27
+ return
28
+
29
+ pattern = re.compile(r'^\d+_.+$')
30
+
31
+ # Iterate over the subdirectories in the selected folder
32
+ for subdir in os.listdir(folder):
33
+ if pattern.match(subdir) or insecure:
34
+ # Calculate the number of repeats for the current subdirectory
35
+ # Get a list of all the files in the folder
36
+ files = os.listdir(os.path.join(folder, subdir))
37
+
38
+ # Filter the list to include only image files
39
+ image_files = [
40
+ f
41
+ for f in files
42
+ if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))
43
+ ]
44
+
45
+ # Count the number of image files
46
+ images = len(image_files)
47
+
48
+ # Check if the subdirectory name starts with a number inside braces,
49
+ # indicating that the repeats value should be multiplied
50
+ match = re.match(r'^\{(\d+\.?\d*)\}', subdir)
51
+ if match:
52
+ # Multiply the repeats value by the number inside the braces
53
+ if not images == 0:
54
+ repeats = max(
55
+ 1,
56
+ round(
57
+ concept_repeats / images * float(match.group(1))
58
+ ),
59
+ )
60
+ else:
61
+ repeats = 0
62
+ subdir = subdir[match.end() :]
63
+ else:
64
+ if not images == 0:
65
+ repeats = max(1, round(concept_repeats / images))
66
+ else:
67
+ repeats = 0
68
+
69
+ # Check if the subdirectory name already has a number at the beginning
70
+ match = re.match(r'^\d+_', subdir)
71
+ if match:
72
+ # Replace the existing number with the new number
73
+ old_name = os.path.join(folder, subdir)
74
+ new_name = os.path.join(
75
+ folder, f'{repeats}_{subdir[match.end():]}'
76
+ )
77
+ else:
78
+ # Add the new number at the beginning of the name
79
+ old_name = os.path.join(folder, subdir)
80
+ new_name = os.path.join(folder, f'{repeats}_{subdir}')
81
+
82
+ os.rename(old_name, new_name)
83
+ else:
84
+ print(
85
+ f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
86
+ )
87
+
88
+ msgbox('Dataset balancing completed...')
89
+
90
+
91
+ def warning(insecure):
92
+ if insecure:
93
+ if boolbox(
94
+ f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?',
95
+ choices=('Yes, I like danger', 'No, get me out of here'),
96
+ ):
97
+ return True
98
+ else:
99
+ return False
100
+
101
+
102
+ def gradio_dataset_balancing_tab():
103
+ with gr.Tab('Dreambooth/LoRA Dataset balancing'):
104
+ gr.Markdown(
105
+ 'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.'
106
+ )
107
+ gr.Markdown(
108
+ 'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!'
109
+ )
110
+ with gr.Row():
111
+ select_dataset_folder_input = gr.Textbox(
112
+ label='Dataset folder',
113
+ placeholder='Folder containing the concepts folders to balance...',
114
+ interactive=True,
115
+ )
116
+
117
+ select_dataset_folder_button = gr.Button(
118
+ '📂', elem_id='open_folder_small'
119
+ )
120
+ select_dataset_folder_button.click(
121
+ get_folder_path,
122
+ outputs=select_dataset_folder_input,
123
+ show_progress=False,
124
+ )
125
+
126
+ total_repeats_number = gr.Number(
127
+ value=1000,
128
+ interactive=True,
129
+ label='Training steps per concept per epoch',
130
+ )
131
+ with gr.Accordion('Advanced options', open=False):
132
+ insecure = gr.Checkbox(
133
+ value=False,
134
+ label='DANGER!!! -- Insecure folder renaming -- DANGER!!!',
135
+ )
136
+ insecure.change(warning, inputs=insecure, outputs=insecure)
137
+ balance_button = gr.Button('Balance dataset')
138
+ balance_button.click(
139
+ dataset_balancing,
140
+ inputs=[
141
+ total_repeats_number,
142
+ select_dataset_folder_input,
143
+ insecure,
144
+ ],
145
+ show_progress=False,
146
+ )
library/dreambooth_folder_creation_gui.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import diropenbox, msgbox
3
+ from .common_gui import get_folder_path
4
+ import shutil
5
+ import os
6
+
7
+
8
+ def copy_info_to_Folders_tab(training_folder):
9
+ img_folder = os.path.join(training_folder, 'img')
10
+ if os.path.exists(os.path.join(training_folder, 'reg')):
11
+ reg_folder = os.path.join(training_folder, 'reg')
12
+ else:
13
+ reg_folder = ''
14
+ model_folder = os.path.join(training_folder, 'model')
15
+ log_folder = os.path.join(training_folder, 'log')
16
+
17
+ return img_folder, reg_folder, model_folder, log_folder
18
+
19
+
20
+ def dreambooth_folder_preparation(
21
+ util_training_images_dir_input,
22
+ util_training_images_repeat_input,
23
+ util_instance_prompt_input,
24
+ util_regularization_images_dir_input,
25
+ util_regularization_images_repeat_input,
26
+ util_class_prompt_input,
27
+ util_training_dir_output,
28
+ ):
29
+
30
+ # Check if the input variables are empty
31
+ if not len(util_training_dir_output):
32
+ print(
33
+ "Destination training directory is missing... can't perform the required task..."
34
+ )
35
+ return
36
+ else:
37
+ # Create the util_training_dir_output directory if it doesn't exist
38
+ os.makedirs(util_training_dir_output, exist_ok=True)
39
+
40
+ # Check for instance prompt
41
+ if util_instance_prompt_input == '':
42
+ msgbox('Instance prompt missing...')
43
+ return
44
+
45
+ # Check for class prompt
46
+ if util_class_prompt_input == '':
47
+ msgbox('Class prompt missing...')
48
+ return
49
+
50
+ # Create the training_dir path
51
+ if util_training_images_dir_input == '':
52
+ print(
53
+ "Training images directory is missing... can't perform the required task..."
54
+ )
55
+ return
56
+ else:
57
+ training_dir = os.path.join(
58
+ util_training_dir_output,
59
+ f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}',
60
+ )
61
+
62
+ # Remove folders if they exist
63
+ if os.path.exists(training_dir):
64
+ print(f'Removing existing directory {training_dir}...')
65
+ shutil.rmtree(training_dir)
66
+
67
+ # Copy the training images to their respective directories
68
+ print(f'Copy {util_training_images_dir_input} to {training_dir}...')
69
+ shutil.copytree(util_training_images_dir_input, training_dir)
70
+
71
+ if not util_regularization_images_dir_input == '':
72
+ # Create the regularization_dir path
73
+ if not util_regularization_images_repeat_input > 0:
74
+ print('Repeats is missing... not copying regularisation images...')
75
+ else:
76
+ regularization_dir = os.path.join(
77
+ util_training_dir_output,
78
+ f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}',
79
+ )
80
+
81
+ # Remove folders if they exist
82
+ if os.path.exists(regularization_dir):
83
+ print(f'Removing existing directory {regularization_dir}...')
84
+ shutil.rmtree(regularization_dir)
85
+
86
+ # Copy the regularisation images to their respective directories
87
+ print(
88
+ f'Copy {util_regularization_images_dir_input} to {regularization_dir}...'
89
+ )
90
+ shutil.copytree(
91
+ util_regularization_images_dir_input, regularization_dir
92
+ )
93
+ else:
94
+ print(
95
+ 'Regularization images directory is missing... not copying regularisation images...'
96
+ )
97
+
98
+ # create log and model folder
99
+ # Check if the log folder exists and create it if it doesn't
100
+ if not os.path.exists(os.path.join(util_training_dir_output, 'log')):
101
+ os.makedirs(os.path.join(util_training_dir_output, 'log'))
102
+
103
+ # Check if the model folder exists and create it if it doesn't
104
+ if not os.path.exists(os.path.join(util_training_dir_output, 'model')):
105
+ os.makedirs(os.path.join(util_training_dir_output, 'model'))
106
+
107
+ print(
108
+ f'Done creating kohya_ss training folder structure at {util_training_dir_output}...'
109
+ )
110
+
111
+
112
+ def gradio_dreambooth_folder_creation_tab(
113
+ train_data_dir_input=gr.Textbox(),
114
+ reg_data_dir_input=gr.Textbox(),
115
+ output_dir_input=gr.Textbox(),
116
+ logging_dir_input=gr.Textbox(),
117
+ ):
118
+ with gr.Tab('Dreambooth/LoRA Folder preparation'):
119
+ gr.Markdown(
120
+ 'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.'
121
+ )
122
+ with gr.Row():
123
+ util_instance_prompt_input = gr.Textbox(
124
+ label='Instance prompt',
125
+ placeholder='Eg: asd',
126
+ interactive=True,
127
+ )
128
+ util_class_prompt_input = gr.Textbox(
129
+ label='Class prompt',
130
+ placeholder='Eg: person',
131
+ interactive=True,
132
+ )
133
+ with gr.Row():
134
+ util_training_images_dir_input = gr.Textbox(
135
+ label='Training images',
136
+ placeholder='Directory containing the training images',
137
+ interactive=True,
138
+ )
139
+ button_util_training_images_dir_input = gr.Button(
140
+ '📂', elem_id='open_folder_small'
141
+ )
142
+ button_util_training_images_dir_input.click(
143
+ get_folder_path,
144
+ outputs=util_training_images_dir_input,
145
+ show_progress=False,
146
+ )
147
+ util_training_images_repeat_input = gr.Number(
148
+ label='Repeats',
149
+ value=40,
150
+ interactive=True,
151
+ elem_id='number_input',
152
+ )
153
+ with gr.Row():
154
+ util_regularization_images_dir_input = gr.Textbox(
155
+ label='Regularisation images',
156
+ placeholder='(Optional) Directory containing the regularisation images',
157
+ interactive=True,
158
+ )
159
+ button_util_regularization_images_dir_input = gr.Button(
160
+ '📂', elem_id='open_folder_small'
161
+ )
162
+ button_util_regularization_images_dir_input.click(
163
+ get_folder_path,
164
+ outputs=util_regularization_images_dir_input,
165
+ show_progress=False,
166
+ )
167
+ util_regularization_images_repeat_input = gr.Number(
168
+ label='Repeats',
169
+ value=1,
170
+ interactive=True,
171
+ elem_id='number_input',
172
+ )
173
+ with gr.Row():
174
+ util_training_dir_output = gr.Textbox(
175
+ label='Destination training directory',
176
+ placeholder='Directory where formatted training and regularisation folders will be placed',
177
+ interactive=True,
178
+ )
179
+ button_util_training_dir_output = gr.Button(
180
+ '📂', elem_id='open_folder_small'
181
+ )
182
+ button_util_training_dir_output.click(
183
+ get_folder_path, outputs=util_training_dir_output
184
+ )
185
+ button_prepare_training_data = gr.Button('Prepare training data')
186
+ button_prepare_training_data.click(
187
+ dreambooth_folder_preparation,
188
+ inputs=[
189
+ util_training_images_dir_input,
190
+ util_training_images_repeat_input,
191
+ util_instance_prompt_input,
192
+ util_regularization_images_dir_input,
193
+ util_regularization_images_repeat_input,
194
+ util_class_prompt_input,
195
+ util_training_dir_output,
196
+ ],
197
+ show_progress=False,
198
+ )
199
+ button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
200
+ button_copy_info_to_Folders_tab.click(
201
+ copy_info_to_Folders_tab,
202
+ inputs=[util_training_dir_output],
203
+ outputs=[
204
+ train_data_dir_input,
205
+ reg_data_dir_input,
206
+ output_dir_input,
207
+ logging_dir_input,
208
+ ],
209
+ show_progress=False,
210
+ )
library/extract_lora_gui.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ from .common_gui import (
6
+ get_saveasfilename_path,
7
+ get_any_file_path,
8
+ get_file_path,
9
+ )
10
+
11
+ folder_symbol = '\U0001f4c2' # 📂
12
+ refresh_symbol = '\U0001f504' # 🔄
13
+ save_style_symbol = '\U0001f4be' # 💾
14
+ document_symbol = '\U0001F4C4' # 📄
15
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
16
+
17
+
18
+ def extract_lora(
19
+ model_tuned,
20
+ model_org,
21
+ save_to,
22
+ save_precision,
23
+ dim,
24
+ v2,
25
+ conv_dim,
26
+ device,
27
+ ):
28
+ # Check for caption_text_input
29
+ if model_tuned == '':
30
+ msgbox('Invalid finetuned model file')
31
+ return
32
+
33
+ if model_org == '':
34
+ msgbox('Invalid base model file')
35
+ return
36
+
37
+ # Check if source model exist
38
+ if not os.path.isfile(model_tuned):
39
+ msgbox('The provided finetuned model is not a file')
40
+ return
41
+
42
+ if not os.path.isfile(model_org):
43
+ msgbox('The provided base model is not a file')
44
+ return
45
+
46
+ run_cmd = (
47
+ f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"'
48
+ )
49
+ run_cmd += f' --save_precision {save_precision}'
50
+ run_cmd += f' --save_to "{save_to}"'
51
+ run_cmd += f' --model_org "{model_org}"'
52
+ run_cmd += f' --model_tuned "{model_tuned}"'
53
+ run_cmd += f' --dim {dim}'
54
+ run_cmd += f' --device {device}'
55
+ if conv_dim > 0:
56
+ run_cmd += f' --conv_dim {conv_dim}'
57
+ if v2:
58
+ run_cmd += f' --v2'
59
+
60
+ print(run_cmd)
61
+
62
+ # Run the command
63
+ if os.name == 'posix':
64
+ os.system(run_cmd)
65
+ else:
66
+ subprocess.run(run_cmd)
67
+
68
+
69
+ ###
70
+ # Gradio UI
71
+ ###
72
+
73
+
74
+ def gradio_extract_lora_tab():
75
+ with gr.Tab('Extract LoRA'):
76
+ gr.Markdown(
77
+ 'This utility can extract a LoRA network from a finetuned model.'
78
+ )
79
+ lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
80
+ lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
81
+ model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
82
+ model_ext_name = gr.Textbox(value='Model types', visible=False)
83
+
84
+ with gr.Row():
85
+ model_tuned = gr.Textbox(
86
+ label='Finetuned model',
87
+ placeholder='Path to the finetuned model to extract',
88
+ interactive=True,
89
+ )
90
+ button_model_tuned_file = gr.Button(
91
+ folder_symbol, elem_id='open_folder_small'
92
+ )
93
+ button_model_tuned_file.click(
94
+ get_file_path,
95
+ inputs=[model_tuned, model_ext, model_ext_name],
96
+ outputs=model_tuned,
97
+ show_progress=False,
98
+ )
99
+
100
+ model_org = gr.Textbox(
101
+ label='Stable Diffusion base model',
102
+ placeholder='Stable Diffusion original model: ckpt or safetensors file',
103
+ interactive=True,
104
+ )
105
+ button_model_org_file = gr.Button(
106
+ folder_symbol, elem_id='open_folder_small'
107
+ )
108
+ button_model_org_file.click(
109
+ get_file_path,
110
+ inputs=[model_org, model_ext, model_ext_name],
111
+ outputs=model_org,
112
+ show_progress=False,
113
+ )
114
+ with gr.Row():
115
+ save_to = gr.Textbox(
116
+ label='Save to',
117
+ placeholder='path where to save the extracted LoRA model...',
118
+ interactive=True,
119
+ )
120
+ button_save_to = gr.Button(
121
+ folder_symbol, elem_id='open_folder_small'
122
+ )
123
+ button_save_to.click(
124
+ get_saveasfilename_path,
125
+ inputs=[save_to, lora_ext, lora_ext_name],
126
+ outputs=save_to,
127
+ show_progress=False,
128
+ )
129
+ save_precision = gr.Dropdown(
130
+ label='Save precision',
131
+ choices=['fp16', 'bf16', 'float'],
132
+ value='float',
133
+ interactive=True,
134
+ )
135
+ with gr.Row():
136
+ dim = gr.Slider(
137
+ minimum=4,
138
+ maximum=1024,
139
+ label='Network Dimension (Rank)',
140
+ value=128,
141
+ step=1,
142
+ interactive=True,
143
+ )
144
+ conv_dim = gr.Slider(
145
+ minimum=0,
146
+ maximum=1024,
147
+ label='Conv Dimension (Rank)',
148
+ value=128,
149
+ step=1,
150
+ interactive=True,
151
+ )
152
+ v2 = gr.Checkbox(label='v2', value=False, interactive=True)
153
+ device = gr.Dropdown(
154
+ label='Device',
155
+ choices=[
156
+ 'cpu',
157
+ 'cuda',
158
+ ],
159
+ value='cuda',
160
+ interactive=True,
161
+ )
162
+
163
+ extract_button = gr.Button('Extract LoRA model')
164
+
165
+ extract_button.click(
166
+ extract_lora,
167
+ inputs=[
168
+ model_tuned,
169
+ model_org,
170
+ save_to,
171
+ save_precision,
172
+ dim,
173
+ v2,
174
+ conv_dim,
175
+ device
176
+ ],
177
+ show_progress=False,
178
+ )
library/extract_lycoris_locon_gui.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ from .common_gui import (
6
+ get_saveasfilename_path,
7
+ get_any_file_path,
8
+ get_file_path,
9
+ )
10
+
11
+ folder_symbol = '\U0001f4c2' # 📂
12
+ refresh_symbol = '\U0001f504' # 🔄
13
+ save_style_symbol = '\U0001f4be' # 💾
14
+ document_symbol = '\U0001F4C4' # 📄
15
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
16
+
17
+
18
+ def extract_lycoris_locon(
19
+ db_model,
20
+ base_model,
21
+ output_name,
22
+ device,
23
+ is_v2,
24
+ mode,
25
+ linear_dim,
26
+ conv_dim,
27
+ linear_threshold,
28
+ conv_threshold,
29
+ linear_ratio,
30
+ conv_ratio,
31
+ linear_quantile,
32
+ conv_quantile,
33
+ use_sparse_bias,
34
+ sparsity,
35
+ disable_cp,
36
+ ):
37
+ # Check for caption_text_input
38
+ if db_model == '':
39
+ msgbox('Invalid finetuned model file')
40
+ return
41
+
42
+ if base_model == '':
43
+ msgbox('Invalid base model file')
44
+ return
45
+
46
+ # Check if source model exist
47
+ if not os.path.isfile(db_model):
48
+ msgbox('The provided finetuned model is not a file')
49
+ return
50
+
51
+ if not os.path.isfile(base_model):
52
+ msgbox('The provided base model is not a file')
53
+ return
54
+
55
+ run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
56
+ if is_v2:
57
+ run_cmd += f' --is_v2'
58
+ run_cmd += f' --device {device}'
59
+ run_cmd += f' --mode {mode}'
60
+ run_cmd += f' --safetensors'
61
+ run_cmd += f' --linear_dim {linear_dim}'
62
+ run_cmd += f' --conv_dim {conv_dim}'
63
+ run_cmd += f' --linear_threshold {linear_threshold}'
64
+ run_cmd += f' --conv_threshold {conv_threshold}'
65
+ run_cmd += f' --linear_ratio {linear_ratio}'
66
+ run_cmd += f' --conv_ratio {conv_ratio}'
67
+ run_cmd += f' --linear_quantile {linear_quantile}'
68
+ run_cmd += f' --conv_quantile {conv_quantile}'
69
+ if use_sparse_bias:
70
+ run_cmd += f' --use_sparse_bias'
71
+ run_cmd += f' --sparsity {sparsity}'
72
+ if disable_cp:
73
+ run_cmd += f' --disable_cp'
74
+ run_cmd += f' "{base_model}"'
75
+ run_cmd += f' "{db_model}"'
76
+ run_cmd += f' "{output_name}"'
77
+
78
+ print(run_cmd)
79
+
80
+ # Run the command
81
+ if os.name == 'posix':
82
+ os.system(run_cmd)
83
+ else:
84
+ subprocess.run(run_cmd)
85
+
86
+
87
+ ###
88
+ # Gradio UI
89
+ ###
90
+ # def update_mode(mode):
91
+ # # 'fixed', 'threshold','ratio','quantile'
92
+ # if mode == 'fixed':
93
+ # return gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False)
94
+ # if mode == 'threshold':
95
+ # return gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False)
96
+ # if mode == 'ratio':
97
+ # return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False)
98
+ # if mode == 'threshold':
99
+ # return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True)
100
+
101
+
102
+ def update_mode(mode):
103
+ # Create a list of possible mode values
104
+ modes = ['fixed', 'threshold', 'ratio', 'quantile']
105
+
106
+ # Initialize an empty list to store visibility updates
107
+ updates = []
108
+
109
+ # Iterate through the possible modes
110
+ for m in modes:
111
+ # Add a visibility update for each mode, setting it to True if the input mode matches the current mode in the loop
112
+ updates.append(gr.Row.update(visible=(mode == m)))
113
+
114
+ # Return the visibility updates as a tuple
115
+ return tuple(updates)
116
+
117
+
118
+ def gradio_extract_lycoris_locon_tab():
119
+ with gr.Tab('Extract LyCORIS LoCON'):
120
+ gr.Markdown(
121
+ 'This utility can extract a LyCORIS LoCon network from a finetuned model.'
122
+ )
123
+ lora_ext = gr.Textbox(
124
+ value='*.safetensors', visible=False
125
+ ) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
126
+ lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
127
+ model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
128
+ model_ext_name = gr.Textbox(value='Model types', visible=False)
129
+
130
+ with gr.Row():
131
+ db_model = gr.Textbox(
132
+ label='Finetuned model',
133
+ placeholder='Path to the finetuned model to extract',
134
+ interactive=True,
135
+ )
136
+ button_db_model_file = gr.Button(
137
+ folder_symbol, elem_id='open_folder_small'
138
+ )
139
+ button_db_model_file.click(
140
+ get_file_path,
141
+ inputs=[db_model, model_ext, model_ext_name],
142
+ outputs=db_model,
143
+ show_progress=False,
144
+ )
145
+
146
+ base_model = gr.Textbox(
147
+ label='Stable Diffusion base model',
148
+ placeholder='Stable Diffusion original model: ckpt or safetensors file',
149
+ interactive=True,
150
+ )
151
+ button_base_model_file = gr.Button(
152
+ folder_symbol, elem_id='open_folder_small'
153
+ )
154
+ button_base_model_file.click(
155
+ get_file_path,
156
+ inputs=[base_model, model_ext, model_ext_name],
157
+ outputs=base_model,
158
+ show_progress=False,
159
+ )
160
+ with gr.Row():
161
+ output_name = gr.Textbox(
162
+ label='Save to',
163
+ placeholder='path where to save the extracted LoRA model...',
164
+ interactive=True,
165
+ )
166
+ button_output_name = gr.Button(
167
+ folder_symbol, elem_id='open_folder_small'
168
+ )
169
+ button_output_name.click(
170
+ get_saveasfilename_path,
171
+ inputs=[output_name, lora_ext, lora_ext_name],
172
+ outputs=output_name,
173
+ show_progress=False,
174
+ )
175
+ device = gr.Dropdown(
176
+ label='Device',
177
+ choices=[
178
+ 'cpu',
179
+ 'cuda',
180
+ ],
181
+ value='cuda',
182
+ interactive=True,
183
+ )
184
+ is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True)
185
+ mode = gr.Dropdown(
186
+ label='Mode',
187
+ choices=['fixed', 'threshold', 'ratio', 'quantile'],
188
+ value='fixed',
189
+ interactive=True,
190
+ )
191
+ with gr.Row(visible=True) as fixed:
192
+ linear_dim = gr.Slider(
193
+ minimum=1,
194
+ maximum=1024,
195
+ label='Network Dimension',
196
+ value=1,
197
+ step=1,
198
+ interactive=True,
199
+ )
200
+ conv_dim = gr.Slider(
201
+ minimum=1,
202
+ maximum=1024,
203
+ label='Conv Dimension',
204
+ value=1,
205
+ step=1,
206
+ interactive=True,
207
+ )
208
+ with gr.Row(visible=False) as threshold:
209
+ linear_threshold = gr.Slider(
210
+ minimum=0,
211
+ maximum=1,
212
+ label='Linear threshold',
213
+ value=0,
214
+ step=0.01,
215
+ interactive=True,
216
+ )
217
+ conv_threshold = gr.Slider(
218
+ minimum=0,
219
+ maximum=1,
220
+ label='Conv threshold',
221
+ value=0,
222
+ step=0.01,
223
+ interactive=True,
224
+ )
225
+ with gr.Row(visible=False) as ratio:
226
+ linear_ratio = gr.Slider(
227
+ minimum=0,
228
+ maximum=1,
229
+ label='Linear ratio',
230
+ value=0,
231
+ step=0.01,
232
+ interactive=True,
233
+ )
234
+ conv_ratio = gr.Slider(
235
+ minimum=0,
236
+ maximum=1,
237
+ label='Conv ratio',
238
+ value=0,
239
+ step=0.01,
240
+ interactive=True,
241
+ )
242
+ with gr.Row(visible=False) as quantile:
243
+ linear_quantile = gr.Slider(
244
+ minimum=0,
245
+ maximum=1,
246
+ label='Linear quantile',
247
+ value=0.75,
248
+ step=0.01,
249
+ interactive=True,
250
+ )
251
+ conv_quantile = gr.Slider(
252
+ minimum=0,
253
+ maximum=1,
254
+ label='Conv quantile',
255
+ value=0.75,
256
+ step=0.01,
257
+ interactive=True,
258
+ )
259
+ with gr.Row():
260
+ use_sparse_bias = gr.Checkbox(
261
+ label='Use sparse biais', value=False, interactive=True
262
+ )
263
+ sparsity = gr.Slider(
264
+ minimum=0,
265
+ maximum=1,
266
+ label='Sparsity',
267
+ value=0.98,
268
+ step=0.01,
269
+ interactive=True,
270
+ )
271
+ disable_cp = gr.Checkbox(
272
+ label='Disable CP decomposition', value=False, interactive=True
273
+ )
274
+ mode.change(
275
+ update_mode,
276
+ inputs=[mode],
277
+ outputs=[
278
+ fixed,
279
+ threshold,
280
+ ratio,
281
+ quantile,
282
+ ],
283
+ )
284
+
285
+ extract_button = gr.Button('Extract LyCORIS LoCon')
286
+
287
+ extract_button.click(
288
+ extract_lycoris_locon,
289
+ inputs=[
290
+ db_model,
291
+ base_model,
292
+ output_name,
293
+ device,
294
+ is_v2,
295
+ mode,
296
+ linear_dim,
297
+ conv_dim,
298
+ linear_threshold,
299
+ conv_threshold,
300
+ linear_ratio,
301
+ conv_ratio,
302
+ linear_quantile,
303
+ conv_quantile,
304
+ use_sparse_bias,
305
+ sparsity,
306
+ disable_cp,
307
+ ],
308
+ show_progress=False,
309
+ )
library/git_caption_gui.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ from .common_gui import get_folder_path, add_pre_postfix
6
+
7
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
8
+
9
+
10
+ def caption_images(
11
+ train_data_dir,
12
+ caption_ext,
13
+ batch_size,
14
+ max_data_loader_n_workers,
15
+ max_length,
16
+ model_id,
17
+ prefix,
18
+ postfix,
19
+ ):
20
+ # Check for images_dir_input
21
+ if train_data_dir == '':
22
+ msgbox('Image folder is missing...')
23
+ return
24
+
25
+ if caption_ext == '':
26
+ msgbox('Please provide an extension for the caption files.')
27
+ return
28
+
29
+ print(f'GIT captioning files in {train_data_dir}...')
30
+ run_cmd = (
31
+ f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"'
32
+ )
33
+ if not model_id == '':
34
+ run_cmd += f' --model_id="{model_id}"'
35
+ run_cmd += f' --batch_size="{int(batch_size)}"'
36
+ run_cmd += (
37
+ f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
38
+ )
39
+ run_cmd += f' --max_length="{int(max_length)}"'
40
+ if caption_ext != '':
41
+ run_cmd += f' --caption_extension="{caption_ext}"'
42
+ run_cmd += f' "{train_data_dir}"'
43
+
44
+ print(run_cmd)
45
+
46
+ # Run the command
47
+ subprocess.run(run_cmd)
48
+
49
+ # Add prefix and postfix
50
+ add_pre_postfix(
51
+ folder=train_data_dir,
52
+ caption_file_ext=caption_ext,
53
+ prefix=prefix,
54
+ postfix=postfix,
55
+ )
56
+
57
+ print('...captioning done')
58
+
59
+
60
+ ###
61
+ # Gradio UI
62
+ ###
63
+
64
+
65
+ def gradio_git_caption_gui_tab():
66
+ with gr.Tab('GIT Captioning'):
67
+ gr.Markdown(
68
+ 'This utility will use GIT to caption files for each images in a folder.'
69
+ )
70
+ with gr.Row():
71
+ train_data_dir = gr.Textbox(
72
+ label='Image folder to caption',
73
+ placeholder='Directory containing the images to caption',
74
+ interactive=True,
75
+ )
76
+ button_train_data_dir_input = gr.Button(
77
+ '📂', elem_id='open_folder_small'
78
+ )
79
+ button_train_data_dir_input.click(
80
+ get_folder_path,
81
+ outputs=train_data_dir,
82
+ show_progress=False,
83
+ )
84
+ with gr.Row():
85
+ caption_ext = gr.Textbox(
86
+ label='Caption file extension',
87
+ placeholder='Extention for caption file. eg: .caption, .txt',
88
+ value='.txt',
89
+ interactive=True,
90
+ )
91
+
92
+ prefix = gr.Textbox(
93
+ label='Prefix to add to BLIP caption',
94
+ placeholder='(Optional)',
95
+ interactive=True,
96
+ )
97
+
98
+ postfix = gr.Textbox(
99
+ label='Postfix to add to BLIP caption',
100
+ placeholder='(Optional)',
101
+ interactive=True,
102
+ )
103
+
104
+ batch_size = gr.Number(
105
+ value=1, label='Batch size', interactive=True
106
+ )
107
+
108
+ with gr.Row():
109
+ max_data_loader_n_workers = gr.Number(
110
+ value=2, label='Number of workers', interactive=True
111
+ )
112
+ max_length = gr.Number(
113
+ value=75, label='Max length', interactive=True
114
+ )
115
+ model_id = gr.Textbox(
116
+ label='Model',
117
+ placeholder='(Optional) model id for GIT in Hugging Face',
118
+ interactive=True,
119
+ )
120
+
121
+ caption_button = gr.Button('Caption images')
122
+
123
+ caption_button.click(
124
+ caption_images,
125
+ inputs=[
126
+ train_data_dir,
127
+ caption_ext,
128
+ batch_size,
129
+ max_data_loader_n_workers,
130
+ max_length,
131
+ model_id,
132
+ prefix,
133
+ postfix,
134
+ ],
135
+ show_progress=False,
136
+ )
library/lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
13
+
14
+ import diffusers
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+
20
+
21
+ try:
22
+ from diffusers.utils import PIL_INTERPOLATION
23
+ except ImportError:
24
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
25
+ PIL_INTERPOLATION = {
26
+ "linear": PIL.Image.Resampling.BILINEAR,
27
+ "bilinear": PIL.Image.Resampling.BILINEAR,
28
+ "bicubic": PIL.Image.Resampling.BICUBIC,
29
+ "lanczos": PIL.Image.Resampling.LANCZOS,
30
+ "nearest": PIL.Image.Resampling.NEAREST,
31
+ }
32
+ else:
33
+ PIL_INTERPOLATION = {
34
+ "linear": PIL.Image.LINEAR,
35
+ "bilinear": PIL.Image.BILINEAR,
36
+ "bicubic": PIL.Image.BICUBIC,
37
+ "lanczos": PIL.Image.LANCZOS,
38
+ "nearest": PIL.Image.NEAREST,
39
+ }
40
+ # ------------------------------------------------------------------------------
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ re_attention = re.compile(
45
+ r"""
46
+ \\\(|
47
+ \\\)|
48
+ \\\[|
49
+ \\]|
50
+ \\\\|
51
+ \\|
52
+ \(|
53
+ \[|
54
+ :([+-]?[.\d]+)\)|
55
+ \)|
56
+ ]|
57
+ [^\\()\[\]:]+|
58
+ :
59
+ """,
60
+ re.X,
61
+ )
62
+
63
+
64
+ def parse_prompt_attention(text):
65
+ """
66
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
67
+ Accepted tokens are:
68
+ (abc) - increases attention to abc by a multiplier of 1.1
69
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
70
+ [abc] - decreases attention to abc by a multiplier of 1.1
71
+ \( - literal character '('
72
+ \[ - literal character '['
73
+ \) - literal character ')'
74
+ \] - literal character ']'
75
+ \\ - literal character '\'
76
+ anything else - just text
77
+ >>> parse_prompt_attention('normal text')
78
+ [['normal text', 1.0]]
79
+ >>> parse_prompt_attention('an (important) word')
80
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
81
+ >>> parse_prompt_attention('(unbalanced')
82
+ [['unbalanced', 1.1]]
83
+ >>> parse_prompt_attention('\(literal\]')
84
+ [['(literal]', 1.0]]
85
+ >>> parse_prompt_attention('(unnecessary)(parens)')
86
+ [['unnecessaryparens', 1.1]]
87
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
88
+ [['a ', 1.0],
89
+ ['house', 1.5730000000000004],
90
+ [' ', 1.1],
91
+ ['on', 1.0],
92
+ [' a ', 1.1],
93
+ ['hill', 0.55],
94
+ [', sun, ', 1.1],
95
+ ['sky', 1.4641000000000006],
96
+ ['.', 1.1]]
97
+ """
98
+
99
+ res = []
100
+ round_brackets = []
101
+ square_brackets = []
102
+
103
+ round_bracket_multiplier = 1.1
104
+ square_bracket_multiplier = 1 / 1.1
105
+
106
+ def multiply_range(start_position, multiplier):
107
+ for p in range(start_position, len(res)):
108
+ res[p][1] *= multiplier
109
+
110
+ for m in re_attention.finditer(text):
111
+ text = m.group(0)
112
+ weight = m.group(1)
113
+
114
+ if text.startswith("\\"):
115
+ res.append([text[1:], 1.0])
116
+ elif text == "(":
117
+ round_brackets.append(len(res))
118
+ elif text == "[":
119
+ square_brackets.append(len(res))
120
+ elif weight is not None and len(round_brackets) > 0:
121
+ multiply_range(round_brackets.pop(), float(weight))
122
+ elif text == ")" and len(round_brackets) > 0:
123
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
124
+ elif text == "]" and len(square_brackets) > 0:
125
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
126
+ else:
127
+ res.append([text, 1.0])
128
+
129
+ for pos in round_brackets:
130
+ multiply_range(pos, round_bracket_multiplier)
131
+
132
+ for pos in square_brackets:
133
+ multiply_range(pos, square_bracket_multiplier)
134
+
135
+ if len(res) == 0:
136
+ res = [["", 1.0]]
137
+
138
+ # merge runs of identical weights
139
+ i = 0
140
+ while i + 1 < len(res):
141
+ if res[i][1] == res[i + 1][1]:
142
+ res[i][0] += res[i + 1][0]
143
+ res.pop(i + 1)
144
+ else:
145
+ i += 1
146
+
147
+ return res
148
+
149
+
150
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
151
+ r"""
152
+ Tokenize a list of prompts and return its tokens with weights of each token.
153
+
154
+ No padding, starting or ending token is included.
155
+ """
156
+ tokens = []
157
+ weights = []
158
+ truncated = False
159
+ for text in prompt:
160
+ texts_and_weights = parse_prompt_attention(text)
161
+ text_token = []
162
+ text_weight = []
163
+ for word, weight in texts_and_weights:
164
+ # tokenize and discard the starting and the ending token
165
+ token = pipe.tokenizer(word).input_ids[1:-1]
166
+ text_token += token
167
+ # copy the weight by length of token
168
+ text_weight += [weight] * len(token)
169
+ # stop if the text is too long (longer than truncation limit)
170
+ if len(text_token) > max_length:
171
+ truncated = True
172
+ break
173
+ # truncate
174
+ if len(text_token) > max_length:
175
+ truncated = True
176
+ text_token = text_token[:max_length]
177
+ text_weight = text_weight[:max_length]
178
+ tokens.append(text_token)
179
+ weights.append(text_weight)
180
+ if truncated:
181
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
182
+ return tokens, weights
183
+
184
+
185
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
186
+ r"""
187
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
188
+ """
189
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
190
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
191
+ for i in range(len(tokens)):
192
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
193
+ if no_boseos_middle:
194
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
195
+ else:
196
+ w = []
197
+ if len(weights[i]) == 0:
198
+ w = [1.0] * weights_length
199
+ else:
200
+ for j in range(max_embeddings_multiples):
201
+ w.append(1.0) # weight for starting token in this chunk
202
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
203
+ w.append(1.0) # weight for ending token in this chunk
204
+ w += [1.0] * (weights_length - len(w))
205
+ weights[i] = w[:]
206
+
207
+ return tokens, weights
208
+
209
+
210
+ def get_unweighted_text_embeddings(
211
+ pipe: StableDiffusionPipeline,
212
+ text_input: torch.Tensor,
213
+ chunk_length: int,
214
+ clip_skip: int,
215
+ eos: int,
216
+ pad: int,
217
+ no_boseos_middle: Optional[bool] = True,
218
+ ):
219
+ """
220
+ When the length of tokens is a multiple of the capacity of the text encoder,
221
+ it should be split into chunks and sent to the text encoder individually.
222
+ """
223
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
224
+ if max_embeddings_multiples > 1:
225
+ text_embeddings = []
226
+ for i in range(max_embeddings_multiples):
227
+ # extract the i-th chunk
228
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
229
+
230
+ # cover the head and the tail by the starting and the ending tokens
231
+ text_input_chunk[:, 0] = text_input[0, 0]
232
+ if pad == eos: # v1
233
+ text_input_chunk[:, -1] = text_input[0, -1]
234
+ else: # v2
235
+ for j in range(len(text_input_chunk)):
236
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
237
+ text_input_chunk[j, -1] = eos
238
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
239
+ text_input_chunk[j, 1] = eos
240
+
241
+ if clip_skip is None or clip_skip == 1:
242
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
243
+ else:
244
+ enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
245
+ text_embedding = enc_out["hidden_states"][-clip_skip]
246
+ text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
247
+
248
+ # cover the head and the tail by the starting and the ending tokens
249
+ text_input_chunk[:, 0] = text_input[0, 0]
250
+ text_input_chunk[:, -1] = text_input[0, -1]
251
+ text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
252
+
253
+ if no_boseos_middle:
254
+ if i == 0:
255
+ # discard the ending token
256
+ text_embedding = text_embedding[:, :-1]
257
+ elif i == max_embeddings_multiples - 1:
258
+ # discard the starting token
259
+ text_embedding = text_embedding[:, 1:]
260
+ else:
261
+ # discard both starting and ending tokens
262
+ text_embedding = text_embedding[:, 1:-1]
263
+
264
+ text_embeddings.append(text_embedding)
265
+ text_embeddings = torch.concat(text_embeddings, axis=1)
266
+ else:
267
+ text_embeddings = pipe.text_encoder(text_input)[0]
268
+ return text_embeddings
269
+
270
+
271
+ def get_weighted_text_embeddings(
272
+ pipe: StableDiffusionPipeline,
273
+ prompt: Union[str, List[str]],
274
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
275
+ max_embeddings_multiples: Optional[int] = 3,
276
+ no_boseos_middle: Optional[bool] = False,
277
+ skip_parsing: Optional[bool] = False,
278
+ skip_weighting: Optional[bool] = False,
279
+ clip_skip=None,
280
+ ):
281
+ r"""
282
+ Prompts can be assigned with local weights using brackets. For example,
283
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
284
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
285
+
286
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
287
+
288
+ Args:
289
+ pipe (`StableDiffusionPipeline`):
290
+ Pipe to provide access to the tokenizer and the text encoder.
291
+ prompt (`str` or `List[str]`):
292
+ The prompt or prompts to guide the image generation.
293
+ uncond_prompt (`str` or `List[str]`):
294
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
295
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
296
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
297
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
298
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
299
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
300
+ ending token in each of the chunk in the middle.
301
+ skip_parsing (`bool`, *optional*, defaults to `False`):
302
+ Skip the parsing of brackets.
303
+ skip_weighting (`bool`, *optional*, defaults to `False`):
304
+ Skip the weighting. When the parsing is skipped, it is forced True.
305
+ """
306
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
307
+ if isinstance(prompt, str):
308
+ prompt = [prompt]
309
+
310
+ if not skip_parsing:
311
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
312
+ if uncond_prompt is not None:
313
+ if isinstance(uncond_prompt, str):
314
+ uncond_prompt = [uncond_prompt]
315
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
316
+ else:
317
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
318
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
319
+ if uncond_prompt is not None:
320
+ if isinstance(uncond_prompt, str):
321
+ uncond_prompt = [uncond_prompt]
322
+ uncond_tokens = [
323
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
324
+ ]
325
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
326
+
327
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
328
+ max_length = max([len(token) for token in prompt_tokens])
329
+ if uncond_prompt is not None:
330
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
331
+
332
+ max_embeddings_multiples = min(
333
+ max_embeddings_multiples,
334
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
335
+ )
336
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
337
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
338
+
339
+ # pad the length of tokens and weights
340
+ bos = pipe.tokenizer.bos_token_id
341
+ eos = pipe.tokenizer.eos_token_id
342
+ pad = pipe.tokenizer.pad_token_id
343
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
344
+ prompt_tokens,
345
+ prompt_weights,
346
+ max_length,
347
+ bos,
348
+ eos,
349
+ no_boseos_middle=no_boseos_middle,
350
+ chunk_length=pipe.tokenizer.model_max_length,
351
+ )
352
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
353
+ if uncond_prompt is not None:
354
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
355
+ uncond_tokens,
356
+ uncond_weights,
357
+ max_length,
358
+ bos,
359
+ eos,
360
+ no_boseos_middle=no_boseos_middle,
361
+ chunk_length=pipe.tokenizer.model_max_length,
362
+ )
363
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
364
+
365
+ # get the embeddings
366
+ text_embeddings = get_unweighted_text_embeddings(
367
+ pipe,
368
+ prompt_tokens,
369
+ pipe.tokenizer.model_max_length,
370
+ clip_skip,
371
+ eos,
372
+ pad,
373
+ no_boseos_middle=no_boseos_middle,
374
+ )
375
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
376
+ if uncond_prompt is not None:
377
+ uncond_embeddings = get_unweighted_text_embeddings(
378
+ pipe,
379
+ uncond_tokens,
380
+ pipe.tokenizer.model_max_length,
381
+ clip_skip,
382
+ eos,
383
+ pad,
384
+ no_boseos_middle=no_boseos_middle,
385
+ )
386
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
387
+
388
+ # assign weights to the prompts and normalize in the sense of mean
389
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
390
+ if (not skip_parsing) and (not skip_weighting):
391
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
392
+ text_embeddings *= prompt_weights.unsqueeze(-1)
393
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
394
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
395
+ if uncond_prompt is not None:
396
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
397
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
398
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
399
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
400
+
401
+ if uncond_prompt is not None:
402
+ return text_embeddings, uncond_embeddings
403
+ return text_embeddings, None
404
+
405
+
406
+ def preprocess_image(image):
407
+ w, h = image.size
408
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
409
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
410
+ image = np.array(image).astype(np.float32) / 255.0
411
+ image = image[None].transpose(0, 3, 1, 2)
412
+ image = torch.from_numpy(image)
413
+ return 2.0 * image - 1.0
414
+
415
+
416
+ def preprocess_mask(mask, scale_factor=8):
417
+ mask = mask.convert("L")
418
+ w, h = mask.size
419
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
420
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
421
+ mask = np.array(mask).astype(np.float32) / 255.0
422
+ mask = np.tile(mask, (4, 1, 1))
423
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
424
+ mask = 1 - mask # repaint white, keep black
425
+ mask = torch.from_numpy(mask)
426
+ return mask
427
+
428
+
429
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
430
+ r"""
431
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
432
+ weighting in prompt.
433
+
434
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
435
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
436
+
437
+ Args:
438
+ vae ([`AutoencoderKL`]):
439
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
440
+ text_encoder ([`CLIPTextModel`]):
441
+ Frozen text-encoder. Stable Diffusion uses the text portion of
442
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
443
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
444
+ tokenizer (`CLIPTokenizer`):
445
+ Tokenizer of class
446
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
447
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
448
+ scheduler ([`SchedulerMixin`]):
449
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
450
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
451
+ safety_checker ([`StableDiffusionSafetyChecker`]):
452
+ Classification module that estimates whether generated images could be considered offensive or harmful.
453
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
454
+ feature_extractor ([`CLIPFeatureExtractor`]):
455
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
456
+ """
457
+
458
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
459
+
460
+ def __init__(
461
+ self,
462
+ vae: AutoencoderKL,
463
+ text_encoder: CLIPTextModel,
464
+ tokenizer: CLIPTokenizer,
465
+ unet: UNet2DConditionModel,
466
+ scheduler: SchedulerMixin,
467
+ clip_skip: int,
468
+ safety_checker: StableDiffusionSafetyChecker,
469
+ feature_extractor: CLIPFeatureExtractor,
470
+ requires_safety_checker: bool = True,
471
+ ):
472
+ super().__init__(
473
+ vae=vae,
474
+ text_encoder=text_encoder,
475
+ tokenizer=tokenizer,
476
+ unet=unet,
477
+ scheduler=scheduler,
478
+ safety_checker=safety_checker,
479
+ feature_extractor=feature_extractor,
480
+ requires_safety_checker=requires_safety_checker,
481
+ )
482
+ self.clip_skip = clip_skip
483
+ self.__init__additional__()
484
+
485
+ # else:
486
+ # def __init__(
487
+ # self,
488
+ # vae: AutoencoderKL,
489
+ # text_encoder: CLIPTextModel,
490
+ # tokenizer: CLIPTokenizer,
491
+ # unet: UNet2DConditionModel,
492
+ # scheduler: SchedulerMixin,
493
+ # safety_checker: StableDiffusionSafetyChecker,
494
+ # feature_extractor: CLIPFeatureExtractor,
495
+ # ):
496
+ # super().__init__(
497
+ # vae=vae,
498
+ # text_encoder=text_encoder,
499
+ # tokenizer=tokenizer,
500
+ # unet=unet,
501
+ # scheduler=scheduler,
502
+ # safety_checker=safety_checker,
503
+ # feature_extractor=feature_extractor,
504
+ # )
505
+ # self.__init__additional__()
506
+
507
+ def __init__additional__(self):
508
+ if not hasattr(self, "vae_scale_factor"):
509
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
510
+
511
+ @property
512
+ def _execution_device(self):
513
+ r"""
514
+ Returns the device on which the pipeline's models will be executed. After calling
515
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
516
+ hooks.
517
+ """
518
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
519
+ return self.device
520
+ for module in self.unet.modules():
521
+ if (
522
+ hasattr(module, "_hf_hook")
523
+ and hasattr(module._hf_hook, "execution_device")
524
+ and module._hf_hook.execution_device is not None
525
+ ):
526
+ return torch.device(module._hf_hook.execution_device)
527
+ return self.device
528
+
529
+ def _encode_prompt(
530
+ self,
531
+ prompt,
532
+ device,
533
+ num_images_per_prompt,
534
+ do_classifier_free_guidance,
535
+ negative_prompt,
536
+ max_embeddings_multiples,
537
+ ):
538
+ r"""
539
+ Encodes the prompt into text encoder hidden states.
540
+
541
+ Args:
542
+ prompt (`str` or `list(int)`):
543
+ prompt to be encoded
544
+ device: (`torch.device`):
545
+ torch device
546
+ num_images_per_prompt (`int`):
547
+ number of images that should be generated per prompt
548
+ do_classifier_free_guidance (`bool`):
549
+ whether to use classifier free guidance or not
550
+ negative_prompt (`str` or `List[str]`):
551
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
552
+ if `guidance_scale` is less than `1`).
553
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
554
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
555
+ """
556
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
557
+
558
+ if negative_prompt is None:
559
+ negative_prompt = [""] * batch_size
560
+ elif isinstance(negative_prompt, str):
561
+ negative_prompt = [negative_prompt] * batch_size
562
+ if batch_size != len(negative_prompt):
563
+ raise ValueError(
564
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
565
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
566
+ " the batch size of `prompt`."
567
+ )
568
+
569
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
570
+ pipe=self,
571
+ prompt=prompt,
572
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
573
+ max_embeddings_multiples=max_embeddings_multiples,
574
+ clip_skip=self.clip_skip,
575
+ )
576
+ bs_embed, seq_len, _ = text_embeddings.shape
577
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
578
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
579
+
580
+ if do_classifier_free_guidance:
581
+ bs_embed, seq_len, _ = uncond_embeddings.shape
582
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
583
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
584
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
585
+
586
+ return text_embeddings
587
+
588
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
589
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
590
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
591
+
592
+ if strength < 0 or strength > 1:
593
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
594
+
595
+ if height % 8 != 0 or width % 8 != 0:
596
+ print(height, width)
597
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
598
+
599
+ if (callback_steps is None) or (
600
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
601
+ ):
602
+ raise ValueError(
603
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
604
+ )
605
+
606
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
607
+ if is_text2img:
608
+ return self.scheduler.timesteps.to(device), num_inference_steps
609
+ else:
610
+ # get the original timestep using init_timestep
611
+ offset = self.scheduler.config.get("steps_offset", 0)
612
+ init_timestep = int(num_inference_steps * strength) + offset
613
+ init_timestep = min(init_timestep, num_inference_steps)
614
+
615
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
616
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
617
+ return timesteps, num_inference_steps - t_start
618
+
619
+ def run_safety_checker(self, image, device, dtype):
620
+ if self.safety_checker is not None:
621
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
622
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
623
+ else:
624
+ has_nsfw_concept = None
625
+ return image, has_nsfw_concept
626
+
627
+ def decode_latents(self, latents):
628
+ latents = 1 / 0.18215 * latents
629
+ image = self.vae.decode(latents).sample
630
+ image = (image / 2 + 0.5).clamp(0, 1)
631
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
632
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
633
+ return image
634
+
635
+ def prepare_extra_step_kwargs(self, generator, eta):
636
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
637
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
638
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
639
+ # and should be between [0, 1]
640
+
641
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
642
+ extra_step_kwargs = {}
643
+ if accepts_eta:
644
+ extra_step_kwargs["eta"] = eta
645
+
646
+ # check if the scheduler accepts generator
647
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
648
+ if accepts_generator:
649
+ extra_step_kwargs["generator"] = generator
650
+ return extra_step_kwargs
651
+
652
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
653
+ if image is None:
654
+ shape = (
655
+ batch_size,
656
+ self.unet.in_channels,
657
+ height // self.vae_scale_factor,
658
+ width // self.vae_scale_factor,
659
+ )
660
+
661
+ if latents is None:
662
+ if device.type == "mps":
663
+ # randn does not work reproducibly on mps
664
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
665
+ else:
666
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
667
+ else:
668
+ if latents.shape != shape:
669
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
670
+ latents = latents.to(device)
671
+
672
+ # scale the initial noise by the standard deviation required by the scheduler
673
+ latents = latents * self.scheduler.init_noise_sigma
674
+ return latents, None, None
675
+ else:
676
+ init_latent_dist = self.vae.encode(image).latent_dist
677
+ init_latents = init_latent_dist.sample(generator=generator)
678
+ init_latents = 0.18215 * init_latents
679
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
680
+ init_latents_orig = init_latents
681
+ shape = init_latents.shape
682
+
683
+ # add noise to latents using the timesteps
684
+ if device.type == "mps":
685
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
686
+ else:
687
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
688
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
689
+ return latents, init_latents_orig, noise
690
+
691
+ @torch.no_grad()
692
+ def __call__(
693
+ self,
694
+ prompt: Union[str, List[str]],
695
+ negative_prompt: Optional[Union[str, List[str]]] = None,
696
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
697
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
698
+ height: int = 512,
699
+ width: int = 512,
700
+ num_inference_steps: int = 50,
701
+ guidance_scale: float = 7.5,
702
+ strength: float = 0.8,
703
+ num_images_per_prompt: Optional[int] = 1,
704
+ eta: float = 0.0,
705
+ generator: Optional[torch.Generator] = None,
706
+ latents: Optional[torch.FloatTensor] = None,
707
+ max_embeddings_multiples: Optional[int] = 3,
708
+ output_type: Optional[str] = "pil",
709
+ return_dict: bool = True,
710
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
711
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
712
+ callback_steps: int = 1,
713
+ ):
714
+ r"""
715
+ Function invoked when calling the pipeline for generation.
716
+
717
+ Args:
718
+ prompt (`str` or `List[str]`):
719
+ The prompt or prompts to guide the image generation.
720
+ negative_prompt (`str` or `List[str]`, *optional*):
721
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
722
+ if `guidance_scale` is less than `1`).
723
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
724
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
725
+ process.
726
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
727
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
728
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
729
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
730
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
731
+ height (`int`, *optional*, defaults to 512):
732
+ The height in pixels of the generated image.
733
+ width (`int`, *optional*, defaults to 512):
734
+ The width in pixels of the generated image.
735
+ num_inference_steps (`int`, *optional*, defaults to 50):
736
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
737
+ expense of slower inference.
738
+ guidance_scale (`float`, *optional*, defaults to 7.5):
739
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
740
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
741
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
742
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
743
+ usually at the expense of lower image quality.
744
+ strength (`float`, *optional*, defaults to 0.8):
745
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
746
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
747
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
748
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
749
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
750
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
751
+ The number of images to generate per prompt.
752
+ eta (`float`, *optional*, defaults to 0.0):
753
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
754
+ [`schedulers.DDIMScheduler`], will be ignored for others.
755
+ generator (`torch.Generator`, *optional*):
756
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
757
+ deterministic.
758
+ latents (`torch.FloatTensor`, *optional*):
759
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
760
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
761
+ tensor will ge generated by sampling using the supplied random `generator`.
762
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
763
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
764
+ output_type (`str`, *optional*, defaults to `"pil"`):
765
+ The output format of the generate image. Choose between
766
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
767
+ return_dict (`bool`, *optional*, defaults to `True`):
768
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
769
+ plain tuple.
770
+ callback (`Callable`, *optional*):
771
+ A function that will be called every `callback_steps` steps during inference. The function will be
772
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
773
+ is_cancelled_callback (`Callable`, *optional*):
774
+ A function that will be called every `callback_steps` steps during inference. If the function returns
775
+ `True`, the inference will be cancelled.
776
+ callback_steps (`int`, *optional*, defaults to 1):
777
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
778
+ called at every step.
779
+
780
+ Returns:
781
+ `None` if cancelled by `is_cancelled_callback`,
782
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
783
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
784
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
785
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
786
+ (nsfw) content, according to the `safety_checker`.
787
+ """
788
+ # 0. Default height and width to unet
789
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
790
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
791
+
792
+ # 1. Check inputs. Raise error if not correct
793
+ self.check_inputs(prompt, height, width, strength, callback_steps)
794
+
795
+ # 2. Define call parameters
796
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
797
+ device = self._execution_device
798
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
799
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
800
+ # corresponds to doing no classifier free guidance.
801
+ do_classifier_free_guidance = guidance_scale > 1.0
802
+
803
+ # 3. Encode input prompt
804
+ text_embeddings = self._encode_prompt(
805
+ prompt,
806
+ device,
807
+ num_images_per_prompt,
808
+ do_classifier_free_guidance,
809
+ negative_prompt,
810
+ max_embeddings_multiples,
811
+ )
812
+ dtype = text_embeddings.dtype
813
+
814
+ # 4. Preprocess image and mask
815
+ if isinstance(image, PIL.Image.Image):
816
+ image = preprocess_image(image)
817
+ if image is not None:
818
+ image = image.to(device=self.device, dtype=dtype)
819
+ if isinstance(mask_image, PIL.Image.Image):
820
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
821
+ if mask_image is not None:
822
+ mask = mask_image.to(device=self.device, dtype=dtype)
823
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
824
+ else:
825
+ mask = None
826
+
827
+ # 5. set timesteps
828
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
829
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
830
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
831
+
832
+ # 6. Prepare latent variables
833
+ latents, init_latents_orig, noise = self.prepare_latents(
834
+ image,
835
+ latent_timestep,
836
+ batch_size * num_images_per_prompt,
837
+ height,
838
+ width,
839
+ dtype,
840
+ device,
841
+ generator,
842
+ latents,
843
+ )
844
+
845
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
846
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
847
+
848
+ # 8. Denoising loop
849
+ for i, t in enumerate(self.progress_bar(timesteps)):
850
+ # expand the latents if we are doing classifier free guidance
851
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
852
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
853
+
854
+ # predict the noise residual
855
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
856
+
857
+ # perform guidance
858
+ if do_classifier_free_guidance:
859
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
860
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
861
+
862
+ # compute the previous noisy sample x_t -> x_t-1
863
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
864
+
865
+ if mask is not None:
866
+ # masking
867
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
868
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
869
+
870
+ # call the callback, if provided
871
+ if i % callback_steps == 0:
872
+ if callback is not None:
873
+ callback(i, t, latents)
874
+ if is_cancelled_callback is not None and is_cancelled_callback():
875
+ return None
876
+
877
+ # 9. Post-processing
878
+ image = self.decode_latents(latents)
879
+
880
+ # 10. Run safety checker
881
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
882
+
883
+ # 11. Convert to PIL
884
+ if output_type == "pil":
885
+ image = self.numpy_to_pil(image)
886
+
887
+ if not return_dict:
888
+ return image, has_nsfw_concept
889
+
890
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
891
+
892
+ def text2img(
893
+ self,
894
+ prompt: Union[str, List[str]],
895
+ negative_prompt: Optional[Union[str, List[str]]] = None,
896
+ height: int = 512,
897
+ width: int = 512,
898
+ num_inference_steps: int = 50,
899
+ guidance_scale: float = 7.5,
900
+ num_images_per_prompt: Optional[int] = 1,
901
+ eta: float = 0.0,
902
+ generator: Optional[torch.Generator] = None,
903
+ latents: Optional[torch.FloatTensor] = None,
904
+ max_embeddings_multiples: Optional[int] = 3,
905
+ output_type: Optional[str] = "pil",
906
+ return_dict: bool = True,
907
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
908
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
909
+ callback_steps: int = 1,
910
+ ):
911
+ r"""
912
+ Function for text-to-image generation.
913
+ Args:
914
+ prompt (`str` or `List[str]`):
915
+ The prompt or prompts to guide the image generation.
916
+ negative_prompt (`str` or `List[str]`, *optional*):
917
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
918
+ if `guidance_scale` is less than `1`).
919
+ height (`int`, *optional*, defaults to 512):
920
+ The height in pixels of the generated image.
921
+ width (`int`, *optional*, defaults to 512):
922
+ The width in pixels of the generated image.
923
+ num_inference_steps (`int`, *optional*, defaults to 50):
924
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
925
+ expense of slower inference.
926
+ guidance_scale (`float`, *optional*, defaults to 7.5):
927
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
928
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
929
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
930
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
931
+ usually at the expense of lower image quality.
932
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
933
+ The number of images to generate per prompt.
934
+ eta (`float`, *optional*, defaults to 0.0):
935
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
936
+ [`schedulers.DDIMScheduler`], will be ignored for others.
937
+ generator (`torch.Generator`, *optional*):
938
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
939
+ deterministic.
940
+ latents (`torch.FloatTensor`, *optional*):
941
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
942
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
943
+ tensor will ge generated by sampling using the supplied random `generator`.
944
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
945
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
946
+ output_type (`str`, *optional*, defaults to `"pil"`):
947
+ The output format of the generate image. Choose between
948
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
949
+ return_dict (`bool`, *optional*, defaults to `True`):
950
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
951
+ plain tuple.
952
+ callback (`Callable`, *optional*):
953
+ A function that will be called every `callback_steps` steps during inference. The function will be
954
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
955
+ is_cancelled_callback (`Callable`, *optional*):
956
+ A function that will be called every `callback_steps` steps during inference. If the function returns
957
+ `True`, the inference will be cancelled.
958
+ callback_steps (`int`, *optional*, defaults to 1):
959
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
960
+ called at every step.
961
+ Returns:
962
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
963
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
964
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
965
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
966
+ (nsfw) content, according to the `safety_checker`.
967
+ """
968
+ return self.__call__(
969
+ prompt=prompt,
970
+ negative_prompt=negative_prompt,
971
+ height=height,
972
+ width=width,
973
+ num_inference_steps=num_inference_steps,
974
+ guidance_scale=guidance_scale,
975
+ num_images_per_prompt=num_images_per_prompt,
976
+ eta=eta,
977
+ generator=generator,
978
+ latents=latents,
979
+ max_embeddings_multiples=max_embeddings_multiples,
980
+ output_type=output_type,
981
+ return_dict=return_dict,
982
+ callback=callback,
983
+ is_cancelled_callback=is_cancelled_callback,
984
+ callback_steps=callback_steps,
985
+ )
986
+
987
+ def img2img(
988
+ self,
989
+ image: Union[torch.FloatTensor, PIL.Image.Image],
990
+ prompt: Union[str, List[str]],
991
+ negative_prompt: Optional[Union[str, List[str]]] = None,
992
+ strength: float = 0.8,
993
+ num_inference_steps: Optional[int] = 50,
994
+ guidance_scale: Optional[float] = 7.5,
995
+ num_images_per_prompt: Optional[int] = 1,
996
+ eta: Optional[float] = 0.0,
997
+ generator: Optional[torch.Generator] = None,
998
+ max_embeddings_multiples: Optional[int] = 3,
999
+ output_type: Optional[str] = "pil",
1000
+ return_dict: bool = True,
1001
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1002
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1003
+ callback_steps: int = 1,
1004
+ ):
1005
+ r"""
1006
+ Function for image-to-image generation.
1007
+ Args:
1008
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1009
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1010
+ process.
1011
+ prompt (`str` or `List[str]`):
1012
+ The prompt or prompts to guide the image generation.
1013
+ negative_prompt (`str` or `List[str]`, *optional*):
1014
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1015
+ if `guidance_scale` is less than `1`).
1016
+ strength (`float`, *optional*, defaults to 0.8):
1017
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1018
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1019
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1020
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1021
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1022
+ num_inference_steps (`int`, *optional*, defaults to 50):
1023
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1024
+ expense of slower inference. This parameter will be modulated by `strength`.
1025
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1026
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1027
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1028
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1029
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1030
+ usually at the expense of lower image quality.
1031
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1032
+ The number of images to generate per prompt.
1033
+ eta (`float`, *optional*, defaults to 0.0):
1034
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1035
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1036
+ generator (`torch.Generator`, *optional*):
1037
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1038
+ deterministic.
1039
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1040
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1041
+ output_type (`str`, *optional*, defaults to `"pil"`):
1042
+ The output format of the generate image. Choose between
1043
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1044
+ return_dict (`bool`, *optional*, defaults to `True`):
1045
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1046
+ plain tuple.
1047
+ callback (`Callable`, *optional*):
1048
+ A function that will be called every `callback_steps` steps during inference. The function will be
1049
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1050
+ is_cancelled_callback (`Callable`, *optional*):
1051
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1052
+ `True`, the inference will be cancelled.
1053
+ callback_steps (`int`, *optional*, defaults to 1):
1054
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1055
+ called at every step.
1056
+ Returns:
1057
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1058
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1059
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1060
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1061
+ (nsfw) content, according to the `safety_checker`.
1062
+ """
1063
+ return self.__call__(
1064
+ prompt=prompt,
1065
+ negative_prompt=negative_prompt,
1066
+ image=image,
1067
+ num_inference_steps=num_inference_steps,
1068
+ guidance_scale=guidance_scale,
1069
+ strength=strength,
1070
+ num_images_per_prompt=num_images_per_prompt,
1071
+ eta=eta,
1072
+ generator=generator,
1073
+ max_embeddings_multiples=max_embeddings_multiples,
1074
+ output_type=output_type,
1075
+ return_dict=return_dict,
1076
+ callback=callback,
1077
+ is_cancelled_callback=is_cancelled_callback,
1078
+ callback_steps=callback_steps,
1079
+ )
1080
+
1081
+ def inpaint(
1082
+ self,
1083
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1084
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1085
+ prompt: Union[str, List[str]],
1086
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1087
+ strength: float = 0.8,
1088
+ num_inference_steps: Optional[int] = 50,
1089
+ guidance_scale: Optional[float] = 7.5,
1090
+ num_images_per_prompt: Optional[int] = 1,
1091
+ eta: Optional[float] = 0.0,
1092
+ generator: Optional[torch.Generator] = None,
1093
+ max_embeddings_multiples: Optional[int] = 3,
1094
+ output_type: Optional[str] = "pil",
1095
+ return_dict: bool = True,
1096
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1097
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1098
+ callback_steps: int = 1,
1099
+ ):
1100
+ r"""
1101
+ Function for inpaint.
1102
+ Args:
1103
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1104
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1105
+ process. This is the image whose masked region will be inpainted.
1106
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1107
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1108
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1109
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1110
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1111
+ prompt (`str` or `List[str]`):
1112
+ The prompt or prompts to guide the image generation.
1113
+ negative_prompt (`str` or `List[str]`, *optional*):
1114
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1115
+ if `guidance_scale` is less than `1`).
1116
+ strength (`float`, *optional*, defaults to 0.8):
1117
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1118
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1119
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1120
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1121
+ num_inference_steps (`int`, *optional*, defaults to 50):
1122
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1123
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1124
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1125
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1126
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1127
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1128
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1129
+ usually at the expense of lower image quality.
1130
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1131
+ The number of images to generate per prompt.
1132
+ eta (`float`, *optional*, defaults to 0.0):
1133
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1134
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1135
+ generator (`torch.Generator`, *optional*):
1136
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1137
+ deterministic.
1138
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1139
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1140
+ output_type (`str`, *optional*, defaults to `"pil"`):
1141
+ The output format of the generate image. Choose between
1142
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1143
+ return_dict (`bool`, *optional*, defaults to `True`):
1144
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1145
+ plain tuple.
1146
+ callback (`Callable`, *optional*):
1147
+ A function that will be called every `callback_steps` steps during inference. The function will be
1148
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1149
+ is_cancelled_callback (`Callable`, *optional*):
1150
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1151
+ `True`, the inference will be cancelled.
1152
+ callback_steps (`int`, *optional*, defaults to 1):
1153
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1154
+ called at every step.
1155
+ Returns:
1156
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1157
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1158
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1159
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1160
+ (nsfw) content, according to the `safety_checker`.
1161
+ """
1162
+ return self.__call__(
1163
+ prompt=prompt,
1164
+ negative_prompt=negative_prompt,
1165
+ image=image,
1166
+ mask_image=mask_image,
1167
+ num_inference_steps=num_inference_steps,
1168
+ guidance_scale=guidance_scale,
1169
+ strength=strength,
1170
+ num_images_per_prompt=num_images_per_prompt,
1171
+ eta=eta,
1172
+ generator=generator,
1173
+ max_embeddings_multiples=max_embeddings_multiples,
1174
+ output_type=output_type,
1175
+ return_dict=return_dict,
1176
+ callback=callback,
1177
+ is_cancelled_callback=is_cancelled_callback,
1178
+ callback_steps=callback_steps,
1179
+ )
library/merge_lora_gui.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ from .common_gui import (
6
+ get_saveasfilename_path,
7
+ get_any_file_path,
8
+ get_file_path,
9
+ )
10
+
11
+ folder_symbol = '\U0001f4c2' # 📂
12
+ refresh_symbol = '\U0001f504' # 🔄
13
+ save_style_symbol = '\U0001f4be' # 💾
14
+ document_symbol = '\U0001F4C4' # 📄
15
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
16
+
17
+
18
+ def merge_lora(
19
+ lora_a_model,
20
+ lora_b_model,
21
+ ratio,
22
+ save_to,
23
+ precision,
24
+ save_precision,
25
+ ):
26
+ # Check for caption_text_input
27
+ if lora_a_model == '':
28
+ msgbox('Invalid model A file')
29
+ return
30
+
31
+ if lora_b_model == '':
32
+ msgbox('Invalid model B file')
33
+ return
34
+
35
+ # Check if source model exist
36
+ if not os.path.isfile(lora_a_model):
37
+ msgbox('The provided model A is not a file')
38
+ return
39
+
40
+ if not os.path.isfile(lora_b_model):
41
+ msgbox('The provided model B is not a file')
42
+ return
43
+
44
+ ratio_a = ratio
45
+ ratio_b = 1 - ratio
46
+
47
+ run_cmd = f'{PYTHON} "{os.path.join("networks","merge_lora.py")}"'
48
+ run_cmd += f' --save_precision {save_precision}'
49
+ run_cmd += f' --precision {precision}'
50
+ run_cmd += f' --save_to "{save_to}"'
51
+ run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"'
52
+ run_cmd += f' --ratios {ratio_a} {ratio_b}'
53
+
54
+ print(run_cmd)
55
+
56
+ # Run the command
57
+ if os.name == 'posix':
58
+ os.system(run_cmd)
59
+ else:
60
+ subprocess.run(run_cmd)
61
+
62
+
63
+ ###
64
+ # Gradio UI
65
+ ###
66
+
67
+
68
+ def gradio_merge_lora_tab():
69
+ with gr.Tab('Merge LoRA'):
70
+ gr.Markdown('This utility can merge two LoRA networks together.')
71
+
72
+ lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
73
+ lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
74
+
75
+ with gr.Row():
76
+ lora_a_model = gr.Textbox(
77
+ label='LoRA model "A"',
78
+ placeholder='Path to the LoRA A model',
79
+ interactive=True,
80
+ )
81
+ button_lora_a_model_file = gr.Button(
82
+ folder_symbol, elem_id='open_folder_small'
83
+ )
84
+ button_lora_a_model_file.click(
85
+ get_file_path,
86
+ inputs=[lora_a_model, lora_ext, lora_ext_name],
87
+ outputs=lora_a_model,
88
+ show_progress=False,
89
+ )
90
+
91
+ lora_b_model = gr.Textbox(
92
+ label='LoRA model "B"',
93
+ placeholder='Path to the LoRA B model',
94
+ interactive=True,
95
+ )
96
+ button_lora_b_model_file = gr.Button(
97
+ folder_symbol, elem_id='open_folder_small'
98
+ )
99
+ button_lora_b_model_file.click(
100
+ get_file_path,
101
+ inputs=[lora_b_model, lora_ext, lora_ext_name],
102
+ outputs=lora_b_model,
103
+ show_progress=False,
104
+ )
105
+ with gr.Row():
106
+ ratio = gr.Slider(
107
+ label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
108
+ minimum=0,
109
+ maximum=1,
110
+ step=0.01,
111
+ value=0.5,
112
+ interactive=True,
113
+ )
114
+
115
+ with gr.Row():
116
+ save_to = gr.Textbox(
117
+ label='Save to',
118
+ placeholder='path for the file to save...',
119
+ interactive=True,
120
+ )
121
+ button_save_to = gr.Button(
122
+ folder_symbol, elem_id='open_folder_small'
123
+ )
124
+ button_save_to.click(
125
+ get_saveasfilename_path,
126
+ inputs=[save_to, lora_ext, lora_ext_name],
127
+ outputs=save_to,
128
+ show_progress=False,
129
+ )
130
+ precision = gr.Dropdown(
131
+ label='Merge precision',
132
+ choices=['fp16', 'bf16', 'float'],
133
+ value='float',
134
+ interactive=True,
135
+ )
136
+ save_precision = gr.Dropdown(
137
+ label='Save precision',
138
+ choices=['fp16', 'bf16', 'float'],
139
+ value='float',
140
+ interactive=True,
141
+ )
142
+
143
+ convert_button = gr.Button('Merge model')
144
+
145
+ convert_button.click(
146
+ merge_lora,
147
+ inputs=[
148
+ lora_a_model,
149
+ lora_b_model,
150
+ ratio,
151
+ save_to,
152
+ precision,
153
+ save_precision,
154
+ ],
155
+ show_progress=False,
156
+ )
library/model_util.py ADDED
@@ -0,0 +1,1165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
8
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
+ from safetensors.torch import load_file, save_file
10
+
11
+ # DiffUsers版StableDiffusionのモデルパラメータ
12
+ NUM_TRAIN_TIMESTEPS = 1000
13
+ BETA_START = 0.00085
14
+ BETA_END = 0.0120
15
+
16
+ UNET_PARAMS_MODEL_CHANNELS = 320
17
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
18
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
19
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
20
+ UNET_PARAMS_IN_CHANNELS = 4
21
+ UNET_PARAMS_OUT_CHANNELS = 4
22
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
23
+ UNET_PARAMS_CONTEXT_DIM = 768
24
+ UNET_PARAMS_NUM_HEADS = 8
25
+
26
+ VAE_PARAMS_Z_CHANNELS = 4
27
+ VAE_PARAMS_RESOLUTION = 256
28
+ VAE_PARAMS_IN_CHANNELS = 3
29
+ VAE_PARAMS_OUT_CH = 3
30
+ VAE_PARAMS_CH = 128
31
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
32
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
33
+
34
+ # V2
35
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
36
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
37
+
38
+ # Diffusersの設定を読み込むための参照モデル
39
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
40
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
41
+
42
+
43
+ # region StableDiffusion->Diffusersの変換コード
44
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
45
+
46
+
47
+ def shave_segments(path, n_shave_prefix_segments=1):
48
+ """
49
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
50
+ """
51
+ if n_shave_prefix_segments >= 0:
52
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
53
+ else:
54
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
55
+
56
+
57
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
+
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
+
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({"old": old_item, "new": new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside resnets to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+
87
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
+
90
+ mapping.append({"old": old_item, "new": new_item})
91
+
92
+ return mapping
93
+
94
+
95
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
+ """
97
+ Updates paths inside attentions to the new naming scheme (local renaming)
98
+ """
99
+ mapping = []
100
+ for old_item in old_list:
101
+ new_item = old_item
102
+
103
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
+
106
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
+
109
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
+
111
+ mapping.append({"old": old_item, "new": new_item})
112
+
113
+ return mapping
114
+
115
+
116
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
+ """
118
+ Updates paths inside attentions to the new naming scheme (local renaming)
119
+ """
120
+ mapping = []
121
+ for old_item in old_list:
122
+ new_item = old_item
123
+
124
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
125
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
126
+
127
+ new_item = new_item.replace("q.weight", "query.weight")
128
+ new_item = new_item.replace("q.bias", "query.bias")
129
+
130
+ new_item = new_item.replace("k.weight", "key.weight")
131
+ new_item = new_item.replace("k.bias", "key.bias")
132
+
133
+ new_item = new_item.replace("v.weight", "value.weight")
134
+ new_item = new_item.replace("v.bias", "value.bias")
135
+
136
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
+
139
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
+
141
+ mapping.append({"old": old_item, "new": new_item})
142
+
143
+ return mapping
144
+
145
+
146
+ def assign_to_checkpoint(
147
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
+ ):
149
+ """
150
+ This does the final conversion step: take locally converted weights and apply a global renaming
151
+ to them. It splits attention layers, and takes into account additional replacements
152
+ that may arise.
153
+
154
+ Assigns the weights to the new checkpoint.
155
+ """
156
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
+
158
+ # Splits the attention layers into three variables.
159
+ if attention_paths_to_split is not None:
160
+ for path, path_map in attention_paths_to_split.items():
161
+ old_tensor = old_checkpoint[path]
162
+ channels = old_tensor.shape[0] // 3
163
+
164
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
+
166
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
+
168
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
+
171
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
172
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
173
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
174
+
175
+ for path in paths:
176
+ new_path = path["new"]
177
+
178
+ # These have already been assigned
179
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
+ continue
181
+
182
+ # Global renaming happens here
183
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
+
187
+ if additional_replacements is not None:
188
+ for replacement in additional_replacements:
189
+ new_path = new_path.replace(replacement["old"], replacement["new"])
190
+
191
+ # proj_attn.weight has to be converted from conv 1D to linear
192
+ if "proj_attn.weight" in new_path:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
+ else:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]]
196
+
197
+
198
+ def conv_attn_to_linear(checkpoint):
199
+ keys = list(checkpoint.keys())
200
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
201
+ for key in keys:
202
+ if ".".join(key.split(".")[-2:]) in attn_keys:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
+ elif "proj_attn.weight" in key:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0]
208
+
209
+
210
+ def linear_transformer_to_conv(checkpoint):
211
+ keys = list(checkpoint.keys())
212
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
213
+ for key in keys:
214
+ if ".".join(key.split(".")[-2:]) in tf_keys:
215
+ if checkpoint[key].ndim == 2:
216
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
217
+
218
+
219
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
220
+ """
221
+ Takes a state dict and a config, and returns a converted checkpoint.
222
+ """
223
+
224
+ # extract state_dict for UNet
225
+ unet_state_dict = {}
226
+ unet_key = "model.diffusion_model."
227
+ keys = list(checkpoint.keys())
228
+ for key in keys:
229
+ if key.startswith(unet_key):
230
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
231
+
232
+ new_checkpoint = {}
233
+
234
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
235
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
236
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
237
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
238
+
239
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
240
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
241
+
242
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
+
247
+ # Retrieves the keys for the input blocks only
248
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
+ input_blocks = {
250
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
251
+ }
252
+
253
+ # Retrieves the keys for the middle blocks only
254
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
255
+ middle_blocks = {
256
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
257
+ }
258
+
259
+ # Retrieves the keys for the output blocks only
260
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
261
+ output_blocks = {
262
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
263
+ }
264
+
265
+ for i in range(1, num_input_blocks):
266
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
267
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
268
+
269
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
270
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
271
+
272
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
273
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
274
+ f"input_blocks.{i}.0.op.weight"
275
+ )
276
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
277
+
278
+ paths = renew_resnet_paths(resnets)
279
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
280
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
281
+
282
+ if len(attentions):
283
+ paths = renew_attention_paths(attentions)
284
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
285
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
286
+
287
+ resnet_0 = middle_blocks[0]
288
+ attentions = middle_blocks[1]
289
+ resnet_1 = middle_blocks[2]
290
+
291
+ resnet_0_paths = renew_resnet_paths(resnet_0)
292
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
293
+
294
+ resnet_1_paths = renew_resnet_paths(resnet_1)
295
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
296
+
297
+ attentions_paths = renew_attention_paths(attentions)
298
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
299
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
300
+
301
+ for i in range(num_output_blocks):
302
+ block_id = i // (config["layers_per_block"] + 1)
303
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
304
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
305
+ output_block_list = {}
306
+
307
+ for layer in output_block_layers:
308
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
309
+ if layer_id in output_block_list:
310
+ output_block_list[layer_id].append(layer_name)
311
+ else:
312
+ output_block_list[layer_id] = [layer_name]
313
+
314
+ if len(output_block_list) > 1:
315
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
316
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
317
+
318
+ resnet_0_paths = renew_resnet_paths(resnets)
319
+ paths = renew_resnet_paths(resnets)
320
+
321
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
322
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
323
+
324
+ # オリジナル:
325
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
326
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
327
+
328
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
329
+ for l in output_block_list.values():
330
+ l.sort()
331
+
332
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
333
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
334
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
335
+ f"output_blocks.{i}.{index}.conv.bias"
336
+ ]
337
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
338
+ f"output_blocks.{i}.{index}.conv.weight"
339
+ ]
340
+
341
+ # Clear attentions as they have been attributed above.
342
+ if len(attentions) == 2:
343
+ attentions = []
344
+
345
+ if len(attentions):
346
+ paths = renew_attention_paths(attentions)
347
+ meta_path = {
348
+ "old": f"output_blocks.{i}.1",
349
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
350
+ }
351
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
352
+ else:
353
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
354
+ for path in resnet_0_paths:
355
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
356
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
357
+
358
+ new_checkpoint[new_path] = unet_state_dict[old_path]
359
+
360
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
361
+ if v2:
362
+ linear_transformer_to_conv(new_checkpoint)
363
+
364
+ return new_checkpoint
365
+
366
+
367
+ def convert_ldm_vae_checkpoint(checkpoint, config):
368
+ # extract state dict for VAE
369
+ vae_state_dict = {}
370
+ vae_key = "first_stage_model."
371
+ keys = list(checkpoint.keys())
372
+ for key in keys:
373
+ if key.startswith(vae_key):
374
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
375
+ # if len(vae_state_dict) == 0:
376
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
377
+ # vae_state_dict = checkpoint
378
+
379
+ new_checkpoint = {}
380
+
381
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
382
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
383
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
384
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
385
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
386
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
387
+
388
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
389
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
390
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
391
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
392
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
393
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
394
+
395
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
396
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
397
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
398
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
399
+
400
+ # Retrieves the keys for the encoder down blocks only
401
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
402
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
403
+
404
+ # Retrieves the keys for the decoder up blocks only
405
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
406
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
407
+
408
+ for i in range(num_down_blocks):
409
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
410
+
411
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
412
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
413
+ f"encoder.down.{i}.downsample.conv.weight"
414
+ )
415
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
416
+ f"encoder.down.{i}.downsample.conv.bias"
417
+ )
418
+
419
+ paths = renew_vae_resnet_paths(resnets)
420
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
421
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
422
+
423
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
424
+ num_mid_res_blocks = 2
425
+ for i in range(1, num_mid_res_blocks + 1):
426
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
427
+
428
+ paths = renew_vae_resnet_paths(resnets)
429
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
430
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
431
+
432
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
433
+ paths = renew_vae_attention_paths(mid_attentions)
434
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
435
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
436
+ conv_attn_to_linear(new_checkpoint)
437
+
438
+ for i in range(num_up_blocks):
439
+ block_id = num_up_blocks - 1 - i
440
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
441
+
442
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
443
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
444
+ f"decoder.up.{block_id}.upsample.conv.weight"
445
+ ]
446
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
447
+ f"decoder.up.{block_id}.upsample.conv.bias"
448
+ ]
449
+
450
+ paths = renew_vae_resnet_paths(resnets)
451
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
452
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
453
+
454
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
455
+ num_mid_res_blocks = 2
456
+ for i in range(1, num_mid_res_blocks + 1):
457
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
458
+
459
+ paths = renew_vae_resnet_paths(resnets)
460
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
461
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
462
+
463
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
464
+ paths = renew_vae_attention_paths(mid_attentions)
465
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
466
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
467
+ conv_attn_to_linear(new_checkpoint)
468
+ return new_checkpoint
469
+
470
+
471
+ def create_unet_diffusers_config(v2):
472
+ """
473
+ Creates a config for the diffusers based on the config of the LDM model.
474
+ """
475
+ # unet_params = original_config.model.params.unet_config.params
476
+
477
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
478
+
479
+ down_block_types = []
480
+ resolution = 1
481
+ for i in range(len(block_out_channels)):
482
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
483
+ down_block_types.append(block_type)
484
+ if i != len(block_out_channels) - 1:
485
+ resolution *= 2
486
+
487
+ up_block_types = []
488
+ for i in range(len(block_out_channels)):
489
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
490
+ up_block_types.append(block_type)
491
+ resolution //= 2
492
+
493
+ config = dict(
494
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
495
+ in_channels=UNET_PARAMS_IN_CHANNELS,
496
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
497
+ down_block_types=tuple(down_block_types),
498
+ up_block_types=tuple(up_block_types),
499
+ block_out_channels=tuple(block_out_channels),
500
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
501
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
502
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
503
+ )
504
+
505
+ return config
506
+
507
+
508
+ def create_vae_diffusers_config():
509
+ """
510
+ Creates a config for the diffusers based on the config of the LDM model.
511
+ """
512
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
513
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
514
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
515
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
516
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
517
+
518
+ config = dict(
519
+ sample_size=VAE_PARAMS_RESOLUTION,
520
+ in_channels=VAE_PARAMS_IN_CHANNELS,
521
+ out_channels=VAE_PARAMS_OUT_CH,
522
+ down_block_types=tuple(down_block_types),
523
+ up_block_types=tuple(up_block_types),
524
+ block_out_channels=tuple(block_out_channels),
525
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
526
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
527
+ )
528
+ return config
529
+
530
+
531
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
532
+ keys = list(checkpoint.keys())
533
+ text_model_dict = {}
534
+ for key in keys:
535
+ if key.startswith("cond_stage_model.transformer"):
536
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
537
+ return text_model_dict
538
+
539
+
540
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
541
+ # 嫌になるくらい違うぞ!
542
+ def convert_key(key):
543
+ if not key.startswith("cond_stage_model"):
544
+ return None
545
+
546
+ # common conversion
547
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
548
+ key = key.replace("cond_stage_model.model.", "text_model.")
549
+
550
+ if "resblocks" in key:
551
+ # resblocks conversion
552
+ key = key.replace(".resblocks.", ".layers.")
553
+ if ".ln_" in key:
554
+ key = key.replace(".ln_", ".layer_norm")
555
+ elif ".mlp." in key:
556
+ key = key.replace(".c_fc.", ".fc1.")
557
+ key = key.replace(".c_proj.", ".fc2.")
558
+ elif ".attn.out_proj" in key:
559
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
560
+ elif ".attn.in_proj" in key:
561
+ key = None # 特殊なので後で処理する
562
+ else:
563
+ raise ValueError(f"unexpected key in SD: {key}")
564
+ elif ".positional_embedding" in key:
565
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
566
+ elif ".text_projection" in key:
567
+ key = None # 使われない???
568
+ elif ".logit_scale" in key:
569
+ key = None # 使われない???
570
+ elif ".token_embedding" in key:
571
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
572
+ elif ".ln_final" in key:
573
+ key = key.replace(".ln_final", ".final_layer_norm")
574
+ return key
575
+
576
+ keys = list(checkpoint.keys())
577
+ new_sd = {}
578
+ for key in keys:
579
+ # remove resblocks 23
580
+ if ".resblocks.23." in key:
581
+ continue
582
+ new_key = convert_key(key)
583
+ if new_key is None:
584
+ continue
585
+ new_sd[new_key] = checkpoint[key]
586
+
587
+ # attnの変換
588
+ for key in keys:
589
+ if ".resblocks.23." in key:
590
+ continue
591
+ if ".resblocks" in key and ".attn.in_proj_" in key:
592
+ # 三つに分割
593
+ values = torch.chunk(checkpoint[key], 3)
594
+
595
+ key_suffix = ".weight" if "weight" in key else ".bias"
596
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
597
+ key_pfx = key_pfx.replace("_weight", "")
598
+ key_pfx = key_pfx.replace("_bias", "")
599
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
600
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
601
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
602
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
603
+
604
+ # rename or add position_ids
605
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
606
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
607
+ # waifu diffusion v1.4
608
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
609
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
610
+ else:
611
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
612
+
613
+ new_sd["text_model.embeddings.position_ids"] = position_ids
614
+ return new_sd
615
+
616
+
617
+ # endregion
618
+
619
+
620
+ # region Diffusers->StableDiffusion の変換コード
621
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
622
+
623
+
624
+ def conv_transformer_to_linear(checkpoint):
625
+ keys = list(checkpoint.keys())
626
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
627
+ for key in keys:
628
+ if ".".join(key.split(".")[-2:]) in tf_keys:
629
+ if checkpoint[key].ndim > 2:
630
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
631
+
632
+
633
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
634
+ unet_conversion_map = [
635
+ # (stable-diffusion, HF Diffusers)
636
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
637
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
638
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
639
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
640
+ ("input_blocks.0.0.weight", "conv_in.weight"),
641
+ ("input_blocks.0.0.bias", "conv_in.bias"),
642
+ ("out.0.weight", "conv_norm_out.weight"),
643
+ ("out.0.bias", "conv_norm_out.bias"),
644
+ ("out.2.weight", "conv_out.weight"),
645
+ ("out.2.bias", "conv_out.bias"),
646
+ ]
647
+
648
+ unet_conversion_map_resnet = [
649
+ # (stable-diffusion, HF Diffusers)
650
+ ("in_layers.0", "norm1"),
651
+ ("in_layers.2", "conv1"),
652
+ ("out_layers.0", "norm2"),
653
+ ("out_layers.3", "conv2"),
654
+ ("emb_layers.1", "time_emb_proj"),
655
+ ("skip_connection", "conv_shortcut"),
656
+ ]
657
+
658
+ unet_conversion_map_layer = []
659
+ for i in range(4):
660
+ # loop over downblocks/upblocks
661
+
662
+ for j in range(2):
663
+ # loop over resnets/attentions for downblocks
664
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
665
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
666
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
667
+
668
+ if i < 3:
669
+ # no attention layers in down_blocks.3
670
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
671
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
672
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
673
+
674
+ for j in range(3):
675
+ # loop over resnets/attentions for upblocks
676
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
677
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
678
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
679
+
680
+ if i > 0:
681
+ # no attention layers in up_blocks.0
682
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
683
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
684
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
685
+
686
+ if i < 3:
687
+ # no downsample in down_blocks.3
688
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
689
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
690
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
691
+
692
+ # no upsample in up_blocks.3
693
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
694
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
695
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
696
+
697
+ hf_mid_atn_prefix = "mid_block.attentions.0."
698
+ sd_mid_atn_prefix = "middle_block.1."
699
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
700
+
701
+ for j in range(2):
702
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
703
+ sd_mid_res_prefix = f"middle_block.{2*j}."
704
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
705
+
706
+ # buyer beware: this is a *brittle* function,
707
+ # and correct output requires that all of these pieces interact in
708
+ # the exact order in which I have arranged them.
709
+ mapping = {k: k for k in unet_state_dict.keys()}
710
+ for sd_name, hf_name in unet_conversion_map:
711
+ mapping[hf_name] = sd_name
712
+ for k, v in mapping.items():
713
+ if "resnets" in k:
714
+ for sd_part, hf_part in unet_conversion_map_resnet:
715
+ v = v.replace(hf_part, sd_part)
716
+ mapping[k] = v
717
+ for k, v in mapping.items():
718
+ for sd_part, hf_part in unet_conversion_map_layer:
719
+ v = v.replace(hf_part, sd_part)
720
+ mapping[k] = v
721
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
722
+
723
+ if v2:
724
+ conv_transformer_to_linear(new_state_dict)
725
+
726
+ return new_state_dict
727
+
728
+
729
+ # ================#
730
+ # VAE Conversion #
731
+ # ================#
732
+
733
+
734
+ def reshape_weight_for_sd(w):
735
+ # convert HF linear weights to SD conv2d weights
736
+ return w.reshape(*w.shape, 1, 1)
737
+
738
+
739
+ def convert_vae_state_dict(vae_state_dict):
740
+ vae_conversion_map = [
741
+ # (stable-diffusion, HF Diffusers)
742
+ ("nin_shortcut", "conv_shortcut"),
743
+ ("norm_out", "conv_norm_out"),
744
+ ("mid.attn_1.", "mid_block.attentions.0."),
745
+ ]
746
+
747
+ for i in range(4):
748
+ # down_blocks have two resnets
749
+ for j in range(2):
750
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
751
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
752
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
753
+
754
+ if i < 3:
755
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
756
+ sd_downsample_prefix = f"down.{i}.downsample."
757
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
758
+
759
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
760
+ sd_upsample_prefix = f"up.{3-i}.upsample."
761
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
762
+
763
+ # up_blocks have three resnets
764
+ # also, up blocks in hf are numbered in reverse from sd
765
+ for j in range(3):
766
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
767
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
768
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
769
+
770
+ # this part accounts for mid blocks in both the encoder and the decoder
771
+ for i in range(2):
772
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
773
+ sd_mid_res_prefix = f"mid.block_{i+1}."
774
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
775
+
776
+ vae_conversion_map_attn = [
777
+ # (stable-diffusion, HF Diffusers)
778
+ ("norm.", "group_norm."),
779
+ ("q.", "query."),
780
+ ("k.", "key."),
781
+ ("v.", "value."),
782
+ ("proj_out.", "proj_attn."),
783
+ ]
784
+
785
+ mapping = {k: k for k in vae_state_dict.keys()}
786
+ for k, v in mapping.items():
787
+ for sd_part, hf_part in vae_conversion_map:
788
+ v = v.replace(hf_part, sd_part)
789
+ mapping[k] = v
790
+ for k, v in mapping.items():
791
+ if "attentions" in k:
792
+ for sd_part, hf_part in vae_conversion_map_attn:
793
+ v = v.replace(hf_part, sd_part)
794
+ mapping[k] = v
795
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
796
+ weights_to_convert = ["q", "k", "v", "proj_out"]
797
+ for k, v in new_state_dict.items():
798
+ for weight_name in weights_to_convert:
799
+ if f"mid.attn_1.{weight_name}.weight" in k:
800
+ # print(f"Reshaping {k} for SD format")
801
+ new_state_dict[k] = reshape_weight_for_sd(v)
802
+
803
+ return new_state_dict
804
+
805
+
806
+ # endregion
807
+
808
+ # region 自作のモデル読み書きなど
809
+
810
+
811
+ def is_safetensors(path):
812
+ return os.path.splitext(path)[1].lower() == ".safetensors"
813
+
814
+
815
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
816
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
817
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
818
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
819
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
820
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
821
+ ]
822
+
823
+ if is_safetensors(ckpt_path):
824
+ checkpoint = None
825
+ state_dict = load_file(ckpt_path) # , device) # may causes error
826
+ else:
827
+ checkpoint = torch.load(ckpt_path, map_location=device)
828
+ if "state_dict" in checkpoint:
829
+ state_dict = checkpoint["state_dict"]
830
+ else:
831
+ state_dict = checkpoint
832
+ checkpoint = None
833
+
834
+ key_reps = []
835
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
836
+ for key in state_dict.keys():
837
+ if key.startswith(rep_from):
838
+ new_key = rep_to + key[len(rep_from) :]
839
+ key_reps.append((key, new_key))
840
+
841
+ for key, new_key in key_reps:
842
+ state_dict[new_key] = state_dict[key]
843
+ del state_dict[key]
844
+
845
+ return checkpoint, state_dict
846
+
847
+
848
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
849
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
850
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
851
+
852
+ # Convert the UNet2DConditionModel model.
853
+ unet_config = create_unet_diffusers_config(v2)
854
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
855
+
856
+ unet = UNet2DConditionModel(**unet_config).to(device)
857
+ info = unet.load_state_dict(converted_unet_checkpoint)
858
+ print("loading u-net:", info)
859
+
860
+ # Convert the VAE model.
861
+ vae_config = create_vae_diffusers_config()
862
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
863
+
864
+ vae = AutoencoderKL(**vae_config).to(device)
865
+ info = vae.load_state_dict(converted_vae_checkpoint)
866
+ print("loading vae:", info)
867
+
868
+ # convert text_model
869
+ if v2:
870
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
871
+ cfg = CLIPTextConfig(
872
+ vocab_size=49408,
873
+ hidden_size=1024,
874
+ intermediate_size=4096,
875
+ num_hidden_layers=23,
876
+ num_attention_heads=16,
877
+ max_position_embeddings=77,
878
+ hidden_act="gelu",
879
+ layer_norm_eps=1e-05,
880
+ dropout=0.0,
881
+ attention_dropout=0.0,
882
+ initializer_range=0.02,
883
+ initializer_factor=1.0,
884
+ pad_token_id=1,
885
+ bos_token_id=0,
886
+ eos_token_id=2,
887
+ model_type="clip_text_model",
888
+ projection_dim=512,
889
+ torch_dtype="float32",
890
+ transformers_version="4.25.0.dev0",
891
+ )
892
+ text_model = CLIPTextModel._from_config(cfg)
893
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
894
+ else:
895
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
896
+
897
+ logging.set_verbosity_error() # don't show annoying warning
898
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
899
+ logging.set_verbosity_warning()
900
+
901
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
902
+ print("loading text encoder:", info)
903
+
904
+ return text_model, vae, unet
905
+
906
+
907
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
908
+ def convert_key(key):
909
+ # position_idsの除去
910
+ if ".position_ids" in key:
911
+ return None
912
+
913
+ # common
914
+ key = key.replace("text_model.encoder.", "transformer.")
915
+ key = key.replace("text_model.", "")
916
+ if "layers" in key:
917
+ # resblocks conversion
918
+ key = key.replace(".layers.", ".resblocks.")
919
+ if ".layer_norm" in key:
920
+ key = key.replace(".layer_norm", ".ln_")
921
+ elif ".mlp." in key:
922
+ key = key.replace(".fc1.", ".c_fc.")
923
+ key = key.replace(".fc2.", ".c_proj.")
924
+ elif ".self_attn.out_proj" in key:
925
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
926
+ elif ".self_attn." in key:
927
+ key = None # 特殊なので後で処理する
928
+ else:
929
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
930
+ elif ".position_embedding" in key:
931
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
932
+ elif ".token_embedding" in key:
933
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
934
+ elif "final_layer_norm" in key:
935
+ key = key.replace("final_layer_norm", "ln_final")
936
+ return key
937
+
938
+ keys = list(checkpoint.keys())
939
+ new_sd = {}
940
+ for key in keys:
941
+ new_key = convert_key(key)
942
+ if new_key is None:
943
+ continue
944
+ new_sd[new_key] = checkpoint[key]
945
+
946
+ # attnの変換
947
+ for key in keys:
948
+ if "layers" in key and "q_proj" in key:
949
+ # 三つを結合
950
+ key_q = key
951
+ key_k = key.replace("q_proj", "k_proj")
952
+ key_v = key.replace("q_proj", "v_proj")
953
+
954
+ value_q = checkpoint[key_q]
955
+ value_k = checkpoint[key_k]
956
+ value_v = checkpoint[key_v]
957
+ value = torch.cat([value_q, value_k, value_v])
958
+
959
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
960
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
961
+ new_sd[new_key] = value
962
+
963
+ # 最後の層などを捏造するか
964
+ if make_dummy_weights:
965
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
966
+ keys = list(new_sd.keys())
967
+ for key in keys:
968
+ if key.startswith("transformer.resblocks.22."):
969
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
970
+
971
+ # Diffusersに含まれない重みを作っておく
972
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
973
+ new_sd["logit_scale"] = torch.tensor(1)
974
+
975
+ return new_sd
976
+
977
+
978
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
979
+ if ckpt_path is not None:
980
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
981
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
982
+ if checkpoint is None: # safetensors または state_dictのckpt
983
+ checkpoint = {}
984
+ strict = False
985
+ else:
986
+ strict = True
987
+ if "state_dict" in state_dict:
988
+ del state_dict["state_dict"]
989
+ else:
990
+ # 新しく作る
991
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
992
+ checkpoint = {}
993
+ state_dict = {}
994
+ strict = False
995
+
996
+ def update_sd(prefix, sd):
997
+ for k, v in sd.items():
998
+ key = prefix + k
999
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1000
+ if save_dtype is not None:
1001
+ v = v.detach().clone().to("cpu").to(save_dtype)
1002
+ state_dict[key] = v
1003
+
1004
+ # Convert the UNet model
1005
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1006
+ update_sd("model.diffusion_model.", unet_state_dict)
1007
+
1008
+ # Convert the text encoder model
1009
+ if v2:
1010
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1011
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1012
+ update_sd("cond_stage_model.model.", text_enc_dict)
1013
+ else:
1014
+ text_enc_dict = text_encoder.state_dict()
1015
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1016
+
1017
+ # Convert the VAE
1018
+ if vae is not None:
1019
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1020
+ update_sd("first_stage_model.", vae_dict)
1021
+
1022
+ # Put together new checkpoint
1023
+ key_count = len(state_dict.keys())
1024
+ new_ckpt = {"state_dict": state_dict}
1025
+
1026
+ # epoch and global_step are sometimes not int
1027
+ try:
1028
+ if "epoch" in checkpoint:
1029
+ epochs += checkpoint["epoch"]
1030
+ if "global_step" in checkpoint:
1031
+ steps += checkpoint["global_step"]
1032
+ except:
1033
+ pass
1034
+
1035
+ new_ckpt["epoch"] = epochs
1036
+ new_ckpt["global_step"] = steps
1037
+
1038
+ if is_safetensors(output_file):
1039
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1040
+ save_file(state_dict, output_file)
1041
+ else:
1042
+ torch.save(new_ckpt, output_file)
1043
+
1044
+ return key_count
1045
+
1046
+
1047
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1048
+ if pretrained_model_name_or_path is None:
1049
+ # load default settings for v1/v2
1050
+ if v2:
1051
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1052
+ else:
1053
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1054
+
1055
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1056
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1057
+ if vae is None:
1058
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1059
+
1060
+ pipeline = StableDiffusionPipeline(
1061
+ unet=unet,
1062
+ text_encoder=text_encoder,
1063
+ vae=vae,
1064
+ scheduler=scheduler,
1065
+ tokenizer=tokenizer,
1066
+ safety_checker=None,
1067
+ feature_extractor=None,
1068
+ requires_safety_checker=None,
1069
+ )
1070
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1071
+
1072
+
1073
+ VAE_PREFIX = "first_stage_model."
1074
+
1075
+
1076
+ def load_vae(vae_id, dtype):
1077
+ print(f"load VAE: {vae_id}")
1078
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1079
+ # Diffusers local/remote
1080
+ try:
1081
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1082
+ except EnvironmentError as e:
1083
+ print(f"exception occurs in loading vae: {e}")
1084
+ print("retry with subfolder='vae'")
1085
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1086
+ return vae
1087
+
1088
+ # local
1089
+ vae_config = create_vae_diffusers_config()
1090
+
1091
+ if vae_id.endswith(".bin"):
1092
+ # SD 1.5 VAE on Huggingface
1093
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1094
+ else:
1095
+ # StableDiffusion
1096
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
1097
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
1098
+
1099
+ # vae only or full model
1100
+ full_model = False
1101
+ for vae_key in vae_sd:
1102
+ if vae_key.startswith(VAE_PREFIX):
1103
+ full_model = True
1104
+ break
1105
+ if not full_model:
1106
+ sd = {}
1107
+ for key, value in vae_sd.items():
1108
+ sd[VAE_PREFIX + key] = value
1109
+ vae_sd = sd
1110
+ del sd
1111
+
1112
+ # Convert the VAE model.
1113
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1114
+
1115
+ vae = AutoencoderKL(**vae_config)
1116
+ vae.load_state_dict(converted_vae_checkpoint)
1117
+ return vae
1118
+
1119
+
1120
+ # endregion
1121
+
1122
+
1123
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1124
+ max_width, max_height = max_reso
1125
+ max_area = (max_width // divisible) * (max_height // divisible)
1126
+
1127
+ resos = set()
1128
+
1129
+ size = int(math.sqrt(max_area)) * divisible
1130
+ resos.add((size, size))
1131
+
1132
+ size = min_size
1133
+ while size <= max_size:
1134
+ width = size
1135
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1136
+ resos.add((width, height))
1137
+ resos.add((height, width))
1138
+
1139
+ # # make additional resos
1140
+ # if width >= height and width - divisible >= min_size:
1141
+ # resos.add((width - divisible, height))
1142
+ # resos.add((height, width - divisible))
1143
+ # if height >= width and height - divisible >= min_size:
1144
+ # resos.add((width, height - divisible))
1145
+ # resos.add((height - divisible, width))
1146
+
1147
+ size += divisible
1148
+
1149
+ resos = list(resos)
1150
+ resos.sort()
1151
+ return resos
1152
+
1153
+
1154
+ if __name__ == "__main__":
1155
+ resos = make_bucket_resolutions((512, 768))
1156
+ print(len(resos))
1157
+ print(resos)
1158
+ aspect_ratios = [w / h for w, h in resos]
1159
+ print(aspect_ratios)
1160
+
1161
+ ars = set()
1162
+ for ar in aspect_ratios:
1163
+ if ar in ars:
1164
+ print("error! duplicate ar:", ar)
1165
+ ars.add(ar)
library/resize_lora_gui.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ from .common_gui import get_saveasfilename_path, get_file_path
6
+
7
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
8
+ folder_symbol = '\U0001f4c2' # 📂
9
+ refresh_symbol = '\U0001f504' # 🔄
10
+ save_style_symbol = '\U0001f4be' # 💾
11
+ document_symbol = '\U0001F4C4' # 📄
12
+
13
+
14
+ def resize_lora(
15
+ model,
16
+ new_rank,
17
+ save_to,
18
+ save_precision,
19
+ device,
20
+ dynamic_method,
21
+ dynamic_param,
22
+ verbose,
23
+ ):
24
+ # Check for caption_text_input
25
+ if model == '':
26
+ msgbox('Invalid model file')
27
+ return
28
+
29
+ # Check if source model exist
30
+ if not os.path.isfile(model):
31
+ msgbox('The provided model is not a file')
32
+ return
33
+
34
+ if dynamic_method == 'sv_ratio':
35
+ if float(dynamic_param) < 2:
36
+ msgbox(
37
+ f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
38
+ )
39
+ return
40
+
41
+ if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
42
+ if float(dynamic_param) < 0 or float(dynamic_param) > 1:
43
+ msgbox(
44
+ f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
45
+ )
46
+ return
47
+
48
+ # Check if save_to end with one of the defines extension. If not add .safetensors.
49
+ if not save_to.endswith(('.pt', '.safetensors')):
50
+ save_to += '.safetensors'
51
+
52
+ if device == '':
53
+ device = 'cuda'
54
+
55
+ run_cmd = f'{PYTHON} "{os.path.join("networks","resize_lora.py")}"'
56
+ run_cmd += f' --save_precision {save_precision}'
57
+ run_cmd += f' --save_to "{save_to}"'
58
+ run_cmd += f' --model "{model}"'
59
+ run_cmd += f' --new_rank {new_rank}'
60
+ run_cmd += f' --device {device}'
61
+ if not dynamic_method == 'None':
62
+ run_cmd += f' --dynamic_method {dynamic_method}'
63
+ run_cmd += f' --dynamic_param {dynamic_param}'
64
+ if verbose:
65
+ run_cmd += f' --verbose'
66
+
67
+ print(run_cmd)
68
+
69
+ # Run the command
70
+ if os.name == 'posix':
71
+ os.system(run_cmd)
72
+ else:
73
+ subprocess.run(run_cmd)
74
+
75
+
76
+ ###
77
+ # Gradio UI
78
+ ###
79
+
80
+
81
+ def gradio_resize_lora_tab():
82
+ with gr.Tab('Resize LoRA'):
83
+ gr.Markdown('This utility can resize a LoRA.')
84
+
85
+ lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
86
+ lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
87
+
88
+ with gr.Row():
89
+ model = gr.Textbox(
90
+ label='Source LoRA',
91
+ placeholder='Path to the LoRA to resize',
92
+ interactive=True,
93
+ )
94
+ button_lora_a_model_file = gr.Button(
95
+ folder_symbol, elem_id='open_folder_small'
96
+ )
97
+ button_lora_a_model_file.click(
98
+ get_file_path,
99
+ inputs=[model, lora_ext, lora_ext_name],
100
+ outputs=model,
101
+ show_progress=False,
102
+ )
103
+ with gr.Row():
104
+ new_rank = gr.Slider(
105
+ label='Desired LoRA rank',
106
+ minimum=1,
107
+ maximum=1024,
108
+ step=1,
109
+ value=4,
110
+ interactive=True,
111
+ )
112
+
113
+ with gr.Row():
114
+ dynamic_method = gr.Dropdown(
115
+ choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'],
116
+ value='sv_fro',
117
+ label='Dynamic method',
118
+ interactive=True,
119
+ )
120
+ dynamic_param = gr.Textbox(
121
+ label='Dynamic parameter',
122
+ value='0.9',
123
+ interactive=True,
124
+ placeholder='Value for the dynamic method selected.',
125
+ )
126
+ verbose = gr.Checkbox(label='Verbose', value=False)
127
+ with gr.Row():
128
+ save_to = gr.Textbox(
129
+ label='Save to',
130
+ placeholder='path for the LoRA file to save...',
131
+ interactive=True,
132
+ )
133
+ button_save_to = gr.Button(
134
+ folder_symbol, elem_id='open_folder_small'
135
+ )
136
+ button_save_to.click(
137
+ get_saveasfilename_path,
138
+ inputs=[save_to, lora_ext, lora_ext_name],
139
+ outputs=save_to,
140
+ show_progress=False,
141
+ )
142
+ save_precision = gr.Dropdown(
143
+ label='Save precision',
144
+ choices=['fp16', 'bf16', 'float'],
145
+ value='fp16',
146
+ interactive=True,
147
+ )
148
+ device = gr.Dropdown(
149
+ label='Device',
150
+ choices=[
151
+ 'cpu',
152
+ 'cuda',
153
+ ],
154
+ value='cuda',
155
+ interactive=True,
156
+ )
157
+
158
+ convert_button = gr.Button('Resize model')
159
+
160
+ convert_button.click(
161
+ resize_lora,
162
+ inputs=[
163
+ model,
164
+ new_rank,
165
+ save_to,
166
+ save_precision,
167
+ device,
168
+ dynamic_method,
169
+ dynamic_param,
170
+ verbose,
171
+ ],
172
+ show_progress=False,
173
+ )
library/sampler_gui.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ import gradio as gr
4
+ from easygui import msgbox
5
+
6
+ folder_symbol = '\U0001f4c2' # 📂
7
+ refresh_symbol = '\U0001f504' # 🔄
8
+ save_style_symbol = '\U0001f4be' # 💾
9
+ document_symbol = '\U0001F4C4' # 📄
10
+
11
+
12
+ ###
13
+ ### Gradio common sampler GUI section
14
+ ###
15
+
16
+
17
+ def sample_gradio_config():
18
+ with gr.Accordion('Sample images config', open=False):
19
+ with gr.Row():
20
+ sample_every_n_steps = gr.Number(
21
+ label='Sample every n steps',
22
+ value=0,
23
+ precision=0,
24
+ interactive=True,
25
+ )
26
+ sample_every_n_epochs = gr.Number(
27
+ label='Sample every n epochs',
28
+ value=0,
29
+ precision=0,
30
+ interactive=True,
31
+ )
32
+ sample_sampler = gr.Dropdown(
33
+ label='Sample sampler',
34
+ choices=[
35
+ 'ddim',
36
+ 'pndm',
37
+ 'lms',
38
+ 'euler',
39
+ 'euler_a',
40
+ 'heun',
41
+ 'dpm_2',
42
+ 'dpm_2_a',
43
+ 'dpmsolver',
44
+ 'dpmsolver++',
45
+ 'dpmsingle',
46
+ 'k_lms',
47
+ 'k_euler',
48
+ 'k_euler_a',
49
+ 'k_dpm_2',
50
+ 'k_dpm_2_a',
51
+ ],
52
+ value='euler_a',
53
+ interactive=True,
54
+ )
55
+ with gr.Row():
56
+ sample_prompts = gr.Textbox(
57
+ lines=5,
58
+ label='Sample prompts',
59
+ interactive=True,
60
+ placeholder='masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28',
61
+ )
62
+ return (
63
+ sample_every_n_steps,
64
+ sample_every_n_epochs,
65
+ sample_sampler,
66
+ sample_prompts,
67
+ )
68
+
69
+
70
+ def run_cmd_sample(
71
+ sample_every_n_steps,
72
+ sample_every_n_epochs,
73
+ sample_sampler,
74
+ sample_prompts,
75
+ output_dir,
76
+ ):
77
+ output_dir = os.path.join(output_dir, 'sample')
78
+
79
+ if not os.path.exists(output_dir):
80
+ os.makedirs(output_dir)
81
+
82
+ run_cmd = ''
83
+
84
+ if sample_every_n_epochs == 0 and sample_every_n_steps == 0:
85
+ return run_cmd
86
+
87
+ # Create the prompt file and get its path
88
+ sample_prompts_path = os.path.join(output_dir, 'prompt.txt')
89
+
90
+ with open(sample_prompts_path, 'w') as f:
91
+ f.write(sample_prompts)
92
+
93
+ run_cmd += f' --sample_sampler={sample_sampler}'
94
+ run_cmd += f' --sample_prompts="{sample_prompts_path}"'
95
+
96
+ if not sample_every_n_epochs == 0:
97
+ run_cmd += f' --sample_every_n_epochs="{sample_every_n_epochs}"'
98
+
99
+ if not sample_every_n_steps == 0:
100
+ run_cmd += f' --sample_every_n_steps="{sample_every_n_steps}"'
101
+
102
+ return run_cmd
library/svd_merge_lora_gui.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from easygui import msgbox
3
+ import subprocess
4
+ import os
5
+ from .common_gui import (
6
+ get_saveasfilename_path,
7
+ get_any_file_path,
8
+ get_file_path,
9
+ )
10
+
11
+ folder_symbol = '\U0001f4c2' # 📂
12
+ refresh_symbol = '\U0001f504' # 🔄
13
+ save_style_symbol = '\U0001f4be' # 💾
14
+ document_symbol = '\U0001F4C4' # 📄
15
+ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
16
+
17
+
18
+ def svd_merge_lora(
19
+ lora_a_model,
20
+ lora_b_model,
21
+ ratio,
22
+ save_to,
23
+ precision,
24
+ save_precision,
25
+ new_rank,
26
+ new_conv_rank,
27
+ device,
28
+ ):
29
+ # Check for caption_text_input
30
+ if lora_a_model == '':
31
+ msgbox('Invalid model A file')
32
+ return
33
+
34
+ if lora_b_model == '':
35
+ msgbox('Invalid model B file')
36
+ return
37
+
38
+ # Check if source model exist
39
+ if not os.path.isfile(lora_a_model):
40
+ msgbox('The provided model A is not a file')
41
+ return
42
+
43
+ if not os.path.isfile(lora_b_model):
44
+ msgbox('The provided model B is not a file')
45
+ return
46
+
47
+ ratio_a = ratio
48
+ ratio_b = 1 - ratio
49
+
50
+ run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"'
51
+ run_cmd += f' --save_precision {save_precision}'
52
+ run_cmd += f' --precision {precision}'
53
+ run_cmd += f' --save_to "{save_to}"'
54
+ run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"'
55
+ run_cmd += f' --ratios {ratio_a} {ratio_b}'
56
+ run_cmd += f' --device {device}'
57
+ run_cmd += f' --new_rank "{new_rank}"'
58
+ run_cmd += f' --new_conv_rank "{new_conv_rank}"'
59
+
60
+ print(run_cmd)
61
+
62
+ # Run the command
63
+ if os.name == 'posix':
64
+ os.system(run_cmd)
65
+ else:
66
+ subprocess.run(run_cmd)
67
+
68
+
69
+ ###
70
+ # Gradio UI
71
+ ###
72
+
73
+
74
+ def gradio_svd_merge_lora_tab():
75
+ with gr.Tab('Merge LoRA (SVD)'):
76
+ gr.Markdown('This utility can merge two LoRA networks together.')
77
+
78
+ lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
79
+ lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
80
+
81
+ with gr.Row():
82
+ lora_a_model = gr.Textbox(
83
+ label='LoRA model "A"',
84
+ placeholder='Path to the LoRA A model',
85
+ interactive=True,
86
+ )
87
+ button_lora_a_model_file = gr.Button(
88
+ folder_symbol, elem_id='open_folder_small'
89
+ )
90
+ button_lora_a_model_file.click(
91
+ get_file_path,
92
+ inputs=[lora_a_model, lora_ext, lora_ext_name],
93
+ outputs=lora_a_model,
94
+ show_progress=False,
95
+ )
96
+
97
+ lora_b_model = gr.Textbox(
98
+ label='LoRA model "B"',
99
+ placeholder='Path to the LoRA B model',
100
+ interactive=True,
101
+ )
102
+ button_lora_b_model_file = gr.Button(
103
+ folder_symbol, elem_id='open_folder_small'
104
+ )
105
+ button_lora_b_model_file.click(
106
+ get_file_path,
107
+ inputs=[lora_b_model, lora_ext, lora_ext_name],
108
+ outputs=lora_b_model,
109
+ show_progress=False,
110
+ )
111
+ with gr.Row():
112
+ ratio = gr.Slider(
113
+ label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
114
+ minimum=0,
115
+ maximum=1,
116
+ step=0.01,
117
+ value=0.5,
118
+ interactive=True,
119
+ )
120
+ new_rank = gr.Slider(
121
+ label='New Rank',
122
+ minimum=1,
123
+ maximum=1024,
124
+ step=1,
125
+ value=128,
126
+ interactive=True,
127
+ )
128
+ new_conv_rank = gr.Slider(
129
+ label='New Conv Rank',
130
+ minimum=1,
131
+ maximum=1024,
132
+ step=1,
133
+ value=128,
134
+ interactive=True,
135
+ )
136
+
137
+ with gr.Row():
138
+ save_to = gr.Textbox(
139
+ label='Save to',
140
+ placeholder='path for the file to save...',
141
+ interactive=True,
142
+ )
143
+ button_save_to = gr.Button(
144
+ folder_symbol, elem_id='open_folder_small'
145
+ )
146
+ button_save_to.click(
147
+ get_saveasfilename_path,
148
+ inputs=[save_to, lora_ext, lora_ext_name],
149
+ outputs=save_to,
150
+ show_progress=False,
151
+ )
152
+ precision = gr.Dropdown(
153
+ label='Merge precision',
154
+ choices=['fp16', 'bf16', 'float'],
155
+ value='float',
156
+ interactive=True,
157
+ )
158
+ save_precision = gr.Dropdown(
159
+ label='Save precision',
160
+ choices=['fp16', 'bf16', 'float'],
161
+ value='float',
162
+ interactive=True,
163
+ )
164
+ device = gr.Dropdown(
165
+ label='Device',
166
+ choices=[
167
+ 'cpu',
168
+ 'cuda',
169
+ ],
170
+ value='cuda',
171
+ interactive=True,
172
+ )
173
+
174
+ convert_button = gr.Button('Merge model')
175
+
176
+ convert_button.click(
177
+ svd_merge_lora,
178
+ inputs=[
179
+ lora_a_model,
180
+ lora_b_model,
181
+ ratio,
182
+ save_to,
183
+ precision,
184
+ save_precision,
185
+ new_rank,
186
+ new_conv_rank,
187
+ device,
188
+ ],
189
+ show_progress=False,
190
+ )