diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b3a0a02d27a2b369b4b2647aaae42765dd390c24 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/typos.yaml b/.github/workflows/typos.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a7c2a8f0646a6421b114be99cc84e4ec8becfeab --- /dev/null +++ b/.github/workflows/typos.yaml @@ -0,0 +1,21 @@ +--- +# yamllint disable rule:line-length +name: Typos + +on: # yamllint disable-line rule:truthy + push: + pull_request: + types: + - opened + - synchronize + - reopened + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: typos-action + uses: crate-ci/typos@v1.13.10 \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..71fe116321a0413bdff0efd1d4ac92fec4c14392 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +venv +__pycache__ +cudnn_windows +.vscode +*.egg-info +build +wd14_tagger_model +.DS_Store +locon +gui-user.bat +gui-user.ps1 \ No newline at end of file diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..56765e795c61de9b5941ba9d9c6379f0b7922203 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2022] [kohya-ss] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 6e0e69f27c789fc781df413326ec81ea35a0159f..4bf57aa2a17531fb4070bed8fec67efe2bb25423 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,21 @@ --- -title: Kohya Ss Colab -emoji: 📈 -colorFrom: indigo -colorTo: gray +title: kohya_ss_colab +app_file: dreambooth_gui.py sdk: gradio -sdk_version: 5.49.0 -app_file: app.py -pinned: false +sdk_version: 5.47.2 --- +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb) -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Kohya SS WebUI Colab Setup + +This Colab workbook sets up a Kohya SS instance on Colab and provides a link to access the Kohya WebUI on Gradio Live. Kohya SS is a Python library that provides Stable Diffusion-based models for image, text, and audio generation tasks. This Colab workbook provides a convenient way for users to run Kohya SS without needing to install anything on their local machine. + +This workbook was inspired by the work of [Spaceginner](https://github.com/Spaceginner)'s original Colab workbook and the [Kohya SS project](https://github.com/bmaltais/kohya_ss) by [bmaltais](https://github.com/bmaltais). The Colab workbook was coded by [panguin6010](https://github.com/panguin6010) + + +## Tutorials + +Before running this code, make sure you are familiar with using Colab workbooks and have a basic understanding of Kohya SS and its usage. You can find tutorials for these online. If you encounter any issues or have suggestions for improvement, feel free to contribute to the project. + +## Link +```https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb``` \ No newline at end of file diff --git a/XTI_hijack.py b/XTI_hijack.py new file mode 100644 index 0000000000000000000000000000000000000000..f39cc8e7e8564ea34507a5ca63056c0fb001cf1f --- /dev/null +++ b/XTI_hijack.py @@ -0,0 +1,209 @@ +import torch +from typing import Union, List, Optional, Dict, Any, Tuple +from diffusers.models.unet_2d_condition import UNet2DConditionOutput + +def unet_forward_XTI(self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + down_i = 0 + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], + ) + down_i += 2 + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) + + # 5. up + up_i = 7 + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], + upsample_size=upsample_size, + ) + up_i += 3 + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + +def downblock_forward_XTI( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None +): + output_states = () + i = 0 + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample + + output_states += (hidden_states,) + i += 1 + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + +def upblock_forward_XTI( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, +): + i = 0 + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample + + i += 1 + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/_typos.toml b/_typos.toml new file mode 100644 index 0000000000000000000000000000000000000000..4902a59b4270e00984877f3637ce9a06a0dd85b5 --- /dev/null +++ b/_typos.toml @@ -0,0 +1,15 @@ +# Files for typos +# Instruction: https://github.com/marketplace/actions/typos-action#getting-started + +[default.extend-identifiers] + +[default.extend-words] +NIN="NIN" +parms="parms" +nin="nin" +extention="extention" # Intentionally left +nd="nd" + + +[files] +extend-exclude = ["_typos.toml"] \ No newline at end of file diff --git a/cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 b/cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 new file mode 100644 index 0000000000000000000000000000000000000000..8f0e467b5313aba1138c91ba7a4919dab2a68815 --- /dev/null +++ b/cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c791d1f047b41ff5885772fc4bf20b797c6059bbd82abb9e31de15e55d6a57c4 +size 11907224 diff --git a/config_README-ja.md b/config_README-ja.md new file mode 100644 index 0000000000000000000000000000000000000000..7f2b6c4c1e3859a6ce79a4b6ece2174b430d1d20 --- /dev/null +++ b/config_README-ja.md @@ -0,0 +1,279 @@ +For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. + +`--dataset_config` で枡すこずができる蚭定ファむルに関する説明です。 + +## 抂芁 + +蚭定ファむルを枡すこずにより、ナヌザが现かい蚭定を行えるようにしたす。 + +* 耇数のデヌタセットが蚭定可胜になりたす + * 䟋えば `resolution` をデヌタセットごずに蚭定しお、それらを混合しお孊習できたす。 + * DreamBooth の手法ず fine tuning の手法の䞡方に察応しおいる孊習方法では、DreamBooth 方匏ず fine tuning 方匏のデヌタセットを混合するこずが可胜です。 +* サブセットごずに蚭定を倉曎するこずが可胜になりたす + * デヌタセットを画像ディレクトリ別たたはメタデヌタ別に分割したものがサブセットです。いく぀かのサブセットが集たっおデヌタセットを構成したす。 + * `keep_tokens` や `flip_aug` 等のオプションはサブセットごずに蚭定可胜です。䞀方、`resolution` や `batch_size` ずいったオプションはデヌタセットごずに蚭定可胜で、同じデヌタセットに属するサブセットでは倀が共通になりたす。詳しくは埌述したす。 + +蚭定ファむルの圢匏は JSON か TOML を利甚できたす。蚘述のしやすさを考えるず [TOML](https://toml.io/ja/v1.0.0-rc.2) を利甚するのがオススメです。以䞋、TOML の利甚を前提に説明したす。 + +TOML で蚘述した蚭定ファむルの䟋です。 + +```toml +[general] +shuffle_caption = true +caption_extension = '.txt' +keep_tokens = 1 + +# これは DreamBooth 方匏のデヌタセット +[[datasets]] +resolution = 512 +batch_size = 4 +keep_tokens = 2 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + class_tokens = 'hoge girl' + # このサブセットは keep_tokens = 2 所属する datasets の倀が䜿われる + + [[datasets.subsets]] + image_dir = 'C:\fuga' + class_tokens = 'fuga boy' + keep_tokens = 3 + + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' + class_tokens = 'human' + keep_tokens = 1 + +# これは fine tuning 方匏のデヌタセット +[[datasets]] +resolution = [768, 768] +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'C:\piyo' + metadata_file = 'C:\piyo\piyo_md.json' + # このサブセットは keep_tokens = 1 general の倀が䜿われる +``` + +この䟋では、3 ぀のディレクトリを DreamBooth 方匏のデヌタセットずしお 512x512 (batch size 4) で孊習させ、1 ぀のディレクトリを fine tuning 方匏のデヌタセットずしお 768x768 (batch size 2) で孊習させるこずになりたす。 + +## デヌタセット・サブセットに関する蚭定 + +デヌタセット・サブセットに関する蚭定は、登録可胜な箇所がいく぀かに分かれおいたす。 + +* `[general]` + * 党デヌタセットたたは党サブセットに適甚されるオプションを指定する箇所です。 + * デヌタセットごずの蚭定及びサブセットごずの蚭定に同名のオプションが存圚しおいた堎合には、デヌタセット・サブセットごずの蚭定が優先されたす。 +* `[[datasets]]` + * `datasets` はデヌタセットに関する蚭定の登録箇所になりたす。各デヌタセットに個別に適甚されるオプションを指定する箇所です。 + * サブセットごずの蚭定が存圚しおいた堎合には、サブセットごずの蚭定が優先されたす。 +* `[[datasets.subsets]]` + * `datasets.subsets` はサブセットに関する蚭定の登録箇所になりたす。各サブセットに個別に適甚されるオプションを指定する箇所です。 + +先皋の䟋における、画像ディレクトリず登録箇所の察応に関するむメヌゞ図です。 + +``` +C:\ +├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐ +├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general] +├─ reg -> [[datasets.subsets]] No.3 ┘ | +└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘ +``` + +画像ディレクトリがそれぞれ1぀の `[[datasets.subsets]]` に察応しおいたす。そしお `[[datasets.subsets]]` が1぀以䞊組み合わさっお1぀の `[[datasets]]` を構成したす。`[general]` には党おの `[[datasets]]`, `[[datasets.subsets]]` が属したす。 + +登録箇所ごずに指定可胜なオプションは異なりたすが、同名のオプションが指定された堎合は䞋䜍の登録箇所にある倀が優先されたす。先皋の䟋の `keep_tokens` オプションの扱われ方を確認しおもらうず理解しやすいかず思いたす。 + +加えお、孊習方法が察応しおいる手法によっおも指定可胜なオプションが倉化したす。 + +* DreamBooth 方匏専甚のオプション +* fine tuning 方匏専甚のオプション +* caption dropout の手法が䜿える堎合のオプション + +DreamBooth の手法ず fine tuning の手法の䞡方ずも利甚可胜な孊習方法では、䞡者を䜵甚するこずができたす。 +䜵甚する際の泚意点ずしお、DreamBooth 方匏なのか fine tuning 方匏なのかはデヌタセット単䜍で刀別を行っおいるため、同じデヌタセット䞭に DreamBooth 方匏のサブセットず fine tuning 方匏のサブセットを混圚させるこずはできたせん。 +぀たり、これらを䜵甚したい堎合には異なる方匏のサブセットが異なるデヌタセットに所属するように蚭定する必芁がありたす。 + +プログラムの挙動ずしおは、埌述する `metadata_file` オプションが存圚しおいたら fine tuning 方匏のサブセットだず刀断したす。 +そのため、同䞀のデヌタセットに所属するサブセットに぀いお蚀うず、「党おが `metadata_file` オプションを持぀」か「党おが `metadata_file` オプションを持たない」かのどちらかになっおいれば問題ありたせん。 + +以䞋、利甚可胜なオプションを説明したす。コマンドラむン匕数ず名称が同䞀のオプションに぀いおは、基本的に説明を割愛したす。他の README を参照しおください。 + +### 党孊習方法で共通のオプション + +孊習方法によらずに指定可胜なオプションです。 + +#### デヌタセット向けオプション + +デヌタセットの蚭定に関わるオプションです。`datasets.subsets` には蚘述できたせん。 + +| オプション名 | 蚭定䟋 | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * コマンドラむン匕数の `--train_batch_size` ず同等です。 + +これらの蚭定はデヌタセットごずに固定です。 +぀たり、デヌタセットに所属するサブセットはこれらの蚭定を共有するこずになりたす。 +䟋えば解像床が異なるデヌタセットを甚意したい堎合は、䞊に挙げた䟋のように別々のデヌタセットずしお定矩すれば別々の解像床を蚭定可胜です。 + +#### サブセット向けオプション + +サブセットの蚭定に関わるオプションです。 + +| オプション名 | 蚭定䟋 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | + +* `num_repeats` + * サブセットの画像の繰り返し回数を指定したす。fine tuning における `--dataset_repeats` に盞圓したすが、`num_repeats` はどの孊習方法でも指定可胜です。 + +### DreamBooth 方匏専甚のオプション + +DreamBooth 方匏のオプションは、サブセット向けオプションのみ存圚したす。 + +#### サブセット向けオプション + +DreamBooth 方匏のサブセットの蚭定に関わるオプションです。 + +| オプション名 | 蚭定䟋 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o必須 | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `“sks girl”` | - | - | o | +| `is_reg` | `false` | - | - | o | + +たず泚意点ずしお、 `image_dir` には画像ファむルが盎䞋に眮かれおいるパスを指定する必芁がありたす。埓来の DreamBooth の手法ではサブディレクトリに画像を眮く必芁がありたしたが、そちらずは仕様に互換性がありたせん。たた、`5_cat` のようなフォルダ名にしおも、画像の繰り返し回数ずクラス名は反映されたせん。これらを個別に蚭定したい堎合、`num_repeats` ず `class_tokens` で明瀺的に指定する必芁があるこずに泚意しおください。 + +* `image_dir` + * 画像ディレクトリのパスを指定したす。指定必須オプションです。 + * 画像はディレクトリ盎䞋に眮かれおいる必芁がありたす。 +* `class_tokens` + * クラストヌクンを蚭定したす。 + * 画像に察応する caption ファむルが存圚しない堎合にのみ孊習時に利甚されたす。利甚するかどうかの刀定は画像ごずに行いたす。`class_tokens` を指定しなかった堎合に caption ファむルも芋぀からなかった堎合にぱラヌになりたす。 +* `is_reg` + * サブセットの画像が正芏化甚かどうかを指定したす。指定しなかった堎合は `false` ずしお、぀たり正芏化画像ではないずしお扱いたす。 + +### fine tuning 方匏専甚のオプション + +fine tuning 方匏のオプションは、サブセット向けオプションのみ存圚したす。 + +#### サブセット向けオプション + +fine tuning 方匏のサブセットの蚭定に関わるオプションです。 + +| オプション名 | 蚭定䟋 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o必須 | + +* `image_dir` + * 画像ディレクトリのパスを指定したす。DreamBooth の手法の方ずは異なり指定は必須ではありたせんが、蚭定するこずを掚奚したす。 + * 指定する必芁がない状況ずしおは、メタデヌタファむルの生成時に `--full_path` を付䞎しお実行しおいた堎合です。 + * 画像はディレクトリ盎䞋に眮かれおいる必芁がありたす。 +* `metadata_file` + * サブセットで利甚されるメタデヌタファむルのパスを指定したす。指定必須オプションです。 + * コマンドラむン匕数の `--in_json` ず同等です。 + * サブセットごずにメタデヌタファむルを指定する必芁がある仕様䞊、ディレクトリを跚いだメタデヌタを1぀のメタデヌタファむルずしお䜜成するこずは避けた方が良いでしょう。画像ディレクトリごずにメタデヌタファむルを甚意し、それらを別々のサブセットずしお登録するこずを匷く掚奚したす。 + +### caption dropout の手法が䜿える堎合に指定可胜なオプション + +caption dropout の手法が䜿える堎合のオプションは、サブセット向けオプションのみ存圚したす。 +DreamBooth 方匏か fine tuning 方匏かに関わらず、caption dropout に察応しおいる孊習方法であれば指定可胜です。 + +#### サブセット向けオプション + +caption dropout が䜿えるサブセットの蚭定に関わるオプションです。 + +| オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## 重耇したサブセットが存圚する時の挙動 + +DreamBooth 方匏のデヌタセットの堎合、その䞭にある `image_dir` が同䞀のサブセットは重耇しおいるず芋なされたす。 +fine tuning 方匏のデヌタセットの堎合は、その䞭にある `metadata_file` が同䞀のサブセットは重耇しおいるず芋なされたす。 +デヌタセット䞭に重耇したサブセットが存圚する堎合、2個目以降は無芖されたす。 + +䞀方、異なるデヌタセットに所属しおいる堎合は、重耇しおいるずは芋なされたせん。 +䟋えば、以䞋のように同䞀の `image_dir` を持぀サブセットを別々のデヌタセットに入れた堎合には、重耇しおいないず芋なしたす。 +これは、同じ画像でも異なる解像床で孊習したい堎合に圹立ちたす。 + +```toml +# 別々のデヌタセットに存圚しおいる堎合は重耇ずは芋なされず、䞡方ずも孊習に䜿われる + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## コマンドラむン匕数ずの䜵甚 + +蚭定ファむルのオプションの䞭には、コマンドラむン匕数のオプションず圹割が重耇しおいるものがありたす。 + +以䞋に挙げるコマンドラむン匕数のオプションは、蚭定ファむルを枡した堎合には無芖されたす。 + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +以䞋に挙げるコマンドラむン匕数のオプションは、コマンドラむン匕数ず蚭定ファむルで同時に指定された堎合、コマンドラむン匕数の倀よりも蚭定ファむルの倀が優先されたす。特に断りがなければ同名のオプションずなりたす。 + +| コマンドラむン匕数のオプション | 優先される蚭定ファむルのオプション | +| ---------------------------------- | ---------------------------------- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop` | | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## ゚ラヌの手匕き + +珟圚、倖郚ラむブラリを利甚しお蚭定ファむルの蚘述が正しいかどうかをチェックしおいるのですが、敎備が行き届いおおらず゚ラヌメッセヌゞがわかりづらいずいう問題がありたす。 +将来的にはこの問題の改善に取り組む予定です。 + +次善策ずしお、頻出の゚ラヌずその察凊法に぀いお茉せおおきたす。 +正しいはずなのに゚ラヌが出る堎合、゚ラヌ内容がどうしおも分からない堎合は、バグかもしれないのでご連絡ください。 + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されおいないずいう゚ラヌです。指定を忘れおいるか、オプション名を間違っお蚘述しおいる可胜性が高いです。 + * `...` の箇所にぱラヌが発生した堎所が茉っおいたす。䟋えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のような゚ラヌが出たら、0 番目の `datasets` 䞭の 0 番目の `subsets` の蚭定に `image_dir` が存圚しないずいうこずになりたす。 +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する倀の圢匏が䞍正ずいう゚ラヌです。倀の圢匏が間違っおいる可胜性が高いです。`int` の郚分は察象ずなるオプションによっお倉わりたす。この README に茉っおいるオプションの「蚭定䟋」が圹立぀かもしれたせん。 +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 察応しおいないオプション名が存圚しおいる堎合に発生する゚ラヌです。オプション名を間違っお蚘述しおいるか、誀っお玛れ蟌んでいる可胜性が高いです。 + + diff --git a/config_files/accelerate/default_config.yaml b/config_files/accelerate/default_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1198fe6d27943b28edb3c312f3ebfcac59ae164c --- /dev/null +++ b/config_files/accelerate/default_config.yaml @@ -0,0 +1,22 @@ +command_file: null +commands: null +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: 'NO' +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_name: null +tpu_zone: null +use_cpu: false \ No newline at end of file diff --git a/dreambooth_gui.py b/dreambooth_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..e93f96e85ed2af39228f6d9ec8d338d0bd7bf06c --- /dev/null +++ b/dreambooth_gui.py @@ -0,0 +1,944 @@ +# v1: initial release +# v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation +# v3.1: Adding captionning of images to utilities + +import gradio as gr +import json +import math +import os +import subprocess +import pathlib +import argparse +from library.common_gui import ( + get_folder_path, + remove_doublequote, + get_file_path, + get_any_file_path, + get_saveasfile_path, + color_aug_changed, + save_inference_file, + gradio_advanced_training, + run_cmd_advanced_training, + run_cmd_training, + gradio_training, + gradio_config, + gradio_source_model, + # set_legacy_8bitadam, + update_my_data, + check_if_model_exist, +) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample +from easygui import msgbox + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get('label') == 'True' else False + + if save_as_bool: + print('Save as...') + file_path = get_saveasfile_path(file_path) + else: + print('Save...') + if file_path == None or file_path == '': + file_path = get_saveasfile_path(file_path) + + # print(file_path) + + if file_path == None or file_path == '': + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Return the values of the variables as a dictionary + variables = { + name: value + for name, value in parameters # locals().items() + if name + not in [ + 'file_path', + 'save_as', + ] + } + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + # Save the data to the selected file + with open(file_path, 'w') as file: + json.dump(variables, file, indent=2) + + return file_path + + +def open_configuration( + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == '' and not file_path == None: + # load variables from JSON file + with open(file_path, 'r') as f: + my_data = json.load(f) + print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ['ask_for_file', 'file_path']: + values.append(my_data.get(key, value)) + return tuple(values) + + +def train_model( + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + if pretrained_model_name_or_path == '': + msgbox('Source model information is missing') + return + + if train_data_dir == '': + msgbox('Image folder path is missing') + return + + if not os.path.exists(train_data_dir): + msgbox('Image folder does not exist') + return + + if reg_data_dir != '': + if not os.path.exists(reg_data_dir): + msgbox('Regularisation folder does not exist') + return + + if output_dir == '': + msgbox('Output folder path is missing') + return + + if check_if_model_exist(output_name, output_dir, save_model_as): + return + + # Get a list of all subfolders in train_data_dir, excluding hidden folders + subfolders = [ + f + for f in os.listdir(train_data_dir) + if os.path.isdir(os.path.join(train_data_dir, f)) + and not f.startswith('.') + ] + + # Check if subfolders are present. If not let the user know and return + if not subfolders: + print( + '\033[33mNo subfolders were found in', + train_data_dir, + " can't train\...033[0m", + ) + return + + total_steps = 0 + + # Loop through each subfolder and extract the number of repeats + for folder in subfolders: + # Extract the number of repeats from the folder name + try: + repeats = int(folder.split('_')[0]) + except ValueError: + print( + '\033[33mSubfolder', + folder, + "does not have a proper repeat value, please correct the name or remove it... can't train...\033[0m", + ) + continue + + # Count the number of images in the folder + num_images = len( + [ + f + for f, lower_f in ( + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) + ) + if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) + ] + ) + + if num_images == 0: + print(f'{folder} folder contain no images, skipping...') + else: + # Calculate the total number of steps for this folder + steps = repeats * num_images + total_steps += steps + + # Print the result + print('\033[33mFolder', folder, ':', steps, 'steps\033[0m') + + if total_steps == 0: + print( + '\033[33mNo images were found in folder', + train_data_dir, + '... please rectify!\033[0m', + ) + return + + # Print the result + # print(f"{total_steps} total steps") + + if reg_data_dir == '': + reg_factor = 1 + else: + print( + '\033[94mRegularisation images are used... Will double the number of steps required...\033[0m' + ) + reg_factor = 2 + + # calculate max_train_steps + max_train_steps = int( + math.ceil( + float(total_steps) + / int(train_batch_size) + * int(epoch) + * int(reg_factor) + ) + ) + print(f'max_train_steps = {max_train_steps}') + + # calculate stop encoder training + if int(stop_text_encoder_training_pct) == -1: + stop_text_encoder_training = -1 + elif stop_text_encoder_training_pct == None: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + ) + print(f'stop_text_encoder_training = {stop_text_encoder_training}') + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + print(f'lr_warmup_steps = {lr_warmup_steps}') + + run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"' + if v2: + run_cmd += ' --v2' + if v_parameterization: + run_cmd += ' --v_parameterization' + if enable_bucket: + run_cmd += ' --enable_bucket' + if no_token_padding: + run_cmd += ' --no_token_padding' + run_cmd += ( + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' + ) + run_cmd += f' --train_data_dir="{train_data_dir}"' + if len(reg_data_dir): + run_cmd += f' --reg_data_dir="{reg_data_dir}"' + run_cmd += f' --resolution={max_resolution}' + run_cmd += f' --output_dir="{output_dir}"' + run_cmd += f' --logging_dir="{logging_dir}"' + if not stop_text_encoder_training == 0: + run_cmd += ( + f' --stop_text_encoder_training={stop_text_encoder_training}' + ) + if not save_model_as == 'same as source model': + run_cmd += f' --save_model_as={save_model_as}' + # if not resume == '': + # run_cmd += f' --resume={resume}' + if not float(prior_loss_weight) == 1.0: + run_cmd += f' --prior_loss_weight={prior_loss_weight}' + if not vae == '': + run_cmd += f' --vae="{vae}"' + if not output_name == '': + run_cmd += f' --output_name="{output_name}"' + if int(max_token_length) > 75: + run_cmd += f' --max_token_length={max_token_length}' + if not max_train_epochs == '': + run_cmd += f' --max_train_epochs="{max_train_epochs}"' + if not max_data_loader_n_workers == '': + run_cmd += ( + f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' + ) + if int(gradient_accumulation_steps) > 1: + run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' + + run_cmd += run_cmd_training( + learning_rate=learning_rate, + lr_scheduler=lr_scheduler, + lr_warmup_steps=lr_warmup_steps, + train_batch_size=train_batch_size, + max_train_steps=max_train_steps, + save_every_n_epochs=save_every_n_epochs, + mixed_precision=mixed_precision, + save_precision=save_precision, + seed=seed, + caption_extension=caption_extension, + cache_latents=cache_latents, + optimizer=optimizer, + optimizer_args=optimizer_args, + ) + + run_cmd += run_cmd_advanced_training( + max_train_epochs=max_train_epochs, + max_data_loader_n_workers=max_data_loader_n_workers, + max_token_length=max_token_length, + resume=resume, + save_state=save_state, + mem_eff_attn=mem_eff_attn, + clip_skip=clip_skip, + flip_aug=flip_aug, + color_aug=color_aug, + shuffle_caption=shuffle_caption, + gradient_checkpointing=gradient_checkpointing, + full_fp16=full_fp16, + xformers=xformers, + # use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + noise_offset=noise_offset, + additional_parameters=additional_parameters, + vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, + ) + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f'{output_dir}/{output_name}') + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file(output_dir, v2, v_parameterization, output_name) + + +def dreambooth_tab( + train_data_dir=gr.Textbox(), + reg_data_dir=gr.Textbox(), + output_dir=gr.Textbox(), + logging_dir=gr.Textbox(), +): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + gr.Markdown('Train a custom model using kohya dreambooth python code...') + ( + button_open_config, + button_save_config, + button_save_as_config, + config_file_name, + button_load_config, + ) = gradio_config() + + ( + pretrained_model_name_or_path, + v2, + v_parameterization, + save_model_as, + model_list, + ) = gradio_source_model() + + with gr.Tab('Folders'): + with gr.Row(): + train_data_dir = gr.Textbox( + label='Image folder', + placeholder='Folder where the training folders containing the images are located', + ) + train_data_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + train_data_dir_input_folder.click( + get_folder_path, + outputs=train_data_dir, + show_progress=False, + ) + reg_data_dir = gr.Textbox( + label='Regularisation folder', + placeholder='(Optional) Folder where where the regularization folders containing the images are located', + ) + reg_data_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + reg_data_dir_input_folder.click( + get_folder_path, + outputs=reg_data_dir, + show_progress=False, + ) + with gr.Row(): + output_dir = gr.Textbox( + label='Model output folder', + placeholder='Folder to output trained model', + ) + output_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + output_dir_input_folder.click(get_folder_path, outputs=output_dir) + logging_dir = gr.Textbox( + label='Logging folder', + placeholder='Optional: enable logging and output TensorBoard log to this folder', + ) + logging_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + logging_dir_input_folder.click( + get_folder_path, + outputs=logging_dir, + show_progress=False, + ) + with gr.Row(): + output_name = gr.Textbox( + label='Model output name', + placeholder='Name of the model to output', + value='last', + interactive=True, + ) + train_data_dir.change( + remove_doublequote, + inputs=[train_data_dir], + outputs=[train_data_dir], + ) + reg_data_dir.change( + remove_doublequote, + inputs=[reg_data_dir], + outputs=[reg_data_dir], + ) + output_dir.change( + remove_doublequote, + inputs=[output_dir], + outputs=[output_dir], + ) + logging_dir.change( + remove_doublequote, + inputs=[logging_dir], + outputs=[logging_dir], + ) + with gr.Tab('Training parameters'): + ( + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + num_cpu_threads_per_process, + seed, + caption_extension, + cache_latents, + optimizer, + optimizer_args, + ) = gradio_training( + learning_rate_value='1e-5', + lr_scheduler_value='cosine', + lr_warmup_value='10', + ) + with gr.Row(): + max_resolution = gr.Textbox( + label='Max resolution', + value='512,512', + placeholder='512,512', + ) + stop_text_encoder_training = gr.Slider( + minimum=-1, + maximum=100, + value=0, + step=1, + label='Stop text encoder training', + ) + enable_bucket = gr.Checkbox(label='Enable buckets', value=True) + with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + no_token_padding = gr.Checkbox( + label='No token padding', value=False + ) + gradient_accumulation_steps = gr.Number( + label='Gradient accumulate steps', value='1' + ) + with gr.Row(): + prior_loss_weight = gr.Number( + label='Prior loss weight', value=1.0 + ) + vae = gr.Textbox( + label='VAE', + placeholder='(Optiona) path to checkpoint of vae to replace for training', + ) + vae_button = gr.Button('📂', elem_id='open_folder_small') + vae_button.click( + get_any_file_path, + outputs=vae, + show_progress=False, + ) + ( + # use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + noise_offset, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ) = gradio_advanced_training() + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latents], + ) + + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() + + with gr.Tab('Tools'): + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=train_data_dir, + reg_data_dir_input=reg_data_dir, + output_dir_input=output_dir, + logging_dir_input=logging_dir, + ) + + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + show_progress=False, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ] + + button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=settings_list, + show_progress=False, + ) + + return ( + train_data_dir, + reg_data_dir, + output_dir, + logging_dir, + ) + + +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('Dreambooth'): + ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) = dreambooth_tab() + with gr.Tab('Utilities'): + utilities_tab( + train_data_dir_input=train_data_dir_input, + reg_data_dir_input=reg_data_dir_input, + output_dir_input=output_dir_input, + logging_dir_input=logging_dir_input, + enable_copy_info_button=True, + ) + + # Show the interface + launch_kwargs = {} + if not kwargs.get('username', None) == '': + launch_kwargs['auth'] = ( + kwargs.get('username', None), + kwargs.get('password', None), + ) + if kwargs.get('server_port', 0) > 0: + launch_kwargs['server_port'] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + + args = parser.parse_args() + + UI( + username=args.username, + password=args.password, + inbrowser=args.inbrowser, + server_port=args.server_port, + ) diff --git a/fine_tune.py b/fine_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..637a729a86800cd09c85ab5c09a02d0349c16ee8 --- /dev/null +++ b/fine_tune.py @@ -0,0 +1,430 @@ +# training with captions +# XXX dropped option: hypernetwork training + +import argparse +import gc +import math +import os +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer = train_util.load_tokenizer(args) + + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 蚭定ファむルが利甚されるため以䞋のオプションは無芖されたす: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありたせん。メタデヌタおよびtrain_data_dirオプションを確認しおください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするずきはcolor_augずrandom_cropは䜿えたせん" + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに察応した型を甚意しおおき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み蟌む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # Diffusers版のxformers䜿甚フラグを蚭定する関数 + def set_diffusers_xformers_flag(model, valid): + # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリヌスでなくなりそう + # pipeが自動で再垰的にset_use_memory_efficient_attention_xformersを探すんだっお(;ŽД) + # U-Netだけ䜿う時にはどうすればいいのか  仕方ないからコピっお䜿うか + # 0.10.2でなんか巻き戻っお個別に指定するようになった(;^ω^) + + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(model) + + # モデルに xformers ずか memory efficient attention を組み蟌む + if args.diffusers_xformers: + print("Use xformers by Diffusers") + set_diffusers_xformers_flag(unet, True) + else: + # Windows版のxformersはfloatで孊習できないのでxformersを䜿わない蚭定も可胜にしおおく必芁がある + print("Disable Diffusers' xformers") + set_diffusers_xformers_flag(unet, False) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 孊習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # 孊習を準備するモデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) + + if args.train_text_encoder: + print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) + else: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) # text encoderは孊習しない + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + for m in training_models: + m.requires_grad_(True) + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params + + # 孊習に必芁なクラスを準備する + print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数0はメむンプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最倧で指定された数たで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 孊習ステップ数を蚈算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定゚ポックたでのステップ数: {args.max_train_steps}") + + # デヌタセット偎にも孊習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを甚意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実隓的機胜募配も含めたfp16孊習を行う モデル党䜓をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を䜿う堎合はmixed_precision='fp16'を指定しおください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + # acceleratorがなんかよろしくやっおくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + # 実隓的機胜募配も含めたfp16孊習を行う PyTorchにパッチを圓おおfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を蚈算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 孊習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / å­Šç¿’é–‹å§‹") + print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサむズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサむズ䞊列孊習、募配合蚈含む: {total_batch_size}") + print(f" gradient accumulation steps / 募配を合蚈するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 孊習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("finetuning") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(training_models[0]): # 耇数モデルに察応しおいない暡様だがずりあえずこうしおおく + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに倉換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + if args.min_snr_gamma: + # do not mean over batch dimension for snr weight + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = loss.mean() # mean over batch dimension + else: + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + # TODO moving averageにする + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この埌メモリを䜿うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを䜿甚する") + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも孊習する") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/fine_tune_README.md b/fine_tune_README.md new file mode 100644 index 0000000000000000000000000000000000000000..7ffd05d4ab7bd2532c69d68f1166b87607724f78 --- /dev/null +++ b/fine_tune_README.md @@ -0,0 +1,465 @@ +It is a fine tuning that corresponds to NovelAI's proposed learning method, automatic captioning, tagging, Windows + VRAM 12GB (for v1.4/1.5) environment, etc. + +## overview +Fine tuning of U-Net of Stable Diffusion using Diffusers. It corresponds to the following improvements in NovelAI's article (For Aspect Ratio Bucketing, I referred to NovelAI's code, but the final code is all original). + +* Use the output of the penultimate layer instead of the last layer of CLIP (Text Encoder). +* Learning at non-square resolutions (Aspect Ratio Bucketing). +* Extend token length from 75 to 225. +* Captioning with BLIP (automatic creation of captions), automatic tagging with DeepDanbooru or WD14Tagger. +* Also supports Hypernetwork learning. +* Supports Stable Diffusion v2.0 (base and 768/v). +* By acquiring the output of VAE in advance and saving it to disk, we aim to save memory and speed up learning. + +Text Encoder is not trained by default. For fine tuning of the whole model, it seems common to learn only U-Net (NovelAI seems to be the same). Text Encoder can also be learned as an option. + +## Additional features +### Change CLIP output +CLIP (Text Encoder) converts the text into features in order to reflect the prompt in the image. Stable diffusion uses the output of the last layer of CLIP, but you can change it to use the output of the penultimate layer. According to NovelAI, this will reflect prompts more accurately. +It is also possible to use the output of the last layer as is. +*Stable Diffusion 2.0 uses the penultimate layer by default. Do not specify the clip_skip option. + +### Training in non-square resolutions +Stable Diffusion is trained at 512\*512, but also at resolutions such as 256\*1024 and 384\*640. It is expected that this will reduce the cropped portion and learn the relationship between prompts and images more correctly. +The learning resolution is adjusted vertically and horizontally in units of 64 pixels within a range that does not exceed the resolution area (= memory usage) given as a parameter. + +In machine learning, it is common to unify all input sizes, but there are no particular restrictions, and in fact it is okay as long as they are unified within the same batch. NovelAI's bucketing seems to refer to classifying training data in advance for each learning resolution according to the aspect ratio. And by creating a batch with the images in each bucket, the image size of the batch is unified. + +### Extending token length from 75 to 225 +Stable diffusion has a maximum of 75 tokens (77 tokens including the start and end), but we will extend it to 225 tokens. +However, the maximum length that CLIP accepts is 75 tokens, so in the case of 225 tokens, we simply divide it into thirds, call CLIP, and then concatenate the results. + +*I'm not sure if this is the preferred implementation. It seems to be working for now. Especially in 2.0, there is no implementation that can be used as a reference, so I have implemented it independently. + +*Automatic1111's Web UI seems to divide the text with commas in mind, but in my case, it's a simple division. + +## Environmental arrangement + +See the [README](./README-en.md) in this repository. + +## Preparing teacher data + +Prepare the image data you want to learn and put it in any folder. No prior preparation such as resizing is required. +However, for images that are smaller than the training resolution, it is recommended to enlarge them while maintaining the quality using super-resolution. + +It also supports multiple teacher data folders. Preprocessing will be executed for each folder. + +For example, store an image like this: + +![Teacher data folder screenshot](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png) + +## Automatic captioning +Skip if you just want to learn tags without captions. + +Also, when preparing captions manually, prepare them in the same directory as the teacher data image, with the same file name, extension .caption, etc. Each file should be a text file with only one line. + +### Captioning with BLIP + +The latest version no longer requires BLIP downloads, weight downloads, and additional virtual environments. Works as-is. + +Run make_captions.py in the finetune folder. + +``` +python finetune\make_captions.py --batch_size +``` + +If the batch size is 8 and the training data is placed in the parent folder train_data, it will be as follows. + +``` +python finetune\make_captions.py --batch_size 8 ..\train_data +``` + +A caption file is created in the same directory as the teacher data image with the same file name and extension .caption. + +Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). +You can specify the maximum length of the caption with the max_length option. Default is 75. It may be longer if the model is trained with a token length of 225. +You can change the caption extension with the caption_extension option. Default is .caption (.txt conflicts with DeepDanbooru described later). + +If there are multiple teacher data folders, execute for each folder. + +Note that the inference is random, so the results will change each time you run it. If you want to fix it, specify a random number seed like "--seed 42" with the --seed option. + +For other options, please refer to the help with --help (there seems to be no documentation for the meaning of the parameters, so you have to look at the source). + +A caption file is generated with the extension .caption by default. + +![Folder where caption is generated](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png) + +For example, with captions like: + +![captions and images](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png) + +## Tagged by DeepDanbooru +If you do not want to tag the danbooru tag itself, please proceed to "Preprocessing of caption and tag information". + +Tagging is done with DeepDanbooru or WD14Tagger. WD14Tagger seems to be more accurate. If you want to tag with WD14Tagger, skip to the next chapter. + +### Environmental arrangement +Clone DeepDanbooru https://github.com/KichangKim/DeepDanbooru into your working folder, or download the zip and extract it. I unzipped it. +Also, download deepdanbooru-v3-20211112-sgd-e28.zip from Assets of "DeepDanbooru Pretrained Model v3-20211112-sgd-e28" on the DeepDanbooru Releases page https://github.com/KichangKim/DeepDanbooru/releases and extract it to the DeepDanbooru folder. + +Download from below. Click to open Assets and download from there. + +![DeepDanbooru download page](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png) + +Make a directory structure like this + +![DeepDanbooru directory structure](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png) + +Install the necessary libraries for the Diffusers environment. Go to the DeepDanbooru folder and install it (I think it's actually just adding tensorflow-io). + +``` +pip install -r requirements.txt +``` + +Next, install DeepDanbooru itself. + +``` +pip install . +``` + +This completes the preparation of the environment for tagging. + +### Implementing tagging +Go to DeepDanbooru's folder and run deepdanbooru to tag. + +``` +deepdanbooru evaluate --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +If you put the training data in the parent folder train_data, it will be as follows. + +``` +deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. It is slow because it is processed one by one. + +If there are multiple teacher data folders, execute for each folder. + +It is generated as follows. + +![DeepDanbooru generated files](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png) + +A tag is attached like this (great amount of information...). + +![Deep Danbooru tag and image](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png) + +## Tagging with WD14Tagger +This procedure uses WD14Tagger instead of DeepDanbooru. + +Use the tagger used in Mr. Automatic1111's WebUI. I referred to the information on this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger). + +The modules required for the initial environment maintenance have already been installed. Weights are automatically downloaded from Hugging Face. + +### Implementing tagging +Run the script to do the tagging. +``` +python tag_images_by_wd14_tagger.py --batch_size +``` + +If you put the training data in the parent folder train_data, it will be as follows. +``` +python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data +``` + +The model file will be automatically downloaded to the wd14_tagger_model folder on first launch (folder can be changed in options). It will be as follows. + +![downloaded file](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png) + +A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. + +![generated tag file](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) + +![tags and images](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) + +With the thresh option, you can specify the number of confidences of the determined tag to attach the tag. The default is 0.35, same as the WD14Tagger sample. Lower values give more tags, but less accuracy. +Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). You can change the tag file extension with the caption_extension option. Default is .txt. +You can specify the folder where the model is saved with the model_dir option. +Also, if you specify the force_download option, the model will be re-downloaded even if there is a save destination folder. + +If there are multiple teacher data folders, execute for each folder. + +## Preprocessing caption and tag information + +Combine captions and tags into a single file as metadata for easy processing from scripts. + +### Caption preprocessing + +To put captions into the metadata, run the following in your working folder (if you don't use captions for learning, you don't need to run this) (it's actually a single line, and so on). + +``` +python merge_captions_to_metadata.py +--in_json + +``` + +The metadata file name is an arbitrary name. +If the training data is train_data, there is no metadata file to read, and the metadata file is meta_cap.json, it will be as follows. + +``` +python merge_captions_to_metadata.py train_data meta_cap.json +``` + +You can specify the caption extension with the caption_extension option. + +If there are multiple teacher data folders, please specify the full_path argument (metadata will have full path information). Then run it for each folder. + +``` +python merge_captions_to_metadata.py --full_path + train_data1 meta_cap1.json +python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json + train_data2 meta_cap2.json +``` + +If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there. + +__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __ + +### Tag preprocessing + +Similarly, tags are also collected in metadata (no need to do this if tags are not used for learning). +``` +python merge_dd_tags_to_metadata.py + --in_json + +``` + +With the same directory structure as above, when reading meta_cap.json and writing to meta_cap_dd.json, it will be as follows. +``` +python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json +``` + +If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder. + +``` +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json + train_data1 meta_cap_dd1.json +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json + train_data2 meta_cap_dd2.json +``` + +If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there. + +__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __ + +### Cleaning captions and tags +Up to this point, captions and DeepDanbooru tags have been put together in the metadata file. However, captions with automatic captioning are subtle due to spelling variations (*), and tags include underscores and ratings (in the case of DeepDanbooru), so the editor's replacement function etc. You should use it to clean your captions and tags. + +*For example, when learning a girl in an anime picture, there are variations in captions such as girl/girls/woman/women. Also, it may be more appropriate to simply use "girl" for things like "anime girl". + +A script for cleaning is provided, so please edit the contents of the script according to the situation and use it. + +(It is no longer necessary to specify the teacher data folder. All data in the metadata will be cleaned.) + +``` +python clean_captions_and_tags.py +``` + +Please note that --in_json is not included. For example: + +``` +python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json +``` + +Preprocessing of captions and tags is now complete. + +## Get latents in advance + +In order to speed up the learning, we acquire the latent representation of the image in advance and save it to disk. At the same time, bucketing (classifying the training data according to the aspect ratio) is performed. + +In your working folder, type: +``` +python prepare_buckets_latents.py + + + --batch_size + --max_resolution + --mixed_precision +``` + +If the model is model.ckpt, batch size 4, training resolution is 512\*512, precision is no (float32), read metadata from meta_clean.json and write to meta_lat.json: + +``` +python prepare_buckets_latents.py + train_data meta_clean.json meta_lat.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no +``` + +Latents are saved in numpy npz format in the teacher data folder. + +Specify the --v2 option when loading a Stable Diffusion 2.0 model (--v_parameterization is not required). + +You can specify the minimum resolution size with the --min_bucket_reso option and the maximum size with the --max_bucket_reso option. The defaults are 256 and 1024 respectively. For example, specifying a minimum size of 384 will not use resolutions such as 256\*1024 or 320\*768. +If you increase the resolution to something like 768\*768, you should specify something like 1280 for the maximum size. + +If you specify the --flip_aug option, it will perform horizontal flip augmentation (data augmentation). You can artificially double the amount of data, but if you specify it when the data is not left-right symmetrical (for example, character appearance, hairstyle, etc.), learning will not go well. +(This is a simple implementation that acquires the latents for the flipped image and saves the \*\_flip.npz file. No options are required for fline_tune.py. If there is a file with \_flip, Randomly load a file without + +The batch size may be increased a little more even with 12GB of VRAM. +The resolution is a number divisible by 64, and is specified by "width, height". The resolution is directly linked to the memory size during fine tuning. 512,512 seems to be the limit with VRAM 12GB (*). 16GB may be raised to 512,704 or 512,768. Even with 256, 256, etc., it seems to be difficult with 8GB of VRAM (because parameters and optimizers require a certain amount of memory regardless of resolution). + +*There was also a report that learning batch size 1 worked with 12GB VRAM and 640,640. + +The result of bucketing is displayed as follows. + +![bucketing result](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png) + +If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder. +``` +python prepare_buckets_latents.py --full_path + train_data1 meta_clean.json meta_lat1.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +python prepare_buckets_latents.py --full_path + train_data2 meta_lat1.json meta_lat2.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +``` +It is possible to make the read source and write destination the same, but separate is safer. + +__*It is safe to rewrite the argument each time and write it to a separate metadata file. __ + + +## Run training +For example: Below are the settings for saving memory. +``` +accelerate launch --num_cpu_threads_per_process 8 fine_tune.py + --pretrained_model_name_or_path=model.ckpt + --in_json meta_lat.json + --train_data_dir=train_data + --output_dir=fine_tuned + --shuffle_caption + --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 + --use_8bit_adam --xformers --gradient_checkpointing + --mixed_precision=bf16 + --save_every_n_epochs=4 +``` + +It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process of accelerate. + +Specify the model to be trained in pretrained_model_name_or_path (Stable Diffusion checkpoint or Diffusers model). Stable Diffusion checkpoint supports .ckpt and .safetensors (automatically determined by extension). + +Specifies the metadata file when caching latent to in_json. + +Specify the training data folder for train_data_dir and the output destination folder for the trained model for output_dir. + +If shuffle_caption is specified, captions and tags are shuffled and learned in units separated by commas (this is the method used in Waifu Diffusion v1.3). +(You can keep some of the leading tokens fixed without shuffling. See keep_tokens for other options.) + +Specify the batch size in train_batch_size. Specify 1 or 2 for VRAM 12GB. The number that can be specified also changes depending on the resolution. +The actual amount of data used for training is "batch size x number of steps". When increasing the batch size, the number of steps can be decreased accordingly. + +Specify the learning rate in learning_rate. For example Waifu Diffusion v1.3 seems to be 5e-6. +Specify the number of steps in max_train_steps. + +Specify use_8bit_adam to use the 8-bit Adam Optimizer. It saves memory and speeds up, but accuracy may decrease. + +Specifying xformers replaces CrossAttention to save memory and speed up. +* As of 11/9, xformers will cause an error in float32 learning, so please use bf16/fp16 or use memory-saving CrossAttention with mem_eff_attn instead (speed is inferior to xformers). + +Enable intermediate saving of gradients in gradient_checkpointing. It's slower, but uses less memory. + +Specifies whether to use mixed precision with mixed_precision. Specifying "fp16" or "bf16" saves memory, but accuracy is inferior. +"fp16" and "bf16" use almost the same amount of memory, and it is said that bf16 has better learning results (I didn't feel much difference in the range I tried). +If "no" is specified, it will not be used (it will be float32). + +* It seems that an error will occur when reading checkpoints learned with bf16 with Mr. AUTOMATIC1111's Web UI. This seems to be because the data type bfloat16 causes an error in the Web UI model safety checker. Save in fp16 or float32 format with the save_precision option. Or it seems to be good to store it in safetytensors format. + +Specifying save_every_n_epochs will save the model being trained every time that many epochs have passed. + +### Supports Stable Diffusion 2.0 +Specify the --v2 option when using Hugging Face's stable-diffusion-2-base, and specify both --v2 and --v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt please. + +### Increase accuracy and speed when memory is available +First, removing gradient_checkpointing will speed it up. However, the batch size that can be set is reduced, so please set while looking at the balance between accuracy and speed. + +Increasing the batch size increases speed and accuracy. Increase the speed while checking the speed per data within the range where the memory is sufficient (the speed may actually decrease when the memory is at the limit). + +### Change CLIP output used +Specifying 2 for the clip_skip option uses the output of the next-to-last layer. If 1 or option is omitted, the last layer is used. +The learned model should be able to be inferred by Automatic1111's web UI. + +*SD2.0 uses the second layer from the back by default, so please do not specify it when learning SD2.0. + +If the model being trained was originally trained to use the second layer, 2 is a good value. + +If you were using the last layer instead, the entire model would have been trained on that assumption. Therefore, if you train again using the second layer, you may need a certain number of teacher data and longer learning to obtain the desired learning result. + +### Extending Token Length +You can learn by extending the token length by specifying 150 or 225 for max_token_length. +The learned model should be able to be inferred by Automatic1111's web UI. + +As with clip_skip, learning with a length different from the learning state of the model may require a certain amount of teacher data and a longer learning time. + +### Save learning log +Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved. + +For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder. +Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=fine_tune_style1" for identification. + +To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard). +``` +tensorboard --logdir=logs +``` + +### Learning Hypernetworks +It will be explained in another article. + +### Learning with fp16 gradient (experimental feature) +The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). As a result, it seems that the SD1.x 512*512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512*512 size can be learned with a VRAM usage of less than 12GB. + +Specify fp16 in advance in accelerate config and optionally set mixed_precision="fp16" (does not work with bf16). + +To minimize memory usage, use the xformers, use_8bit_adam, gradient_checkpointing options and set train_batch_size to 1. +(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.) + +It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). The accuracy will drop considerably, and the probability of learning failure on the way will also increase. The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk. + +### Other Options + +#### keep_tokens +If a number is specified, the specified number of tokens (comma-separated strings) from the beginning of the caption are fixed without being shuffled. + +If there are both captions and tags, the prompts during learning will be concatenated like "caption, tag 1, tag 2...", so if you set "--keep_tokens=1", the caption will always be at the beginning during learning. will come. + +#### dataset_repeats +If the number of data sets is extremely small, the epoch will end soon (it will take some time at the epoch break), so please specify a numerical value and multiply the data by some to make the epoch longer. + +#### train_text_encoder +Text Encoder is also a learning target. Slightly increased memory usage. + +In normal fine tuning, the Text Encoder is not targeted for training (probably because U-Net is trained to follow the output of the Text Encoder), but if the number of training data is small, the Text Encoder is trained like DreamBooth. also seems to be valid. + +#### save_precision +The data format when saving checkpoints can be specified from float, fp16, and bf16 (if not specified, it is the same as the data format during learning). It saves disk space, but the model produces different results. Also, if you specify float or fp16, you should be able to read it on Mr. 1111's Web UI. + +*For VAE, the data format of the original checkpoint will remain, so the model size may not be reduced to a little over 2GB even with fp16. + +#### save_model_as +Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors. + +When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face. + +#### use_safetensors +This option saves checkpoints in safetyensors format. The save format will be the default (same format as loaded). + +#### save_state and resume +The save_state option saves the learning state of the optimizer, etc. in addition to the checkpoint in the folder when saving midway and at the final save. This avoids a decrease in accuracy when learning is resumed after being interrupted (since the optimizer optimizes while having a state, if the state is reset, the optimization must be performed again from the initial state. not). Note that the number of steps is not saved due to Accelerate specifications. + +When starting the script, you can resume by specifying the folder where the state is saved with the resume option. + +Please note that the learning state will be about 5 GB per save, so please be careful of the disk capacity. + +#### gradient_accumulation_steps +Updates the gradient in batches for the specified number of steps. Has a similar effect to increasing the batch size, but consumes slightly more memory. + +*The Accelerate specification does not support multiple learning models, so if you set Text Encoder as the learning target and specify a value of 2 or more for this option, an error may occur. + +#### lr_scheduler / lr_warmup_steps +You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant. + +With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details. + +#### diffusers_xformers +Uses Diffusers' xformers feature rather than the script's own xformers replacement feature. Hypernetwork learning is no longer possible. \ No newline at end of file diff --git a/fine_tune_README_ja.md b/fine_tune_README_ja.md new file mode 100644 index 0000000000000000000000000000000000000000..686947c952b19c016974792cdc5f4f903701cfc9 --- /dev/null +++ b/fine_tune_README_ja.md @@ -0,0 +1,140 @@ +NovelAIの提案した孊習手法、自動キャプションニング、タグ付け、WindowsVRAM 12GBSD v1.xの堎合環境等に察応したfine tuningです。ここでfine tuningずは、モデルを画像ずキャプションで孊習するこずを指したすLoRAやTextual Inversion、Hypernetworksは含みたせん + +[孊習に぀いおの共通ドキュメント](./train_README-ja.md) もあわせおご芧ください。 + +# 抂芁 + +Diffusersを甚いおStable DiffusionのU-Netのfine tuningを行いたす。NovelAIの蚘事にある以䞋の改善に察応しおいたすAspect Ratio Bucketingに぀いおはNovelAIのコヌドを参考にしたしたが、最終的なコヌドはすべおオリゞナルです。 + +* CLIPText Encoderの最埌の局ではなく最埌から二番目の局の出力を甚いる。 +* 正方圢以倖の解像床での孊習Aspect Ratio Bucketing 。 +* トヌクン長を75から225に拡匵する。 +* BLIPによるキャプショニングキャプションの自動䜜成、DeepDanbooruたたはWD14Taggerによる自動タグ付けを行う。 +* Hypernetworkの孊習にも察応する。 +* Stable Diffusion v2.0baseおよび768/vに察応。 +* VAEの出力をあらかじめ取埗しディスクに保存しおおくこずで、孊習の省メモリ化、高速化を図る。 + +デフォルトではText Encoderの孊習は行いたせん。モデル党䜓のfine tuningではU-Netだけを孊習するのが䞀般的なようですNovelAIもそのようです。オプション指定でText Encoderも孊習察象ずできたす。 + +# 远加機胜に぀いお + +## CLIPの出力の倉曎 + +プロンプトを画像に反映するため、テキストの特城量ぞの倉換を行うのがCLIPText Encoderです。Stable DiffusionではCLIPの最埌の局の出力を甚いおいたすが、それを最埌から二番目の局の出力を甚いるよう倉曎できたす。NovelAIによるず、これによりより正確にプロンプトが反映されるようになるずのこずです。 +元のたた、最埌の局の出力を甚いるこずも可胜です。 + +※Stable Diffusion 2.0では最埌から二番目の局をデフォルトで䜿いたす。clip_skipオプションを指定しないでください。 + +## 正方圢以倖の解像床での孊習 + +Stable Diffusionは512\*512で孊習されおいたすが、それに加えお256\*1024や384\*640ずいった解像床でも孊習したす。これによりトリミングされる郚分が枛り、より正しくプロンプトず画像の関係が孊習されるこずが期埅されたす。 +孊習解像床はパラメヌタずしお䞎えられた解像床の面積メモリ䜿甚量を超えない範囲で、64ピクセル単䜍で瞊暪に調敎、䜜成されたす。 + +機械孊習では入力サむズをすべお統䞀するのが䞀般的ですが、特に制玄があるわけではなく、実際は同䞀のバッチ内で統䞀されおいれば倧䞈倫です。NovelAIの蚀うbucketingは、あらかじめ教垫デヌタを、アスペクト比に応じた孊習解像床ごずに分類しおおくこずを指しおいるようです。そしおバッチを各bucket内の画像で䜜成するこずで、バッチの画像サむズを統䞀したす。 + +## トヌクン長の75から225ぞの拡匵 + +Stable Diffusionでは最倧75トヌクン開始・終了を含むず77トヌクンですが、それを225トヌクンたで拡匵したす。 +ただしCLIPが受け付ける最倧長は75トヌクンですので、225トヌクンの堎合、単玔に䞉分割しおCLIPを呌び出しおから結果を連結しおいたす。 + +※これが望たしい実装なのかどうかはいたひず぀わかりたせん。ずりあえず動いおはいるようです。特に2.0では䜕も参考になる実装がないので独自に実装しおありたす。 + +※Automatic1111氏のWeb UIではカンマを意識しお分割、ずいったこずもしおいるようですが、私の堎合はそこたでしおおらず単玔な分割です。 + +# 孊習の手順 + +あらかじめこのリポゞトリのREADMEを参照し、環境敎備を行っおください。 + +## デヌタの準備 + +[孊習デヌタの準備に぀いお](./train_README-ja.md) を参照しおください。fine tuningではメタデヌタを甚いるfine tuning方匏のみ察応しおいたす。 + +## 孊習の実行 +たずえば以䞋のように実行したす。以䞋は省メモリ化のための蚭定です。それぞれの行を必芁に応じお曞き換えおください。 + +``` +accelerate launch --num_cpu_threads_per_process 1 fine_tune.py + --pretrained_model_name_or_path=<.ckptたたは.safetensordたたはDiffusers版モデルのディレクトリ> + --output_dir=<孊習したモデルの出力先フォルダ> + --output_name=<孊習したモデル出力時のファむル名> + --dataset_config=<デヌタ準備で䜜成した.tomlファむル> + --save_model_as=safetensors + --learning_rate=5e-6 --max_train_steps=10000 + --use_8bit_adam --xformers --gradient_checkpointing + --mixed_precision=fp16 +``` + +`num_cpu_threads_per_process` には通垞は1を指定するずよいようです。 + +`pretrained_model_name_or_path` に远加孊習を行う元ずなるモデルを指定したす。Stable Diffusionのcheckpointファむル.ckptたたは.safetensors、Diffusersのロヌカルディスクにあるモデルディレクトリ、DiffusersのモデルID"stabilityai/stable-diffusion-2"などが指定できたす。 + +`output_dir` に孊習埌のモデルを保存するフォルダを指定したす。`output_name` にモデルのファむル名を拡匵子を陀いお指定したす。`save_model_as` でsafetensors圢匏での保存を指定しおいたす。 + +`dataset_config` に `.toml` ファむルを指定したす。ファむル内でのバッチサむズ指定は、圓初はメモリ消費を抑えるために `1` ずしおください。 + +孊習させるステップ数 `max_train_steps` を10000ずしたす。孊習率 `learning_rate` はここでは5e-6を指定しおいたす。 + +省メモリ化のため `mixed_precision="fp16"` を指定したすRTX30 シリヌズ以降では `bf16` も指定できたす。環境敎備時にaccelerateに行った蚭定ず合わせおください。たた `gradient_checkpointing` を指定したす。 + +オプティマむザモデルを孊習デヌタにあうように最適化孊習させるクラスにメモリ消費の少ない 8bit AdamW を䜿うため、 `optimizer_type="AdamW8bit"` を指定したす。 + +`xformers` オプションを指定し、xformersのCrossAttentionを甚いたす。xformersをむンストヌルしおいない堎合や゚ラヌずなる堎合環境にもよりたすが `mixed_precision="no"` の堎合など、代わりに `mem_eff_attn` オプションを指定するず省メモリ版CrossAttentionを䜿甚したす速床は遅くなりたす。 + +ある皋床メモリがある堎合は、`.toml` ファむルを線集しおバッチサむズをたずえば `4` くらいに増やしおください高速化ず粟床向䞊の可胜性がありたす。 + +### よく䜿われるオプションに぀いお + +以䞋の堎合にはオプションに関するドキュメントを参照しおください。 + +- Stable Diffusion 2.xたたはそこからの掟生モデルを孊習する +- clip skipを2以䞊を前提ずしたモデルを孊習する +- 75トヌクンを超えたキャプションで孊習する + +### バッチサむズに぀いお + +モデル党䜓を孊習するためLoRA等の孊習に比べるずメモリ消費量は倚くなりたすDreamBoothず同じ。 + +### 孊習率に぀いお + +1e-6から5e-6皋床が䞀般的なようです。他のfine tuningの䟋なども参照しおみおください。 + +### 以前の圢匏のデヌタセット指定をした堎合のコマンドラむン + +解像床やバッチサむズをオプションで指定したす。コマンドラむンの䟋は以䞋の通りです。 + +``` +accelerate launch --num_cpu_threads_per_process 1 fine_tune.py + --pretrained_model_name_or_path=model.ckpt + --in_json meta_lat.json + --train_data_dir=train_data + --output_dir=fine_tuned + --shuffle_caption + --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 + --use_8bit_adam --xformers --gradient_checkpointing + --mixed_precision=bf16 + --save_every_n_epochs=4 +``` + + + +# fine tuning特有のその他の䞻なオプション + +すべおのオプションに぀いおは別文曞を参照しおください。 + +## `train_text_encoder` +Text Encoderも孊習察象ずしたす。メモリ䜿甚量が若干増加したす。 + +通垞のfine tuningではText Encoderは孊習察象ずしたせんが恐らくText Encoderの出力に埓うようにU-Netを孊習するため、孊習デヌタ数が少ない堎合には、DreamBoothのようにText Encoder偎に孊習させるのも有効的なようです。 + +## `diffusers_xformers` +スクリプト独自のxformers眮換機胜ではなくDiffusersのxformers機胜を利甚したす。Hypernetworkの孊習はできなくなりたす。 diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py new file mode 100644 index 0000000000000000000000000000000000000000..7851fb08b21d15c93aab2a1d109f5018423b4e6b --- /dev/null +++ b/finetune/blip/blip.py @@ -0,0 +1,240 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import warnings +warnings.filterwarnings("ignore") + +# from models.vit import VisionTransformer, interpolate_pos_embed +# from models.med import BertConfig, BertModel, BertLMHeadModel +from blip.vit import VisionTransformer, interpolate_pos_embed +from blip.med import BertConfig, BertModel, BertLMHeadModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +import os +from urllib.parse import urlparse +from timm.models.hub import download_cached_file + +class BLIP_Base(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 224, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + + def forward(self, image, caption, mode): + + assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" + text = self.tokenizer(caption, return_tensors="pt").to(image.device) + + if mode=='image': + # return image features + image_embeds = self.visual_encoder(image) + return image_embeds + + elif mode=='text': + # return text features + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + return text_output.last_hidden_state + + elif mode=='multimodal': + # return multimodel features + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text.input_ids[:,0] = self.tokenizer.enc_token_id + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + return output.last_hidden_state + + + +class BLIP_Decoder(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + prompt = 'a picture of ', + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_decoder = BertLMHeadModel(config=med_config) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 + + + def forward(self, image, caption): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) + + text.input_ids[:,0] = self.tokenizer.bos_token_id + + decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) + decoder_targets[:,:self.prompt_length] = -100 + + decoder_output = self.text_decoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + labels = decoder_targets, + return_dict = True, + ) + loss_lm = decoder_output.loss + + return loss_lm + + def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): + image_embeds = self.visual_encoder(image) + + if not sample: + image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) + + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} + + prompt = [self.prompt] * image.size(0) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) + input_ids[:,0] = self.tokenizer.bos_token_id + input_ids = input_ids[:, :-1] + + if sample: + #nucleus sampling + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + #beam search + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + captions = [] + for output in outputs: + caption = self.tokenizer.decode(output, skip_special_tokens=True) + captions.append(caption[len(self.prompt):]) + return captions + + +def blip_decoder(pretrained='',**kwargs): + model = BLIP_Decoder(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def blip_feature_extractor(pretrained='',**kwargs): + model = BLIP_Base(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def init_tokenizer(): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer.add_special_tokens({'bos_token':'[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): + + assert vit in ['base', 'large'], "vit parameter must be base or large" + if vit=='base': + vision_width = 768 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, + num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate + ) + elif vit=='large': + vision_width = 1024 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, + num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate + ) + return visual_encoder, vision_width + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + +def load_checkpoint(model,url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) + if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): + state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], + model.visual_encoder_m) + for key in model.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape!=model.state_dict()[key].shape: + del state_dict[key] + + msg = model.load_state_dict(state_dict,strict=False) + print('load checkpoint from %s'%url_or_filename) + return model,msg + diff --git a/finetune/blip/med.py b/finetune/blip/med.py new file mode 100644 index 0000000000000000000000000000000000000000..7b00a35450b736180a805d4f4664b4fb95aeba01 --- /dev/null +++ b/finetune/blip/med.py @@ -0,0 +1,955 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +''' + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction=='none': + lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/finetune/blip/med_config.json b/finetune/blip/med_config.json new file mode 100644 index 0000000000000000000000000000000000000000..dc12b99cf539b751d442b4ca7785c9f6a4f8306e --- /dev/null +++ b/finetune/blip/med_config.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30524, + "encoder_width": 768, + "add_cross_attention": true + } + \ No newline at end of file diff --git a/finetune/blip/vit.py b/finetune/blip/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899 --- /dev/null +++ b/finetune/blip/vit.py @@ -0,0 +1,305 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on timm code base + * https://github.com/rwightman/pytorch-image-models/tree/master/timm +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, + use_grad_checkpointing=False, ckpt_layer=0): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + for i,blk in enumerate(self.blocks): + x = blk(x, register_blk==i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) +# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: +# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) +# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) +# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: +# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) +# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size!=new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint \ No newline at end of file diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py new file mode 100644 index 0000000000000000000000000000000000000000..68839ecccbbaf056204be1d6fb0d204e104091e6 --- /dev/null +++ b/finetune/clean_captions_and_tags.py @@ -0,0 +1,190 @@ +# このスクリプトのラむセンスは、Apache License 2.0ずしたす +# (c) 2022 Kohya S. @kohya_ss + +import argparse +import glob +import os +import json +import re + +from tqdm import tqdm + +PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') +PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') +PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') +PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') + +# 耇数人がいるずき、耇数の髪色や目の色が定矩されおいれば削陀する +PATTERNS_REMOVE_IN_MULTI = [ + PATTERN_HAIR_LENGTH, + PATTERN_HAIR_CUT, + re.compile(r', [\w\-]+ eyes, '), + re.compile(r', ([\w\-]+ sleeves|sleeveless), '), + # 耇数の髪型定矩がある堎合は削陀する + re.compile( + r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), +] + + +def clean_tags(image_key, tags): + # replace '_' to ' ' + tags = tags.replace('^_^', '^@@@^') + tags = tags.replace('_', ' ') + tags = tags.replace('^@@@^', '^_^') + + # remove rating: deepdanbooruのみ + tokens = tags.split(", rating") + if len(tokens) == 1: + # WD14 taggerのずきはこちらになるのでメッセヌゞは出さない + # print("no rating:") + # print(f"{image_key} {tags}") + pass + else: + if len(tokens) > 2: + print("multiple ratings:") + print(f"{image_key} {tags}") + tags = tokens[0] + + tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで怜玢をするための身も蓋もない察策 + + # 耇数の人物がいる堎合は髪色等のタグを削陀する + if 'girls' in tags or 'boys' in tags: + for pat in PATTERNS_REMOVE_IN_MULTI: + found = pat.findall(tags) + if len(found) > 1: # 二぀以䞊、タグがある + tags = pat.sub("", tags) + + # 髪の特殊察応 + srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは䟋倖なので避けおおく党員が同じ髪の長さの堎合 + if srch_hair_len: + org = srch_hair_len.group() + tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) + + found = PATTERN_HAIR.findall(tags) + if len(found) > 1: + tags = PATTERN_HAIR.sub("", tags) + + if srch_hair_len: + tags = tags.replace(", @@@, ", org) # 戻す + + # white shirtずshirtみたいな重耇タグの削陀 + found = PATTERN_WORD.findall(tags) + for word in found: + if re.search(f", ((\w+) )+{word}, ", tags): + tags = tags.replace(f", {word}, ", "") + + tags = tags.replace(", , ", ", ") + assert tags.startswith(", ") and tags.endswith(", ") + tags = tags[2:-2] + return tags + + +# 䞊から順に怜玢、眮換される +# ('眮換元文字列', '眮換埌文字列') +CAPTION_REPLACEMENTS = [ + ('anime anime', 'anime'), + ('young ', ''), + ('anime girl', 'girl'), + ('cartoon female', 'girl'), + ('cartoon lady', 'girl'), + ('cartoon character', 'girl'), # a or ~s + ('cartoon woman', 'girl'), + ('cartoon women', 'girls'), + ('cartoon girl', 'girl'), + ('anime female', 'girl'), + ('anime lady', 'girl'), + ('anime character', 'girl'), # a or ~s + ('anime woman', 'girl'), + ('anime women', 'girls'), + ('lady', 'girl'), + ('female', 'girl'), + ('woman', 'girl'), + ('women', 'girls'), + ('people', 'girls'), + ('person', 'girl'), + ('a cartoon figure', 'a figure'), + ('a cartoon image', 'an image'), + ('a cartoon picture', 'a picture'), + ('an anime cartoon image', 'an image'), + ('a cartoon anime drawing', 'a drawing'), + ('a cartoon drawing', 'a drawing'), + ('girl girl', 'girl'), +] + + +def clean_caption(caption): + for rf, rt in CAPTION_REPLACEMENTS: + replaced = True + while replaced: + bef = caption + caption = caption.replace(rf, rt) + replaced = bef != caption + return caption + + +def main(args): + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + print("no metadata / メタデヌタファむルがありたせん") + return + + print("cleaning captions and tags.") + image_keys = list(metadata.keys()) + for image_key in tqdm(image_keys): + tags = metadata[image_key].get('tags') + if tags is None: + print(f"image does not have tags / メタデヌタにタグがありたせん: {image_key}") + else: + org = tags + tags = clean_tags(image_key, tags) + metadata[image_key]['tags'] = tags + if args.debug and org != tags: + print("FROM: " + org) + print("TO: " + tags) + + caption = metadata[image_key].get('caption') + if caption is None: + print(f"image does not have caption / メタデヌタにキャプションがありたせん: {image_key}") + else: + org = caption + caption = clean_caption(caption) + metadata[image_key]['caption'] = caption + if args.debug and org != caption: + print("FROM: " + org) + print("TO: " + caption) + + # metadataを曞き出しお終わり + print(f"writing metadata: {args.out_json}") + with open(args.out_json, "wt", encoding='utf-8') as f: + json.dump(metadata, f, indent=2) + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + # parser.add_argument("train_data_dir", type=str, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument("in_json", type=str, help="metadata file to input / 読み蟌むメタデヌタファむル") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデヌタファむル曞き出し先") + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args, unknown = parser.parse_known_args() + if len(unknown) == 1: + print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") + print("All captions and tags in the metadata are processed.") + print("譊告: train_data_dir匕数は䞍芁になりたした。将来的には䞉぀の匕数を指定するず動かなくなる予定です。読み蟌み元のメタデヌタず曞き出し先の二぀の匕数だけ指定しおください。") + print("メタデヌタ内のすべおのキャプションずタグが凊理されたす。") + args.in_json = args.out_json + args.out_json = unknown[0] + elif len(unknown) > 0: + raise ValueError(f"error: unrecognized arguments: {unknown}") + + main(args) diff --git a/finetune/hypernetwork_nai.py b/finetune/hypernetwork_nai.py new file mode 100644 index 0000000000000000000000000000000000000000..dcaaa714a08bb2cfc417d827e8bdd01c8c1ad367 --- /dev/null +++ b/finetune/hypernetwork_nai.py @@ -0,0 +1,96 @@ +# NAI compatible + +import torch + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, multiplier=1.0): + super().__init__() + + linear1 = torch.nn.Linear(dim, dim * 2) + linear2 = torch.nn.Linear(dim * 2, dim) + linear1.weight.data.normal_(mean=0.0, std=0.01) + linear1.bias.data.zero_() + linear2.weight.data.normal_(mean=0.0, std=0.01) + linear2.bias.data.zero_() + linears = [linear1, linear2] + + self.linear = torch.nn.Sequential(*linears) + self.multiplier = multiplier + + def forward(self, x): + return x + self.linear(x) * self.multiplier + + +class Hypernetwork(torch.nn.Module): + enable_sizes = [320, 640, 768, 1280] + # return self.modules[Hypernetwork.enable_sizes.index(size)] + + def __init__(self, multiplier=1.0) -> None: + super().__init__() + self.modules = [] + for size in Hypernetwork.enable_sizes: + self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) + self.register_module(f"{size}_0", self.modules[-1][0]) + self.register_module(f"{size}_1", self.modules[-1][1]) + + def apply_to_stable_diffusion(self, text_encoder, vae, unet): + blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks + for block in blocks: + for subblk in block: + if 'SpatialTransformer' in str(type(subblk)): + for tf_block in subblk.transformer_blocks: + for attn in [tf_block.attn1, tf_block.attn2]: + size = attn.context_dim + if size in Hypernetwork.enable_sizes: + attn.hypernetwork = self + else: + attn.hypernetwork = None + + def apply_to_diffusers(self, text_encoder, vae, unet): + blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks + for block in blocks: + if hasattr(block, 'attentions'): + for subblk in block.attentions: + if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ + for tf_block in subblk.transformer_blocks: + for attn in [tf_block.attn1, tf_block.attn2]: + size = attn.to_k.in_features + if size in Hypernetwork.enable_sizes: + attn.hypernetwork = self + else: + attn.hypernetwork = None + return True # TODO error checking + + def forward(self, x, context): + size = context.shape[-1] + assert size in Hypernetwork.enable_sizes + module = self.modules[Hypernetwork.enable_sizes.index(size)] + return module[0].forward(context), module[1].forward(context) + + def load_from_state_dict(self, state_dict): + # old ver to new ver + changes = { + 'linear1.bias': 'linear.0.bias', + 'linear1.weight': 'linear.0.weight', + 'linear2.bias': 'linear.1.bias', + 'linear2.weight': 'linear.1.weight', + } + for key_from, key_to in changes.items(): + if key_from in state_dict: + state_dict[key_to] = state_dict[key_from] + del state_dict[key_from] + + for size, sd in state_dict.items(): + if type(size) == int: + self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) + self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) + return True + + def get_state_dict(self): + state_dict = {} + for i, size in enumerate(Hypernetwork.enable_sizes): + sd0 = self.modules[i][0].state_dict() + sd1 = self.modules[i][1].state_dict() + state_dict[size] = [sd0, sd1] + return state_dict diff --git a/finetune/make_captions.py b/finetune/make_captions.py new file mode 100644 index 0000000000000000000000000000000000000000..e690349a23c0151e38188f8720765cad53e13a10 --- /dev/null +++ b/finetune/make_captions.py @@ -0,0 +1,168 @@ +import argparse +import glob +import os +import json +import random + +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from blip.blip import blip_decoder +import library.train_util as train_util + +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +IMAGE_SIZE = 384 + +# 正方圢でいいのか ずいう気がするが゜ヌスがそうなので +IMAGE_TRANSFORM = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) +]) + +# 共通化したいが埮劙に凊理が異なる   +class ImageLoadingTransformDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor = IMAGE_TRANSFORM(image) + except Exception as e: + print(f"Could not load image path / 画像を読み蟌めたせん: {img_path}, error: {e}") + return None + + return (tensor, img_path) + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + +def main(args): + # fix the seed for reproducibility + seed = args.seed # + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + if not os.path.exists("blip"): + args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path + + cwd = os.getcwd() + print('Current Working Directory is: ', cwd) + os.chdir('finetune') + + print(f"load images from {args.train_data_dir}") + image_paths = train_util.glob_images(args.train_data_dir) + print(f"found {len(image_paths)} images.") + + print(f"loading BLIP caption: {args.caption_weights}") + model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json") + model.eval() + model = model.to(DEVICE) + print("BLIP loaded") + + # captioningする + def run_batch(path_imgs): + imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) + + with torch.no_grad(): + if args.beam_search: + captions = model.generate(imgs, sample=False, num_beams=args.num_beams, + max_length=args.max_length, min_length=args.min_length) + else: + captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) + + for (image_path, _), caption in zip(path_imgs, captions): + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(caption + "\n") + if args.debug: + print(image_path, caption) + + # 読み蟌みの高速化のためにDataLoaderを䜿うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingTransformDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + img_tensor, image_path = data + if img_tensor is None: + try: + raw_image = Image.open(image_path) + if raw_image.mode != 'RGB': + raw_image = raw_image.convert("RGB") + img_tensor = IMAGE_TRANSFORM(raw_image) + except Exception as e: + print(f"Could not load image path / 画像を読み蟌めたせん: {image_path}, error: {e}") + continue + + b_imgs.append((image_path, img_tensor)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + if len(b_imgs) > 0: + run_batch(b_imgs) + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", + help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファむル(model_large_caption.pth)") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファむルの拡匵子スペルミスしおいたのを残しおありたす") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファむルの拡匵子") + parser.add_argument("--beam_search", action="store_true", + help="use beam search (default Nucleus sampling) / beam searchを䜿うこのオプション未指定時はNucleus sampling") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 掚論時のバッチサむズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み蟌みを有効にしおこのワヌカヌ数を適甚する読み蟌みを高速化") + parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビヌム数倚いず粟床が䞊がるが時間がかかる") + parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") + parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最倧長") + parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") + parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再珟性を確保するための乱数seed') + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしおいたオプションを埩元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py new file mode 100644 index 0000000000000000000000000000000000000000..06af559878d2109e5643e01bfd623622903f2489 --- /dev/null +++ b/finetune/make_captions_by_git.py @@ -0,0 +1,151 @@ +import argparse +import os +import re + +from PIL import Image +from tqdm import tqdm +import torch +from transformers import AutoProcessor, AutoModelForCausalLM +from transformers.generation.utils import GenerationMixin + +import library.train_util as train_util + + +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +PATTERN_REPLACE = [ + re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), + re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), + re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), + re.compile(r'with the number \d+ on (it|\w+ \w+)'), + re.compile(r'with the words "'), + re.compile(r'word \w+ on it'), + re.compile(r'that says the word \w+ on it'), + re.compile('that says\'the word "( on it)?'), +] + +# 誀怜知したくりの with the word xxxx を消す + + +def remove_words(captions, debug): + removed_caps = [] + for caption in captions: + cap = caption + for pat in PATTERN_REPLACE: + cap = pat.sub("", cap) + if debug and cap != caption: + print(caption) + print(cap) + removed_caps.append(cap) + return removed_caps + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + +def main(args): + # GITにバッチサむズが1より倧きくおも動くようにパッチを圓おる: transformers 4.26.0甹 + org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation + curr_batch_size = [args.batch_size] # ルヌプの最埌で件数がbatch_size未満になるので入れ替えられるように + + # input_idsがバッチサむズず同じ件数である必芁があるバッチサむズはこの関数から参照できないので倖から枡す + # ここより䞊で眮き換えようずするずすごく倧倉 + def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): + input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) + if input_ids.size()[0] != curr_batch_size[0]: + input_ids = input_ids.repeat(curr_batch_size[0], 1) + return input_ids + GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch + + print(f"load images from {args.train_data_dir}") + image_paths = train_util.glob_images(args.train_data_dir) + print(f"found {len(image_paths)} images.") + + # できればcacheに䟝存せず明瀺的にダりンロヌドしたい + print(f"loading GIT: {args.model_id}") + git_processor = AutoProcessor.from_pretrained(args.model_id) + git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) + print("GIT loaded") + + # captioningする + def run_batch(path_imgs): + imgs = [im for _, im in path_imgs] + + curr_batch_size[0] = len(path_imgs) + inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil圢匏 + generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) + captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) + + if args.remove_words: + captions = remove_words(captions, args.debug) + + for (image_path, _), caption in zip(path_imgs, captions): + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(caption + "\n") + if args.debug: + print(image_path, caption) + + # 読み蟌みの高速化のためにDataLoaderを䜿うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is None: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み蟌めたせん: {image_path}, error: {e}") + continue + + b_imgs.append((image_path, image)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: + run_batch(b_imgs) + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファむルの拡匵子") + parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps", + help="model id for GIT in Hugging Face / 䜿甚するGITのHugging FaceのモデルID") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 掚論時のバッチサむズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み蟌みを有効にしおこのワヌカヌ数を適甚する読み蟌みを高速化") + parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最倧長") + parser.add_argument("--remove_words", action="store_true", + help="remove like `with the words xxx` from caption / `with the words xxx`のような郚分をキャプションから削陀する") + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..241f6f902867dfe8c8ba2cd8bed9f3553fd5b07f --- /dev/null +++ b/finetune/merge_captions_to_metadata.py @@ -0,0 +1,76 @@ +import argparse +import json +from pathlib import Path +from typing import List +from tqdm import tqdm +import library.train_util as train_util +import os + +def main(args): + assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathず同時に指定しおください" + + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") + + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json + + if args.in_json is not None: + print(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) + print("captions for existing images will be overwritten / 既存の画像のキャプションは䞊曞きされたす") + else: + print("new metadata will be created / 新しいメタデヌタファむルが䜜成されたす") + metadata = {} + + print("merge caption texts to metadata json.") + for image_path in tqdm(image_paths): + caption_path = image_path.with_suffix(args.caption_extension) + caption = caption_path.read_text(encoding='utf-8').strip() + + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) + + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} + + metadata[image_key]['caption'] = caption + if args.debug: + print(image_key, caption) + + # metadataを曞き出しお終わり + print(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデヌタファむル曞き出し先") + parser.add_argument("--in_json", type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み蟌むメタデヌタファむル省略時、out_jsonが存圚すればそれを読み蟌む") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 読み蟌むキャプションファむルの拡匵子スペルミスしおいたのを残しおありたす") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み蟌むキャプションファむルの拡匵子") + parser.add_argument("--full_path", action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデヌタで画像キヌをフルパスにする耇数の孊習画像ディレクトリに察応") + parser.add_argument("--recursive", action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべおの子フォルダにある孊習タグを再垰的に探す") + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしおいたオプションを埩元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..db1bff6da7a2227e2e04558de8c8f93e8523d2f9 --- /dev/null +++ b/finetune/merge_dd_tags_to_metadata.py @@ -0,0 +1,71 @@ +import argparse +import json +from pathlib import Path +from typing import List +from tqdm import tqdm +import library.train_util as train_util +import os + +def main(args): + assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathず同時に指定しおください" + + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") + + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json + + if args.in_json is not None: + print(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) + print("tags data for existing images will be overwritten / 既存の画像のタグは䞊曞きされたす") + else: + print("new metadata will be created / 新しいメタデヌタファむルが䜜成されたす") + metadata = {} + + print("merge tags to metadata json.") + for image_path in tqdm(image_paths): + tags_path = image_path.with_suffix(args.caption_extension) + tags = tags_path.read_text(encoding='utf-8').strip() + + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) + + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} + + metadata[image_key]['tags'] = tags + if args.debug: + print(image_key, tags) + + # metadataを曞き出しお終わり + print(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデヌタファむル曞き出し先") + parser.add_argument("--in_json", type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み蟌むメタデヌタファむル省略時、out_jsonが存圚すればそれを読み蟌む") + parser.add_argument("--full_path", action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデヌタで画像キヌをフルパスにする耇数の孊習画像ディレクトリに察応") + parser.add_argument("--recursive", action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべおの子フォルダにある孊習タグを再垰的に探す") + parser.add_argument("--caption_extension", type=str, default=".txt", + help="extension of caption (tag) file / 読み蟌むキャプションタグファむルの拡匵子") + parser.add_argument("--debug", action="store_true", help="debug mode, print tags") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9a38ab391fe6748b1b7e51848342f66570539e --- /dev/null +++ b/finetune/prepare_buckets_latents.py @@ -0,0 +1,267 @@ +import argparse +import os +import json + +from tqdm import tqdm +import numpy as np +from PIL import Image +import cv2 +import torch +from torchvision import transforms + +import library.model_util as model_util +import library.train_util as train_util + +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + +def get_latents(vae, images, weight_dtype): + img_tensors = [IMAGE_TRANSFORMS(image) for image in images] + img_tensors = torch.stack(img_tensors) + img_tensors = img_tensors.to(DEVICE, weight_dtype) + with torch.no_grad(): + latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() + return latents + + +def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip): + if is_full_path: + base_name = os.path.splitext(os.path.basename(image_key))[0] + else: + base_name = image_key + if flip: + base_name += '_flip' + return os.path.join(data_dir, base_name) + + +def main(args): + # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必芁がありたす" + if args.bucket_reso_steps % 8 > 0: + print(f"resolution of buckets in training time is a multiple of 8 / 孊習時の各bucketの解像床は8単䜍になりたす") + + image_paths = train_util.glob_images(args.train_data_dir) + print(f"found {len(image_paths)} images.") + + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + print(f"no metadata / メタデヌタファむルがありたせん: {args.in_json}") + return + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae = model_util.load_vae(args.model_name_or_path, weight_dtype) + vae.eval() + vae.to(DEVICE, dtype=weight_dtype) + + # bucketのサむズを蚈算する + max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) + assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サむズに誀りがありたす。'幅,高さ'で指定しおください: {args.max_resolution}" + + bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso, + args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps) + if not args.bucket_no_upscale: + bucket_manager.make_buckets() + else: + print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された堎合は、bucketの解像床は画像サむズから自動蚈算されるため、min_bucket_resoずmax_bucket_resoは無芖されたす") + + # 画像をひず぀ず぀適切なbucketに割り圓おながらlatentを蚈算する + img_ar_errors = [] + + def process_batch(is_last): + for bucket in bucket_manager.buckets: + if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: + latents = get_latents(vae, [img for _, img in bucket], weight_dtype) + assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \ + f"latent shape {latents.shape}, {bucket[0][1].shape}" + + for (image_key, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + np.savez(npz_file_name, latent) + + # flip + if args.flip_aug: + latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないずTensor倉換できない + + for (image_key, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + np.savez(npz_file_name, latent) + else: + # remove existing flipped npz + for image_key, _ in bucket: + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" + if os.path.isfile(npz_file_name): + print(f"remove existing flipped npz / 既存のflipされたnpzファむルを削陀したす: {npz_file_name}") + os.remove(npz_file_name) + + bucket.clear() + + # 読み蟌みの高速化のためにDataLoaderを䜿うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + bucket_counts = {} + for data_entry in tqdm(data, smoothing=0.0): + if data_entry[0] is None: + continue + + img_tensor, image_path = data_entry[0] + if img_tensor is not None: + image = transforms.functional.to_pil_image(img_tensor) + else: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み蟌めたせん: {image_path}, error: {e}") + continue + + image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] + if image_key not in metadata: + metadata[image_key] = {} + + # 本圓はこのあずの郚分もDataSetに持っおいけば高速化できるがいろいろ倧倉 + + reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) + img_ar_errors.append(abs(ar_error)) + bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 + + # メタデヌタに蚘録する解像床はlatent単䜍ずするので、8単䜍で切り捚お + metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) + + if not args.bucket_no_upscale: + # upscaleを行わないずきには、resize埌のサむズは、bucketのサむズず、瞊暪どちらかが同じであるこずを確認する + assert resized_size[0] == reso[0] or resized_size[1] == reso[ + 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" + assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ + 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" + + assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ + 1], f"internal error resized size is small: {resized_size}, {reso}" + + # 既に存圚するファむルがあればshapeを確認しお同じならskipする + if args.skip_existing: + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] + if args.flip_aug: + npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") + + found = True + for npz_file in npz_files: + if not os.path.exists(npz_file): + found = False + break + + dat = np.load(npz_file)['arr_0'] + if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 + found = False + break + if found: + continue + + # 画像をリサむズしおトリミングする + # PILにinter_areaがないのでcv2で   + image = np.array(image) + if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサむズ凊理が必芁 + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) + + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size//2:trim_size//2 + reso[0]] + + if resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size//2:trim_size//2 + reso[1]] + + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" + + # # debug + # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) + + # バッチぞ远加 + bucket_manager.add_image(reso, (image_key, image)) + + # バッチを掚論するか刀定しお掚論する + process_batch(False) + + # 残りを凊理する + process_batch(True) + + bucket_manager.sort() + for i, reso in enumerate(bucket_manager.resos): + count = bucket_counts.get(reso, 0) + if count > 0: + print(f"bucket {i} {reso}: {count}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error: {np.mean(img_ar_errors)}") + + # metadataを曞き出しお終わり + print(f"writing metadata: {args.out_json}") + with open(args.out_json, "wt", encoding='utf-8') as f: + json.dump(metadata, f, indent=2) + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument("in_json", type=str, help="metadata file to input / 読み蟌むメタデヌタファむル") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデヌタファむル曞き出し先") + parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取埗するためのモデル") + parser.add_argument("--v2", action='store_true', + help='not used (for backward compatibility) / 䜿甚されたせん互換性のため残しおありたす') + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 掚論時のバッチサむズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み蟌みを有効にしおこのワヌカヌ数を適甚する読み蟌みを高速化") + parser.add_argument("--max_resolution", type=str, default="512,512", + help="max resolution in fine tuning (width,height) / fine tuning時の最倧画像サむズ 「幅,高さ」䜿甚メモリ量に関係したす") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像床") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像床") + parser.add_argument("--bucket_reso_steps", type=int, default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像床の単䜍、8で割り切れる倀を掚奚したす") + parser.add_argument("--bucket_no_upscale", action="store_true", + help="make bucket for each image without upscaling / 画像を拡倧せずbucketを䜜成したす") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合粟床を䜿う堎合、その粟床") + parser.add_argument("--full_path", action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデヌタで画像キヌをフルパスにする耇数の孊習画像ディレクトリに察応") + parser.add_argument("--flip_aug", action="store_true", + help="flip augmentation, save latents for flipped images / 巊右反転した画像もlatentを取埗、保存する") + parser.add_argument("--skip_existing", action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存圚する画像をスキップするflip_aug有効時は通垞、反転の䞡方が存圚する画像をスキップ") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py new file mode 100644 index 0000000000000000000000000000000000000000..2286115ec5e7d1e1caad8204f71a74a33ea858c3 --- /dev/null +++ b/finetune/tag_images_by_wd14_tagger.py @@ -0,0 +1,206 @@ +import argparse +import csv +import glob +import os + +from PIL import Image +import cv2 +from tqdm import tqdm +import numpy as np +from tensorflow.keras.models import load_model +from huggingface_hub import hf_hub_download +import torch + +import library.train_util as train_util + +# from wd14 tagger +IMAGE_SIZE = 448 + +# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 +DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' +FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] +SUB_DIR = "variables" +SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] +CSV_FILE = FILES[-1] + + +def preprocess_image(image): + image = np.array(image) + image = image[:, :, ::-1] # RGB->BGR + + # pad to square + size = max(image.shape[0:2]) + pad_x = size - image.shape[1] + pad_y = size - image.shape[0] + pad_l = pad_x // 2 + pad_t = pad_y // 2 + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + + interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + + image = image.astype(np.float32) + return image + + +class ImageLoadingPrepDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + image = preprocess_image(image) + tensor = torch.tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み蟌めたせん: {img_path}, error: {e}") + return None + + return (tensor, img_path) + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + +def main(args): + # hf_hub_downloadをそのたた䜿うずsymlink関係で問題があるらしいので、キャッシュディレクトリずforce_filenameを指定しおなんずかする + # depreacatedの譊告が出るけどなくなったらその時 + # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 + if not os.path.exists(args.model_dir) or args.force_download: + print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + for file in FILES: + hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + for file in SUB_DIR_FILES: + hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( + args.model_dir, SUB_DIR), force_download=True, force_filename=file) + else: + print("using existing wd14 tagger model") + + # 画像を読み蟌む + image_paths = train_util.glob_images(args.train_data_dir) + print(f"found {len(image_paths)} images.") + + print("loading model and labels") + model = load_model(args.model_dir) + + # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") + # 䟝存ラむブラリを増やしたくないので自力で読むよ + with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + reader = csv.reader(f) + l = [row for row in reader] + header = l[0] # tag_id,name,category,count + rows = l[1:] + assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" + + tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、぀たり通垞のタグのみ + + # 掚論する + def run_batch(path_imgs): + imgs = np.array([im for _, im in path_imgs]) + + probs = model(imgs, training=False) + probs = probs.numpy() + + for (image_path, _), prob in zip(path_imgs, probs): + # 最初の4぀はratingなので無芖する + # # First 4 labels are actually ratings: pick one with argmax + # ratings_names = label_names[:4] + # rating_index = ratings_names["probs"].argmax() + # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + + # それ以降はタグなのでconfidenceがthresholdより高いものを远加する + # Everything else is tags: pick any where prediction confidence > threshold + tag_text = "" + for i, p in enumerate(prob[4:]): # numpyずか䜿うのが良いけど、たあそれほど数も倚くないのでルヌプで + if p >= args.thresh and i < len(tags): + tag_text += ", " + tags[i] + + if len(tag_text) > 0: + tag_text = tag_text[2:] # 最初の ", " を消す + + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(tag_text + '\n') + if args.debug: + print(image_path, tag_text) + + # 読み蟌みの高速化のためにDataLoaderを䜿うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingPrepDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is not None: + image = image.detach().numpy() + else: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + image = preprocess_image(image) + except Exception as e: + print(f"Could not load image path / 画像を読み蟌めたせん: {image_path}, error: {e}") + continue + b_imgs.append((image_path, image)) + + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: + run_batch(b_imgs) + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, + help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポゞトリID") + parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", + help="directory to store wd14 tagger model / wd14 taggerのモデルを栌玍するディレクトリ") + parser.add_argument("--force_download", action='store_true', + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダりンロヌドしたす") + parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを远加するか刀定する閟倀") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 掚論時のバッチサむズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み蟌みを有効にしおこのワヌカヌ数を適甚する読み蟌みを高速化") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファむルの拡匵子スペルミスしおいたのを残しおありたす") + parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファむルの拡匵子") + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしおいたオプションを埩元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/finetune_gui.py b/finetune_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..b0859288dbfed07105d5352eb43d765d9b7ea475 --- /dev/null +++ b/finetune_gui.py @@ -0,0 +1,888 @@ +import gradio as gr +import json +import math +import os +import subprocess +import pathlib +import argparse +from library.common_gui import ( + get_folder_path, + get_file_path, + get_saveasfile_path, + save_inference_file, + gradio_advanced_training, + run_cmd_advanced_training, + gradio_training, + run_cmd_advanced_training, + gradio_config, + gradio_source_model, + color_aug_changed, + run_cmd_training, + # set_legacy_8bitadam, + update_my_data, + check_if_model_exist, +) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 + +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + train_text_encoder, + create_caption, + create_buckets, + save_model_as, + caption_extension, + # use_8bit_adam, + xformers, + clip_skip, + save_state, + resume, + gradient_checkpointing, + gradient_accumulation_steps, + mem_eff_attn, + shuffle_caption, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + full_fp16, + color_aug, + model_list, + cache_latents, + use_latent_files, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get('label') == 'True' else False + + if save_as_bool: + print('Save as...') + file_path = get_saveasfile_path(file_path) + else: + print('Save...') + if file_path == None or file_path == '': + file_path = get_saveasfile_path(file_path) + + # print(file_path) + + if file_path == None or file_path == '': + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Return the values of the variables as a dictionary + variables = { + name: value + for name, value in parameters # locals().items() + if name + not in [ + 'file_path', + 'save_as', + ] + } + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + # Save the data to the selected file + with open(file_path, 'w') as file: + json.dump(variables, file, indent=2) + + return file_path + + +def open_configuration( + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + train_text_encoder, + create_caption, + create_buckets, + save_model_as, + caption_extension, + # use_8bit_adam, + xformers, + clip_skip, + save_state, + resume, + gradient_checkpointing, + gradient_accumulation_steps, + mem_eff_attn, + shuffle_caption, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + full_fp16, + color_aug, + model_list, + cache_latents, + use_latent_files, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == '' and not file_path == None: + # load variables from JSON file + with open(file_path, 'r') as f: + my_data = json.load(f) + print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ['ask_for_file', 'file_path']: + values.append(my_data.get(key, value)) + return tuple(values) + + +def train_model( + pretrained_model_name_or_path, + v2, + v_parameterization, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + train_text_encoder, + generate_caption_database, + generate_image_buckets, + save_model_as, + caption_extension, + # use_8bit_adam, + xformers, + clip_skip, + save_state, + resume, + gradient_checkpointing, + gradient_accumulation_steps, + mem_eff_attn, + shuffle_caption, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + full_fp16, + color_aug, + model_list, # Keep this. Yes, it is unused here but required given the common list used + cache_latents, + use_latent_files, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + if check_if_model_exist(output_name, output_dir, save_model_as): + return + + # create caption json file + if generate_caption_database: + if not os.path.exists(train_dir): + os.mkdir(train_dir) + + run_cmd = f'{PYTHON} finetune/merge_captions_to_metadata.py' + if caption_extension == '': + run_cmd += f' --caption_extension=".caption"' + else: + run_cmd += f' --caption_extension={caption_extension}' + run_cmd += f' "{image_folder}"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + if full_path: + run_cmd += f' --full_path' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + # create images buckets + if generate_image_buckets: + run_cmd = f'{PYTHON} finetune/prepare_buckets_latents.py' + run_cmd += f' "{image_folder}"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + run_cmd += f' "{train_dir}/{latent_metadata_filename}"' + run_cmd += f' "{pretrained_model_name_or_path}"' + run_cmd += f' --batch_size={batch_size}' + run_cmd += f' --max_resolution={max_resolution}' + run_cmd += f' --min_bucket_reso={min_bucket_reso}' + run_cmd += f' --max_bucket_reso={max_bucket_reso}' + run_cmd += f' --mixed_precision={mixed_precision}' + # if flip_aug: + # run_cmd += f' --flip_aug' + if full_path: + run_cmd += f' --full_path' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + image_num = len( + [ + f + for f, lower_f in ( + (file, file.lower()) for file in os.listdir(image_folder) + ) + if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) + ] + ) + print(f'image_num = {image_num}') + + repeats = int(image_num) * int(dataset_repeats) + print(f'repeats = {str(repeats)}') + + # calculate max_train_steps + max_train_steps = int( + math.ceil(float(repeats) / int(train_batch_size) * int(epoch)) + ) + + # Divide by two because flip augmentation create two copied of the source images + if flip_aug: + max_train_steps = int(math.ceil(float(max_train_steps) / 2)) + + print(f'max_train_steps = {max_train_steps}') + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + print(f'lr_warmup_steps = {lr_warmup_steps}') + + run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"' + if v2: + run_cmd += ' --v2' + if v_parameterization: + run_cmd += ' --v_parameterization' + if train_text_encoder: + run_cmd += ' --train_text_encoder' + run_cmd += ( + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' + ) + if use_latent_files == 'Yes': + run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"' + else: + run_cmd += f' --in_json="{train_dir}/{caption_metadata_filename}"' + run_cmd += f' --train_data_dir="{image_folder}"' + run_cmd += f' --output_dir="{output_dir}"' + if not logging_dir == '': + run_cmd += f' --logging_dir="{logging_dir}"' + run_cmd += f' --dataset_repeats={dataset_repeats}' + run_cmd += f' --learning_rate={learning_rate}' + + run_cmd += ' --enable_bucket' + run_cmd += f' --resolution={max_resolution}' + run_cmd += f' --min_bucket_reso={min_bucket_reso}' + run_cmd += f' --max_bucket_reso={max_bucket_reso}' + + if not save_model_as == 'same as source model': + run_cmd += f' --save_model_as={save_model_as}' + if int(gradient_accumulation_steps) > 1: + run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' + # if save_state: + # run_cmd += ' --save_state' + # if not resume == '': + # run_cmd += f' --resume={resume}' + if not output_name == '': + run_cmd += f' --output_name="{output_name}"' + if int(max_token_length) > 75: + run_cmd += f' --max_token_length={max_token_length}' + + run_cmd += run_cmd_training( + learning_rate=learning_rate, + lr_scheduler=lr_scheduler, + lr_warmup_steps=lr_warmup_steps, + train_batch_size=train_batch_size, + max_train_steps=max_train_steps, + save_every_n_epochs=save_every_n_epochs, + mixed_precision=mixed_precision, + save_precision=save_precision, + seed=seed, + caption_extension=caption_extension, + cache_latents=cache_latents, + optimizer=optimizer, + optimizer_args=optimizer_args, + ) + + run_cmd += run_cmd_advanced_training( + max_train_epochs=max_train_epochs, + max_data_loader_n_workers=max_data_loader_n_workers, + max_token_length=max_token_length, + resume=resume, + save_state=save_state, + mem_eff_attn=mem_eff_attn, + clip_skip=clip_skip, + flip_aug=flip_aug, + color_aug=color_aug, + shuffle_caption=shuffle_caption, + gradient_checkpointing=gradient_checkpointing, + full_fp16=full_fp16, + xformers=xformers, + # use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + noise_offset=noise_offset, + additional_parameters=additional_parameters, + vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, + ) + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f'{output_dir}/{output_name}') + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file(output_dir, v2, v_parameterization, output_name) + + +def remove_doublequote(file_path): + if file_path != None: + file_path = file_path.replace('"', '') + + return file_path + + +def finetune_tab(): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + gr.Markdown('Train a custom model using kohya finetune python code...') + + ( + button_open_config, + button_save_config, + button_save_as_config, + config_file_name, + button_load_config, + ) = gradio_config() + + ( + pretrained_model_name_or_path, + v2, + v_parameterization, + save_model_as, + model_list, + ) = gradio_source_model() + + with gr.Tab('Folders'): + with gr.Row(): + train_dir = gr.Textbox( + label='Training config folder', + placeholder='folder where the training configuration files will be saved', + ) + train_dir_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + train_dir_folder.click( + get_folder_path, + outputs=train_dir, + show_progress=False, + ) + + image_folder = gr.Textbox( + label='Training Image folder', + placeholder='folder where the training images are located', + ) + image_folder_input_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + image_folder_input_folder.click( + get_folder_path, + outputs=image_folder, + show_progress=False, + ) + with gr.Row(): + output_dir = gr.Textbox( + label='Model output folder', + placeholder='folder where the model will be saved', + ) + output_dir_input_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + output_dir_input_folder.click( + get_folder_path, + outputs=output_dir, + show_progress=False, + ) + + logging_dir = gr.Textbox( + label='Logging folder', + placeholder='Optional: enable logging and output TensorBoard log to this folder', + ) + logging_dir_input_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + logging_dir_input_folder.click( + get_folder_path, + outputs=logging_dir, + show_progress=False, + ) + with gr.Row(): + output_name = gr.Textbox( + label='Model output name', + placeholder='Name of the model to output', + value='last', + interactive=True, + ) + train_dir.change( + remove_doublequote, + inputs=[train_dir], + outputs=[train_dir], + ) + image_folder.change( + remove_doublequote, + inputs=[image_folder], + outputs=[image_folder], + ) + output_dir.change( + remove_doublequote, + inputs=[output_dir], + outputs=[output_dir], + ) + with gr.Tab('Dataset preparation'): + with gr.Row(): + max_resolution = gr.Textbox( + label='Resolution (width,height)', value='512,512' + ) + min_bucket_reso = gr.Textbox( + label='Min bucket resolution', value='256' + ) + max_bucket_reso = gr.Textbox( + label='Max bucket resolution', value='1024' + ) + batch_size = gr.Textbox(label='Batch size', value='1') + with gr.Row(): + create_caption = gr.Checkbox( + label='Generate caption metadata', value=True + ) + create_buckets = gr.Checkbox( + label='Generate image buckets metadata', value=True + ) + use_latent_files = gr.Dropdown( + label='Use latent files', + choices=[ + 'No', + 'Yes', + ], + value='Yes', + ) + with gr.Accordion('Advanced parameters', open=False): + with gr.Row(): + caption_metadata_filename = gr.Textbox( + label='Caption metadata filename', value='meta_cap.json' + ) + latent_metadata_filename = gr.Textbox( + label='Latent metadata filename', value='meta_lat.json' + ) + full_path = gr.Checkbox(label='Use full path', value=True) + with gr.Tab('Training parameters'): + ( + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + num_cpu_threads_per_process, + seed, + caption_extension, + cache_latents, + optimizer, + optimizer_args, + ) = gradio_training(learning_rate_value='1e-5') + with gr.Row(): + dataset_repeats = gr.Textbox(label='Dataset repeats', value=40) + train_text_encoder = gr.Checkbox( + label='Train text encoder', value=True + ) + with gr.Accordion('Advanced parameters', open=False): + with gr.Row(): + gradient_accumulation_steps = gr.Number( + label='Gradient accumulate steps', value='1' + ) + ( + # use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + noise_offset, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ) = gradio_advanced_training() + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latents], # Not applicable to fine_tune.py + ) + + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() + + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + pretrained_model_name_or_path, + v2, + v_parameterization, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + train_text_encoder, + create_caption, + create_buckets, + save_model_as, + caption_extension, + # use_8bit_adam, + xformers, + clip_skip, + save_state, + resume, + gradient_checkpointing, + gradient_accumulation_steps, + mem_eff_attn, + shuffle_caption, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + full_fp16, + color_aug, + model_list, + cache_latents, + use_latent_files, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ] + + button_run.click(train_model, inputs=settings_list) + + button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + +def UI(**kwargs): + + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('Finetune'): + finetune_tab() + with gr.Tab('Utilities'): + utilities_tab(enable_dreambooth_tab=False) + + # Show the interface + launch_kwargs = {} + if not kwargs.get('username', None) == '': + launch_kwargs['auth'] = ( + kwargs.get('username', None), + kwargs.get('password', None), + ) + if kwargs.get('server_port', 0) > 0: + launch_kwargs['server_port'] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + + args = parser.parse_args() + + UI( + username=args.username, + password=args.password, + inbrowser=args.inbrowser, + server_port=args.server_port, + ) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..225de33c80cbe35e78987d0c8fc9572bc7962f38 --- /dev/null +++ b/gen_img_diffusers.py @@ -0,0 +1,3206 @@ +""" +VGG( + (features): Sequential( + (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (1): ReLU(inplace=True) + (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (3): ReLU(inplace=True) + (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (6): ReLU(inplace=True) + (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (8): ReLU(inplace=True) + (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (11): ReLU(inplace=True) + (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (13): ReLU(inplace=True) + (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (15): ReLU(inplace=True) + (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (18): ReLU(inplace=True) + (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (20): ReLU(inplace=True) + (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (22): ReLU(inplace=True) + (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (25): ReLU(inplace=True) + (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (27): ReLU(inplace=True) + (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (29): ReLU(inplace=True) + (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + ) + (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) + (classifier): Sequential( + (0): Linear(in_features=25088, out_features=4096, bias=True) + (1): ReLU(inplace=True) + (2): Dropout(p=0.5, inplace=False) + (3): Linear(in_features=4096, out_features=4096, bias=True) + (4): ReLU(inplace=True) + (5): Dropout(p=0.5, inplace=False) + (6): Linear(in_features=4096, out_features=1000, bias=True) + ) +) +""" + +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from torch import einsum +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo + +from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI + +# Tokenizer: checkpointから読み蟌むのではなくあらかじめ提䟛されおいるものを䜿う +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ䜿う + +DEFAULT_TOKEN_LENGTH = 75 + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の蚭定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +# CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336" + +# CLIP guided SD関連 +CLIP_MODEL_PATH = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" +FEATURE_EXTRACTOR_SIZE = (224, 224) +FEATURE_EXTRACTOR_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073] +FEATURE_EXTRACTOR_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711] + +VGG16_IMAGE_MEAN = [0.485, 0.456, 0.406] +VGG16_IMAGE_STD = [0.229, 0.224, 0.225] +VGG16_INPUT_RESIZE_DIV = 4 + +# CLIP特城量の取埗時にcutoutを䜿うか䜿う堎合には゜ヌスを曞き換えおください +NUM_CUTOUTS = 4 +USE_CUTOUTS = False + +# region モゞュヌル入れ替え郚 +""" +高速化のためのモゞュヌル入れ替え +""" + +# FlashAttentionを䜿うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがむンストヌルされおいないようです") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを遞んでくれる + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers + + +# endregion + +# region 画像生成の本䜓lpw_stable_diffusion.py ASLからコピヌしお修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立しお䜿えないのず機胜远加するのずでコピヌしお修正 + + +class PipelineLike: + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + clip_model: CLIPModel, + clip_guidance_scale: float, + clip_image_guidance_scale: float, + vgg16_model: torchvision.models.VGG, + vgg16_guidance_scale: float, + vgg16_layer_no: int, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.unet = unet + self.scheduler = scheduler + self.safety_checker = None + + # Textual Inversion + self.token_replacements = {} + + # XTI + self.token_replacements_XTI = {} + + # CLIP guidance + self.clip_guidance_scale = clip_guidance_scale + self.clip_image_guidance_scale = clip_image_guidance_scale + self.clip_model = clip_model + self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD) + self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE) + + # VGG16 guidance + self.vgg16_guidance_scale = vgg16_guidance_scale + if self.vgg16_guidance_scale > 0.0: + return_layers = {f"{vgg16_layer_no}": "feat"} + self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter( + vgg16_model.features, return_layers=return_layers + ) + self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] + + # Textual Inversion + def add_token_replacement(self, target_token_id, rep_token_ids): + self.token_replacements[target_token_id] = rep_token_ids + + def replace_token(self, tokens, layer=None): + new_tokens = [] + for token in tokens: + if token in self.token_replacements: + replacer_ = self.token_replacements[token] + if layer: + replacer = [] + for r in replacer_: + if r in self.token_replacements_XTI: + replacer.append(self.token_replacements_XTI[r][layer]) + else: + replacer = replacer_ + new_tokens.extend(replacer) + else: + new_tokens.append(token) + return new_tokens + + def add_token_replacement_XTI(self, target_token_id, rep_token_ids): + self.token_replacements_XTI[target_token_id] = rep_token_ids + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + # region xformersずか䜿う郚分独自に曞き換えるので関係なし + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + # accelerateが必芁になるのでずりあえず省略 + raise NotImplementedError("cpu_offload is omitted.") + # if is_accelerate_available(): + # from accelerate import cpu_offload + # else: + # raise ImportError("Please install accelerate via `pip install accelerate`") + + # device = self.device + + # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + # if cpu_offloaded_model is not None: + # cpu_offload(cpu_offloaded_model, device) + + # endregion + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_prompts=None, + clip_guide_images=None, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + num_images_per_prompt = 1 # fixed + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + if not self.token_replacements_XTI: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, # こちらのトヌクン長に合わせおuncondを䜜るので75トヌクン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) + + if self.token_replacements_XTI: + text_embeddings_concat = [] + for layer in [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ]: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + layer=layer, + **kwargs, + ) + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings])) + else: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])) + text_embeddings = torch.stack(text_embeddings_concat) + else: + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + # CLIP guidanceで䜿甚するembeddingsを取埗する + if self.clip_guidance_scale > 0: + clip_text_input = prompt_tokens + if clip_text_input.shape[1] > self.tokenizer.model_max_length: + # TODO 75文字を超えたら譊告を出す + print("trim text input", clip_text_input.shape) + clip_text_input = torch.cat( + [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 + ) + print("trimmed", clip_text_input.shape) + + for i, clip_prompt in enumerate(clip_prompts): + if clip_prompt is not None: # clip_promptがあれば䞊曞きする + clip_text_input[i] = self.tokenizer( + clip_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(self.device) + + text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) + text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt耇数件でもOK + + if ( + self.clip_image_guidance_scale > 0 + or self.vgg16_guidance_scale > 0 + and clip_guide_images is not None + or self.control_nets + ): + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + + if self.clip_image_guidance_scale > 0: + clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) + + clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + if len(image_embeddings_clip) == 1: + image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) + elif self.vgg16_guidance_scale > 0: + size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # ずりあえず1/4に小さいか? + clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) + + clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)["feat"] + if len(image_embeddings_vgg16) == 1: + image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) + else: + # ControlNetのhintにguide imageを流甚する + # 前凊理はControlNet偎で行う + pass + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + init_latent_dist = self.vae.encode( + init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = 0.18215 * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.control_nets: + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_embeddings, + ).sample + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # perform clip guidance + if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(num_latent_input)[1] if do_classifier_free_guidance else text_embeddings + ) + + if self.clip_guidance_scale > 0: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + text_embeddings_clip, + self.clip_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.clip_image_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + image_embeddings_clip, + self.clip_image_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.vgg16_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn_vgg16( + latents, t, i, text_embeddings_for_guidance, noise_pred, image_embeddings_vgg16, self.vgg16_guidance_scale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return (latents, False) + + latents = 1 / 0.18215 * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + # if not return_dict: + return (image, has_nsfw_concept) + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def img2img( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for image-to-image generation. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def inpaint( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for inpaint. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + # CLIP guidance StableDiffusion + # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py + + # バッチを分解しお1件ず぀凊理する + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + if len(latents) == 1: + return self.cond_fn1( + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts, + ) + + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings_clip[i].unsqueeze(0) + npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) + noise_pred.append(npr1) + cond_latents.append(cla1) + + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents + + @torch.enable_grad() + def cond_fn1( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + latents = latents.detach().requires_grad_() + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + + if use_cutouts: + image = self.make_cutouts(image, num_cutouts) + else: + image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image) + image = self.normalize(image).to(latents.dtype) + + image_embeddings_clip = self.clip_model.get_image_features(image) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + if use_cutouts: + dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip) + dists = dists.view([num_cutouts, sample.shape[0], -1]) + loss = dists.sum(2).mean(0).sum() * clip_guidance_scale + else: + # バッチサむズが耇数だず正しく動くかわからない + loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale + + grads = -torch.autograd.grad(loss, latents)[0] + + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents + + # バッチを分解しお䞀件ず぀凊理する + def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + if len(latents) == 1: + return self.cond_fn_vgg16_b1( + latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale + ) + + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings[i].unsqueeze(0) + npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale) + noise_pred.append(npr1) + cond_latents.append(cla1) + + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents + + # 1件だけ凊理する + @torch.enable_grad() + def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + latents = latents.detach().requires_grad_() + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image) + image = self.vgg16_normalize(image).to(latents.dtype) + + image_embeddings = self.vgg16_feat_model(image)["feat"] + + # バッチサむズが耇数だず正しく動くかわからない + loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損倱はMSEなので + + grads = -torch.autograd.grad(loss, latents)[0] + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents + + +class MakeCutouts(torch.nn.Module): + def __init__(self, cut_size, cut_power=1.0): + super().__init__() + + self.cut_size = cut_size + self.cut_power = cut_power + + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +def spherical_dist_loss(x, y): + x = torch.nn.functional.normalize(x, dim=-1) + y = torch.nn.functional.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + + token = pipe.replace_token(token, layer=layer) + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe: PipelineLike, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最埌に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであずはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + if clip_skip is None or clip_skip == 1: + text_embeddings = pipe.text_encoder(text_input)[0] + else: + enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) + return text_embeddings + + +def get_weighted_text_embeddings( + pipe: PipelineLike, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + layer=None, + **kwargs, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `1`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →党䜓でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings, prompt_tokens + return text_embeddings, None, prompt_tokens + + +def preprocess_guide_image(image): + image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # cond_fnず合わせる + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 + + +# VGG16の入力は任意サむズでよいので入力画像を適宜リサむズする +def preprocess_vgg16_guide_image(image, size): + image = image.resize(size, resample=Image.NEAREST) # cond_fnず合わせる + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# endregion + + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必芁ないデヌタ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + + +class BatchDataExt(NamedTuple): + # バッチ分割が必芁なデヌタ + width: int + height: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgず同時に䜿えたせん" + + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを䜿甚するこずは想定されおいたせん") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを䜿甚するこずは想定されおいたせん") + + # モデルを読み蟌む + if not os.path.isfile(args.ckpt): # ファむルがないならパタヌンで探し、䞀぀だけ該圓すればそれを䜿う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + use_stable_diffusion_format = os.path.isfile(args.ckpt) + if use_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + print("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe + + # VAEを読み蟌む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + print("additional VAE loaded") + + # # 眮換するCLIPを読み蟌む + # if args.replace_clip_l14_336: + # text_encoder = load_clip_l14_336(dtype) + # print(f"large clip {CLIP_ID_L14_336} is loaded") + + if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: + print("prepare clip model") + clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) + else: + clip_model = None + + if args.vgg16_guidance_scale > 0.0: + print("prepare resnet model") + vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) + else: + vgg16_model = None + + # xformers、Hypernetwork察応 + if not args.diffusers_xformers: + replace_unet_modules(unet, not args.xformers, args.xformers) + + # tokenizerを読み蟌む + print("loading tokenizer") + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) + + # schedulerを甚意する + sched_init_args = {} + scheduler_num_noises_per_step = 1 + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから倖しおある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + # samplerの乱数をあらかじめ指定するための凊理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + print("set clip_sample to True") + scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量しおない + + # custom pipelineをコピったや぀を生成する + vae.to(dtype).to(device) + text_encoder.to(dtype).to(device) + unet.to(dtype).to(device) + if clip_model is not None: + clip_model.to(dtype).to(device) + if vgg16_model is not None: + vgg16_model.to(dtype).to(device) + + # networkを組み蟌む + if args.network_module: + networks = [] + network_default_muls = [] + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoder, unet, **net_kwargs + ) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return + + if not args.network_merge: + network.apply_to(text_encoder, unet) + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) + else: + network.merge_to(text_encoder, unet, dtype, device) + + else: + networks = [] + + # ControlNetの凊理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + if args.opt_channels_last: + print(f"set optimizing: channels last") + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if clip_model is not None: + clip_model.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + if vgg16_model is not None: + vgg16_model.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.unet.to(memory_format=torch.channels_last) + cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + args.clip_skip, + clip_model, + args.clip_guidance_scale, + args.clip_image_guidance_scale, + vgg16_model, + args.vgg16_guidance_scale, + args.vgg16_guidance_layer, + ) + pipe.set_control_nets(control_nets) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + if args.XTI_embeddings: + diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + + # Textual Inversionを凊理する + if args.textual_inversion_embeddings: + token_ids_embeds = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + embeds = next(iter(data.values())) + + if type(embeds) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファむルのデヌタがTensorではありたせん: {embeds_file}") + + num_vectors_per_token = embeds.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前ファむル名のトヌクンが既に存圚したす。ファむルをリネヌムしおください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_ids_embeds.append((token_ids, embeds)) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + + if args.XTI_embeddings: + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] + token_ids_embeds_XTI = [] + for embeds_file in args.XTI_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + if set(data.keys()) != set(XTI_layers): + raise ValueError("NOT XTI") + embeds = torch.concat(list(data.values())) + num_vectors_per_token = data["MID"].size()[0] + + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前ファむル名のトヌクンが既に存圚したす。ファむルをリネヌムしおください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + + # if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_strings_XTI = [] + for layer_name in XTI_layers: + token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] + tokenizer.add_tokens(token_strings_XTI) + token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) + token_ids_embeds_XTI.append((token_ids_XTI, embeds)) + for t in token_ids: + t_XTI_dic = {} + for i, layer_name in enumerate(XTI_layers): + t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens + pipe.add_token_replacement_XTI(t, t_XTI_dic) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds_XTI: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + + # promptを取埗する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前凊理、画像の読み蟌みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない堎合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありたせん: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + else: + init_images = None + + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありたせん: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないずき、画像のPngInfoから取埗する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' meta data") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトず画像を䞀臎させるため指定回数だけ繰り返す画像を増幅する + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サむズにオプション指定があるずきはリサむズする + if args.W is not None and args.H is not None: + if init_images is not None: + print(f"resize img2img source images to {args.W}*{args.H}") + init_images = resize_images(init_images, (args.W, args.H)) + if mask_images is not None: + print(f"resize img2img mask images to {args.W}*{args.H}") + mask_images = resize_images(mask_images, (args.W, args.H)) + + if networks and mask_images: + # mask を領域情報ずしお流甚する、珟圚は1枚だけ察応 + # TODO 耇数のnetwork classの混圚時の考慮 + print("use mask as region") + # import cv2 + # for i in range(3): + # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) + # cv2.waitKey() + # cv2.destroyAllWindows() + networks[0].__class__.set_regions(networks, np.array(mask_images[0])) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print(f"No guide image, use previous generated image. / ガむド画像がありたせん。盎前に生成した画像を䜿いたす: {args.image_path}") + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めおおく + if args.seed is not None: + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + # デフォルト画像サむズを蚭定するimg2imgではこれらの倀は無芖されるたたはW*Hにリサむズ枈み + if args.W is None: + args.W = 512 + if args.H is None: + args.H = 512 + + # 画像生成のルヌプ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ凊理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの凊理 + if highres_fix and not highres_1st: + # 1st stageのバッチを䜜成しお呌び出すサむズを小さくしお呌び出す + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + width_1st = int(ext.width * args.highres_fix_scale + 0.5) + height_1st = int(ext.height * args.highres_fix_scale + 0.5) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + ext_1st = BatchDataExt( + width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls + ) + batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを䜜成しお以䞋凊理する + print("process 2nd stage") + if args.highres_fix_latents_upscaling: + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポヌトしおいない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + if not args.highres_fix_latents_upscaling: + image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgずしお蚭定 + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image), + (width, height, steps, scale, negative_scale, strength, network_muls), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の䜍眮に関わらず同じ乱数を䜿うためにここで乱数を生成しおおく。あわせおimage/maskがbatch内で同䞀かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべおの画像が同じなら1枚だけpipeに枡すこずでpipe偎で凊理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet䜿甚時はguide imageをリサむズする + if control_nets: + # TODO resampleのメ゜ッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + )[0] + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts) + ): + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いず死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print("opencv-python is not installed, cannot preview / opencv-pythonがむンストヌルされおいないためプレビュヌできたせん") + + return images + + # 画像生成のプロンプトが䞀呚するたでのルヌプ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + prompt = input() + except EOFError: + break + + valid = len(prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + prompt = prompt_list[prompt_index] + + # parse prompt + width = args.W + height = args.H + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + prompt_args = prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析゚ラヌ: {parg}") + print(ex) + + if seeds is not None: + # 数が足りないなら繰り返す + if len(seeds) < args.images_per_prompt: + seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) + seeds = seeds[: args.images_per_prompt] + else: + if predefined_seeds is not None: + seeds = predefined_seeds[-args.images_per_prompt :] + predefined_seeds = predefined_seeds[: -args.images_per_prompt] + elif args.iter_same_seed: + seeds = [iter_seed] * args.images_per_prompt + else: + seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] + if args.interactive: + print(f"seed: {seeds}") + + init_image = mask_image = guide_image = None + for seed in seeds: # images_per_promptの数だけ + # 同䞀むメヌゞを䜿うずき、本圓はlatentに倉換しおおくず無駄がないが面倒なのでずりあえず毎回凊理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # 32単䜍に䞞めたや぀にresizeされるので螏襲する + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サむズが32で割り切れないためリサむズされたす。画像が歪みたす" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 耇数件の堎合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: + if prev_image is None: + print("Generate 1st image without guide image.") + else: + print("Use previous image as guide image.") + guide_image = prev_image + + b1 = BatchData( + False, + BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt( + width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必芁 + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み蟌む") + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization孊習を有効にする" + ) + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファむルから読み蟌む" + ) + parser.add_argument( + "--interactive", action="store_true", help="interactive mode (generates one image) / 察話モヌド生成される画像は1枚になりたす" + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 察話モヌドで画像を衚瀺しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgたたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファむル名を連番にする") + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファむル名を生成画像のファむル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサむズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE凊理時のバッチサむズ、1未満の倀の堎合は通垞バッチサむズの比率", + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラヌスケゞュヌラの皮類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファむルたたはディレクトリ") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える堎合、VAEのcheckpointファむルたたはディレクトリ" + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの孊習のため", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、たたは耇数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないずき繰り返し内はすべお同じseedを䜿うプロンプト間の差異の比范甚", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを䜿甚し高速化する") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを䜿甚するHypernetwork利甚䞍可", + ) + parser.add_argument( + "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + ) + parser.add_argument( + "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 远加ネットワヌクを䜿う時そのモゞュヌル名" + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 远加ネットワヌクの重み" + ) + parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 远加ネットワヌクの効果の倍率") + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワヌクぞの远加の匕数" + ) + parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワヌクモデルのメタデヌタを衚瀺する") + parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワヌクの重みをマヌゞする") + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument( + "--XTI_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", + ) + parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの埌ろからn局目の出力を䜿う") + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embeding multiples, max token length is 75 * multiples / トヌクン長をデフォルトの䜕倍ずするか 75*この倀 がトヌクン長ずなる", + ) + parser.add_argument( + "--clip_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only) / CLIP guided SDを有効にしおこのscaleを適甚するサンプラヌはDDIM、PNDM、LMSのみ", + ) + parser.add_argument( + "--clip_image_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD by image, scale for guidance / 画像によるCLIP guided SDを有効にしおこのscaleを適甚する", + ) + parser.add_argument( + "--vgg16_guidance_scale", + type=float, + default=0.0, + help="enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしおこのscaleを適甚する", + ) + parser.add_argument( + "--vgg16_guidance_layer", + type=int, + default=20, + help="layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに䜿うレむダヌ番号 (1~30、20はconv4_2)", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガむドに䜿う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にしお最初の解像床をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステヌゞのステップ数" + ) + parser.add_argument( + "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステヌゞの画像を保存する" + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡倧する", + ) + parser.add_argument( + "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + ) + + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 䜿甚するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 䜿甚するControlNetのプリプロセス名" + ) + parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガむドするステップ比率", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/gui.sh b/gui.sh new file mode 100644 index 0000000000000000000000000000000000000000..b780839a7e1dad01f664bdeeeb2ada6bf9fb9de7 --- /dev/null +++ b/gui.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +# Activate the virtual environment +source ./venv/bin/activate + +# If the requirements are validated, run the kohya_gui.py script with the command-line arguments +if python tools/validate_requirements.py; then + python kohya_gui.py "$@" +fi \ No newline at end of file diff --git a/kohya_gui.py b/kohya_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e0d8ca38b94bc7b8bb3ae36f89048641739ba3 --- /dev/null +++ b/kohya_gui.py @@ -0,0 +1,110 @@ +import gradio as gr +import os +import argparse +from dreambooth_gui import dreambooth_tab +from finetune_gui import finetune_tab +from textual_inversion_gui import ti_tab +from library.utilities import utilities_tab +from library.extract_lora_gui import gradio_extract_lora_tab +from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab +from library.merge_lora_gui import gradio_merge_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab +from lora_gui import lora_tab + + +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css, title='Kohya_ss GUI') + + with interface: + with gr.Tab('Dreambooth'): + ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) = dreambooth_tab() + with gr.Tab('Dreambooth LoRA'): + lora_tab() + with gr.Tab('Dreambooth TI'): + ti_tab() + with gr.Tab('Finetune'): + finetune_tab() + with gr.Tab('Utilities'): + utilities_tab( + train_data_dir_input=train_data_dir_input, + reg_data_dir_input=reg_data_dir_input, + output_dir_input=output_dir_input, + logging_dir_input=logging_dir_input, + enable_copy_info_button=True, + ) + gradio_extract_lora_tab() + gradio_extract_lycoris_locon_tab() + gradio_merge_lora_tab() + gradio_resize_lora_tab() + + # Show the interface + launch_kwargs = {} + username = kwargs.get('username') + password = kwargs.get('password') + server_port = kwargs.get('server_port', 0) + inbrowser = kwargs.get('inbrowser', False) + share = kwargs.get('share', False) + server_name = kwargs.get('listen') + + launch_kwargs['server_name'] = server_name + if username and password: + launch_kwargs['auth'] = (username, password) + if server_port > 0: + launch_kwargs['server_port'] = server_port + if inbrowser: + launch_kwargs['inbrowser'] = inbrowser + if share: + launch_kwargs['share'] = share + interface.launch(**launch_kwargs) + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument( + '--listen', + type=str, + default='127.0.0.1', + help='IP to listen on for connections to Gradio', + ) + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + parser.add_argument( + '--share', action='store_true', help='Share the gradio UI' + ) + + args = parser.parse_args() + + UI( + username=args.username, + password=args.password, + inbrowser=args.inbrowser, + server_port=args.server_port, + share=args.share, + listen=args.listen, + ) diff --git a/kohya_ss_colab.ipynb b/kohya_ss_colab.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..641f50e96f81fc4f0c4faf4687885355bf567403 --- /dev/null +++ b/kohya_ss_colab.ipynb @@ -0,0 +1,448 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MvroZ9rJ1iqN" + }, + "source": [ + "# Kohya SS WebUI Colab Setup\n", + "\n", + "This Colab workbook sets up a Kohya SS instance on Colab and provides a link to access the Kohya WebUI on Gradio Live. Kohya SS is a Python library that provides Stable Diffusion-based models for image, text, and audio generation tasks. This Colab workbook provides a convenient way for users to run Kohya SS without needing to install anything on their local machine.\n", + "\n", + "This workbook was inspired by the work of [Spaceginner](https://github.com/Spaceginner)'s original Colab workbook and the [Kohya SS project](https://github.com/bmaltais/kohya_ss) by [bmaltais](https://github.com/bmaltais). The Colab workbook was coded by [panguin6010](https://github.com/panguin6010) \n", + "\n", + "\n", + "## Tutorials\n", + "\n", + "Before running this code, make sure you are familiar with using Colab workbooks and have a basic understanding of Kohya SS and its usage. You can find tutorials for these online. If you encounter any issues or have suggestions for improvement, feel free to contribute to the project.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DrAnm1um5vjh" + }, + "source": [ + "\n", + "\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vmoRnFQEqOeY", + "outputId": "09876c9a-d043-4881-d92f-6ed54313c390" + }, + "outputs": [], + "source": [ + "#@markdown #Step 1: Mounting Google Drive\n", + "\n", + "#@markdown The first step in setting up Kohya SS on Colab is to mount your Google Drive to the Colab notebook. This allows you to save and access files from your Google Drive in the Colab notebook.\n", + "\n", + "#@markdown To mount your Google Drive, run the following code block, which mounts your Google Drive to the /content/gdrive directory in the Colab notebook.\n", + "\n", + "\n", + "\n", + "from google.colab import drive\n", + "drive.mount('/content/gdrive')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mvQwnr4354BM" + }, + "source": [ + "\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "7ca7f6f727da46ac9a1149e69c16c81f", + "77e5e07552b641cf9c368fb3939cb1d1", + "235e01b92646444387ebd31ab945358e" + ] + }, + "id": "jnhm7ycMrLWb", + "outputId": "63ba39ed-90c6-4b2d-f03e-61775587b083" + }, + "outputs": [], + "source": [ + "#@markdown #Kohya SS WebUI Installation\n", + "\n", + "#@markdown Now that your Google Drive is linked, we need to install the Kohya SS WebUI.\n", + "\n", + "#@markdown The code clones the [Kohya SS Google Colab](\"https://github.com/panguin6010/kohya_ss_google_colab\") repository and creates the necessary directories for Kohya SS to run. It then resets the git repository and pulls the latest changes. Finally, it displays a success message.\n", + "\n", + "#@markdown Note: If Google Drive is not connected, the code will use Colab storage instead.\n", + "\n", + "#@title\n", + "# Import necessary libraries\n", + "from IPython.display import clear_output\n", + "from IPython.utils import capture\n", + "from subprocess import getoutput\n", + "import ipywidgets as widgets\n", + "import sys\n", + "import fileinput\n", + "import os\n", + "import time\n", + "\n", + "# WebUI Installation\n", + "\n", + "# Check if Google Drive is connected\n", + "if not os.path.exists(\"/content/gdrive/MyDrive/\"):\n", + " print(\"Gdrive not connected, using colab storage ...\")\n", + " time.sleep(4)\n", + " !mkdir -p /content/gdrive/MyDrive/\n", + "\n", + "# Clone the repository and create necessary directories\n", + "with capture.capture_output() as cap:\n", + " def inf(msg, style, wdth):\n", + " inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth))\n", + " display(inf)\n", + "\n", + " %mkdir -p /content/gdrive/MyDrive/sd\n", + " %cd /content/gdrive/MyDrive/sd\n", + " !git clone https://github.com/panguin6010/kohya_ss_google_colab kohya_ss_colab\n", + " !mkdir -p /content/gdrive/MyDrive/sd/kohya_ss_colab/cache/huggingface\n", + " !ln -s /content/gdrive/MyDrive/sd/kohya_ss_colab/cache/huggingface /root/.cache/\n", + "\n", + "# Reset the git repository and pull the latest changes\n", + "with capture.capture_output() as cap:\n", + " %cd /content/gdrive/MyDrive/sd/kohya_ss_colab/\n", + " !git reset --hard\n", + " time.sleep(1)\n", + "\n", + "print(\"Updating the repository...\")\n", + "!git pull\n", + "\n", + "# Clear the output and display the success message\n", + "clear_output()\n", + "inf(\"✓ Done\", \"success\", \"50px\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8SrMhmFz7Lt4" + }, + "source": [ + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "54e929bcb37e4997a696d0becdecfd84", + "43fbca3abb04401296967f819680f94f", + "6d87b2c916394932b1a53382fe3cdb4e" + ] + }, + "id": "yjvkHRlDtDmV", + "outputId": "06e1e873-b1ed-4403-c9a4-19ac1caa961b" + }, + "outputs": [], + "source": [ + "#@markdown #Requirements Installation\n", + "\n", + "#@markdown Now that we have downloaded the Kohya SS WebUI, we need to install the necessary requirements.\n", + "\n", + "# Print the status message\n", + "print(\"Installing requirements...\")\n", + "\n", + "# Change the working directory to the project folder\n", + "%cd /content/gdrive/MyDrive/sd/kohya_ss_colab/\n", + "\n", + "# Install the requirements\n", + "with capture.capture_output() as cap:\n", + " # Uncomment the following line if you need to install specific versions of torch and torchvision\n", + " # !pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116\n", + " \n", + " # Install the requirements from the requirements.txt file\n", + " !pip install -r requirements.txt\n", + "\n", + "# Clear the output to keep the notebook clean\n", + "clear_output()\n", + "\n", + "# Print the success message\n", + "inf(\"✓ Done\", \"success\", \"50px\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FLDvlHm1tYud" + }, + "source": [ + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IzS3hvuTtTqW", + "outputId": "9e629e1f-c8eb-43a2-9639-2583937ba81a" + }, + "outputs": [], + "source": [ + "#@markdown # Start Kohya ss WebUI\n", + "\n", + "User = \"\" #@param {type:\"string\"}\n", + "Password = \"\" #@param {type:\"string\"}\n", + "\n", + "#@markdown - Adding a username and password is not necessary but it will improve the security of your Kohya instance.\n", + "#@markdown ______\n", + "#@markdown # Please click the link that concludes with ```gradio.live``` to access your instance\n", + "# Encourage users to contribute improvements\n", + "print(\"Please feel free to make any changes or improvements you think would enhance this setup. Your input and contributions are greatly appreciated!\")\n", + "# Check if the user has provided a username and password\n", + "if User and Password:\n", + " # Run the Kohya GUI with the provided credentials\n", + " !python /content/gdrive/MyDrive/sd/kohya_ss_colab/kohya_gui.py --username $User --password $Password --share \n", + "else:\n", + " # Run the Kohya GUI without credentials\n", + " !python /content/gdrive/MyDrive/sd/kohya_ss_colab/kohya_gui.py --share \n" + ] + } + ], + "metadata": { + "colab": { + "authorship_tag": "ABX9TyOZmOjfS55zOBmbTmRNOf3b", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "235e01b92646444387ebd31ab945358e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "43fbca3abb04401296967f819680f94f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": "50px", + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "54e929bcb37e4997a696d0becdecfd84": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "success", + "description": "✓ Done", + "disabled": true, + "icon": "", + "layout": "IPY_MODEL_43fbca3abb04401296967f819680f94f", + "style": "IPY_MODEL_6d87b2c916394932b1a53382fe3cdb4e", + "tooltip": "" + } + }, + "6d87b2c916394932b1a53382fe3cdb4e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "77e5e07552b641cf9c368fb3939cb1d1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": "50px", + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7ca7f6f727da46ac9a1149e69c16c81f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "success", + "description": "✓ Done", + "disabled": true, + "icon": "", + "layout": "IPY_MODEL_77e5e07552b641cf9c368fb3939cb1d1", + "style": "IPY_MODEL_235e01b92646444387ebd31ab945358e", + "tooltip": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/library/__init__.py b/library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d208d1ef7ea51140d1ab9520ba5b8a1a57dc84 --- /dev/null +++ b/library/basic_caption_gui.py @@ -0,0 +1,140 @@ +import gradio as gr +from easygui import msgbox +import subprocess +from .common_gui import get_folder_path, add_pre_postfix, find_replace +import os + + +def caption_images( + caption_text, + images_dir, + overwrite, + caption_ext, + prefix, + postfix, + find_text, + replace_text, +): + # Check for images_dir + if not images_dir: + msgbox('Image folder is missing...') + return + + if not caption_ext: + msgbox('Please provide an extension for the caption files.') + return + + if caption_text: + print(f'Captioning files in {images_dir} with {caption_text}...') + run_cmd = f'python "tools/caption.py"' + run_cmd += f' --caption_text="{caption_text}"' + if overwrite: + run_cmd += f' --overwrite' + if caption_ext: + run_cmd += f' --caption_file_ext="{caption_ext}"' + run_cmd += f' "{images_dir}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + if overwrite: + if prefix or postfix: + # Add prefix and postfix + add_pre_postfix( + folder=images_dir, + caption_file_ext=caption_ext, + prefix=prefix, + postfix=postfix, + ) + if find_text: + find_replace( + folder_path=images_dir, + caption_file_ext=caption_ext, + search_text=find_text, + replace_text=replace_text, + ) + else: + if prefix or postfix: + msgbox( + 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' + ) + + print('...captioning done') + + +# Gradio UI +def gradio_basic_caption_gui_tab(): + with gr.Tab('Basic Captioning'): + gr.Markdown( + 'This utility will allow the creation of simple caption files for each image in a folder.' + ) + with gr.Row(): + images_dir = gr.Textbox( + label='Image folder to caption', + placeholder='Directory containing the images to caption', + interactive=True, + ) + folder_button = gr.Button('📂', elem_id='open_folder_small') + folder_button.click( + get_folder_path, + outputs=images_dir, + show_progress=False, + ) + caption_ext = gr.Textbox( + label='Caption file extension', + placeholder='Extension for caption file. eg: .caption, .txt', + value='.txt', + interactive=True, + ) + overwrite = gr.Checkbox( + label='Overwrite existing captions in folder', + interactive=True, + value=False, + ) + with gr.Row(): + prefix = gr.Textbox( + label='Prefix to add to caption', + placeholder='(Optional)', + interactive=True, + ) + caption_text = gr.Textbox( + label='Caption text', + placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', + interactive=True, + ) + postfix = gr.Textbox( + label='Postfix to add to caption', + placeholder='(Optional)', + interactive=True, + ) + with gr.Row(): + find_text = gr.Textbox( + label='Find text', + placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', + interactive=True, + ) + replace_text = gr.Textbox( + label='Replacement text', + placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing', + interactive=True, + ) + caption_button = gr.Button('Caption images') + caption_button.click( + caption_images, + inputs=[ + caption_text, + images_dir, + overwrite, + caption_ext, + prefix, + postfix, + find_text, + replace_text, + ], + show_progress=False, + ) diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0081ddc438f41129501cf1e5d435c6d6dcbdec --- /dev/null +++ b/library/blip_caption_gui.py @@ -0,0 +1,149 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import get_folder_path, add_pre_postfix + +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def caption_images( + train_data_dir, + caption_file_ext, + batch_size, + num_beams, + top_p, + max_length, + min_length, + beam_search, + prefix, + postfix, +): + # Check for caption_text_input + # if caption_text_input == "": + # msgbox("Caption text is missing...") + # return + + # Check for images_dir_input + if train_data_dir == '': + msgbox('Image folder is missing...') + return + + if caption_file_ext == '': + msgbox('Please provide an extension for the caption files.') + return + + print(f'Captioning files in {train_data_dir}...') + run_cmd = f'{PYTHON} "finetune/make_captions.py"' + run_cmd += f' --batch_size="{int(batch_size)}"' + run_cmd += f' --num_beams="{int(num_beams)}"' + run_cmd += f' --top_p="{top_p}"' + run_cmd += f' --max_length="{int(max_length)}"' + run_cmd += f' --min_length="{int(min_length)}"' + if beam_search: + run_cmd += f' --beam_search' + if caption_file_ext != '': + run_cmd += f' --caption_extension="{caption_file_ext}"' + run_cmd += f' "{train_data_dir}"' + run_cmd += f' --caption_weights="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + # Add prefix and postfix + add_pre_postfix( + folder=train_data_dir, + caption_file_ext=caption_file_ext, + prefix=prefix, + postfix=postfix, + ) + + print('...captioning done') + + +### +# Gradio UI +### + + +def gradio_blip_caption_gui_tab(): + with gr.Tab('BLIP Captioning'): + gr.Markdown( + 'This utility will use BLIP to caption files for each images in a folder.' + ) + with gr.Row(): + train_data_dir = gr.Textbox( + label='Image folder to caption', + placeholder='Directory containing the images to caption', + interactive=True, + ) + button_train_data_dir_input = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_train_data_dir_input.click( + get_folder_path, + outputs=train_data_dir, + show_progress=False, + ) + with gr.Row(): + caption_file_ext = gr.Textbox( + label='Caption file extension', + placeholder='Extention for caption file. eg: .caption, .txt', + value='.txt', + interactive=True, + ) + + prefix = gr.Textbox( + label='Prefix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) + + postfix = gr.Textbox( + label='Postfix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) + + batch_size = gr.Number( + value=1, label='Batch size', interactive=True + ) + + with gr.Row(): + beam_search = gr.Checkbox( + label='Use beam search', interactive=True, value=True + ) + num_beams = gr.Number( + value=1, label='Number of beams', interactive=True + ) + top_p = gr.Number(value=0.9, label='Top p', interactive=True) + max_length = gr.Number( + value=75, label='Max length', interactive=True + ) + min_length = gr.Number( + value=5, label='Min length', interactive=True + ) + + caption_button = gr.Button('Caption images') + + caption_button.click( + caption_images, + inputs=[ + train_data_dir, + caption_file_ext, + batch_size, + num_beams, + top_p, + max_length, + min_length, + beam_search, + prefix, + postfix, + ], + show_progress=False, + ) diff --git a/library/common_gui.py b/library/common_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..b08ac9cd49936de7f0dc9ee6a7983a6b6a2a570a --- /dev/null +++ b/library/common_gui.py @@ -0,0 +1,978 @@ +from tkinter import filedialog, Tk +from easygui import msgbox +import os +import gradio as gr +import easygui +import shutil + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 + +# define a list of substrings to search for v2 base models +V2_BASE_MODELS = [ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', +] + +# define a list of substrings to search for v_parameterization models +V_PARAMETERIZATION_MODELS = [ + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', +] + +# define a list of substrings to v1.x models +V1_MODELS = [ + 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5', +] + +# define a list of substrings to search for +ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + +FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] + + +def check_if_model_exist(output_name, output_dir, save_model_as): + if save_model_as in ['diffusers', 'diffusers_safetendors']: + ckpt_folder = os.path.join(output_dir, output_name) + if os.path.isdir(ckpt_folder): + msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?' + if not easygui.ynbox(msg, 'Overwrite Existing Model?'): + print( + 'Aborting training due to existing model with same name...' + ) + return True + elif save_model_as in ['ckpt', 'safetensors']: + ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as) + if os.path.isfile(ckpt_file): + msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?' + if not easygui.ynbox(msg, 'Overwrite Existing Model?'): + print( + 'Aborting training due to existing model with same name...' + ) + return True + else: + print( + 'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...' + ) + return False + + return False + + +def update_my_data(my_data): + # Update the optimizer based on the use_8bit_adam flag + use_8bit_adam = my_data.get('use_8bit_adam', False) + my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW') + + # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model + model_list = my_data.get('model_list', []) + pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') + if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: + my_data['model_list'] = 'custom' + + # Convert epoch and save_every_n_epochs values to int if they are strings + for key in ['epoch', 'save_every_n_epochs']: + value = my_data.get(key, -1) + if isinstance(value, str) and value.isdigit(): + my_data[key] = int(value) + elif not value: + my_data[key] = -1 + + # Update LoRA_type if it is set to LoCon + if my_data.get('LoRA_type', 'Standard') == 'LoCon': + my_data['LoRA_type'] = 'LyCORIS/LoCon' + + # Update model save choices due to changes for LoRA and TI training + if ( + (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) + and my_data.get('save_model_as') not in ['safetensors', 'ckpt'] + ): + message = ( + 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' + ) + if my_data.get('LoRA_type'): + print(message.format('LoRA')) + if my_data.get('num_vectors_per_token'): + print(message.format('TI')) + my_data['save_model_as'] = 'safetensors' + + return my_data + + +def get_dir_and_file(file_path): + dir_path, file_name = os.path.split(file_path) + return (dir_path, file_name) + + +# def has_ext_files(directory, extension): +# # Iterate through all the files in the directory +# for file in os.listdir(directory): +# # If the file name ends with extension, return True +# if file.endswith(extension): +# return True +# # If no extension files were found, return False +# return False + + +def get_file_path( + file_path='', default_extension='.json', extension_name='Config files' +): + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) + + # Create a hidden Tkinter root window + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + + # Show the open file dialog and get the selected file path + file_path = filedialog.askopenfilename( + filetypes=( + (extension_name, f'*{default_extension}'), + ('All files', '*.*'), + ), + defaultextension=default_extension, + initialfile=initial_file, + initialdir=initial_dir, + ) + + # Destroy the hidden root window + root.destroy() + + # If no file is selected, use the current file path + if not file_path: + file_path = current_file_path + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + return file_path + + +def get_any_file_path(file_path=''): + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) + + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + file_path = filedialog.askopenfilename( + initialdir=initial_dir, + initialfile=initial_file, + ) + root.destroy() + + if file_path == '': + file_path = current_file_path + + return file_path + + +def remove_doublequote(file_path): + if file_path != None: + file_path = file_path.replace('"', '') + + return file_path + + +# def set_legacy_8bitadam(optimizer, use_8bit_adam): +# if optimizer == 'AdamW8bit': +# # use_8bit_adam = True +# return gr.Dropdown.update(value=optimizer), gr.Checkbox.update( +# value=True, interactive=False, visible=True +# ) +# else: +# # use_8bit_adam = False +# return gr.Dropdown.update(value=optimizer), gr.Checkbox.update( +# value=False, interactive=False, visible=True +# ) + + +def get_folder_path(folder_path=''): + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_folder_path = folder_path + + initial_dir, initial_file = get_dir_and_file(folder_path) + + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + folder_path = filedialog.askdirectory(initialdir=initial_dir) + root.destroy() + + if folder_path == '': + folder_path = current_folder_path + + return folder_path + + +def get_saveasfile_path( + file_path='', defaultextension='.json', extension_name='Config files' +): + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) + + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + save_file_path = filedialog.asksaveasfile( + filetypes=( + (f'{extension_name}', f'{defaultextension}'), + ('All files', '*'), + ), + defaultextension=defaultextension, + initialdir=initial_dir, + initialfile=initial_file, + ) + root.destroy() + + # print(save_file_path) + + if save_file_path == None: + file_path = current_file_path + else: + print(save_file_path.name) + file_path = save_file_path.name + + # print(file_path) + + return file_path + + +def get_saveasfilename_path( + file_path='', extensions='*', extension_name='Config files' +): + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) + + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + save_file_path = filedialog.asksaveasfilename( + filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')), + defaultextension=extensions, + initialdir=initial_dir, + initialfile=initial_file, + ) + root.destroy() + + if save_file_path == '': + file_path = current_file_path + else: + # print(save_file_path) + file_path = save_file_path + + return file_path + + +def add_pre_postfix( + folder: str = '', + prefix: str = '', + postfix: str = '', + caption_file_ext: str = '.caption', +) -> None: + """ + Add prefix and/or postfix to the content of caption files within a folder. + If no caption files are found, create one with the requested prefix and/or postfix. + + Args: + folder (str): Path to the folder containing caption files. + prefix (str, optional): Prefix to add to the content of the caption files. + postfix (str, optional): Postfix to add to the content of the caption files. + caption_file_ext (str, optional): Extension of the caption files. + """ + + if prefix == '' and postfix == '': + return + + image_extensions = ('.jpg', '.jpeg', '.png', '.webp') + image_files = [ + f for f in os.listdir(folder) if f.lower().endswith(image_extensions) + ] + + for image_file in image_files: + caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext + caption_file_path = os.path.join(folder, caption_file_name) + + if not os.path.exists(caption_file_path): + with open(caption_file_path, 'w') as f: + separator = ' ' if prefix and postfix else '' + f.write(f'{prefix}{separator}{postfix}') + else: + with open(caption_file_path, 'r+') as f: + content = f.read() + content = content.rstrip() + f.seek(0, 0) + + prefix_separator = ' ' if prefix else '' + postfix_separator = ' ' if postfix else '' + f.write( + f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}' + ) + + +def has_ext_files(folder_path: str, file_extension: str) -> bool: + """ + Check if there are any files with the specified extension in the given folder. + + Args: + folder_path (str): Path to the folder containing files. + file_extension (str): Extension of the files to look for. + + Returns: + bool: True if files with the specified extension are found, False otherwise. + """ + for file in os.listdir(folder_path): + if file.endswith(file_extension): + return True + return False + + +def find_replace( + folder_path: str = '', + caption_file_ext: str = '.caption', + search_text: str = '', + replace_text: str = '', +) -> None: + """ + Find and replace text in caption files within a folder. + + Args: + folder_path (str, optional): Path to the folder containing caption files. + caption_file_ext (str, optional): Extension of the caption files. + search_text (str, optional): Text to search for in the caption files. + replace_text (str, optional): Text to replace the search text with. + """ + print('Running caption find/replace') + + if not has_ext_files(folder_path, caption_file_ext): + msgbox( + f'No files with extension {caption_file_ext} were found in {folder_path}...' + ) + return + + if search_text == '': + return + + caption_files = [ + f for f in os.listdir(folder_path) if f.endswith(caption_file_ext) + ] + + for caption_file in caption_files: + with open( + os.path.join(folder_path, caption_file), 'r', errors='ignore' + ) as f: + content = f.read() + + content = content.replace(search_text, replace_text) + + with open(os.path.join(folder_path, caption_file), 'w') as f: + f.write(content) + + +def color_aug_changed(color_aug): + if color_aug: + msgbox( + 'Disabling "Cache latent" because "Color augmentation" has been selected...' + ) + return gr.Checkbox.update(value=False, interactive=False) + else: + return gr.Checkbox.update(value=True, interactive=True) + + +def save_inference_file(output_dir, v2, v_parameterization, output_name): + # List all files in the directory + files = os.listdir(output_dir) + + # Iterate over the list of files + for file in files: + # Check if the file starts with the value of output_name + if file.startswith(output_name): + # Check if it is a file or a directory + if os.path.isfile(os.path.join(output_dir, file)): + # Split the file name and extension + file_name, ext = os.path.splitext(file) + + # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension + if v2 and v_parameterization: + print( + f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml' + ) + shutil.copy( + f'./v2_inference/v2-inference-v.yaml', + f'{output_dir}/{file_name}.yaml', + ) + elif v2: + print( + f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml' + ) + shutil.copy( + f'./v2_inference/v2-inference.yaml', + f'{output_dir}/{file_name}.yaml', + ) + + +def set_pretrained_model_name_or_path_input( + model_list, pretrained_model_name_or_path, v2, v_parameterization +): + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list + if str(model_list) in V2_BASE_MODELS: + print('SD v2 model detected. Setting --v2 parameter') + v2 = True + v_parameterization = False + pretrained_model_name_or_path = str(model_list) + + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list + if str(model_list) in V_PARAMETERIZATION_MODELS: + print( + 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' + ) + v2 = True + v_parameterization = True + pretrained_model_name_or_path = str(model_list) + + if str(model_list) in V1_MODELS: + v2 = False + v_parameterization = False + pretrained_model_name_or_path = str(model_list) + + if model_list == 'custom': + if ( + str(pretrained_model_name_or_path) in V1_MODELS + or str(pretrained_model_name_or_path) in V2_BASE_MODELS + or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS + ): + pretrained_model_name_or_path = '' + v2 = False + v_parameterization = False + return model_list, pretrained_model_name_or_path, v2, v_parameterization + + +def set_v2_checkbox(model_list, v2, v_parameterization): + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list + if str(model_list) in V2_BASE_MODELS: + v2 = True + v_parameterization = False + + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list + if str(model_list) in V_PARAMETERIZATION_MODELS: + v2 = True + v_parameterization = True + + if str(model_list) in V1_MODELS: + v2 = False + v_parameterization = False + + return v2, v_parameterization + + +def set_model_list( + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, +): + + if not pretrained_model_name_or_path in ALL_PRESET_MODELS: + model_list = 'custom' + else: + model_list = pretrained_model_name_or_path + + return model_list, v2, v_parameterization + + +### +### Gradio common GUI section +### + + +def gradio_config(): + with gr.Accordion('Configuration file', open=False): + with gr.Row(): + button_open_config = gr.Button('Open 📂', elem_id='open_folder') + button_save_config = gr.Button('Save 💟', elem_id='open_folder') + button_save_as_config = gr.Button( + 'Save as... 💟', elem_id='open_folder' + ) + config_file_name = gr.Textbox( + label='', + placeholder="type the configuration file path or use the 'Open' button above to select it...", + interactive=True, + ) + button_load_config = gr.Button('Load 💟', elem_id='open_folder') + config_file_name.change( + remove_doublequote, + inputs=[config_file_name], + outputs=[config_file_name], + ) + return ( + button_open_config, + button_save_config, + button_save_as_config, + config_file_name, + button_load_config, + ) + + +def get_pretrained_model_name_or_path_file( + model_list, pretrained_model_name_or_path +): + pretrained_model_name_or_path = get_any_file_path( + pretrained_model_name_or_path + ) + set_model_list(model_list, pretrained_model_name_or_path) + + +def gradio_source_model(save_model_as_choices = [ + 'same as source model', + 'ckpt', + 'diffusers', + 'diffusers_safetensors', + 'safetensors', + ]): + with gr.Tab('Source model'): + # Define the input elements + with gr.Row(): + pretrained_model_name_or_path = gr.Textbox( + label='Pretrained model name or path', + placeholder='enter the path to custom model or name of pretrained model', + value='runwayml/stable-diffusion-v1-5', + ) + pretrained_model_name_or_path_file = gr.Button( + document_symbol, elem_id='open_folder_small' + ) + pretrained_model_name_or_path_file.click( + get_any_file_path, + inputs=pretrained_model_name_or_path, + outputs=pretrained_model_name_or_path, + show_progress=False, + ) + pretrained_model_name_or_path_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + pretrained_model_name_or_path_folder.click( + get_folder_path, + inputs=pretrained_model_name_or_path, + outputs=pretrained_model_name_or_path, + show_progress=False, + ) + model_list = gr.Dropdown( + label='Model Quick Pick', + choices=[ + 'custom', + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + 'runwayml/stable-diffusion-v1-5', + 'CompVis/stable-diffusion-v1-4', + ], + value='runwayml/stable-diffusion-v1-5', + ) + save_model_as = gr.Dropdown( + label='Save trained model as', + choices=save_model_as_choices, + value='safetensors', + ) + + with gr.Row(): + v2 = gr.Checkbox(label='v2', value=False) + v_parameterization = gr.Checkbox( + label='v_parameterization', value=False + ) + v2.change( + set_v2_checkbox, + inputs=[model_list, v2, v_parameterization], + outputs=[v2, v_parameterization], + show_progress=False, + ) + v_parameterization.change( + set_v2_checkbox, + inputs=[model_list, v2, v_parameterization], + outputs=[v2, v_parameterization], + show_progress=False, + ) + model_list.change( + set_pretrained_model_name_or_path_input, + inputs=[ + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, + ], + outputs=[ + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, + ], + show_progress=False, + ) + # Update the model list and parameters when user click outside the button or field + pretrained_model_name_or_path.change( + set_model_list, + inputs=[ + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, + ], + outputs=[ + model_list, + v2, + v_parameterization, + ], + show_progress=False, + ) + return ( + pretrained_model_name_or_path, + v2, + v_parameterization, + save_model_as, + model_list, + ) + + +def gradio_training( + learning_rate_value='1e-6', + lr_scheduler_value='constant', + lr_warmup_value='0', +): + with gr.Row(): + train_batch_size = gr.Slider( + minimum=1, + maximum=64, + label='Train batch size', + value=1, + step=1, + ) + epoch = gr.Number(label='Epoch', value=1, precision=0) + save_every_n_epochs = gr.Number( + label='Save every N epochs', value=1, precision=0 + ) + caption_extension = gr.Textbox( + label='Caption Extension', + placeholder='(Optional) Extension for caption files. default: .caption', + ) + with gr.Row(): + mixed_precision = gr.Dropdown( + label='Mixed precision', + choices=[ + 'no', + 'fp16', + 'bf16', + ], + value='fp16', + ) + save_precision = gr.Dropdown( + label='Save precision', + choices=[ + 'float', + 'fp16', + 'bf16', + ], + value='fp16', + ) + num_cpu_threads_per_process = gr.Slider( + minimum=1, + maximum=os.cpu_count(), + step=1, + label='Number of CPU threads per core', + value=2, + ) + seed = gr.Textbox(label='Seed', placeholder='(Optional) eg:1234') + cache_latents = gr.Checkbox(label='Cache latent', value=True) + with gr.Row(): + learning_rate = gr.Textbox( + label='Learning rate', value=learning_rate_value + ) + lr_scheduler = gr.Dropdown( + label='LR Scheduler', + choices=[ + 'adafactor', + 'constant', + 'constant_with_warmup', + 'cosine', + 'cosine_with_restarts', + 'linear', + 'polynomial', + ], + value=lr_scheduler_value, + ) + lr_warmup = gr.Textbox( + label='LR warmup (% of steps)', value=lr_warmup_value + ) + optimizer = gr.Dropdown( + label='Optimizer', + choices=[ + 'AdamW', + 'AdamW8bit', + 'Adafactor', + 'DAdaptation', + 'Lion', + 'SGDNesterov', + 'SGDNesterov8bit', + ], + value='AdamW8bit', + interactive=True, + ) + with gr.Row(): + optimizer_args = gr.Textbox( + label='Optimizer extra arguments', + placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True', + ) + return ( + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + num_cpu_threads_per_process, + seed, + caption_extension, + cache_latents, + optimizer, + optimizer_args, + ) + + +def run_cmd_training(**kwargs): + options = [ + f' --learning_rate="{kwargs.get("learning_rate", "")}"' + if kwargs.get('learning_rate') + else '', + f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"' + if kwargs.get('lr_scheduler') + else '', + f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"' + if kwargs.get('lr_warmup_steps') + else '', + f' --train_batch_size="{kwargs.get("train_batch_size", "")}"' + if kwargs.get('train_batch_size') + else '', + f' --max_train_steps="{kwargs.get("max_train_steps", "")}"' + if kwargs.get('max_train_steps') + else '', + f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"' + if int(kwargs.get('save_every_n_epochs')) + else '', + f' --mixed_precision="{kwargs.get("mixed_precision", "")}"' + if kwargs.get('mixed_precision') + else '', + f' --save_precision="{kwargs.get("save_precision", "")}"' + if kwargs.get('save_precision') + else '', + f' --seed="{kwargs.get("seed", "")}"' + if kwargs.get('seed') != '' + else '', + f' --caption_extension="{kwargs.get("caption_extension", "")}"' + if kwargs.get('caption_extension') + else '', + ' --cache_latents' if kwargs.get('cache_latents') else '', + # ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '', + f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"', + f' --optimizer_args {kwargs.get("optimizer_args", "")}' + if not kwargs.get('optimizer_args') == '' + else '', + ] + run_cmd = ''.join(options) + return run_cmd + + +def gradio_advanced_training(): + with gr.Row(): + additional_parameters = gr.Textbox( + label='Additional parameters', + placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"', + ) + with gr.Row(): + keep_tokens = gr.Slider( + label='Keep n tokens', value='0', minimum=0, maximum=32, step=1 + ) + clip_skip = gr.Slider( + label='Clip skip', value='1', minimum=1, maximum=12, step=1 + ) + max_token_length = gr.Dropdown( + label='Max Token Length', + choices=[ + '75', + '150', + '225', + ], + value='75', + ) + full_fp16 = gr.Checkbox( + label='Full fp16 training (experimental)', value=False + ) + with gr.Row(): + gradient_checkpointing = gr.Checkbox( + label='Gradient checkpointing', value=False + ) + shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False) + persistent_data_loader_workers = gr.Checkbox( + label='Persistent data loader', value=False + ) + mem_eff_attn = gr.Checkbox( + label='Memory efficient attention', value=False + ) + with gr.Row(): + # This use_8bit_adam element should be removed in a future release as it is no longer used + # use_8bit_adam = gr.Checkbox( + # label='Use 8bit adam', value=False, visible=False + # ) + xformers = gr.Checkbox(label='Use xformers', value=True) + color_aug = gr.Checkbox(label='Color augmentation', value=False) + flip_aug = gr.Checkbox(label='Flip augmentation', value=False) + min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1) + with gr.Row(): + bucket_no_upscale = gr.Checkbox( + label="Don't upscale bucket resolution", value=True + ) + bucket_reso_steps = gr.Number( + label='Bucket resolution steps', value=64 + ) + random_crop = gr.Checkbox( + label='Random crop instead of center crop', value=False + ) + noise_offset = gr.Textbox( + label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1' + ) + + with gr.Row(): + caption_dropout_every_n_epochs = gr.Number( + label='Dropout caption every n epochs', value=0 + ) + caption_dropout_rate = gr.Slider( + label='Rate of caption dropout', value=0, minimum=0, maximum=1 + ) + vae_batch_size = gr.Slider( + label='VAE batch size', + minimum=0, + maximum=32, + value=0, + step=1 + ) + with gr.Row(): + save_state = gr.Checkbox(label='Save training state', value=False) + resume = gr.Textbox( + label='Resume from saved training state', + placeholder='path to "last-state" state folder to resume from', + ) + resume_button = gr.Button('📂', elem_id='open_folder_small') + resume_button.click( + get_folder_path, + outputs=resume, + show_progress=False, + ) + max_train_epochs = gr.Textbox( + label='Max train epoch', + placeholder='(Optional) Override number of epoch', + ) + max_data_loader_n_workers = gr.Textbox( + label='Max num workers for DataLoader', + placeholder='(Optional) Override number of epoch. Default: 8', + value="0", + ) + return ( + # use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + noise_offset, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ) + + +def run_cmd_advanced_training(**kwargs): + options = [ + f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"' + if kwargs.get('max_train_epochs') + else '', + f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"' + if kwargs.get('max_data_loader_n_workers') + else '', + f' --max_token_length={kwargs.get("max_token_length", "")}' + if int(kwargs.get('max_token_length', 75)) > 75 + else '', + f' --clip_skip={kwargs.get("clip_skip", "")}' + if int(kwargs.get('clip_skip', 1)) > 1 + else '', + f' --resume="{kwargs.get("resume", "")}"' + if kwargs.get('resume') + else '', + f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' + if int(kwargs.get('keep_tokens', 0)) > 0 + else '', + f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"' + if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 + else '', + f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"' + if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 + else '', + f' --vae_batch_size="{kwargs.get("vae_batch_size", 0)}"' + if int(kwargs.get('vae_batch_size', 0)) > 0 + else '', + f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' + if int(kwargs.get('bucket_reso_steps', 64)) >= 1 + else '', + f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}' + if int(kwargs.get('min_snr_gamma', 0)) >= 1 + else '', + ' --save_state' if kwargs.get('save_state') else '', + ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', + ' --color_aug' if kwargs.get('color_aug') else '', + ' --flip_aug' if kwargs.get('flip_aug') else '', + ' --shuffle_caption' if kwargs.get('shuffle_caption') else '', + ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') + else '', + ' --full_fp16' if kwargs.get('full_fp16') else '', + ' --xformers' if kwargs.get('xformers') else '', + # ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '', + ' --persistent_data_loader_workers' + if kwargs.get('persistent_data_loader_workers') + else '', + ' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '', + ' --random_crop' if kwargs.get('random_crop') else '', + f' --noise_offset={float(kwargs.get("noise_offset", 0))}' + if not kwargs.get('noise_offset', '') == '' + else '', + f' {kwargs.get("additional_parameters", "")}', + ] + run_cmd = ''.join(options) + return run_cmd diff --git a/library/config_util.py b/library/config_util.py new file mode 100644 index 0000000000000000000000000000000000000000..97bbb4a8dc5b7e15c1ac305f56dcfafc3c9b6366 --- /dev/null +++ b/library/config_util.py @@ -0,0 +1,536 @@ +import argparse +from dataclasses import ( + asdict, + dataclass, +) +import functools +import random +from textwrap import dedent, indent +import json +from pathlib import Path +# from toolz import curry +from typing import ( + List, + Optional, + Sequence, + Tuple, + Union, +) + +import toml +import voluptuous +from voluptuous import ( + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, +) +from transformers import CLIPTokenizer + +from . import train_util +from .train_util import ( + DreamBoothSubset, + FineTuningSubset, + DreamBoothDataset, + FineTuningDataset, + DatasetGroup, +) + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳现な蚭定甚の蚭定ファむル") + +# TODO: inherit Params class in Subset, Dataset + +@dataclass +class BaseSubsetParams: + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + keep_tokens: int = 0 + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + +@dataclass +class DreamBoothSubsetParams(BaseSubsetParams): + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + +@dataclass +class FineTuningSubsetParams(BaseSubsetParams): + metadata_file: Optional[str] = None + +@dataclass +class BaseDatasetParams: + tokenizer: CLIPTokenizer = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + debug_dataset: bool = False + +@dataclass +class DreamBoothDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + +@dataclass +class FineTuningDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + +@dataclass +class SubsetBlueprint: + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + +@dataclass +class DatasetBlueprint: + is_dreambooth: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + +@dataclass +class DatasetGroupBlueprint: + datasets: Sequence[DatasetBlueprint] +@dataclass +class Blueprint: + dataset_group: DatasetGroupBlueprint + + +class ConfigSanitizer: + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "token_warmup_min": int, + "token_warmup_step": Any(float,int), + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モヌドか fine tuning モヌドのどちらも指定されおいたせん。1぀以䞊指定しおください。" + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + if all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットずfine tuninのサブセットを同䞀のデヌタセットに混圚させるこずはできたせん。別々のデヌタセットに分割しおください。") + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + self.dataset_schema = self.db_dataset_schema + else: + self.dataset_schema = self.ft_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema({ + "general": self.general_schema, + "datasets": [self.dataset_schema], + }) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: ゚ラヌ発生時のメッセヌゞをわかりやすくする + print("Invalid user config / ナヌザ蚭定の圢匏が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラむンのパヌス結果が正しくないようです。プログラムのバグの可胜性が高いです。") + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged + + +class BlueprintGenerator: + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { + } + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + if is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks(subset_params_klass, + [subset_config, dataset_config, general_config, argparse_config, runtime_params]) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks(dataset_params_klass, + [dataset_config, general_config, argparse_config, runtime_params]) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value = None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value + + +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): + datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent(dedent(f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + else: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + print(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return DatasetGroup(datasets) + + +def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split('_') + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無芖したす: {dir}") + return 0, "" + caption_by_folder = '_'.join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) + + return subsets_config + + +def load_user_config(file: str) -> dict: + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファむルが芋぀かりたせん: {file}") + + if file.name.lower().endswith('.json'): + try: + config = json.load(file) + except Exception: + print(f"Error on parsing JSON config file. Please check the format. / JSON 圢匏の蚭定ファむルの読み蟌みに倱敗したした。文法が正しいか確認しおください。: {file}") + raise + elif file.name.lower().endswith('.toml'): + try: + config = toml.load(file) + except Exception: + print(f"Error on parsing TOML config file. Please check the format. / TOML 圢匏の蚭定ファむルの読み蟌みに倱敗したした。文法が正しいか確認しおください。: {file}") + raise + else: + raise ValueError(f"not supported config file format / 察応しおいない蚭定ファむルの圢匏です: {file}") + + return config + +# for config test +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + + print("[argparse_namespace]") + print(vars(argparse_namespace)) + + user_config = load_user_config(config_args.dataset_config) + + print("\n[user_config]") + print(user_config) + + sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) + + print("\n[sanitized_user_config]") + print(sanitized_user_config) + + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + + print("\n[blueprint]") + print(blueprint) diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa39b87dfc9976d5ffd5e5f05bb4eda80a9ebc1 --- /dev/null +++ b/library/convert_model_gui.py @@ -0,0 +1,247 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +import shutil +from .common_gui import get_folder_path, get_file_path + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def convert_model( + source_model_input, + source_model_type, + target_model_folder_input, + target_model_name_input, + target_model_type, + target_save_precision_type, +): + # Check for caption_text_input + if source_model_type == '': + msgbox('Invalid source model type') + return + + # Check if source model exist + if os.path.isfile(source_model_input): + print('The provided source model is a file') + elif os.path.isdir(source_model_input): + print('The provided model is a folder') + else: + msgbox('The provided source model is neither a file nor a folder') + return + + # Check if source model exist + if os.path.isdir(target_model_folder_input): + print('The provided model folder exist') + else: + msgbox('The provided target folder does not exist') + return + + run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"' + + v1_models = [ + 'runwayml/stable-diffusion-v1-5', + 'CompVis/stable-diffusion-v1-4', + ] + + # check if v1 models + if str(source_model_type) in v1_models: + print('SD v1 model specified. Setting --v1 parameter') + run_cmd += ' --v1' + else: + print('SD v2 model specified. Setting --v2 parameter') + run_cmd += ' --v2' + + if not target_save_precision_type == 'unspecified': + run_cmd += f' --{target_save_precision_type}' + + if ( + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' + ): + run_cmd += f' --reference_model="{source_model_type}"' + + if target_model_type == 'diffuser_safetensors': + run_cmd += ' --use_safetensors' + + run_cmd += f' "{source_model_input}"' + + if ( + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' + ): + target_model_path = os.path.join( + target_model_folder_input, target_model_name_input + ) + run_cmd += f' "{target_model_path}"' + else: + target_model_path = os.path.join( + target_model_folder_input, + f'{target_model_name_input}.{target_model_type}', + ) + run_cmd += f' "{target_model_path}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + if ( + not target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' + ): + + v2_models = [ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + ] + v_parameterization = [ + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + ] + + if str(source_model_type) in v2_models: + inference_file = os.path.join( + target_model_folder_input, f'{target_model_name_input}.yaml' + ) + print(f'Saving v2-inference.yaml as {inference_file}') + shutil.copy( + f'./v2_inference/v2-inference.yaml', + f'{inference_file}', + ) + + if str(source_model_type) in v_parameterization: + inference_file = os.path.join( + target_model_folder_input, f'{target_model_name_input}.yaml' + ) + print(f'Saving v2-inference-v.yaml as {inference_file}') + shutil.copy( + f'./v2_inference/v2-inference-v.yaml', + f'{inference_file}', + ) + + +# parser = argparse.ArgumentParser() +# parser.add_argument("--v1", action='store_true', +# help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み蟌む') +# parser.add_argument("--v2", action='store_true', +# help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み蟌む') +# parser.add_argument("--fp16", action='store_true', +# help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16圢匏で読み蟌みDiffusers圢匏のみ察応、保存するcheckpointのみ察応') +# parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16圢匏で保存するcheckpointのみ察応') +# parser.add_argument("--float", action='store_true', +# help='save as float (checkpoint only) / float(float32)圢匏で保存するcheckpointのみ察応') +# parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに蚘録するepoch数の倀') +# parser.add_argument("--global_step", type=int, default=0, +# help='global_step to write to checkpoint / checkpointに蚘録するglobal_stepの倀') +# parser.add_argument("--reference_model", type=str, default=None, +# help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピヌ元のDiffusersモデル、Diffusers圢匏で保存するずきに必芁") + +# parser.add_argument("model_to_load", type=str, default=None, +# help="model to load: checkpoint file or Diffusers model's directory / 読み蟌むモデル、checkpointかDiffusers圢匏モデルのディレクトリ") +# parser.add_argument("model_to_save", type=str, default=None, +# help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 倉換埌のモデル、拡匵子がある堎合はcheckpoint、ない堎合はDiffusesモデルずしお保存") + + +### +# Gradio UI +### + + +def gradio_convert_model_tab(): + with gr.Tab('Convert model'): + gr.Markdown( + 'This utility can be used to convert from one stable diffusion model format to another.' + ) + with gr.Row(): + source_model_input = gr.Textbox( + label='Source model', + placeholder='path to source model folder of file to convert...', + interactive=True, + ) + button_source_model_dir = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_source_model_dir.click( + get_folder_path, + outputs=source_model_input, + show_progress=False, + ) + + button_source_model_file = gr.Button( + document_symbol, elem_id='open_folder_small' + ) + button_source_model_file.click( + get_file_path, + inputs=[source_model_input], + outputs=source_model_input, + show_progress=False, + ) + + source_model_type = gr.Dropdown( + label='Source model type', + choices=[ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + 'runwayml/stable-diffusion-v1-5', + 'CompVis/stable-diffusion-v1-4', + ], + ) + with gr.Row(): + target_model_folder_input = gr.Textbox( + label='Target model folder', + placeholder='path to target model folder of file name to create...', + interactive=True, + ) + button_target_model_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_target_model_folder.click( + get_folder_path, + outputs=target_model_folder_input, + show_progress=False, + ) + + target_model_name_input = gr.Textbox( + label='Target model name', + placeholder='target model name...', + interactive=True, + ) + target_model_type = gr.Dropdown( + label='Target model type', + choices=[ + 'diffuser', + 'diffuser_safetensors', + 'ckpt', + 'safetensors', + ], + ) + target_save_precision_type = gr.Dropdown( + label='Target model precision', + choices=['unspecified', 'fp16', 'bf16', 'float'], + value='unspecified', + ) + + convert_button = gr.Button('Convert model') + + convert_button.click( + convert_model, + inputs=[ + source_model_input, + source_model_type, + target_model_folder_input, + target_model_name_input, + target_model_type, + target_save_precision_type, + ], + show_progress=False, + ) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..4d844be3fb706db1287ef858028142e88241bf5a --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,18 @@ +import torch +import argparse + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + snr = torch.stack([all_snr[t] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight + return loss + +def add_custom_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 䜎いタむムステップでの高いlossに察しお重みを枛らすためのgamma倀、䜎いほど効果が匷く、論文では5が掚奚") diff --git a/library/dataset_balancing_gui.py b/library/dataset_balancing_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6bc984c3a86ebde7464321456da57d402e5319 --- /dev/null +++ b/library/dataset_balancing_gui.py @@ -0,0 +1,146 @@ +import os +import re +import gradio as gr +from easygui import msgbox, boolbox +from .common_gui import get_folder_path + +# def select_folder(): +# # Open a file dialog to select a directory +# folder = filedialog.askdirectory() + +# # Update the GUI to display the selected folder +# selected_folder_label.config(text=folder) + + +def dataset_balancing(concept_repeats, folder, insecure): + + if not concept_repeats > 0: + # Display an error message if the total number of repeats is not a valid integer + msgbox('Please enter a valid integer for the total number of repeats.') + return + + concept_repeats = int(concept_repeats) + + # Check if folder exist + if folder == '' or not os.path.isdir(folder): + msgbox('Please enter a valid folder for balancing.') + return + + pattern = re.compile(r'^\d+_.+$') + + # Iterate over the subdirectories in the selected folder + for subdir in os.listdir(folder): + if pattern.match(subdir) or insecure: + # Calculate the number of repeats for the current subdirectory + # Get a list of all the files in the folder + files = os.listdir(os.path.join(folder, subdir)) + + # Filter the list to include only image files + image_files = [ + f + for f in files + if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp')) + ] + + # Count the number of image files + images = len(image_files) + + # Check if the subdirectory name starts with a number inside braces, + # indicating that the repeats value should be multiplied + match = re.match(r'^\{(\d+\.?\d*)\}', subdir) + if match: + # Multiply the repeats value by the number inside the braces + if not images == 0: + repeats = max( + 1, + round( + concept_repeats / images * float(match.group(1)) + ), + ) + else: + repeats = 0 + subdir = subdir[match.end() :] + else: + if not images == 0: + repeats = max(1, round(concept_repeats / images)) + else: + repeats = 0 + + # Check if the subdirectory name already has a number at the beginning + match = re.match(r'^\d+_', subdir) + if match: + # Replace the existing number with the new number + old_name = os.path.join(folder, subdir) + new_name = os.path.join( + folder, f'{repeats}_{subdir[match.end():]}' + ) + else: + # Add the new number at the beginning of the name + old_name = os.path.join(folder, subdir) + new_name = os.path.join(folder, f'{repeats}_{subdir}') + + os.rename(old_name, new_name) + else: + print( + f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...' + ) + + msgbox('Dataset balancing completed...') + + +def warning(insecure): + if insecure: + if boolbox( + f'WARNING!!! You have asked to rename non kohya_ss _ folders...\n\nAre you sure you want to do that?', + choices=('Yes, I like danger', 'No, get me out of here'), + ): + return True + else: + return False + + +def gradio_dataset_balancing_tab(): + with gr.Tab('Dreambooth/LoRA Dataset balancing'): + gr.Markdown( + 'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.' + ) + gr.Markdown( + 'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!' + ) + with gr.Row(): + select_dataset_folder_input = gr.Textbox( + label='Dataset folder', + placeholder='Folder containing the concepts folders to balance...', + interactive=True, + ) + + select_dataset_folder_button = gr.Button( + '📂', elem_id='open_folder_small' + ) + select_dataset_folder_button.click( + get_folder_path, + outputs=select_dataset_folder_input, + show_progress=False, + ) + + total_repeats_number = gr.Number( + value=1000, + interactive=True, + label='Training steps per concept per epoch', + ) + with gr.Accordion('Advanced options', open=False): + insecure = gr.Checkbox( + value=False, + label='DANGER!!! -- Insecure folder renaming -- DANGER!!!', + ) + insecure.change(warning, inputs=insecure, outputs=insecure) + balance_button = gr.Button('Balance dataset') + balance_button.click( + dataset_balancing, + inputs=[ + total_repeats_number, + select_dataset_folder_input, + insecure, + ], + show_progress=False, + ) diff --git a/library/dreambooth_folder_creation_gui.py b/library/dreambooth_folder_creation_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d5ff49abb0a2c8b3db7fa5063bf47f8c93e16b --- /dev/null +++ b/library/dreambooth_folder_creation_gui.py @@ -0,0 +1,210 @@ +import gradio as gr +from easygui import diropenbox, msgbox +from .common_gui import get_folder_path +import shutil +import os + + +def copy_info_to_Folders_tab(training_folder): + img_folder = os.path.join(training_folder, 'img') + if os.path.exists(os.path.join(training_folder, 'reg')): + reg_folder = os.path.join(training_folder, 'reg') + else: + reg_folder = '' + model_folder = os.path.join(training_folder, 'model') + log_folder = os.path.join(training_folder, 'log') + + return img_folder, reg_folder, model_folder, log_folder + + +def dreambooth_folder_preparation( + util_training_images_dir_input, + util_training_images_repeat_input, + util_instance_prompt_input, + util_regularization_images_dir_input, + util_regularization_images_repeat_input, + util_class_prompt_input, + util_training_dir_output, +): + + # Check if the input variables are empty + if not len(util_training_dir_output): + print( + "Destination training directory is missing... can't perform the required task..." + ) + return + else: + # Create the util_training_dir_output directory if it doesn't exist + os.makedirs(util_training_dir_output, exist_ok=True) + + # Check for instance prompt + if util_instance_prompt_input == '': + msgbox('Instance prompt missing...') + return + + # Check for class prompt + if util_class_prompt_input == '': + msgbox('Class prompt missing...') + return + + # Create the training_dir path + if util_training_images_dir_input == '': + print( + "Training images directory is missing... can't perform the required task..." + ) + return + else: + training_dir = os.path.join( + util_training_dir_output, + f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}', + ) + + # Remove folders if they exist + if os.path.exists(training_dir): + print(f'Removing existing directory {training_dir}...') + shutil.rmtree(training_dir) + + # Copy the training images to their respective directories + print(f'Copy {util_training_images_dir_input} to {training_dir}...') + shutil.copytree(util_training_images_dir_input, training_dir) + + if not util_regularization_images_dir_input == '': + # Create the regularization_dir path + if not util_regularization_images_repeat_input > 0: + print('Repeats is missing... not copying regularisation images...') + else: + regularization_dir = os.path.join( + util_training_dir_output, + f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}', + ) + + # Remove folders if they exist + if os.path.exists(regularization_dir): + print(f'Removing existing directory {regularization_dir}...') + shutil.rmtree(regularization_dir) + + # Copy the regularisation images to their respective directories + print( + f'Copy {util_regularization_images_dir_input} to {regularization_dir}...' + ) + shutil.copytree( + util_regularization_images_dir_input, regularization_dir + ) + else: + print( + 'Regularization images directory is missing... not copying regularisation images...' + ) + + # create log and model folder + # Check if the log folder exists and create it if it doesn't + if not os.path.exists(os.path.join(util_training_dir_output, 'log')): + os.makedirs(os.path.join(util_training_dir_output, 'log')) + + # Check if the model folder exists and create it if it doesn't + if not os.path.exists(os.path.join(util_training_dir_output, 'model')): + os.makedirs(os.path.join(util_training_dir_output, 'model')) + + print( + f'Done creating kohya_ss training folder structure at {util_training_dir_output}...' + ) + + +def gradio_dreambooth_folder_creation_tab( + train_data_dir_input=gr.Textbox(), + reg_data_dir_input=gr.Textbox(), + output_dir_input=gr.Textbox(), + logging_dir_input=gr.Textbox(), +): + with gr.Tab('Dreambooth/LoRA Folder preparation'): + gr.Markdown( + 'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.' + ) + with gr.Row(): + util_instance_prompt_input = gr.Textbox( + label='Instance prompt', + placeholder='Eg: asd', + interactive=True, + ) + util_class_prompt_input = gr.Textbox( + label='Class prompt', + placeholder='Eg: person', + interactive=True, + ) + with gr.Row(): + util_training_images_dir_input = gr.Textbox( + label='Training images', + placeholder='Directory containing the training images', + interactive=True, + ) + button_util_training_images_dir_input = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_util_training_images_dir_input.click( + get_folder_path, + outputs=util_training_images_dir_input, + show_progress=False, + ) + util_training_images_repeat_input = gr.Number( + label='Repeats', + value=40, + interactive=True, + elem_id='number_input', + ) + with gr.Row(): + util_regularization_images_dir_input = gr.Textbox( + label='Regularisation images', + placeholder='(Optional) Directory containing the regularisation images', + interactive=True, + ) + button_util_regularization_images_dir_input = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_util_regularization_images_dir_input.click( + get_folder_path, + outputs=util_regularization_images_dir_input, + show_progress=False, + ) + util_regularization_images_repeat_input = gr.Number( + label='Repeats', + value=1, + interactive=True, + elem_id='number_input', + ) + with gr.Row(): + util_training_dir_output = gr.Textbox( + label='Destination training directory', + placeholder='Directory where formatted training and regularisation folders will be placed', + interactive=True, + ) + button_util_training_dir_output = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_util_training_dir_output.click( + get_folder_path, outputs=util_training_dir_output + ) + button_prepare_training_data = gr.Button('Prepare training data') + button_prepare_training_data.click( + dreambooth_folder_preparation, + inputs=[ + util_training_images_dir_input, + util_training_images_repeat_input, + util_instance_prompt_input, + util_regularization_images_dir_input, + util_regularization_images_repeat_input, + util_class_prompt_input, + util_training_dir_output, + ], + show_progress=False, + ) + button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab') + button_copy_info_to_Folders_tab.click( + copy_info_to_Folders_tab, + inputs=[util_training_dir_output], + outputs=[ + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ], + show_progress=False, + ) diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..53292d3542d88fd1e20af05c6117f8103caf0177 --- /dev/null +++ b/library/extract_lora_gui.py @@ -0,0 +1,178 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def extract_lora( + model_tuned, + model_org, + save_to, + save_precision, + dim, + v2, + conv_dim, + device, +): + # Check for caption_text_input + if model_tuned == '': + msgbox('Invalid finetuned model file') + return + + if model_org == '': + msgbox('Invalid base model file') + return + + # Check if source model exist + if not os.path.isfile(model_tuned): + msgbox('The provided finetuned model is not a file') + return + + if not os.path.isfile(model_org): + msgbox('The provided base model is not a file') + return + + run_cmd = ( + f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"' + ) + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --save_to "{save_to}"' + run_cmd += f' --model_org "{model_org}"' + run_cmd += f' --model_tuned "{model_tuned}"' + run_cmd += f' --dim {dim}' + run_cmd += f' --device {device}' + if conv_dim > 0: + run_cmd += f' --conv_dim {conv_dim}' + if v2: + run_cmd += f' --v2' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_extract_lora_tab(): + with gr.Tab('Extract LoRA'): + gr.Markdown( + 'This utility can extract a LoRA network from a finetuned model.' + ) + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False) + model_ext_name = gr.Textbox(value='Model types', visible=False) + + with gr.Row(): + model_tuned = gr.Textbox( + label='Finetuned model', + placeholder='Path to the finetuned model to extract', + interactive=True, + ) + button_model_tuned_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_model_tuned_file.click( + get_file_path, + inputs=[model_tuned, model_ext, model_ext_name], + outputs=model_tuned, + show_progress=False, + ) + + model_org = gr.Textbox( + label='Stable Diffusion base model', + placeholder='Stable Diffusion original model: ckpt or safetensors file', + interactive=True, + ) + button_model_org_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_model_org_file.click( + get_file_path, + inputs=[model_org, model_ext, model_ext_name], + outputs=model_org, + show_progress=False, + ) + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path where to save the extracted LoRA model...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + save_precision = gr.Dropdown( + label='Save precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + with gr.Row(): + dim = gr.Slider( + minimum=4, + maximum=1024, + label='Network Dimension (Rank)', + value=128, + step=1, + interactive=True, + ) + conv_dim = gr.Slider( + minimum=0, + maximum=1024, + label='Conv Dimension (Rank)', + value=128, + step=1, + interactive=True, + ) + v2 = gr.Checkbox(label='v2', value=False, interactive=True) + device = gr.Dropdown( + label='Device', + choices=[ + 'cpu', + 'cuda', + ], + value='cuda', + interactive=True, + ) + + extract_button = gr.Button('Extract LoRA model') + + extract_button.click( + extract_lora, + inputs=[ + model_tuned, + model_org, + save_to, + save_precision, + dim, + v2, + conv_dim, + device + ], + show_progress=False, + ) diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..13575bbcb84ec066b0d1813186229c6a8d6a7792 --- /dev/null +++ b/library/extract_lycoris_locon_gui.py @@ -0,0 +1,309 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def extract_lycoris_locon( + db_model, + base_model, + output_name, + device, + is_v2, + mode, + linear_dim, + conv_dim, + linear_threshold, + conv_threshold, + linear_ratio, + conv_ratio, + linear_quantile, + conv_quantile, + use_sparse_bias, + sparsity, + disable_cp, +): + # Check for caption_text_input + if db_model == '': + msgbox('Invalid finetuned model file') + return + + if base_model == '': + msgbox('Invalid base model file') + return + + # Check if source model exist + if not os.path.isfile(db_model): + msgbox('The provided finetuned model is not a file') + return + + if not os.path.isfile(base_model): + msgbox('The provided base model is not a file') + return + + run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"' + if is_v2: + run_cmd += f' --is_v2' + run_cmd += f' --device {device}' + run_cmd += f' --mode {mode}' + run_cmd += f' --safetensors' + run_cmd += f' --linear_dim {linear_dim}' + run_cmd += f' --conv_dim {conv_dim}' + run_cmd += f' --linear_threshold {linear_threshold}' + run_cmd += f' --conv_threshold {conv_threshold}' + run_cmd += f' --linear_ratio {linear_ratio}' + run_cmd += f' --conv_ratio {conv_ratio}' + run_cmd += f' --linear_quantile {linear_quantile}' + run_cmd += f' --conv_quantile {conv_quantile}' + if use_sparse_bias: + run_cmd += f' --use_sparse_bias' + run_cmd += f' --sparsity {sparsity}' + if disable_cp: + run_cmd += f' --disable_cp' + run_cmd += f' "{base_model}"' + run_cmd += f' "{db_model}"' + run_cmd += f' "{output_name}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### +# def update_mode(mode): +# # 'fixed', 'threshold','ratio','quantile' +# if mode == 'fixed': +# return gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False) +# if mode == 'threshold': +# return gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False) +# if mode == 'ratio': +# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False) +# if mode == 'threshold': +# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True) + + +def update_mode(mode): + # Create a list of possible mode values + modes = ['fixed', 'threshold', 'ratio', 'quantile'] + + # Initialize an empty list to store visibility updates + updates = [] + + # Iterate through the possible modes + for m in modes: + # Add a visibility update for each mode, setting it to True if the input mode matches the current mode in the loop + updates.append(gr.Row.update(visible=(mode == m))) + + # Return the visibility updates as a tuple + return tuple(updates) + + +def gradio_extract_lycoris_locon_tab(): + with gr.Tab('Extract LyCORIS LoCON'): + gr.Markdown( + 'This utility can extract a LyCORIS LoCon network from a finetuned model.' + ) + lora_ext = gr.Textbox( + value='*.safetensors', visible=False + ) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) + model_ext_name = gr.Textbox(value='Model types', visible=False) + + with gr.Row(): + db_model = gr.Textbox( + label='Finetuned model', + placeholder='Path to the finetuned model to extract', + interactive=True, + ) + button_db_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_db_model_file.click( + get_file_path, + inputs=[db_model, model_ext, model_ext_name], + outputs=db_model, + show_progress=False, + ) + + base_model = gr.Textbox( + label='Stable Diffusion base model', + placeholder='Stable Diffusion original model: ckpt or safetensors file', + interactive=True, + ) + button_base_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_base_model_file.click( + get_file_path, + inputs=[base_model, model_ext, model_ext_name], + outputs=base_model, + show_progress=False, + ) + with gr.Row(): + output_name = gr.Textbox( + label='Save to', + placeholder='path where to save the extracted LoRA model...', + interactive=True, + ) + button_output_name = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_output_name.click( + get_saveasfilename_path, + inputs=[output_name, lora_ext, lora_ext_name], + outputs=output_name, + show_progress=False, + ) + device = gr.Dropdown( + label='Device', + choices=[ + 'cpu', + 'cuda', + ], + value='cuda', + interactive=True, + ) + is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) + mode = gr.Dropdown( + label='Mode', + choices=['fixed', 'threshold', 'ratio', 'quantile'], + value='fixed', + interactive=True, + ) + with gr.Row(visible=True) as fixed: + linear_dim = gr.Slider( + minimum=1, + maximum=1024, + label='Network Dimension', + value=1, + step=1, + interactive=True, + ) + conv_dim = gr.Slider( + minimum=1, + maximum=1024, + label='Conv Dimension', + value=1, + step=1, + interactive=True, + ) + with gr.Row(visible=False) as threshold: + linear_threshold = gr.Slider( + minimum=0, + maximum=1, + label='Linear threshold', + value=0, + step=0.01, + interactive=True, + ) + conv_threshold = gr.Slider( + minimum=0, + maximum=1, + label='Conv threshold', + value=0, + step=0.01, + interactive=True, + ) + with gr.Row(visible=False) as ratio: + linear_ratio = gr.Slider( + minimum=0, + maximum=1, + label='Linear ratio', + value=0, + step=0.01, + interactive=True, + ) + conv_ratio = gr.Slider( + minimum=0, + maximum=1, + label='Conv ratio', + value=0, + step=0.01, + interactive=True, + ) + with gr.Row(visible=False) as quantile: + linear_quantile = gr.Slider( + minimum=0, + maximum=1, + label='Linear quantile', + value=0.75, + step=0.01, + interactive=True, + ) + conv_quantile = gr.Slider( + minimum=0, + maximum=1, + label='Conv quantile', + value=0.75, + step=0.01, + interactive=True, + ) + with gr.Row(): + use_sparse_bias = gr.Checkbox( + label='Use sparse biais', value=False, interactive=True + ) + sparsity = gr.Slider( + minimum=0, + maximum=1, + label='Sparsity', + value=0.98, + step=0.01, + interactive=True, + ) + disable_cp = gr.Checkbox( + label='Disable CP decomposition', value=False, interactive=True + ) + mode.change( + update_mode, + inputs=[mode], + outputs=[ + fixed, + threshold, + ratio, + quantile, + ], + ) + + extract_button = gr.Button('Extract LyCORIS LoCon') + + extract_button.click( + extract_lycoris_locon, + inputs=[ + db_model, + base_model, + output_name, + device, + is_v2, + mode, + linear_dim, + conv_dim, + linear_threshold, + conv_threshold, + linear_ratio, + conv_ratio, + linear_quantile, + conv_quantile, + use_sparse_bias, + sparsity, + disable_cp, + ], + show_progress=False, + ) diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..9aaf3d9bee98ce02a4618e6ebcffb16c3f007f53 --- /dev/null +++ b/library/git_caption_gui.py @@ -0,0 +1,136 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import get_folder_path, add_pre_postfix + +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def caption_images( + train_data_dir, + caption_ext, + batch_size, + max_data_loader_n_workers, + max_length, + model_id, + prefix, + postfix, +): + # Check for images_dir_input + if train_data_dir == '': + msgbox('Image folder is missing...') + return + + if caption_ext == '': + msgbox('Please provide an extension for the caption files.') + return + + print(f'GIT captioning files in {train_data_dir}...') + run_cmd = ( + f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"' + ) + if not model_id == '': + run_cmd += f' --model_id="{model_id}"' + run_cmd += f' --batch_size="{int(batch_size)}"' + run_cmd += ( + f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' + ) + run_cmd += f' --max_length="{int(max_length)}"' + if caption_ext != '': + run_cmd += f' --caption_extension="{caption_ext}"' + run_cmd += f' "{train_data_dir}"' + + print(run_cmd) + + # Run the command + subprocess.run(run_cmd) + + # Add prefix and postfix + add_pre_postfix( + folder=train_data_dir, + caption_file_ext=caption_ext, + prefix=prefix, + postfix=postfix, + ) + + print('...captioning done') + + +### +# Gradio UI +### + + +def gradio_git_caption_gui_tab(): + with gr.Tab('GIT Captioning'): + gr.Markdown( + 'This utility will use GIT to caption files for each images in a folder.' + ) + with gr.Row(): + train_data_dir = gr.Textbox( + label='Image folder to caption', + placeholder='Directory containing the images to caption', + interactive=True, + ) + button_train_data_dir_input = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_train_data_dir_input.click( + get_folder_path, + outputs=train_data_dir, + show_progress=False, + ) + with gr.Row(): + caption_ext = gr.Textbox( + label='Caption file extension', + placeholder='Extention for caption file. eg: .caption, .txt', + value='.txt', + interactive=True, + ) + + prefix = gr.Textbox( + label='Prefix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) + + postfix = gr.Textbox( + label='Postfix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) + + batch_size = gr.Number( + value=1, label='Batch size', interactive=True + ) + + with gr.Row(): + max_data_loader_n_workers = gr.Number( + value=2, label='Number of workers', interactive=True + ) + max_length = gr.Number( + value=75, label='Max length', interactive=True + ) + model_id = gr.Textbox( + label='Model', + placeholder='(Optional) model id for GIT in Hugging Face', + interactive=True, + ) + + caption_button = gr.Button('Caption images') + + caption_button.click( + caption_images, + inputs=[ + train_data_dir, + caption_ext, + batch_size, + max_data_loader_n_workers, + max_length, + model_id, + prefix, + postfix, + ], + show_progress=False, + ) diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e04b8876f39086e22368b736d00e047b09f0fe3 --- /dev/null +++ b/library/lpw_stable_diffusion.py @@ -0,0 +1,1179 @@ +# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# and modify to support SD2.x + +import inspect +import re +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import SchedulerMixin, StableDiffusionPipeline +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.utils import logging + + +try: + from diffusers.utils import PIL_INTERPOLATION +except ImportError: + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } + else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe: StableDiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最埌に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであずはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings = pipe.text_encoder(text_input)[0] + return text_embeddings + + +def get_weighted_text_embeddings( + pipe: StableDiffusionPipeline, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`StableDiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + clip_skip: int, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, + ) + self.clip_skip = clip_skip + self.__init__additional__() + + # else: + # def __init__( + # self, + # vae: AutoencoderKL, + # text_encoder: CLIPTextModel, + # tokenizer: CLIPTokenizer, + # unet: UNet2DConditionModel, + # scheduler: SchedulerMixin, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPFeatureExtractor, + # ): + # super().__init__( + # vae=vae, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + # unet=unet, + # scheduler=scheduler, + # safety_checker=safety_checker, + # feature_extractor=feature_extractor, + # ) + # self.__init__additional__() + + def __init__additional__(self): + if not hasattr(self, "vae_scale_factor"): + setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + ) + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + bs_embed, seq_len, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def check_inputs(self, prompt, height, width, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + print(height, width) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device, is_text2img): + if is_text2img: + return self.scheduler.timesteps.to(device), num_inference_steps + else: + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(device) + return timesteps, num_inference_steps - t_start + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): + if image is None: + shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, None, None + else: + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + init_latents = torch.cat([init_latents] * batch_size, dim=0) + init_latents_orig = init_latents + shape = init_latents.shape + + # add noise to latents using the timesteps + if device.type == "mps": + noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + ) + dtype = text_embeddings.dtype + + # 4. Preprocess image and mask + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + if image is not None: + image = image.to(device=self.device, dtype=dtype) + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=dtype) + mask = torch.cat([mask] * batch_size * num_images_per_prompt) + else: + mask = None + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return image, has_nsfw_concept + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def img2img( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for image-to-image generation. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def inpaint( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for inpaint. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..21cd16ad364204b4e2522838fe7f2bc0d69f3ff7 --- /dev/null +++ b/library/merge_lora_gui.py @@ -0,0 +1,156 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def merge_lora( + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, +): + # Check for caption_text_input + if lora_a_model == '': + msgbox('Invalid model A file') + return + + if lora_b_model == '': + msgbox('Invalid model B file') + return + + # Check if source model exist + if not os.path.isfile(lora_a_model): + msgbox('The provided model A is not a file') + return + + if not os.path.isfile(lora_b_model): + msgbox('The provided model B is not a file') + return + + ratio_a = ratio + ratio_b = 1 - ratio + + run_cmd = f'{PYTHON} "{os.path.join("networks","merge_lora.py")}"' + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --precision {precision}' + run_cmd += f' --save_to "{save_to}"' + run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"' + run_cmd += f' --ratios {ratio_a} {ratio_b}' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_merge_lora_tab(): + with gr.Tab('Merge LoRA'): + gr.Markdown('This utility can merge two LoRA networks together.') + + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + lora_a_model = gr.Textbox( + label='LoRA model "A"', + placeholder='Path to the LoRA A model', + interactive=True, + ) + button_lora_a_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_a_model_file.click( + get_file_path, + inputs=[lora_a_model, lora_ext, lora_ext_name], + outputs=lora_a_model, + show_progress=False, + ) + + lora_b_model = gr.Textbox( + label='LoRA model "B"', + placeholder='Path to the LoRA B model', + interactive=True, + ) + button_lora_b_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_b_model_file.click( + get_file_path, + inputs=[lora_b_model, lora_ext, lora_ext_name], + outputs=lora_b_model, + show_progress=False, + ) + with gr.Row(): + ratio = gr.Slider( + label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B', + minimum=0, + maximum=1, + step=0.01, + value=0.5, + interactive=True, + ) + + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path for the file to save...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + precision = gr.Dropdown( + label='Merge precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + save_precision = gr.Dropdown( + label='Save precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + + convert_button = gr.Button('Merge model') + + convert_button.click( + merge_lora, + inputs=[ + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, + ], + show_progress=False, + ) diff --git a/library/model_util.py b/library/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..35b0b6afe82734e194d533de0bcbb73c704ba65e --- /dev/null +++ b/library/model_util.py @@ -0,0 +1,1165 @@ +# v1: split from train_db_fixed.py. +# v2: support safetensors + +import math +import os +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from safetensors.torch import load_file, save_file + +# DiffUsers版StableDiffusionのモデルパラメヌタ +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 + +# Diffusersの蚭定を読み蟌むための参照モデル +DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" +DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" + + +# region StableDiffusion->Diffusersの倉換コヌド +# convert_original_stable_diffusion_to_diffusers をコピヌしお修正しおいるASL 2.0 + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + + +def convert_ldm_unet_checkpoint(v2, checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + # オリゞナル + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasずweightの順番に䟝存しないようにするもっずいいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに倉わっおいるので、linear->convに倉換する + if v2: + linear_transformer_to_conv(new_checkpoint) + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 枡されたcheckpointは.ckptから読み蟌んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def create_unet_diffusers_config(v2): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + ) + + return config + + +def create_vae_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config + + +def convert_ldm_clip_checkpoint_v1(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + return text_model_dict + + +def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): + # 嫌になるくらい違うぞ + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので埌で凊理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 䜿われない??? + elif ".logit_scale" in key: + key = None # 䜿われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの倉換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 䞉぀に分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd + + +# endregion + + +# region Diffusers->StableDiffusion の倉換コヌド +# convert_diffusers_to_original_stable_diffusion をコピヌしお修正しおいるASL 2.0 + + +def conv_transformer_to_linear(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + + +def convert_unet_state_dict_to_sd(v2, unet_state_dict): + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict + + +# endregion + +# region 自䜜のモデル読み曞きなど + + +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): + # text encoderの栌玍圢匏が違うモデルに察応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +# TODO dtype指定の動䜜が怪しいので確認する text_encoderを指定圢匏で䜜れるか未確認 +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None): + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config).to(device) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config).to(device) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loading vae:", info) + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + + logging.set_verbosity_error() # don't show annoying warning + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + logging.set_verbosity_warning() + + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) + + return text_model, vae, unet + + +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの陀去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので埌で凊理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの倉換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 䞉぀を結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最埌の局などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないずsafetensorsの保存で萜ちる + + # Diffusersに含たれない重みを䜜っおおく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd + + +def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): + if ckpt_path is not None: + # epoch/stepを参照する。たたVAEがメモリ䞊にないずきなど、もう䞀床VAEを含めお読み蟌む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors たたは state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく䜜る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない堎合は最埌の局を前の局から耇補しお䜜るなどダミヌの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + # TODO Tensor以倖のdictの倀を削陀したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +# endregion + + +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) + + resos = set() + + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) + + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) + + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) + + size += divisible + + resos = list(resos) + resos.sort() + return resos + + +if __name__ == "__main__": + resos = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + aspect_ratios = [w / h for w, h in resos] + print(aspect_ratios) + + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf1b4569f36bf148ff591d63b4888cdea1b77e4 --- /dev/null +++ b/library/resize_lora_gui.py @@ -0,0 +1,173 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import get_saveasfilename_path, get_file_path + +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 + + +def resize_lora( + model, + new_rank, + save_to, + save_precision, + device, + dynamic_method, + dynamic_param, + verbose, +): + # Check for caption_text_input + if model == '': + msgbox('Invalid model file') + return + + # Check if source model exist + if not os.path.isfile(model): + msgbox('The provided model is not a file') + return + + if dynamic_method == 'sv_ratio': + if float(dynamic_param) < 2: + msgbox( + f'Dynamic parameter for {dynamic_method} need to be 2 or greater...' + ) + return + + if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative': + if float(dynamic_param) < 0 or float(dynamic_param) > 1: + msgbox( + f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...' + ) + return + + # Check if save_to end with one of the defines extension. If not add .safetensors. + if not save_to.endswith(('.pt', '.safetensors')): + save_to += '.safetensors' + + if device == '': + device = 'cuda' + + run_cmd = f'{PYTHON} "{os.path.join("networks","resize_lora.py")}"' + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --save_to "{save_to}"' + run_cmd += f' --model "{model}"' + run_cmd += f' --new_rank {new_rank}' + run_cmd += f' --device {device}' + if not dynamic_method == 'None': + run_cmd += f' --dynamic_method {dynamic_method}' + run_cmd += f' --dynamic_param {dynamic_param}' + if verbose: + run_cmd += f' --verbose' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_resize_lora_tab(): + with gr.Tab('Resize LoRA'): + gr.Markdown('This utility can resize a LoRA.') + + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + model = gr.Textbox( + label='Source LoRA', + placeholder='Path to the LoRA to resize', + interactive=True, + ) + button_lora_a_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_a_model_file.click( + get_file_path, + inputs=[model, lora_ext, lora_ext_name], + outputs=model, + show_progress=False, + ) + with gr.Row(): + new_rank = gr.Slider( + label='Desired LoRA rank', + minimum=1, + maximum=1024, + step=1, + value=4, + interactive=True, + ) + + with gr.Row(): + dynamic_method = gr.Dropdown( + choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'], + value='sv_fro', + label='Dynamic method', + interactive=True, + ) + dynamic_param = gr.Textbox( + label='Dynamic parameter', + value='0.9', + interactive=True, + placeholder='Value for the dynamic method selected.', + ) + verbose = gr.Checkbox(label='Verbose', value=False) + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path for the LoRA file to save...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + save_precision = gr.Dropdown( + label='Save precision', + choices=['fp16', 'bf16', 'float'], + value='fp16', + interactive=True, + ) + device = gr.Dropdown( + label='Device', + choices=[ + 'cpu', + 'cuda', + ], + value='cuda', + interactive=True, + ) + + convert_button = gr.Button('Resize model') + + convert_button.click( + resize_lora, + inputs=[ + model, + new_rank, + save_to, + save_precision, + device, + dynamic_method, + dynamic_param, + verbose, + ], + show_progress=False, + ) diff --git a/library/sampler_gui.py b/library/sampler_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..ce953138b4dd49fb61879b971a6b7ac8a87979b6 --- /dev/null +++ b/library/sampler_gui.py @@ -0,0 +1,102 @@ +import tempfile +import os +import gradio as gr +from easygui import msgbox + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 + + +### +### Gradio common sampler GUI section +### + + +def sample_gradio_config(): + with gr.Accordion('Sample images config', open=False): + with gr.Row(): + sample_every_n_steps = gr.Number( + label='Sample every n steps', + value=0, + precision=0, + interactive=True, + ) + sample_every_n_epochs = gr.Number( + label='Sample every n epochs', + value=0, + precision=0, + interactive=True, + ) + sample_sampler = gr.Dropdown( + label='Sample sampler', + choices=[ + 'ddim', + 'pndm', + 'lms', + 'euler', + 'euler_a', + 'heun', + 'dpm_2', + 'dpm_2_a', + 'dpmsolver', + 'dpmsolver++', + 'dpmsingle', + 'k_lms', + 'k_euler', + 'k_euler_a', + 'k_dpm_2', + 'k_dpm_2_a', + ], + value='euler_a', + interactive=True, + ) + with gr.Row(): + sample_prompts = gr.Textbox( + lines=5, + label='Sample prompts', + interactive=True, + placeholder='masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28', + ) + return ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) + + +def run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, +): + output_dir = os.path.join(output_dir, 'sample') + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + run_cmd = '' + + if sample_every_n_epochs == 0 and sample_every_n_steps == 0: + return run_cmd + + # Create the prompt file and get its path + sample_prompts_path = os.path.join(output_dir, 'prompt.txt') + + with open(sample_prompts_path, 'w') as f: + f.write(sample_prompts) + + run_cmd += f' --sample_sampler={sample_sampler}' + run_cmd += f' --sample_prompts="{sample_prompts_path}"' + + if not sample_every_n_epochs == 0: + run_cmd += f' --sample_every_n_epochs="{sample_every_n_epochs}"' + + if not sample_every_n_steps == 0: + run_cmd += f' --sample_every_n_steps="{sample_every_n_steps}"' + + return run_cmd diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..042be2ecc443c5ecefa58bf44d026bd12700ad90 --- /dev/null +++ b/library/svd_merge_lora_gui.py @@ -0,0 +1,190 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def svd_merge_lora( + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, + new_rank, + new_conv_rank, + device, +): + # Check for caption_text_input + if lora_a_model == '': + msgbox('Invalid model A file') + return + + if lora_b_model == '': + msgbox('Invalid model B file') + return + + # Check if source model exist + if not os.path.isfile(lora_a_model): + msgbox('The provided model A is not a file') + return + + if not os.path.isfile(lora_b_model): + msgbox('The provided model B is not a file') + return + + ratio_a = ratio + ratio_b = 1 - ratio + + run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"' + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --precision {precision}' + run_cmd += f' --save_to "{save_to}"' + run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"' + run_cmd += f' --ratios {ratio_a} {ratio_b}' + run_cmd += f' --device {device}' + run_cmd += f' --new_rank "{new_rank}"' + run_cmd += f' --new_conv_rank "{new_conv_rank}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_svd_merge_lora_tab(): + with gr.Tab('Merge LoRA (SVD)'): + gr.Markdown('This utility can merge two LoRA networks together.') + + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + lora_a_model = gr.Textbox( + label='LoRA model "A"', + placeholder='Path to the LoRA A model', + interactive=True, + ) + button_lora_a_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_a_model_file.click( + get_file_path, + inputs=[lora_a_model, lora_ext, lora_ext_name], + outputs=lora_a_model, + show_progress=False, + ) + + lora_b_model = gr.Textbox( + label='LoRA model "B"', + placeholder='Path to the LoRA B model', + interactive=True, + ) + button_lora_b_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_b_model_file.click( + get_file_path, + inputs=[lora_b_model, lora_ext, lora_ext_name], + outputs=lora_b_model, + show_progress=False, + ) + with gr.Row(): + ratio = gr.Slider( + label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B', + minimum=0, + maximum=1, + step=0.01, + value=0.5, + interactive=True, + ) + new_rank = gr.Slider( + label='New Rank', + minimum=1, + maximum=1024, + step=1, + value=128, + interactive=True, + ) + new_conv_rank = gr.Slider( + label='New Conv Rank', + minimum=1, + maximum=1024, + step=1, + value=128, + interactive=True, + ) + + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path for the file to save...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + precision = gr.Dropdown( + label='Merge precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + save_precision = gr.Dropdown( + label='Save precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + device = gr.Dropdown( + label='Device', + choices=[ + 'cpu', + 'cuda', + ], + value='cuda', + interactive=True, + ) + + convert_button = gr.Button('Merge model') + + convert_button.click( + svd_merge_lora, + inputs=[ + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, + new_rank, + new_conv_rank, + device, + ], + show_progress=False, + ) diff --git a/library/tensorboard_gui.py b/library/tensorboard_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..d08a02d94f22ec41aa58fb4ed7968184dd66bef3 --- /dev/null +++ b/library/tensorboard_gui.py @@ -0,0 +1,53 @@ +import os +import gradio as gr +from easygui import msgbox +import subprocess +import time + +tensorboard_proc = None # I know... bad but heh +TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe' + + +def start_tensorboard(logging_dir): + global tensorboard_proc + + if not os.listdir(logging_dir): + print('Error: log folder is empty') + msgbox(msg='Error: log folder is empty') + return + + run_cmd = [f'{TENSORBOARD}', '--logdir', f'{logging_dir}'] + + print(run_cmd) + if tensorboard_proc is not None: + print( + 'Tensorboard is already running. Terminating existing process before starting new one...' + ) + stop_tensorboard() + + # Start background process + print('Starting tensorboard...') + tensorboard_proc = subprocess.Popen(run_cmd) + + # Wait for some time to allow TensorBoard to start up + time.sleep(5) + + # Open the TensorBoard URL in the default browser + print('Opening tensorboard url in browser...') + import webbrowser + + webbrowser.open('http://localhost:6006') + + +def stop_tensorboard(): + print('Stopping tensorboard process...') + tensorboard_proc.kill() + print('...process stopped') + + +def gradio_tensorboard(): + with gr.Row(): + button_start_tensorboard = gr.Button('Start tensorboard') + button_stop_tensorboard = gr.Button('Stop tensorboard') + + return (button_start_tensorboard, button_stop_tensorboard) diff --git a/library/train_util.py b/library/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..59dbc44c71a5390148c04743dc1982ae12cb6b1d --- /dev/null +++ b/library/train_util.py @@ -0,0 +1,3100 @@ +# common functions for training + +import argparse +import ast +import importlib +import json +import pathlib +import re +import shutil +import time +from typing import ( + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) +from accelerate import Accelerator +import glob +import math +import os +import random +import hashlib +import subprocess +from io import BytesIO +import toml + +from tqdm import tqdm +import torch +from torch.optim import Optimizer +from torchvision import transforms +from transformers import CLIPTokenizer +import transformers +import diffusers +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers import ( + StableDiffusionPipeline, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, +) +import albumentations as albu +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +from torch import einsum +import safetensors.torch +from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline +import library.model_util as model_util + +# Tokenizer: checkpointから読み蟌むのではなくあらかじめ提䟛されおいるものを䜿う +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ䜿う v2ずv2.1はtokenizer仕様は同じ + +# checkpointファむル名 +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +DEFAULT_EPOCH_NAME = "epoch" +DEFAULT_LAST_OUTPUT_NAME = "last" + +# region dataset + +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] + + +class ImageInfo: + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: Tuple[int, int] = None + self.resized_size: Tuple[int, int] = None + self.bucket_reso: Tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None + + +class BucketManager: + def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: + self.no_upscale = no_upscale + if max_reso is None: + self.max_reso = None + self.max_area = None + else: + self.max_reso = max_reso + self.max_area = max_reso[0] * max_reso[1] + self.min_size = min_size + self.max_size = max_size + self.reso_steps = reso_steps + + self.resos = [] + self.reso_to_id = {} + self.buckets = [] # 前凊理時は (image_key, image)、孊習時は image_key + + def add_image(self, reso, image): + bucket_id = self.reso_to_id[reso] + self.buckets[bucket_id].append(image) + + def shuffle(self): + for bucket in self.buckets: + random.shuffle(bucket) + + def sort(self): + # 解像床順に゜ヌトする衚瀺時、メタデヌタ栌玍時の芋栄えをよくするためだけ。bucketsも入れ替えおreso_to_idも振り盎す + sorted_resos = self.resos.copy() + sorted_resos.sort() + + sorted_buckets = [] + sorted_reso_to_id = {} + for i, reso in enumerate(sorted_resos): + bucket_id = self.reso_to_id[reso] + sorted_buckets.append(self.buckets[bucket_id]) + sorted_reso_to_id[reso] = i + + self.resos = sorted_resos + self.buckets = sorted_buckets + self.reso_to_id = sorted_reso_to_id + + def make_buckets(self): + resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) + self.set_predefined_resos(resos) + + def set_predefined_resos(self, resos): + # 芏定サむズから遞ぶ堎合の解像床、aspect ratioの情報を栌玍しおおく + self.predefined_resos = resos.copy() + self.predefined_resos_set = set(resos) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) + + def add_if_new_reso(self, reso): + if reso not in self.reso_to_id: + bucket_id = len(self.resos) + self.reso_to_id[reso] = bucket_id + self.resos.append(reso) + self.buckets.append([]) + # print(reso, bucket_id, len(self.buckets)) + + def round_to_steps(self, x): + x = int(x + 0.5) + return x - x % self.reso_steps + + def select_bucket(self, image_width, image_height): + aspect_ratio = image_width / image_height + if not self.no_upscale: + # 同じaspect ratioがあるかもしれないのでfine tuningで、no_upscale=Trueで前凊理した堎合、解像床が同じものを優先する + reso = (image_width, image_height) + if reso in self.predefined_resos_set: + pass + else: + ar_errors = self.predefined_aspect_ratios - aspect_ratio + predefined_bucket_id = np.abs(ar_errors).argmin() # 圓該解像床以倖でaspect ratio errorが最も少ないもの + reso = self.predefined_resos[predefined_bucket_id] + + ar_reso = reso[0] / reso[1] + if aspect_ratio > ar_reso: # 暪が長い→瞊を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + + resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) + # print("use predef", image_width, image_height, reso, resized_size) + else: + if image_width * image_height > self.max_area: + # 画像が倧きすぎるのでアスペクト比を保ったたた瞮小するこずを前提にbucketを決める + resized_width = math.sqrt(self.max_area * aspect_ratio) + resized_height = self.max_area / resized_width + assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" + + # リサむズ埌の短蟺たたは長蟺をreso_steps単䜍にするaspect ratioの差が少ないほうを遞ぶ + # 元のbucketingず同じロゞック + b_width_rounded = self.round_to_steps(resized_width) + b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) + ar_width_rounded = b_width_rounded / b_height_in_wr + + b_height_rounded = self.round_to_steps(resized_height) + b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) + ar_height_rounded = b_width_in_hr / b_height_rounded + + # print(b_width_rounded, b_height_in_wr, ar_width_rounded) + # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + + if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): + resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) + else: + resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) + # print(resized_size) + else: + resized_size = (image_width, image_height) # リサむズは䞍芁 + + # 画像のサむズ未満をbucketのサむズずするpaddingせずにcroppingする + bucket_width = resized_size[0] - resized_size[0] % self.reso_steps + bucket_height = resized_size[1] - resized_size[1] % self.reso_steps + # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + + reso = (bucket_width, bucket_height) + + self.add_if_new_reso(reso) + + ar_error = (reso[0] / reso[1]) - aspect_ratio + return reso, resized_size, ar_error + + +class BucketBatchIndex(NamedTuple): + bucket_index: int + bucket_batch_size: int + batch_index: int + + +class AugHelper: + def __init__(self): + # prepare all possible augmentators + color_aug_method = albu.OneOf( + [ + albu.HueSaturationValue(8, 0, 0, p=0.5), + albu.RandomGamma((95, 105), p=0.5), + ], + p=0.33, + ) + flip_aug_method = albu.HorizontalFlip(p=0.5) + + # key: (use_color_aug, use_flip_aug) + self.augmentors = { + (True, True): albu.Compose( + [ + color_aug_method, + flip_aug_method, + ], + p=1.0, + ), + (True, False): albu.Compose( + [ + color_aug_method, + ], + p=1.0, + ), + (False, True): albu.Compose( + [ + flip_aug_method, + ], + p=1.0, + ), + (False, False): None, + } + + def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: + return self.augmentors[(use_color_aug, use_flip_aug)] + + +class BaseSubset: + def __init__( + self, + image_dir: Optional[str], + num_repeats: int, + shuffle_caption: bool, + keep_tokens: int, + color_aug: bool, + flip_aug: bool, + face_crop_aug_range: Optional[Tuple[float, float]], + random_crop: bool, + caption_dropout_rate: float, + caption_dropout_every_n_epochs: int, + caption_tag_dropout_rate: float, + token_warmup_min: int, + token_warmup_step: Union[float, int], + ) -> None: + self.image_dir = image_dir + self.num_repeats = num_repeats + self.shuffle_caption = shuffle_caption + self.keep_tokens = keep_tokens + self.color_aug = color_aug + self.flip_aug = flip_aug + self.face_crop_aug_range = face_crop_aug_range + self.random_crop = random_crop + self.caption_dropout_rate = caption_dropout_rate + self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs + self.caption_tag_dropout_rate = caption_tag_dropout_rate + + self.token_warmup_min = token_warmup_min # step=0におけるタグの数 + self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最倧になる + + self.img_count = 0 + + +class DreamBoothSubset(BaseSubset): + def __init__( + self, + image_dir: str, + is_reg: bool, + class_tokens: Optional[str], + caption_extension: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) + + self.is_reg = is_reg + self.class_tokens = class_tokens + self.caption_extension = caption_extension + + def __eq__(self, other) -> bool: + if not isinstance(other, DreamBoothSubset): + return NotImplemented + return self.image_dir == other.image_dir + + +class FineTuningSubset(BaseSubset): + def __init__( + self, + image_dir, + metadata_file: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) -> None: + assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" + + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) + + self.metadata_file = metadata_file + + def __eq__(self, other) -> bool: + if not isinstance(other, FineTuningSubset): + return NotImplemented + return self.metadata_file == other.metadata_file + + +class BaseDataset(torch.utils.data.Dataset): + def __init__( + self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.max_token_length = max_token_length + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution + self.debug_dataset = debug_dataset + + self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] + + self.token_padding_disabled = False + self.tag_frequency = {} + self.XTI_layers = None + self.token_strings = None + + self.enable_bucket = False + self.bucket_manager: BucketManager = None # not initialized + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None + self.bucket_no_upscale = None + self.bucket_info = None # for metadata + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + self.current_epoch: int = 0 # むンスタンスがepochごずに新しく䜜られるようなので倖偎から枡さないずダメ + + self.current_step: int = 0 + self.max_train_steps: int = 0 + self.seed: int = 0 + + # augmentation + self.aug_helper = AugHelper() + + self.image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.image_data: Dict[str, ImageInfo] = {} + self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} + + self.replacements = {} + + def set_seed(self, seed): + self.seed = seed + + def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする + self.shuffle_buckets() + self.current_epoch = epoch + + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps + + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + tag = tag.strip() + if tag: + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + + def disable_token_padding(self): + self.token_padding_disabled = True + + def enable_XTI(self, layers=None, token_strings=None): + self.XTI_layers = layers + self.token_strings = token_strings + + def add_replacement(self, str_from, str_to): + self.replacements[str_from] = str_to + + def process_caption(self, subset: BaseSubset, caption): + # dropoutの決定tag dropがこのメ゜ッド内にあるのでここで行うのが良い + is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate + is_drop_out = ( + is_drop_out + or subset.caption_dropout_every_n_epochs > 0 + and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 + ) + + if is_drop_out: + caption = "" + else: + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + tokens = [t.strip() for t in caption.strip().split(",")] + if subset.token_warmup_step < 1: # 初回に䞊曞きする + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = ( + math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + + subset.token_warmup_min + ) + tokens = tokens[:tokens_len] + + def dropout_tags(tokens): + if subset.caption_tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= subset.caption_tag_dropout_rate: + l.append(token) + return l + + fixed_tokens = [] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens :] + + if subset.shuffle_caption: + random.shuffle(flex_tokens) + + flex_tokens = dropout_tags(flex_tokens) + + caption = ", ".join(fixed_tokens + flex_tokens) + + # textual inversion察応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) + + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer( + caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" + ).input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以䞊の時は " .... " でトヌタル227ずかになっおいるので、"..."の䞉連に倉換する + # 1111氏のや぀は , で区切る、ずかしおいるようだが ずりあえず単玔に + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以䞊の時は " .... ..." でトヌタル227ずかになっおいるので、"... ..."の䞉連に倉換する + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尟が たたは の堎合は、䜕もしなくおよい + # 末尟が x の堎合は末尟を に倉えるx なら結果的に倉化なし + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の堎合は ... に倉える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo, subset: BaseSubset): + self.image_data[info.image_key] = info + self.image_to_subset[info.image_key] = subset + + def make_buckets(self): + """ + bucketingを行わない堎合も呌び出し必須ひず぀だけbucketを䜜る + min_size and max_size are ignored when enable_bucket is False + """ + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if self.enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketを䜜成し、画像をbucketに振り分ける + if self.enable_bucket: + if self.bucket_manager is None: # fine tuningの堎合でmetadataに定矩がある堎合は、すでに初期化枈み + self.bucket_manager = BucketManager( + self.bucket_no_upscale, + (self.width, self.height), + self.min_bucket_reso, + self.max_bucket_reso, + self.bucket_reso_steps, + ) + if not self.bucket_no_upscale: + self.bucket_manager.make_buckets() + else: + print( + "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された堎合は、bucketの解像床は画像サむズから自動蚈算されるため、min_bucket_resoずmax_bucket_resoは無芖されたす" + ) + + img_ar_errors = [] + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( + image_width, image_height + ) + + # print(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) + + self.bucket_manager.sort() + else: + self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) + self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひず぀の固定サむズbucketのみ + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + + for image_info in self.image_data.values(): + for _ in range(image_info.num_repeats): + self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + + # bucket情報を衚瀺、栌玍する + if self.enable_bucket: + self.bucket_info = {"buckets": {}} + print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む") + for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): + count = len(bucket) + if count > 0: + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} + print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + + # デヌタ参照甚indexを䜜る。このindexはdatasetのshuffleに甚いられる + self.buckets_indices: List(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.bucket_manager.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) + + # ↓以䞋はbucketごずのbatch件数があたりにも増えお混乱を招くので元に戻す + #  孊習時はステップ数がランダムなので、同䞀画像が同䞀batch内にあっおもそれほど悪圱響はないであろう、ず考えられる + # + # # bucketが现分化されるこずにより、ひず぀のbucketに䞀皮類の画像のみずいうケヌスが増え、぀たりそれは + # # ひず぀のbatchが同じ画像で占められるこずになるので、さすがに良くないであろう + # # そのためバッチサむズを画像皮類たでに制限する + # # ただそれでも同䞀画像が同䞀バッチに含たれる可胜性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなるこずは間違いない + # # TO DO 正則化画像をepochたたがりで利甚する仕組み + # num_of_image_types = len(set(bucket)) + # bucket_batch_size = min(self.batch_size, num_of_image_types) + # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) + # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # for batch_index in range(batch_count): + # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) + # ↑ここたで + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + + random.shuffle(self.buckets_indices) + self.bucket_manager.shuffle() + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + image_height, image_width = image.shape[0:2] + + if image_width != resized_size[0] or image_height != resized_size[1]: + # リサむズする + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサむズ + + image_height, image_width = image.shape[0:2] + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("w", trim_size, p) + image = image[:, p : p + reso[0]] + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("h", trim_size, p) + image = image[p : p + reso[1]] + + assert ( + image.shape[0] == reso[1] and image.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def is_latent_cacheable(self): + return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + + def cache_latents(self, vae, vae_batch_size=1): + # ちょっず速くした + print("caching latents.") + + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + for info in image_infos: + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= vae_batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # iterate batches + for batch in tqdm(batches, smoothing=1, total=len(batches)): + images = [] + for info in batch: + image = self.load_image(info.absolute_path) + image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + image = self.image_transforms(image) + images.append(image) + + img_tensors = torch.stack(images, dim=0) + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): + info.latents = latent + + if subset.flip_aug: + img_tensors = torch.flip(img_tensors, dims=[3]) + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): + info.latents_flipped = latent + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, subset: BaseSubset, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if subset.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split("_") + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サむズはsizeより倧きいのでリサむズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サむズぎったりになる倍率最小の倍率 + min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サむズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最倧サむズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + 0.5) + nw = int(width * scale + 0.5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + 0.5) + face_cy = int(face_cy * scale + 0.5) + height, width = nh, nw + + # 顔を䞭心ずしお448*640ずかぞ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を䞭心に持っおくるための切り出し䜍眮 + + if subset.random_crop: + # 背景も含めるために顔を䞭心に眮く確率を高め぀぀ずらす + range = max(length - face_p, face_p) # 画像の端から顔䞭心たでの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range たでのいい感じの乱数 + else: + # range指定があるずきのみ、すこしだけランダムにわりず適圓 + if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1 : p1 + target_size, :] + else: + image = image[:, p1 : p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None + return np.load(npz_file)["arr_0"] + + def __len__(self): + return self._length + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを凊理する + if image_info.latents is not None: + latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み蟌み、必芁ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + else: + if face_cx > 0: # 顔䜍眮情報あり + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert ( + subset.random_crop + ), f"image too large, but cropping and bucketing are disabled / 画像サむズが倧きいのでface_crop_aug_rangeかrandom_crop、たたはbucketを有効にしおください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p : p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p : p + self.width] + + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small / 画像サむズが小さいようです: {image_info.absolute_path}" + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)["image"] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(subset, image_info.caption) + if self.XTI_layers: + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) + else: + captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future + if self.XTI_layers: + token_caption = self.get_input_ids(caption_layer) + else: + token_caption = self.get_input_ids(caption) + input_ids_list.append(token_caption) + + example = {} + example["loss_weights"] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example["input_ids"] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example["images"] = images + + example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example["image_keys"] = bucket[image_index : image_index + self.batch_size] + example["captions"] = captions + return example + + +class DreamBoothDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[DreamBoothSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + prior_loss_weight: float, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + assert resolution is not None, f"resolution is required / resolution解像床指定は必須です" + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.latents_cache = None + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像床より倧きくできたせん。解像床を倧きくするかmin_bucket_resoを小さくしおください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最倧解像床より小さくできたせん。解像床を小さくするかmin_bucket_resoを倧きくしおください" + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は䜿われない + self.bucket_no_upscale = False + + def read_caption(img_path, caption_extension): + # captionの候補ファむル名を䜜る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファむルにUTF-8以倖の文字がありたす: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファむルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption + + def load_dreambooth_dir(subset: DreamBoothSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] + + img_paths = glob_images(subset.image_dir, "*") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + + # 画像ファむルごずにプロンプトを読み蟌み、もしあればそちらを䜿う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None and subset.class_tokens is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + captions.append("") + else: + captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻床を蚘録 + + return img_paths, captions + + print("prepare images.") + num_train_images = 0 + num_reg_images = 0 + reg_infos: List[ImageInfo] = [] + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を䞋回っおいるためサブセットを無芖したす: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されおいるため、重耇した埌発のサブセットを無芖したす" + ) + continue + + img_paths, captions = load_dreambooth_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が芋぀からないためサブセットを無芖したす") + continue + + if subset.is_reg: + num_reg_images += subset.num_repeats * len(img_paths) + else: + num_train_images += subset.num_repeats * len(img_paths) + + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if subset.is_reg: + reg_infos.append(info) + else: + self.register_image(info, subset) + + subset.img_count = len(img_paths) + self.subsets.append(subset) + + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が倚いので、䞀郚䜿甚されない正則化画像がありたす") + + if num_reg_images == 0: + print("no regularization images / 正則化画像が芋぀かりたせんでした") + else: + # num_repeatsを蚈算するどうせ倧した数ではないのでルヌプで凊理する + n = 0 + first_loop = True + while n < num_train_images: + for info in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 # rewrite registered info + n += 1 + if n >= num_train_images: + break + first_loop = False + + self.num_reg_images = num_reg_images + + +class FineTuningDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[FineTuningSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + self.batch_size = batch_size + + self.num_train_images = 0 + self.num_reg_images = 0 + + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を䞋回っおいるためサブセットを無芖したす: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されおいるため、重耇した埌発のサブセットを無芖したす" + ) + continue + + # メタデヌタを読み蟌む + if os.path.exists(subset.metadata_file): + print(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding="utf-8") as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデヌタファむルがありたせん: {subset.metadata_file}") + + if len(metadata) < 1: + print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデヌタが芋぀からないためサブセットを無芖したす") + continue + + tags_list = [] + for image_key, img_md in metadata.items(): + # path情報を䜜る + if os.path.exists(image_key): + abs_path = image_key + elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" + else: + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + else: + # わりずいい加枛だがいい方法が思い぀かん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありたせん: {image_key}" + abs_path = abs_path[0] + + caption = img_md.get("caption") + tags = img_md.get("tags") + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ", " + tags + tags_list.append(tags) + + if caption is None: + caption = "" + + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info.image_size = img_md.get("train_resolution") + + if not subset.color_aug and not subset.random_crop: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + + self.register_image(image_info, subset) + + self.num_train_images += len(metadata) * subset.num_repeats + + # TODO do not record tag freq when no tag + self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) + subset.img_count = len(metadata) + self.subsets.append(subset) + + # check existence of all npz files + use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) + if use_npz_latents: + flip_aug_in_subset = False + npz_any = False + npz_all = True + + for image_info in self.image_data.values(): + subset = self.image_to_subset[image_info.image_key] + + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if subset.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + flip_aug_in_subset = True + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + use_npz_latents = False + print(f"npz file does not exist. ignore npz files / npzファむルが芋぀からないためnpzファむルを無芖したす") + elif not npz_all: + use_npz_latents = False + print(f"some of npz file does not exist. ignore npz files / いく぀かのnpzファむルが芋぀からないためnpzファむルを無芖したす") + if flip_aug_in_subset: + print("maybe no flipped files / 反転されたnpzファむルがないのかもしれたせん") + # else: + # print("npz files are not used with color_aug and/or random_crop / color_augたたはrandom_cropが指定されおいるためnpzファむルは䜿甚されたせん") + + # check min/max bucket size + sizes = set() + resos = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + resos.add(tuple(image_info.image_size)) + + if sizes is None: + if use_npz_latents: + use_npz_latents = False + print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデヌタにbucket情報がないためnpzファむルを無芖したす") + + assert ( + resolution is not None + ), "if metadata doesn't have bucket info, resolution is required / メタデヌタにbucket情報がない堎合はresolutionを指定しおください" + + self.enable_bucket = enable_bucket + if self.enable_bucket: + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + if not enable_bucket: + print("metadata has bucket info, enable bucketing / メタデヌタにbucket情報があるためbucketを有効にしたす") + print("using bucket info in metadata / メタデヌタ内のbucket情報を䜿いたす") + self.enable_bucket = True + + assert ( + not bucket_no_upscale + ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデヌタ内にbucket情報がある堎合はbucketの解像床は蚈算枈みのため、bucket_no_upscaleは䜿えたせん" + + # bucket情報を初期化しおおく、make_bucketsで再䜜成しない + self.bucket_manager = BucketManager(False, None, None, None, None) + self.bucket_manager.set_predefined_resos(resos) + + # npz情報をきれいにしおおく + if not use_npz_latents: + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None + + def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + ".npz" + + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + "_flip.npz" + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip + + # if not full path, check image_dir. if image_dir is None, return None + if subset.image_dir is None: + return None, None + + # image_key is relative path + npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") + npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") + + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + + +# behave as Dataset mock +class DatasetGroup(torch.utils.data.ConcatDataset): + def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): + self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] + + super().__init__(datasets) + + self.image_data = {} + self.num_train_images = 0 + self.num_reg_images = 0 + + # simply concat together + # TODO: handling image_data key duplication among dataset + # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. + for dataset in datasets: + self.image_data.update(dataset.image_data) + self.num_train_images += dataset.num_train_images + self.num_reg_images += dataset.num_reg_images + + def add_replacement(self, str_from, str_to): + for dataset in self.datasets: + dataset.add_replacement(str_from, str_to) + + # def make_buckets(self): + # for dataset in self.datasets: + # dataset.make_buckets() + + def enable_XTI(self, *args, **kwargs): + for dataset in self.datasets: + dataset.enable_XTI(*args, **kwargs) + + def cache_latents(self, vae, vae_batch_size=1): + for i, dataset in enumerate(self.datasets): + print(f"[Dataset {i}]") + dataset.cache_latents(vae, vae_batch_size) + + def is_latent_cacheable(self) -> bool: + return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) + + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + + def disable_token_padding(self): + for dataset in self.datasets: + dataset.disable_token_padding() + + +def debug_dataset(train_dataset, show_input_ids=False): + print(f"Total dataset length (steps) / デヌタセットの長さステップ数: {len(train_dataset)}") + print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキヌで次のステップ、Eキヌで次の゚ポック、Escキヌで䞭断、終了したす") + + epoch = 1 + while True: + print(f"epoch: {epoch}") + + steps = (epoch - 1) * len(train_dataset) + 1 + indices = list(range(len(train_dataset))) + random.shuffle(indices) + + k = 0 + for i, idx in enumerate(indices): + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) + print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + + example = train_dataset[idx] + if example["latents"] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate( + zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + ): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example["images"] is not None: + im = example["images"][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27 or k == ord("s") or k == ord("e"): + break + steps += 1 + + if k == ord("e"): + break + if k == 27 or (example["images"] is None and i >= 8): + k = 27 + break + if k == 27: + break + + epoch += 1 + + +def glob_images(directory, base="*"): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == "*": + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + img_paths = list(set(img_paths)) # 重耇を排陀 + img_paths.sort() + return img_paths + + +def glob_images_pathlib(dir_path, recursive): + image_paths = [] + if recursive: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.rglob("*" + ext)) + else: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.glob("*" + ext)) + image_paths = list(set(image_paths)) # 重耇を排陀 + image_paths.sort() + return image_paths + + +# endregion + + +# region モゞュヌル入れ替え郚 +""" +高速化のためのモゞュヌル入れ替え +""" + +# FlashAttentionを䜿うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" + try: + with open(filename, "rb") as file: + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + try: + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def get_git_revision_hash() -> str: + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() + except: + return "(unknown)" + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # diffusers 0.7.0~ わざわざ倉えるなよ (;ŽД) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがむンストヌルされおいないようです") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを遞んでくれる + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers + + +# endregion + + +# region arguments + + +def add_sd_models_arguments(parser: argparse.ArgumentParser): + # for pretrained models + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み蟌む") + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization孊習を有効にする" + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 孊習元モデル、Diffusers圢匏モデルのディレクトリたたはStableDiffusionのckptファむル", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの孊習のため", + ) + + +def add_optimizer_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--optimizer_type", + type=str, + default="", + help="Optimizer to use / オプティマむザの皮類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", + ) + + # backward compatibility + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマむザを䜿うbitsandbytesのむンストヌルが必芁", + ) + parser.add_argument( + "--use_lion_optimizer", + action="store_true", + help="use Lion optimizer (requires lion-pytorch) / Lionオプティマむザを䜿う lion-pytorch のむンストヌルが必芁", + ) + + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 孊習率") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 募配正芏化の最倧norm、0でclippingを行わない" + ) + + parser.add_argument( + "--optimizer_args", + type=str, + default=None, + nargs="*", + help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマむザの远加匕数䟋 "weight_decay=0.01 betas=0.9,0.999 ..."', + ) + + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 䜿甚するスケゞュヌラ") + parser.add_argument( + "--lr_scheduler_args", + type=str, + default=None, + nargs="*", + help='additional arguments for scheduler (like "T_max=100") / スケゞュヌラの远加匕数䟋 "T_max100"', + ) + + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help="scheduler to use for learning rate / 孊習率のスケゞュヌラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 孊習率のスケゞュヌラをりォヌムアップするステップ数デフォルト0", + ) + parser.add_argument( + "--lr_scheduler_num_cycles", + type=int, + default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケゞュヌラでのリスタヌト回数", + ) + parser.add_argument( + "--lr_scheduler_power", + type=float, + default=1, + help="Polynomial power for polynomial scheduler / polynomialスケゞュヌラでのpolynomial power", + ) + + +def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): + parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 孊習埌のモデル出力先ディレクトリ") + parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 孊習埌のモデルの拡匵子を陀くファむル名") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving / 保存時に粟床を倉曎しお保存する", + ) + parser.add_argument( + "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 孊習䞭のモデルを指定゚ポックごずに保存する" + ) + parser.add_argument( + "--save_n_epoch_ratio", + type=int, + default=None, + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 孊習䞭のモデルを指定の゚ポック割合で保存するたずえば5を指定するず最䜎5個のファむルが保存される", + ) + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最倧N゚ポック保存する") + parser.add_argument( + "--save_last_n_epochs_state", + type=int, + default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最倧N゚ポックstateを保存する(--save_last_n_epochsの指定を䞊曞きしたす)", + ) + parser.add_argument( + "--save_state", + action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど孊習状態も含めたstateを远加で保存する", + ) + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 孊習再開するモデルのstate") + + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 孊習時のバッチサむズ") + parser.add_argument( + "--max_token_length", + type=int, + default=None, + choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトヌクンの最倧長未指定で75、150たたは225が指定可", + ) + parser.add_argument( + "--mem_eff_attn", + action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを䜿う", + ) + parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを䜿う") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える堎合、VAEのcheckpointファむルたたはディレクトリ" + ) + + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 孊習ステップ数") + parser.add_argument( + "--max_train_epochs", + type=int, + default=None, + help="training epochs (overrides max_train_steps) / 孊習゚ポック数max_train_stepsを䞊曞きしたす", + ) + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最倧プロセス数小さい倀ではメむンメモリの䜿甚量が枛り゚ポック間の埅ち時間が枛りたすが、デヌタ読み蟌みは遅くなりたす", + ) + parser.add_argument( + "--persistent_data_loader_workers", + action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワヌカヌを持続させる (゚ポック間の時間差を少なくするのに有効だが、より倚くのメモリを消費する可胜性がある)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 孊習時の乱数のseed") + parser.add_argument( + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 孊習時に逆䌝播をする前に募配を合蚈するステップ数", + ) + parser.add_argument( + "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合粟床を䜿う堎合、その粟床" + ) + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 募配も含めおfp16で孊習する") + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの埌ろからn番目の局の出力を甚いるnは1以䞊", + ) + parser.add_argument( + "--logging_dir", + type=str, + default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしおこのディレクトリにTensorBoard甚のログを出力する", + ) + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に远加する文字列") + parser.add_argument( + "--noise_offset", + type=float, + default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしおこの倀を蚭定する有効にする堎合は0.1皋床を掚奚", + ) + parser.add_argument( + "--lowram", + action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メむンメモリが少ない環境向け最適化を有効にする。たずえばVRAMにモデルを読み蟌むなどColabやKaggleなどRAMに比べおVRAMが倚い環境向け", + ) + + parser.add_argument( + "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 孊習䞭のモデルで指定ステップごずにサンプル出力する" + ) + parser.add_argument( + "--sample_every_n_epochs", + type=int, + default=None, + help="generate sample images every N epochs (overwrites n_steps) / 孊習䞭のモデルで指定゚ポックごずにサンプル出力するステップ数指定を䞊曞きしたす", + ) + parser.add_argument( + "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 孊習䞭モデルのサンプル出力甚プロンプトのファむル" + ) + parser.add_argument( + "--sample_sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラヌスケゞュヌラの皮類", + ) + + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter / ハむパヌパラメヌタを匕数ではなく.tomlファむルで枡す", + ) + parser.add_argument( + "--output_config", action="store_true", help="output command line args to given .toml file / 匕数を.tomlファむルに出力する" + ) + + if support_dreambooth: + # DreamBooth training + parser.add_argument( + "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" + ) + + +def verify_training_args(args: argparse.Namespace): + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを䜿甚するこずは想定されおいたせん") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを䜿甚するこずは想定されおいたせん") + + +def add_dataset_arguments( + parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool +): + # dataset common + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 孊習画像デヌタのディレクトリ") + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各芁玠をshuffleする" + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み蟌むcaptionファむルの拡匵子" + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption files (backward compatibility) / 読み蟌むcaptionファむルの拡匵子スペルミスを残しおありたす", + ) + parser.add_argument( + "--keep_tokens", + type=int, + default=0, + help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトヌクンをシャッフルしないで残すトヌクンはカンマ区切りの各郚分を意味する", + ) + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 孊習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 孊習時に巊右反転のaugmentationを有効にする") + parser.add_argument( + "--face_crop_aug_range", + type=str, + default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 孊習時に顔を䞭心ずした切り出しaugmentationを有効にするずきは倍率を指定する䟋2.0,4.0", + ) + parser.add_argument( + "--random_crop", + action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする顔を䞭心ずしたaugmentationを行うずきに画颚の孊習甚に指定する", + ) + parser.add_argument( + "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ甚に孊習デヌタを画面衚瀺する孊習は行わない" + ) + parser.add_argument( + "--resolution", + type=str, + default=None, + help="resolution in training ('size' or 'width,height') / 孊習時の画像解像床'サむズ'指定、たたは'幅,高さ'指定", + ) + parser.add_argument( + "--cache_latents", + action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削枛のためにlatentをcacheするaugmentationは䜿甚䞍可", + ) + parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサむズ") + parser.add_argument( + "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 耇数解像床孊習のためのbucketを有効にする" + ) + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像床") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最倧解像床") + parser.add_argument( + "--bucket_reso_steps", + type=int, + default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像床の単䜍、8で割り切れる倀を掚奚したす", + ) + parser.add_argument( + "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡倧せずbucketを䜜成したす" + ) + + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら孊習する", + ) + + parser.add_argument( + "--token_warmup_step", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最倧になる。デフォルトは0最初から最倧", + ) + + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutず玛らわしいのでprefixにcaptionを付けおおく every_n_epochsは他ず平仄を合わせおdefault Noneに + parser.add_argument( + "--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合" + ) + parser.add_argument( + "--caption_dropout_every_n_epochs", + type=int, + default=0, + help="Dropout all captions every N epochs / captionを指定゚ポックごずにdropoutする", + ) + parser.add_argument( + "--caption_tag_dropout_rate", + type=float, + default=0.0, + help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合", + ) + + if support_dreambooth: + # DreamBooth dataset + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像デヌタのディレクトリ") + + if support_caption: + # caption dataset + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / デヌタセットのmetadataのjsonファむル") + parser.add_argument( + "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの孊習時にデヌタセットを繰り返す回数" + ) + + +def add_sd_saving_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--save_model_as", + type=str, + default=None, + choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の圢匏未指定時は元モデルず同じ", + ) + parser.add_argument( + "--use_safetensors", + action="store_true", + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors圢匏で保存するsave_model_as未指定時", + ) + + +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if args.output_config: + # check if config file exists + if os.path.exists(config_path): + print(f"Config file already exists. Aborting... / 出力先の蚭定ファむルが既に存圚したす: {config_path}") + exit(1) + + # convert args to dictionary + args_dict = vars(args) + + # remove unnecessary keys + for key in ["config_file", "output_config"]: + if key in args_dict: + del args_dict[key] + + # get default args from parser + default_args = vars(parser.parse_args([])) + + # remove default values: cannot use args_dict.items directly because it will be changed during iteration + for key, value in list(args_dict.items()): + if key in default_args and value == default_args[key]: + del args_dict[key] + + # convert Path to str in dictionary + for key, value in args_dict.items(): + if isinstance(value, pathlib.Path): + args_dict[key] = str(value) + + # convert to toml and output to file + with open(config_path, "w") as f: + toml.dump(args_dict, f) + + print(f"Saved config file / 蚭定ファむルを保存したした: {config_path}") + exit(0) + + if not os.path.exists(config_path): + print(f"{config_path} not found.") + exit(1) + + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + print(args.config_file) + + return args + + +# endregion + +# region utils + + +def get_optimizer(args, trainable_params): + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + + optimizer_type = args.optimizer_type + if args.use_8bit_adam: + assert ( + not args.use_lion_optimizer + ), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamずuse_lion_optimizerの䞡方のオプションが指定されおいたす" + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamずoptimizer_typeの䞡方のオプションが指定されおいたす" + optimizer_type = "AdamW8bit" + + elif args.use_lion_optimizer: + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerずoptimizer_typeの䞡方のオプションが指定されおいたす" + optimizer_type = "Lion" + + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" + optimizer_type = optimizer_type.lower() + + # 匕数を分解する + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.float(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = tuple(value) + + optimizer_kwargs[key] = value + # print("optkwargs:", optimizer_kwargs) + + lr = args.learning_rate + + if optimizer_type == "AdamW8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがむンストヌルされおいないようです") + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがむンストヌルされおいないようです") + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print( + f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に蚭定したす" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion".lower(): + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がむンストヌルされおいないようです") + print(f"use Lion optimizer | {optimizer_kwargs}") + optimizer_class = lion_pytorch.Lion + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov".lower(): + print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に蚭定したす") + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = torch.optim.SGD + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "DAdaptation".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がむンストヌルされおいないようです") + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 孊習率が䜎すぎるようです。1.0前埌の倀を指定しおください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 掚奚は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで耇数の孊習率を指定した堎合Text EncoderずU-Netなど、最初の孊習率のみが有効になりたす: lr={actual_lr}" + ) + + optimizer_class = dadaptation.DAdaptAdam + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # 匕数を確認しお適宜補正する + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにしたす") + optimizer_kwargs["relative_step"] = True + print(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + print(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrずしお䜿甚されたす") + args.learning_rate = None + + # trainable_paramsがgroupだった時の凊理lrを削陀する + if type(trainable_params) == list and type(trainable_params[0]) == dict: + has_group_lr = False + for group in trainable_params: + p = group.pop("lr", None) + has_group_lr = has_group_lr or (p is not None) + + if has_group_lr: + # 䞀応argsを無効にしおおく TODO 䟝存関係が逆転しおるのであたり望たしくない + print(f"unet_lr and text_encoder_lr are ignored / unet_lrずtext_encoder_lrは無芖されたす") + args.unet_lr = None + args.text_encoder_lr = None + + if args.lr_scheduler != "adafactor": + print(f"use adafactor_scheduler / スケゞュヌラにadafactor_schedulerを䜿甚したす") + args.lr_scheduler = f"adafactor:{lr}" # ちょっず埮劙だけど + + lr = None + else: + if args.max_grad_norm != 0.0: + print( + f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが蚭定されおいるためclip_grad_normが有効になりたす。0に蚭定しお無効にしたほうがいいかもしれたせん" + ) + if args.lr_scheduler != "constant_with_warmup": + print(f"constant_with_warmup will be good / スケゞュヌラはconstant_with_warmupが良いかもしれたせん") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれたせん") + + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "AdamW".lower(): + print(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + else: + # 任意のoptimizerを䜿う + optimizer_type = args.optimizer_type # lowerでないや぀埮劙 + print(f"use {optimizer_type} | {optimizer_kwargs}") + if "." not in optimizer_type: + optimizer_module = torch.optim + else: + values = optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + + return optimizer_name, optimizer_args, optimizer + + +# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler +# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 +# Which is a newer release of diffusers than currently packaged with sd-scripts +# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts + + +def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): + """ + Unified API to get any scheduler from its name. + """ + name = args.lr_scheduler + num_warmup_steps = args.lr_warmup_steps + num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power + + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split("=") + + value = ast.literal_eval(value) + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.literal_eval(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = list(value) # some may use list? + + lr_scheduler_kwargs[key] = value + + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return lr_scheduler + + if name.startswith("adafactor"): + assert ( + type(optimizer) == transformers.optimization.Adafactor + ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマむザず同時に䜿っおください" + initial_lr = float(name.split(":")[1]) + # print("adafactor scheduler init lr", initial_lr) + return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + + +def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None + + # assert args.resolution is not None, f"resolution is required / resolution解像床を指定しおください" + if args.resolution is not None: + args.resolution = tuple([int(r) for r in args.resolution.split(",")]) + if len(args.resolution) == 1: + args.resolution = (args.resolution[0], args.resolution[0]) + assert ( + len(args.resolution) == 2 + ), f"resolution must be 'size' or 'width,height' / resolution解像床は'サむズ'たたは'幅','高さ'で指定しおください: {args.resolution}" + + if args.face_crop_aug_range is not None: + args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) + assert ( + len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1] + ), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'例限,侊限'で指定しおください: {args.face_crop_aug_range}" + else: + args.face_crop_aug_range = None + + if support_metadata: + if args.in_json is not None and (args.color_aug or args.random_crop): + print( + f"latents in npz is ignored when color_aug or random_crop is True / color_augたたはrandom_cropを有効にした堎合、npzファむルのlatentsは無芖されたす" + ) + + +def load_tokenizer(args: argparse.Namespace): + print("prepare tokenizer") + original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH + + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + +def prepare_accelerator(args: argparse.Namespace): + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=log_with, + logging_dir=logging_dir, + ) + + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + + return accelerator, unwrap_model + + +def prepare_dtype(args: argparse.Namespace): + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + + return weight_dtype, save_dtype + + +def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'): + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) + else: + # Diffusers model is loaded to CPU + print("load Diffusers pretrained models") + try: + pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファむル、たたはHugging Faceのモデルが芋぀かりたせん。ファむル名が誀っおいるかもしれたせん: {name_or_path}" + ) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe + + # VAEを読み蟌む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") + + return text_encoder, vae, unet, load_stable_diffusion_format + + +def patch_accelerator_for_fp16_training(accelerator): + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + +def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] + + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の䞉連を ... ... ぞ戻す 正盎この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] # の埌から 最埌の前たで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、぀たり ...のパタヌン + chunk[j, 0] = chunk[j, 1] # 次の の倀をコピヌする + states_list.append(chunk) # の埌から の前たで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # か のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の䞉連を ... ぞ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # の埌から の前たで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + return encoder_hidden_states + + +def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name + + +def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + if saving: + os.makedirs(args.output_dir, exist_ok=True) + save_func() + + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving + + +def save_sd_model_on_epoch_end( + args: argparse.Namespace, + accelerator, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder, + unet, + vae, +): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) + + if save_stable_diffusion_format: + + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae + ) + + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + save_func = save_sd + remove_old_func = remove_sd + else: + + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") + shutil.rmtree(out_dir_old) + + save_func = save_du + remove_old_func = remove_du + + saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + save_state_on_epoch_end(args, accelerator, model_name, epoch_no) + + +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder, + unet, + vae, +): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) + + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae + ) + else: + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) + + print(f"save trained model as Diffusers to {out_dir}") + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + +def save_state_on_train_end(args: argparse.Namespace, accelerator): + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) + + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + + +def sample_images( + accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None +): + """ + StableDiffusionLongPromptWeightingPipelineの改造版を䜿うようにしたので、clip skipおよびプロンプトの重みづけに察応した + """ + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無芖する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + print(f"No prompt file / プロンプトファむルがありたせん: {args.sample_prompts}") + return + + org_vae_device = vae.device # CPUにいるはず + vae.to(device) + + # read prompts + with open(args.sample_prompts, "rt", encoding="utf-8") as f: + prompts = f.readlines() + + # schedulerを甚意する + sched_init_args = {} + if args.sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから倖しおある + scheduler_cls = DDPMScheduler + elif args.sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sample_sampler + elif args.sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif args.sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + scheduler.config.clip_sample = True + + pipeline = StableDiffusionLongPromptWeightingPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + clip_skip=args.clip_skip, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + pipeline.to(device) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + with accelerator.autocast(): + for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue + prompt = prompt.strip() + if len(prompt) == 0 or prompt[0] == "#": + continue + + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析゚ラヌ: {parg}") + print(ex) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + image = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + ).images[0] + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + ) + + image.save(os.path.join(save_dir, img_filename)) + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + torch.set_rng_state(rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + vae.to(org_vae_device) + + +# endregion + +# region 前凊理甚 + + +class ImageLoadingDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor_pil = transforms.functional.pil_to_tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み蟌めたせん: {img_path}, error: {e}") + return None + + return (tensor_pil, img_path) + + +# endregion + + +# collate_fn甹 epoch,stepはmultiprocessing.Value +class collater_class: + def __init__(self, epoch, step, dataset): + self.current_epoch = epoch + self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing + + def __call__(self, examples): + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] diff --git a/library/utilities.py b/library/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..8c45bff6d9d542e9f9244d55553658c221b3c2d5 --- /dev/null +++ b/library/utilities.py @@ -0,0 +1,93 @@ +# v1: initial release +# v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation +# v3.1: Adding captionning of images to utilities + +import gradio as gr +import os +import argparse +from library.basic_caption_gui import gradio_basic_caption_gui_tab +from library.convert_model_gui import gradio_convert_model_tab +from library.blip_caption_gui import gradio_blip_caption_gui_tab +from library.git_caption_gui import gradio_git_caption_gui_tab +from library.wd14_caption_gui import gradio_wd14_caption_gui_tab + + +def utilities_tab( + train_data_dir_input=gr.Textbox(), + reg_data_dir_input=gr.Textbox(), + output_dir_input=gr.Textbox(), + logging_dir_input=gr.Textbox(), + enable_copy_info_button=bool(False), + enable_dreambooth_tab=True, +): + with gr.Tab('Captioning'): + gradio_basic_caption_gui_tab() + gradio_blip_caption_gui_tab() + gradio_git_caption_gui_tab() + gradio_wd14_caption_gui_tab() + gradio_convert_model_tab() + + return ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) + + +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + utilities_tab() + + # Show the interface + launch_kwargs = {} + if not kwargs.get('username', None) == '': + launch_kwargs['auth'] = ( + kwargs.get('username', None), + kwargs.get('password', None), + ) + if kwargs.get('server_port', 0) > 0: + launch_kwargs['server_port'] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + + args = parser.parse_args() + + UI( + username=args.username, + password=args.password, + inbrowser=args.inbrowser, + server_port=args.server_port, + ) diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a0bf9ef575437fb1c48c6c688bd5c42692d181 --- /dev/null +++ b/library/verify_lora_gui.py @@ -0,0 +1,102 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 + + +def verify_lora( + lora_model, +): + # verify for caption_text_input + if lora_model == '': + msgbox('Invalid model A file') + return + + # verify if source model exist + if not os.path.isfile(lora_model): + msgbox('The provided model A is not a file') + return + + run_cmd = [ + PYTHON, + os.path.join('networks', 'check_lora_weights.py'), + f'{lora_model}', + ] + + print(' '.join(run_cmd)) + + # Run the command + process = subprocess.Popen( + run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + output, error = process.communicate() + + return (output.decode(), error.decode()) + + +### +# Gradio UI +### + + +def gradio_verify_lora_tab(): + with gr.Tab('Verify LoRA'): + gr.Markdown( + 'This utility can verify a LoRA network to make sure it is properly trained.' + ) + + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + lora_model = gr.Textbox( + label='LoRA model', + placeholder='Path to the LoRA model to verify', + interactive=True, + ) + button_lora_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_model_file.click( + get_file_path, + inputs=[lora_model, lora_ext, lora_ext_name], + outputs=lora_model, + show_progress=False, + ) + verify_button = gr.Button('Verify', variant='primary') + + lora_model_verif_output = gr.Textbox( + label='Output', + placeholder='Verification output', + interactive=False, + lines=1, + max_lines=10, + ) + + lora_model_verif_error = gr.Textbox( + label='Error', + placeholder='Verification error', + interactive=False, + lines=1, + max_lines=10, + ) + + verify_button.click( + verify_lora, + inputs=[ + lora_model, + ], + outputs=[lora_model_verif_output, lora_model_verif_error], + show_progress=False, + ) diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..1970849bddd27108fd5011a007e14e0b5e83af0c --- /dev/null +++ b/library/wd14_caption_gui.py @@ -0,0 +1,111 @@ +import gradio as gr +from easygui import msgbox +import subprocess +from .common_gui import get_folder_path +import os + + +def replace_underscore_with_space(folder_path, file_extension): + for file_name in os.listdir(folder_path): + if file_name.endswith(file_extension): + file_path = os.path.join(folder_path, file_name) + with open(file_path, 'r') as file: + file_content = file.read() + new_file_content = file_content.replace('_', ' ') + with open(file_path, 'w') as file: + file.write(new_file_content) + +def caption_images( + train_data_dir, caption_extension, batch_size, thresh, replace_underscores +): + # Check for caption_text_input + # if caption_text_input == "": + # msgbox("Caption text is missing...") + # return + + # Check for images_dir_input + if train_data_dir == '': + msgbox('Image folder is missing...') + return + + if caption_extension == '': + msgbox('Please provide an extension for the caption files.') + return + + print(f'Captioning files in {train_data_dir}...') + run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' + run_cmd += f' --batch_size="{int(batch_size)}"' + run_cmd += f' --thresh="{thresh}"' + run_cmd += f' --caption_extension="{caption_extension}"' + run_cmd += f' "{train_data_dir}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + if replace_underscores: + replace_underscore_with_space(train_data_dir, caption_extension) + + print('...captioning done') + + +### +# Gradio UI +### + + +def gradio_wd14_caption_gui_tab(): + with gr.Tab('WD14 Captioning'): + gr.Markdown( + 'This utility will use WD14 to caption files for each images in a folder.' + ) + with gr.Row(): + train_data_dir = gr.Textbox( + label='Image folder to caption', + placeholder='Directory containing the images to caption', + interactive=True, + ) + button_train_data_dir_input = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_train_data_dir_input.click( + get_folder_path, + outputs=train_data_dir, + show_progress=False, + ) + + caption_extension = gr.Textbox( + label='Caption file extension', + placeholder='Extention for caption file. eg: .caption, .txt', + value='.txt', + interactive=True, + ) + thresh = gr.Number(value=0.35, label='Threshold') + + batch_size = gr.Number( + value=1, label='Batch size', interactive=True + ) + + replace_underscores = gr.Checkbox( + label='Replace underscores in filenames with spaces', + value=False, + interactive=True, + ) + + caption_button = gr.Button('Caption images') + + caption_button.click( + caption_images, + inputs=[ + train_data_dir, + caption_extension, + batch_size, + thresh, + replace_underscores, + ], + show_progress=False, + ) diff --git a/lora_gui.py b/lora_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..ccca9473f5960d7eb2f3c65624d28643ae9f8311 --- /dev/null +++ b/lora_gui.py @@ -0,0 +1,1156 @@ +# v1: initial release +# v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation +# v3.1: Adding captionning of images to utilities + +import gradio as gr +import easygui +import json +import math +import os +import subprocess +import pathlib +import argparse +from library.common_gui import ( + get_folder_path, + remove_doublequote, + get_file_path, + get_any_file_path, + get_saveasfile_path, + color_aug_changed, + save_inference_file, + gradio_advanced_training, + run_cmd_advanced_training, + gradio_training, + gradio_config, + gradio_source_model, + run_cmd_training, + # set_legacy_8bitadam, + update_my_data, + check_if_model_exist, +) +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from library.dataset_balancing_gui import gradio_dataset_balancing_tab +from library.utilities import utilities_tab +from library.merge_lora_gui import gradio_merge_lora_tab +from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab +from library.verify_lora_gui import gradio_verify_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample +from easygui import msgbox + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 +path_of_this_folder = os.getcwd() + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, + clip_skip, + gradient_accumulation_steps, + mem_eff_attn, + output_name, + model_list, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + network_alpha, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + LoRA_type, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get('label') == 'True' else False + + if save_as_bool: + print('Save as...') + file_path = get_saveasfile_path(file_path) + else: + print('Save...') + if file_path == None or file_path == '': + file_path = get_saveasfile_path(file_path) + + # print(file_path) + + if file_path == None or file_path == '': + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Return the values of the variables as a dictionary + variables = { + name: value + for name, value in parameters # locals().items() + if name + not in [ + 'file_path', + 'save_as', + ] + } + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + # Save the data to the selected file + with open(file_path, 'w') as file: + json.dump(variables, file, indent=2) + + return file_path + + +def open_configuration( + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, + clip_skip, + gradient_accumulation_steps, + mem_eff_attn, + output_name, + model_list, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + network_alpha, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + LoRA_type, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == '' and not file_path == None: + # load variables from JSON file + with open(file_path, 'r') as f: + my_data = json.load(f) + print('Loading config...') + + # Update values to fix deprecated use_8bit_adam checkbox, set appropriate optimizer if it is set to True, etc. + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ['ask_for_file', 'file_path']: + values.append(my_data.get(key, value)) + + # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' + if my_data.get('LoRA_type', 'Standard') == 'LoCon': + values.append(gr.Row.update(visible=True)) + else: + values.append(gr.Row.update(visible=False)) + + return tuple(values) + + +def train_model( + print_only, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, + clip_skip, + gradient_accumulation_steps, + mem_eff_attn, + output_name, + model_list, # Keep this. Yes, it is unused here but required given the common list used + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + network_alpha, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + LoRA_type, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + print_only_bool = True if print_only.get('label') == 'True' else False + + if pretrained_model_name_or_path == '': + msgbox('Source model information is missing') + return + + if train_data_dir == '': + msgbox('Image folder path is missing') + return + + if not os.path.exists(train_data_dir): + msgbox('Image folder does not exist') + return + + if reg_data_dir != '': + if not os.path.exists(reg_data_dir): + msgbox('Regularisation folder does not exist') + return + + if output_dir == '': + msgbox('Output folder path is missing') + return + + if int(bucket_reso_steps) < 1: + msgbox('Bucket resolution steps need to be greater than 0') + return + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if stop_text_encoder_training_pct > 0: + msgbox( + 'Output "stop text encoder training" is not yet supported. Ignoring' + ) + stop_text_encoder_training_pct = 0 + + if check_if_model_exist(output_name, output_dir, save_model_as): + return + + # If string is empty set string to 0. + if text_encoder_lr == '': + text_encoder_lr = 0 + if unet_lr == '': + unet_lr = 0 + + # if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): + # msgbox( + # 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided' + # ) + # return + + # Get a list of all subfolders in train_data_dir + subfolders = [ + f + for f in os.listdir(train_data_dir) + if os.path.isdir(os.path.join(train_data_dir, f)) + ] + + total_steps = 0 + + # Loop through each subfolder and extract the number of repeats + for folder in subfolders: + # Extract the number of repeats from the folder name + repeats = int(folder.split('_')[0]) + + # Count the number of images in the folder + num_images = len( + [ + f + for f, lower_f in ( + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) + ) + if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) + ] + ) + + print(f'Folder {folder}: {num_images} images found') + + # Calculate the total number of steps for this folder + steps = repeats * num_images + + # Print the result + print(f'Folder {folder}: {steps} steps') + + total_steps += steps + + # calculate max_train_steps + max_train_steps = int( + math.ceil( + float(total_steps) + / int(train_batch_size) + * int(epoch) + # * int(reg_factor) + ) + ) + print(f'max_train_steps = {max_train_steps}') + + # calculate stop encoder training + if stop_text_encoder_training_pct == None: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + ) + print(f'stop_text_encoder_training = {stop_text_encoder_training}') + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + print(f'lr_warmup_steps = {lr_warmup_steps}') + + run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' + + # run_cmd += f' --caption_dropout_rate="0.1" --caption_dropout_every_n_epochs=1' # --random_crop' + + if v2: + run_cmd += ' --v2' + if v_parameterization: + run_cmd += ' --v_parameterization' + if enable_bucket: + run_cmd += ' --enable_bucket' + if no_token_padding: + run_cmd += ' --no_token_padding' + run_cmd += ( + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' + ) + run_cmd += f' --train_data_dir="{train_data_dir}"' + if len(reg_data_dir): + run_cmd += f' --reg_data_dir="{reg_data_dir}"' + run_cmd += f' --resolution={max_resolution}' + run_cmd += f' --output_dir="{output_dir}"' + run_cmd += f' --logging_dir="{logging_dir}"' + run_cmd += f' --network_alpha="{network_alpha}"' + if not training_comment == '': + run_cmd += f' --training_comment="{training_comment}"' + if not stop_text_encoder_training == 0: + run_cmd += ( + f' --stop_text_encoder_training={stop_text_encoder_training}' + ) + if not save_model_as == 'same as source model': + run_cmd += f' --save_model_as={save_model_as}' + if not float(prior_loss_weight) == 1.0: + run_cmd += f' --prior_loss_weight={prior_loss_weight}' + if LoRA_type == 'LoCon' or LoRA_type == 'LyCORIS/LoCon': + try: + import lycoris + except ModuleNotFoundError: + print( + "\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program." + ) + return + run_cmd += f' --network_module=lycoris.kohya' + run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"' + if LoRA_type == 'LyCORIS/LoHa': + try: + import lycoris + except ModuleNotFoundError: + print( + "\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program." + ) + return + run_cmd += f' --network_module=lycoris.kohya' + run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"' + if LoRA_type == 'Kohya LoCon': + run_cmd += f' --network_module=networks.lora' + run_cmd += ( + f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"' + ) + if LoRA_type == 'Standard': + run_cmd += f' --network_module=networks.lora' + + if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): + if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): + run_cmd += f' --text_encoder_lr={text_encoder_lr}' + run_cmd += f' --unet_lr={unet_lr}' + elif not (float(text_encoder_lr) == 0): + run_cmd += f' --text_encoder_lr={text_encoder_lr}' + run_cmd += f' --network_train_text_encoder_only' + else: + run_cmd += f' --unet_lr={unet_lr}' + run_cmd += f' --network_train_unet_only' + else: + if float(text_encoder_lr) == 0: + msgbox('Please input learning rate values.') + return + + run_cmd += f' --network_dim={network_dim}' + + if not lora_network_weights == '': + run_cmd += f' --network_weights="{lora_network_weights}"' + if int(gradient_accumulation_steps) > 1: + run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' + if not output_name == '': + run_cmd += f' --output_name="{output_name}"' + if not lr_scheduler_num_cycles == '': + run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"' + else: + run_cmd += f' --lr_scheduler_num_cycles="{epoch}"' + if not lr_scheduler_power == '': + run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"' + + run_cmd += run_cmd_training( + learning_rate=learning_rate, + lr_scheduler=lr_scheduler, + lr_warmup_steps=lr_warmup_steps, + train_batch_size=train_batch_size, + max_train_steps=max_train_steps, + save_every_n_epochs=save_every_n_epochs, + mixed_precision=mixed_precision, + save_precision=save_precision, + seed=seed, + caption_extension=caption_extension, + cache_latents=cache_latents, + optimizer=optimizer, + optimizer_args=optimizer_args, + ) + + run_cmd += run_cmd_advanced_training( + max_train_epochs=max_train_epochs, + max_data_loader_n_workers=max_data_loader_n_workers, + max_token_length=max_token_length, + resume=resume, + save_state=save_state, + mem_eff_attn=mem_eff_attn, + clip_skip=clip_skip, + flip_aug=flip_aug, + color_aug=color_aug, + shuffle_caption=shuffle_caption, + gradient_checkpointing=gradient_checkpointing, + full_fp16=full_fp16, + xformers=xformers, + # use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + noise_offset=noise_offset, + additional_parameters=additional_parameters, + vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, + ) + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + if print_only_bool: + print( + '\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n' + ) + print('\033[96m' + run_cmd + '\033[0m\n') + else: + print(run_cmd) + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f'{output_dir}/{output_name}') + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file( + output_dir, v2, v_parameterization, output_name + ) + + +def lora_tab( + train_data_dir_input=gr.Textbox(), + reg_data_dir_input=gr.Textbox(), + output_dir_input=gr.Textbox(), + logging_dir_input=gr.Textbox(), +): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + gr.Markdown( + 'Train a custom model using kohya train network LoRA python code...' + ) + ( + button_open_config, + button_save_config, + button_save_as_config, + config_file_name, + button_load_config, + ) = gradio_config() + + ( + pretrained_model_name_or_path, + v2, + v_parameterization, + save_model_as, + model_list, + ) = gradio_source_model( + save_model_as_choices=[ + 'ckpt', + 'safetensors', + ] + ) + + with gr.Tab('Folders'): + with gr.Row(): + train_data_dir = gr.Textbox( + label='Image folder', + placeholder='Folder where the training folders containing the images are located', + ) + train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small') + train_data_dir_folder.click( + get_folder_path, + outputs=train_data_dir, + show_progress=False, + ) + reg_data_dir = gr.Textbox( + label='Regularisation folder', + placeholder='(Optional) Folder where where the regularization folders containing the images are located', + ) + reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small') + reg_data_dir_folder.click( + get_folder_path, + outputs=reg_data_dir, + show_progress=False, + ) + with gr.Row(): + output_dir = gr.Textbox( + label='Output folder', + placeholder='Folder to output trained model', + ) + output_dir_folder = gr.Button('📂', elem_id='open_folder_small') + output_dir_folder.click( + get_folder_path, + outputs=output_dir, + show_progress=False, + ) + logging_dir = gr.Textbox( + label='Logging folder', + placeholder='Optional: enable logging and output TensorBoard log to this folder', + ) + logging_dir_folder = gr.Button('📂', elem_id='open_folder_small') + logging_dir_folder.click( + get_folder_path, + outputs=logging_dir, + show_progress=False, + ) + with gr.Row(): + output_name = gr.Textbox( + label='Model output name', + placeholder='(Name of the model to output)', + value='last', + interactive=True, + ) + training_comment = gr.Textbox( + label='Training comment', + placeholder='(Optional) Add training comment to be included in metadata', + interactive=True, + ) + train_data_dir.change( + remove_doublequote, + inputs=[train_data_dir], + outputs=[train_data_dir], + ) + reg_data_dir.change( + remove_doublequote, + inputs=[reg_data_dir], + outputs=[reg_data_dir], + ) + output_dir.change( + remove_doublequote, + inputs=[output_dir], + outputs=[output_dir], + ) + logging_dir.change( + remove_doublequote, + inputs=[logging_dir], + outputs=[logging_dir], + ) + with gr.Tab('Training parameters'): + with gr.Row(): + LoRA_type = gr.Dropdown( + label='LoRA type', + choices=[ + 'Kohya LoCon', + # 'LoCon', + 'LyCORIS/LoCon', + 'LyCORIS/LoHa', + 'Standard', + ], + value='Standard', + ) + lora_network_weights = gr.Textbox( + label='LoRA network weights', + placeholder='{Optional) Path to existing LoRA network weights to resume training', + ) + lora_network_weights_file = gr.Button( + document_symbol, elem_id='open_folder_small' + ) + lora_network_weights_file.click( + get_any_file_path, + inputs=[lora_network_weights], + outputs=lora_network_weights, + show_progress=False, + ) + ( + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + num_cpu_threads_per_process, + seed, + caption_extension, + cache_latents, + optimizer, + optimizer_args, + ) = gradio_training( + learning_rate_value='0.0001', + lr_scheduler_value='cosine', + lr_warmup_value='10', + ) + + with gr.Row(): + text_encoder_lr = gr.Textbox( + label='Text Encoder learning rate', + value='5e-5', + placeholder='Optional', + ) + unet_lr = gr.Textbox( + label='Unet learning rate', + value='0.0001', + placeholder='Optional', + ) + network_dim = gr.Slider( + minimum=1, + maximum=1024, + label='Network Rank (Dimension)', + value=8, + step=1, + interactive=True, + ) + network_alpha = gr.Slider( + minimum=0.1, + maximum=1024, + label='Network Alpha', + value=1, + step=0.1, + interactive=True, + ) + + with gr.Row(visible=False) as LoCon_row: + + # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) + conv_dim = gr.Slider( + minimum=1, + maximum=512, + value=1, + step=1, + label='Convolution Rank (Dimension)', + ) + conv_alpha = gr.Slider( + minimum=0.1, + maximum=512, + value=1, + step=0.1, + label='Convolution Alpha', + ) + # Show of hide LoCon conv settings depending on LoRA type selection + def LoRA_type_change(LoRA_type): + print('LoRA type changed...') + if ( + LoRA_type == 'LoCon' + or LoRA_type == 'Kohya LoCon' + or LoRA_type == 'LyCORIS/LoHa' + or LoRA_type == 'LyCORIS/LoCon' + ): + return gr.Group.update(visible=True) + else: + return gr.Group.update(visible=False) + + LoRA_type.change( + LoRA_type_change, inputs=[LoRA_type], outputs=[LoCon_row] + ) + with gr.Row(): + max_resolution = gr.Textbox( + label='Max resolution', + value='512,512', + placeholder='512,512', + ) + stop_text_encoder_training = gr.Slider( + minimum=0, + maximum=100, + value=0, + step=1, + label='Stop text encoder training', + ) + enable_bucket = gr.Checkbox(label='Enable buckets', value=True) + with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + no_token_padding = gr.Checkbox( + label='No token padding', value=False + ) + gradient_accumulation_steps = gr.Number( + label='Gradient accumulate steps', value='1' + ) + with gr.Row(): + prior_loss_weight = gr.Number( + label='Prior loss weight', value=1.0 + ) + lr_scheduler_num_cycles = gr.Textbox( + label='LR number of cycles', + placeholder='(Optional) For Cosine with restart and polynomial only', + ) + + lr_scheduler_power = gr.Textbox( + label='LR power', + placeholder='(Optional) For Cosine with restart and polynomial only', + ) + ( + # use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + noise_offset, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ) = gradio_advanced_training() + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latents], + ) + + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() + + with gr.Tab('Tools'): + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=train_data_dir, + reg_data_dir_input=reg_data_dir, + output_dir_input=output_dir, + logging_dir_input=logging_dir, + ) + gradio_dataset_balancing_tab() + gradio_merge_lora_tab() + gradio_svd_merge_lora_tab() + gradio_resize_lora_tab() + gradio_verify_lora_tab() + + button_run = gr.Button('Train model', variant='primary') + + button_print = gr.Button('Print training command') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + show_progress=False, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, + clip_skip, + gradient_accumulation_steps, + mem_eff_attn, + output_name, + model_list, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + network_alpha, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + LoRA_type, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ] + + button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list + [LoCon_row], + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list + [LoCon_row], + show_progress=False, + ) + + button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=[dummy_db_false] + settings_list, + show_progress=False, + ) + + button_print.click( + train_model, + inputs=[dummy_db_true] + settings_list, + show_progress=False, + ) + + return ( + train_data_dir, + reg_data_dir, + output_dir, + logging_dir, + ) + + +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('LoRA'): + ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) = lora_tab() + with gr.Tab('Utilities'): + utilities_tab( + train_data_dir_input=train_data_dir_input, + reg_data_dir_input=reg_data_dir_input, + output_dir_input=output_dir_input, + logging_dir_input=logging_dir_input, + enable_copy_info_button=True, + ) + + # Show the interface + launch_kwargs = {} + if not kwargs.get('username', None) == '': + launch_kwargs['auth'] = ( + kwargs.get('username', None), + kwargs.get('password', None), + ) + if kwargs.get('server_port', 0) > 0: + launch_kwargs['server_port'] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) + if kwargs.get('listen', True): + launch_kwargs['server_name'] = '0.0.0.0' + print(launch_kwargs) + interface.launch(**launch_kwargs) + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + parser.add_argument( + '--listen', + action='store_true', + help='Launch gradio with server name 0.0.0.0, allowing LAN access', + ) + + args = parser.parse_args() + + UI( + username=args.username, + password=args.password, + inbrowser=args.inbrowser, + server_port=args.server_port, + ) diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8dcd6ba4f393fe0a40ae668c530e18b87aea16 --- /dev/null +++ b/networks/check_lora_weights.py @@ -0,0 +1,39 @@ +import argparse +import os +import torch +from safetensors.torch import load_file + + +def main(file): + print(f"loading: {file}") + if os.path.splitext(file)[1] == '.safetensors': + sd = load_file(file) + else: + sd = torch.load(file, map_location='cpu') + + values = [] + + keys = list(sd.keys()) + for key in keys: + if 'lora_up' in key or 'lora_down' in key: + values.append((key, sd[key])) + print(f"number of LoRA modules: {len(values)}") + + for key, value in values: + value = value.to(torch.float32) + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファむル") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + + main(args.file) diff --git a/networks/extract_lora_from_models copy.py b/networks/extract_lora_from_models copy.py new file mode 100644 index 0000000000000000000000000000000000000000..aacd21b5aa3b3b726621a192d6fe593f3fafe47e --- /dev/null +++ b/networks/extract_lora_from_models copy.py @@ -0,0 +1,194 @@ +# extract approximating LoRA by svd from two SD models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import library.model_util as model_util +import lora +import numpy as np + + +CLAMP_QUANTILE = 1 # 0.99 +MIN_DIFF = 1e-6 + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def svd(args): + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + save_dtype = str_to_dtype(args.save_precision) + + print(f"loading SD model : {args.model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + print(f"loading SD model : {args.model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + + # create LoRA network to extract weights: Use dim (rank) as alpha + lora_network_o = lora.create_network(1.0, args.dim, args.dim * 1.5, None, text_encoder_o, unet_o) + lora_network_t = lora.create_network(1.0, args.dim, args.dim * 1.5, None, text_encoder_t, unet_t) + assert len(lora_network_o.text_encoder_loras) == len( + lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバヌゞョンが違いたすSD1.xベヌスずSD2.xベヌス " + + # get diffs + diffs = {} + text_encoder_different = False + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + + # Text Encoder might be same + if torch.max(torch.abs(diff)) > MIN_DIFF: + text_encoder_different = True + + diff = diff.float() + diffs[lora_name] = diff + + if not text_encoder_different: + print("Text encoder is same. Extract U-Net only.") + lora_network_o.text_encoder_loras = [] + diffs = {} + + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + diff = diff.float() + + if args.device: + diff = diff.to(args.device) + + diffs[lora_name] = diff + + # make LoRA with SVD + print("calculating by SVD") + rank = args.dim + lora_weights = {} + with torch.no_grad(): + for lora_name, mat in tqdm(list(diffs.items())): + conv2d = (len(mat.size()) == 4) + if conv2d: + mat = mat.squeeze() + + U, S, Vt = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vt = Vt[:rank, :] + + lora_weights[lora_name] = (U, Vt) + + # # make LoRA with svd + # print("calculating by svd") + # rank = args.dim + # lora_weights = {} + # with torch.no_grad(): + # for lora_name, mat in tqdm(list(diffs.items())): + # conv2d = (len(mat.size()) == 4) + # if conv2d: + # mat = mat.squeeze() + + # U, S, Vh = torch.linalg.svd(mat) + + # U = U[:, :rank] + # S = S[:rank] + # U = U @ torch.diag(S) + + # Vh = Vh[:rank, :] + + # # create new tensors directly from the numpy arrays + # U = torch.as_tensor(U) + # Vh = torch.as_tensor(Vh) + + # # dist = torch.cat([U.flatten(), Vh.flatten()]) + # # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # # low_val = -hi_val + + # # U = U.clamp(low_val, hi_val) + # # Vh = Vh.clamp(low_val, hi_val) + + # # # soft thresholding + # # alpha = S[-1] / 1000.0 # adjust this parameter as needed + # # U = torch.sign(U) * torch.nn.functional.relu(torch.abs(U) - alpha) + # # Vh = torch.sign(Vh) * torch.nn.functional.relu(torch.abs(Vh) - alpha) + + # lora_weights[lora_name] = (U, Vh) + + # make state dict for LoRA + lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict + lora_sd = lora_network_o.state_dict() + print(f"LoRA has {len(lora_sd)} weights.") + + for key in list(lora_sd.keys()): + if "alpha" in key: + continue + + lora_name = key.split('.')[0] + i = 0 if "lora_up" in key else 1 + + weights = lora_weights[lora_name][i] + # print(key, i, weights.size(), lora_sd[key].size()) + if len(lora_sd[key].size()) == 4: + weights = weights.unsqueeze(2).unsqueeze(3) + + assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" + lora_sd[key] = weights + + # load state dict to LoRA and save it + info = lora_network_o.load_state_dict(lora_sd) + print(f"Loading extracted LoRA weights: {info}") + + dir_name = os.path.dirname(args.save_to) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + # minimum metadata + metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim * 1.5)} + + lora_network_o.save_weights(args.save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {args.save_to}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み蟌む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に粟床を倉曎しお保存する、省略時はfloat") + parser.add_argument("--model_org", type=str, default=None, + help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptたたはsafetensors") + parser.add_argument("--model_tuned", type=str, default=None, + help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 掟生モデル生成されるLoRAは元→掟生の差分になりたす、ckptたたはsafetensors") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファむル名、ckptたたはsafetensors") + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 蚈算を行うデバむス、cuda でGPUを䜿う") + + args = parser.parse_args() + svd(args) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py new file mode 100644 index 0000000000000000000000000000000000000000..783fa1b3377a180e98605fa8792ba0a27dafd920 --- /dev/null +++ b/networks/extract_lora_from_models.py @@ -0,0 +1,189 @@ +# extract approximating LoRA by svd from two SD models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import library.model_util as model_util +import lora + + +CLAMP_QUANTILE = 1 +MIN_DIFF = 1e-8 + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def svd(args): + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + save_dtype = str_to_dtype(args.save_precision) + + print(f"loading SD model : {args.model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + print(f"loading SD model : {args.model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + + # create LoRA network to extract weights: Use dim (rank) as alpha + if args.conv_dim is None: + kwargs = {} + else: + kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + + lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs) + assert len(lora_network_o.text_encoder_loras) == len( + lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバヌゞョンが違いたすSD1.xベヌスずSD2.xベヌス " + + # get diffs + diffs = {} + text_encoder_different = False + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + + # Text Encoder might be same + if torch.max(torch.abs(diff)) > MIN_DIFF: + text_encoder_different = True + + diff = diff.float() + diffs[lora_name] = diff + + if not text_encoder_different: + print("Text encoder is same. Extract U-Net only.") + lora_network_o.text_encoder_loras = [] + diffs = {} + + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + diff = diff.float() + + if args.device: + diff = diff.to(args.device) + + diffs[lora_name] = diff + + # make LoRA with svd + print("calculating by svd") + lora_weights = {} + with torch.no_grad(): + for lora_name, mat in tqdm(list(diffs.items())): + # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 + conv2d = (len(mat.size()) == 4) + kernel_size = None if not conv2d else mat.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + + rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + out_dim, in_dim = mat.size()[0:2] + + if args.device: + mat = mat.to(args.device) + + # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) + rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + if conv2d: + if conv2d_3x3: + mat = mat.flatten(start_dim=1) + else: + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + # dist = torch.cat([U.flatten(), Vh.flatten()]) + # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # low_val = -hi_val + + # U = U.clamp(low_val, hi_val) + # Vh = Vh.clamp(low_val, hi_val) + + if conv2d: + U = U.reshape(out_dim, rank, 1, 1) + Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) + + U = U.to("cpu").contiguous() + Vh = Vh.to("cpu").contiguous() + + lora_weights[lora_name] = (U, Vh) + + # make state dict for LoRA + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + '.lora_up.weight'] = up_weight + lora_sd[lora_name + '.lora_down.weight'] = down_weight + lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) + + # load state dict to LoRA and save it + lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + + info = lora_network_save.load_state_dict(lora_sd) + print(f"Loading extracted LoRA weights: {info}") + + dir_name = os.path.dirname(args.save_to) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + # minimum metadata + metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} + + lora_network_save.save_weights(args.save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {args.save_to}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み蟌む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に粟床を倉曎しお保存する、省略時はfloat") + parser.add_argument("--model_org", type=str, default=None, + help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptたたはsafetensors") + parser.add_argument("--model_tuned", type=str, default=None, + help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 掟生モデル生成されるLoRAは元→掟生の差分になりたす、ckptたたはsafetensors") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファむル名、ckptたたはsafetensors") + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4") + parser.add_argument("--conv_dim", type=int, default=None, + help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適甚なし") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 蚈算を行うデバむス、cuda でGPUを䜿う") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + svd(args) diff --git a/networks/lora.py b/networks/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf7851180d74a1036afd01eed948bc7efb6ae9f --- /dev/null +++ b/networks/lora.py @@ -0,0 +1,483 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import List +import numpy as np +import torch + +from library import train_util + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数ずしお扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.region = None + self.region_mask = None + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + def set_region(self, region): + self.region = region + self.region_mask = None + + def forward(self, x): + if self.region is None: + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + # regional LoRA FIXME same as additional-network extension + if x.size()[1] % 77 == 0: + # print(f"LoRA for context: {self.lora_name}") + self.region = None + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + # calculate region mask first time + if self.region_mask is None: + if len(x.size()) == 4: + h, w = x.size()[2:4] + else: + seq_len = x.size()[1] + ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) + h = int(self.region.size()[0] / ratio + 0.5) + w = seq_len // h + + r = self.region.to(x.device) + if r.dtype == torch.bfloat16: + r = r.to(torch.float) + r = r.unsqueeze(0).unsqueeze(1) + # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) + r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear") + r = r.to(x.dtype) + + if len(x.size()) == 3: + r = torch.reshape(r, (1, x.size()[1], -1)) + + self.region_mask = r + + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask + + +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): + if network_dim is None: + network_dim = 4 # default + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + """ + block_dims = kwargs.get("block_dims") + block_alphas = None + + if block_dims is not None: + block_dims = [int(d) for d in block_dims.split(',')] + assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" + block_alphas = kwargs.get("block_alphas") + if block_alphas is None: + block_alphas = [1] * len(block_dims) + else: + block_alphas = [int(a) for a in block_alphas(',')] + assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" + + conv_block_dims = kwargs.get("conv_block_dims") + conv_block_alphas = None + + if conv_block_dims is not None: + conv_block_dims = [int(d) for d in conv_block_dims.split(',')] + assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" + conv_block_alphas = kwargs.get("conv_block_alphas") + if conv_block_alphas is None: + conv_block_alphas = [1] * len(conv_block_dims) + else: + conv_block_alphas = [int(a) for a in conv_block_alphas(',')] + assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" + """ + + network = LoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + ) + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha = modules_dim[key] + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) + network.weights_sd = weights_sd + return network + + +class LoRANetwork(torch.nn.Module): + # is it possible to apply conv_in and conv_out? + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + alpha=1, + conv_lora_dim=None, + conv_alpha=None, + modules_dim=None, + modules_alpha=None, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + + if modules_dim is not None: + print(f"create LoRA network from weights") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + + self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None + if self.apply_to_conv2d_3x3: + if self.conv_alpha is None: + self.conv_alpha = self.alpha + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + loras = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + # TODO get block index here + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + if modules_dim is not None: + if lora_name not in modules_dim: + continue # no LoRA module in this weights file + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.apply_to_conv2d_3x3: + dim = self.conv_lora_dim + alpha = self.conv_alpha + else: + continue + + lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) + loras.append(lora) + return loras + + self.text_encoder_loras = create_modules( + LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + ) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + self.weights_sd = None + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location="cpu") + + def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): + if self.weights_sd: + weights_has_text_encoder = weights_has_unet = False + for key in self.weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + weights_has_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + weights_has_unet = True + + if apply_text_encoder is None: + apply_text_encoder = weights_has_text_encoder + else: + assert ( + apply_text_encoder == weights_has_text_encoder + ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みずText Encoderのフラグが矛盟しおいたす" + + if apply_unet is None: + apply_unet = weights_has_unet + else: + assert ( + apply_unet == weights_has_unet + ), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みずU-Netのフラグが矛盟しおいたす" + else: + assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + print(f"weights are loaded: {info}") + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, dtype, device): + assert self.weights_sd is not None, "weights are not loaded" + + apply_text_encoder = apply_unet = False + for key in self.weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in self.weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + print(f"weights are merged") + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + self.requires_grad_(True) + all_params = [] + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + @staticmethod + def set_regions(networks, image): + image = image.astype(np.float32) / 255.0 + for i, network in enumerate(networks[:3]): + # NOTE: consider averaging overwrapping area + region = image[:, :, i] + if region.max() == 0: + continue + region = torch.tensor(region) + network.set_region(region) + + def set_region(self, region): + for lora in self.unet_loras: + lora.set_region(region) diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py new file mode 100644 index 0000000000000000000000000000000000000000..2891798b7dd16ddf9129213927866aca2d35840a --- /dev/null +++ b/networks/lora_interrogator.py @@ -0,0 +1,128 @@ + + +from tqdm import tqdm +from library import model_util +import argparse +from transformers import CLIPTokenizer +import torch + +import library.model_util as model_util +import lora + +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ䜿う + +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +def interrogate(args): + # いろいろ準備する + print(f"loading SD model: {args.sd_model}") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + print(f"loading LoRA: {args.model}") + network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) + + # text encoder向けの重みがあるかチェックする本圓はlora偎でやるのがいい + has_te_weight = False + for key in network.weights_sd.keys(): + if 'lora_te' in key: + has_te_weight = True + break + if not has_te_weight: + print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモゞュヌルがないため調査できたせん") + return + del vae + + print("loading tokenizer") + if args.v2: + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + else: + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) + + text_encoder.to(DEVICE) + text_encoder.eval() + unet.to(DEVICE) + unet.eval() # U-Netは呌び出さないので䞍芁だけど + + # トヌクンをひず぀ひず぀圓たっおいく + token_id_start = 0 + token_id_end = max(tokenizer.all_special_ids) + print(f"interrogate tokens are: {token_id_start} to {token_id_end}") + + def get_all_embeddings(text_encoder): + embs = [] + with torch.no_grad(): + for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)): + batch = [] + for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)): + tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id] + # tokens = [tid] # こちらは結果がいたひず぀ + batch.append(tokens) + + # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1] + # clip skip察応 + batch = torch.tensor(batch).to(DEVICE) + if args.clip_skip is None: + encoder_hidden_states = text_encoder(batch)[0] + else: + enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.to("cpu") + + embs.extend(encoder_hidden_states) + return torch.stack(embs) + + print("get original text encoder embeddings.") + orig_embs = get_all_embeddings(text_encoder) + + network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) + network.to(DEVICE) + network.eval() + + print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の譊告は無芖しお構いたせん以前のスクリプトで孊習されたLoRAモデルのためalphaの定矩がありたせん") + print("get text encoder embeddings with lora.") + lora_embs = get_all_embeddings(text_encoder) + + # 比べるずりあえず単玔に差分の絶察倀で + print("comparing...") + diffs = {} + for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): + diff = torch.mean(torch.abs(orig_emb - lora_emb)) + # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うたく怜出できない + diff = float(diff.detach().to('cpu').numpy()) + diffs[token_id_start + i] = diff + + diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1]) + + # 結果を衚瀺する + print("top 100:") + for i, (token, diff) in enumerate(diffs_sorted[:100]): + # if diff < 1e-6: + # break + string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token])) + print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み蟌む') + parser.add_argument("--sd_model", type=str, default=None, + help="Stable Diffusion model to load: ckpt or safetensors file / 読み蟌むSDのモデル、ckptたたはsafetensors") + parser.add_argument("--model", type=str, default=None, + help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptたたはsafetensors") + parser.add_argument("--batch_size", type=int, default=16, + help="batch size for processing with Text Encoder / Text Encoderで凊理するずきのバッチサむズ") + parser.add_argument("--clip_skip", type=int, default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの埌ろからn番目の局の出力を甚いるnは1以䞊") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + interrogate(args) diff --git a/networks/merge_lora.py b/networks/merge_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa8861bca3233a89640b9f02e75c8646a84171a --- /dev/null +++ b/networks/merge_lora.py @@ -0,0 +1,243 @@ +import math +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +import library.model_util as model_util +import lora + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # print(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + print(f"merging...") + for key in lora_sd.keys(): + if "alpha" in key: + continue + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサむズが合いたせん。v1ずv2、たたは次元数の異なるモデルはマヌゞできたせん" + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + + print("merged model") + print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + return merged_sd + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数ず重みの数は合わせおください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") + + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + + print(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) + else: + state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み蟌む") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に粟床を倉曎しお保存する、省略時はマヌゞ時の粟床ず同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マヌゞの蚈算時の粟床floatを掚奚", + ) + parser.add_argument( + "--sd_model", + type=str, + default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み蟌むモデル、ckptたたはsafetensors。省略時はLoRAモデル同士をマヌゞする", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファむル名、ckptたたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マヌゞするLoRAモデル、ckptたたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b6efce316b901787f14dc3fa9ccbf7e06cff68 --- /dev/null +++ b/networks/merge_lora_old.py @@ -0,0 +1,185 @@ + + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +import library.model_util as model_util +import lora + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == '.safetensors': + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location='cpu') + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[:key.index("lora_down")] + 'alpha' + + # find original module for this lora + module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + else: + # conv2d + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype): + merged_sd = {} + + alpha = None + dim = None + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if 'alpha' in key: + if key in merged_sd: + assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる堎合、珟時点ではマヌゞできたせん" + else: + alpha = lora_sd[key].detach().numpy() + merged_sd[key] = lora_sd[key] + else: + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサむズが合いたせん。v1ずv2、たたは次元数の異なるモデルはマヌゞできたせん" + merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + else: + if "lora_down" in key: + dim = lora_sd[key].size()[0] + merged_sd[key] = lora_sd[key] * ratio + + print(f"dim (rank): {dim}, alpha: {alpha}") + if alpha is None: + alpha = dim + + return merged_sd, dim, alpha + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数ず重みの数は合わせおください" + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") + + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + + print(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, + args.sd_model, 0, 0, save_dtype, vae) + else: + state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み蟌む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に粟床を倉曎しお保存する、省略時はマヌゞ時の粟床ず同じ") + parser.add_argument("--precision", type=str, default="float", + choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マヌゞの蚈算時の粟床floatを掚奚") + parser.add_argument("--sd_model", type=str, default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み蟌むモデル、ckptたたはsafetensors。省略時はLoRAモデル同士をマヌゞする") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファむル名、ckptたたはsafetensors") + parser.add_argument("--models", type=str, nargs='*', + help="LoRA models to merge: ckpt or safetensors file / マヌゞするLoRAモデル、ckptたたはsafetensors") + parser.add_argument("--ratios", type=float, nargs='*', + help="ratios for each model / それぞれのLoRAモデルの比率") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/networks/resize_lora.py b/networks/resize_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7406347d77d05b8bbf3797577a28d07d60b975 --- /dev/null +++ b/networks/resize_lora.py @@ -0,0 +1,359 @@ +# Convert LoRA to different rank approximation (should only be used to go to lower rank) +# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo + +import argparse +import torch +from safetensors.torch import load_file, save_file, safe_open +from tqdm import tqdm +from library import train_util, model_util +import numpy as np + +MIN_SV = 1e-6 + +# Model save and load functions + +def load_state_dict(file_name, dtype): + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location='cpu') + metadata = None + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, model, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if model_util.is_safetensors(file_name): + save_file(model, file_name, metadata) + else: + torch.save(model, file_name) + + +# Indexing functions + +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0)/original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S)-1)) + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S)-1)) + + return index + + +def index_sv_ratio(S, target): + max_sv = S[0] + min_sv = max_sv/target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S)-1)) + + return index + + +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + +# Calculate new rank + +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method=="sv_ratio": + # Calculate new dim and alpha based off ratio + new_rank = index_sv_ratio(S, dynamic_param) + 1 + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + 1 + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + 1 + new_alpha = float(scale*new_rank) + else: + new_rank = rank + new_alpha = float(scale*new_rank) + + + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + new_rank = 1 + new_alpha = float(scale*new_rank) + elif new_rank > rank: # cap max rank at rank + new_rank = rank + new_alpha = float(scale*new_rank) + + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro/s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0]/S[new_rank - 1] + + return param_dict + + +def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha/network_dim + + if dynamic_method: + print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + weight_name = None + if 'lora_down' in key: + block_down_name = key.split(".")[0] + weight_name = key.split(".")[-1] + lora_down_weight = value + else: + continue + + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + '.alpha', None) + + weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) + + if weights_loaded: + + conv2d = (len(lora_down_weight.size()) == 4) + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha/lora_down_weight.size()[0] + + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + + if verbose: + max_ratio = param_dict['max_ratio'] + sum_retained = param_dict['sum_retained'] + fro_retained = param_dict['fro_retained'] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) + + verbose_str+=f"{block_down_name:75} | " + verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + + if verbose and dynamic_method: + verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str+=f"\n" + + new_alpha = param_dict['new_alpha'] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict + + if verbose: + print(verbose_str) + + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + print("resizing complete") + return o_lora_sd, network_dim, new_alpha + + +def resize(args): + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + print("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) + + print("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + + if not args.dynamic_method: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + metadata["ss_network_dim"] = 'Dynamic' + metadata["ss_network_alpha"] = 'Dynamic' + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の粟床、未指定時はfloat") + parser.add_argument("--new_rank", type=int, default=4, + help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファむル名、ckptたたはsafetensors") + parser.add_argument("--model", type=str, default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み蟌むLoRAモデル、ckptたたはsafetensors") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 蚈算を行うデバむス、cuda でGPUを䜿う") + parser.add_argument("--verbose", action="store_true", + help="Display verbose resizing information / rank倉曎時の詳现情報を出力する") + parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") + parser.add_argument("--dynamic_param", type=float, default=None, + help="Specify target for dynamic reduction") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + resize(args) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..9d17efba5e80fe0b75fcdb43e0ef6a7bb2ca5d7e --- /dev/null +++ b/networks/svd_merge_lora.py @@ -0,0 +1,192 @@ + +import math +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import library.model_util as model_util +import lora + + +CLAMP_QUANTILE = 0.99 + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == '.safetensors': + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location='cpu') + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(state_dict, file_name) + else: + torch.save(state_dict, file_name) + + +def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): + print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") + merged_sd = {} + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + # merge + print(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if 'lora_down' not in key: + continue + + lora_module_name = key[:key.rfind(".lora_down")] + + down_weight = lora_sd[key] + network_dim = down_weight.size()[0] + + up_weight = lora_sd[lora_module_name + '.lora_up.weight'] + alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) + + in_dim = down_weight.size()[1] + out_dim = up_weight.size()[0] + conv2d = len(down_weight.size()) == 4 + kernel_size = None if not conv2d else down_weight.size()[2:4] + # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) + + # make original weight if not exist + if lora_module_name not in merged_sd: + weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) + if device: + weight = weight.to(device) + else: + weight = merged_sd[lora_module_name] + + # merge to weight + if device: + up_weight = up_weight.to(device) + down_weight = down_weight.to(device) + + # W <- W + U * D + scale = (alpha / network_dim) + + if device: # and isinstance(scale, torch.Tensor): + scale = scale.to(device) + + if not conv2d: # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif kernel_size == (1, 1): + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * scale + else: + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + merged_sd[lora_module_name] = weight + + # extract from merged weights + print("extract new lora...") + merged_lora_sd = {} + with torch.no_grad(): + for lora_module_name, mat in tqdm(list(merged_sd.items())): + conv2d = (len(mat.size()) == 4) + kernel_size = None if not conv2d else mat.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + out_dim, in_dim = mat.size()[0:2] + + if conv2d: + if conv2d_3x3: + mat = mat.flatten(start_dim=1) + else: + mat = mat.squeeze() + + module_new_rank = new_conv_rank if conv2d_3x3 else new_rank + module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :module_new_rank] + S = S[:module_new_rank] + U = U @ torch.diag(S) + + Vh = Vh[:module_new_rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + if conv2d: + U = U.reshape(out_dim, module_new_rank, 1, 1) + Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) + + up_weight = U + down_weight = Vh + + merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() + merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() + merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank) + + return merged_lora_sd + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数ず重みの数は合わせおください" + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank + state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に粟床を倉曎しお保存する、省略時はマヌゞ時の粟床ず同じ") + parser.add_argument("--precision", type=str, default="float", + choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マヌゞの蚈算時の粟床floatを掚奚") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファむル名、ckptたたはsafetensors") + parser.add_argument("--models", type=str, nargs='*', + help="LoRA models to merge: ckpt or safetensors file / マヌゞするLoRAモデル、ckptたたはsafetensors") + parser.add_argument("--ratios", type=float, nargs='*', + help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--new_rank", type=int, default=4, + help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--new_conv_rank", type=int, default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankず同じ") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 蚈算を行うデバむス、cuda でGPUを䜿う") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/presets/finetune/adafactor.json b/presets/finetune/adafactor.json new file mode 100644 index 0000000000000000000000000000000000000000..0e0149dc71266611a198625230ba90260feffa15 --- /dev/null +++ b/presets/finetune/adafactor.json @@ -0,0 +1,61 @@ +{ + "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", + "v2": false, + "v_parameterization": false, + "train_dir": "D:/dataset/paige_spiranac/ft", + "image_folder": "D:\\dataset\\paige_spiranac\\lora\\img4_g8\\16_paige_spiranac", + "output_dir": "D:/models/test", + "logging_dir": "D:/dataset/paige_spiranac/ft/logs", + "max_resolution": "512,512", + "min_bucket_reso": "256", + "max_bucket_reso": "1024", + "batch_size": "1", + "flip_aug": false, + "caption_metadata_filename": "meta_cap.json", + "latent_metadata_filename": "meta_lat.json", + "full_path": true, + "learning_rate": "1e-6", + "lr_scheduler": "adafactor", + "lr_warmup": "10", + "dataset_repeats": "10", + "train_batch_size": 4, + "epoch": "2", + "save_every_n_epochs": "1", + "mixed_precision": "bf16", + "save_precision": "fp16", + "seed": "1234", + "num_cpu_threads_per_process": 2, + "train_text_encoder": true, + "create_caption": true, + "create_buckets": false, + "save_model_as": "safetensors", + "caption_extension": ".txt", + "use_8bit_adam": false, + "xformers": true, + "clip_skip": 1, + "save_state": false, + "resume": "", + "gradient_checkpointing": false, + "gradient_accumulation_steps": 1.0, + "mem_eff_attn": false, + "shuffle_caption": true, + "output_name": "paige_spiranac_v1.5e", + "max_token_length": "150", + "max_train_epochs": "", + "max_data_loader_n_workers": "0", + "full_fp16": false, + "color_aug": false, + "model_list": "runwayml/stable-diffusion-v1-5", + "cache_latents": true, + "use_latent_files": "No", + "keep_tokens": 1, + "persistent_data_loader_workers": false, + "bucket_no_upscale": true, + "random_crop": false, + "bucket_reso_steps": 1.0, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.1, + "optimizer": "Adafactor", + "optimizer_args": "scale_parameter=True relative_step=True warmup_init=True weight_decay=2", + "noise_offset": "" +} \ No newline at end of file diff --git a/presets/finetune/lion.json b/presets/finetune/lion.json new file mode 100644 index 0000000000000000000000000000000000000000..982c8a869f807acdfec0e936eeef22e19fe093d9 --- /dev/null +++ b/presets/finetune/lion.json @@ -0,0 +1,61 @@ +{ + "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", + "v2": false, + "v_parameterization": false, + "train_dir": "D:/dataset/paige_spiranac/ft", + "image_folder": "D:\\dataset\\paige_spiranac\\lora\\img4_g8\\16_paige_spiranac", + "output_dir": "D:/models/test", + "logging_dir": "D:/dataset/paige_spiranac/ft/logs", + "max_resolution": "512,512", + "min_bucket_reso": "256", + "max_bucket_reso": "1024", + "batch_size": "1", + "flip_aug": false, + "caption_metadata_filename": "meta_cap.json", + "latent_metadata_filename": "meta_lat.json", + "full_path": true, + "learning_rate": "0.0000166666666", + "lr_scheduler": "cosine", + "lr_warmup": "10", + "dataset_repeats": "10", + "train_batch_size": 4, + "epoch": "2", + "save_every_n_epochs": "1", + "mixed_precision": "bf16", + "save_precision": "fp16", + "seed": "1234", + "num_cpu_threads_per_process": 2, + "train_text_encoder": true, + "create_caption": true, + "create_buckets": false, + "save_model_as": "safetensors", + "caption_extension": ".txt", + "use_8bit_adam": false, + "xformers": true, + "clip_skip": 1, + "save_state": false, + "resume": "", + "gradient_checkpointing": false, + "gradient_accumulation_steps": 1.0, + "mem_eff_attn": false, + "shuffle_caption": true, + "output_name": "paige_spiranac_v1.5e", + "max_token_length": "150", + "max_train_epochs": "", + "max_data_loader_n_workers": "0", + "full_fp16": false, + "color_aug": false, + "model_list": "runwayml/stable-diffusion-v1-5", + "cache_latents": true, + "use_latent_files": "No", + "keep_tokens": 1, + "persistent_data_loader_workers": false, + "bucket_no_upscale": true, + "random_crop": false, + "bucket_reso_steps": 1.0, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.1, + "optimizer": "Lion", + "optimizer_args": "", + "noise_offset": "" +} \ No newline at end of file diff --git a/presets/lora/lion_optimizer.json b/presets/lora/lion_optimizer.json new file mode 100644 index 0000000000000000000000000000000000000000..77ffa4de2c40379ff485fde76c51f26e606b084b --- /dev/null +++ b/presets/lora/lion_optimizer.json @@ -0,0 +1,59 @@ +{ + "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", + "v2": false, + "v_parameterization": false, + "logging_dir": "D:\\dataset\\marty_mcfly\\1985\\lora/log", + "train_data_dir": "D:\\dataset\\marty_mcfly\\1985\\lora\\img_gan", + "reg_data_dir": "", + "output_dir": "D:/lora/sd1.5/marty_mcfly", + "max_resolution": "512,512", + "learning_rate": "0.00003333", + "lr_scheduler": "cosine", + "lr_warmup": "0", + "train_batch_size": 8, + "epoch": "1", + "save_every_n_epochs": "1", + "mixed_precision": "bf16", + "save_precision": "fp16", + "seed": "1234", + "num_cpu_threads_per_process": 2, + "cache_latents": false, + "caption_extension": "", + "enable_bucket": true, + "gradient_checkpointing": false, + "full_fp16": false, + "no_token_padding": false, + "stop_text_encoder_training": 0, + "use_8bit_adam": false, + "xformers": true, + "save_model_as": "safetensors", + "shuffle_caption": false, + "save_state": false, + "resume": "", + "prior_loss_weight": 1.0, + "text_encoder_lr": "0.000016666", + "unet_lr": "0.00003333", + "network_dim": 128, + "lora_network_weights": "", + "color_aug": false, + "flip_aug": false, + "clip_skip": "1", + "gradient_accumulation_steps": 1.0, + "mem_eff_attn": false, + "output_name": "mrtmcfl_v2.0", + "model_list": "runwayml/stable-diffusion-v1-5", + "max_token_length": "75", + "max_train_epochs": "", + "max_data_loader_n_workers": "0", + "network_alpha": 128, + "training_comment": "", + "keep_tokens": "0", + "lr_scheduler_num_cycles": "", + "lr_scheduler_power": "", + "persistent_data_loader_workers": false, + "bucket_no_upscale": true, + "random_crop": true, + "bucket_reso_steps": 64.0, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.1 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..acdcfbbbd12a0a99a27e05c189d87d1e54f872d9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +accelerate==0.15.0 +albumentations==1.3.0 +altair==4.2.2 +bitsandbytes==0.35.0 +dadaptation==1.5 +diffusers[torch]==0.10.2 +easygui==0.98.3 +einops==0.6.0 +ftfy==6.1.1 +gradio==3.19.1; sys_platform != 'darwin' +gradio==3.23.0; sys_platform == 'darwin' +lion-pytorch==0.0.6 +opencv-python==4.7.0.68 +pytorch-lightning==1.9.0 +safetensors==0.2.6 +tensorboard==2.10.1 ; sys_platform != 'darwin' +tensorboard==2.12.1 ; sys_platform == 'darwin' +tk==0.1.0 +toml==0.10.2 +transformers==4.26.0 +voluptuous==0.13.1 +# for BLIP captioning +fairscale==0.4.13 +requests==2.28.2 +timm==0.6.12 +# tensorflow<2.11 +huggingface-hub==0.12.0; sys_platform != 'darwin' +huggingface-hub==0.13.0; sys_platform == 'darwin' +tensorflow==2.10.1; sys_platform != 'darwin' +# For locon support +lycoris-lora @ git+https://github.com/KohakuBlueleaf/LyCORIS.git@c3d925421209a22a60d863ffa3de0b3e7e89f047 +# lycoris_lora==0.1.4 +# for kohya_ss library +. \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa7c075d47079f3a60a560edca428c33200bf95 --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from setuptools import setup, find_packages +import subprocess +import os +import sys + +# Call the create_user_files.py script +script_path = os.path.join("tools", "create_user_files.py") +subprocess.run([sys.executable, script_path]) + +setup(name="library", version="1.0.3", packages=find_packages()) diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..607f2bcb1e322e153e7355fa19a729a8a809a5b6 --- /dev/null +++ b/setup.sh @@ -0,0 +1,609 @@ +#!/usr/bin/env bash + +# This file will be the host environment setup file for all operating systems other than base Windows. + +# Set the required package versions here. +# They will be appended to the requirements.txt file in the installation directory. +TENSORFLOW_VERSION="2.12.0" +TENSORFLOW_MACOS_VERSION="2.12.0" +TENSORFLOW_METAL_VERSION="0.8.0" + +display_help() { + cat <&2" #Don't change anything higher than the maximum verbosity allowed. +done + +for v in $( #From the verbosity level one higher than requested, through the maximum; + seq $((VERBOSITY + 1)) $MAXVERBOSITY +); do + (("$v" > "2")) && eval exec "$v>/dev/null" #Redirect these to bitbucket, provided that they don't match stdout and stderr. +done + +# Example of how to use the verbosity levels. +# printf "%s\n" "This message is seen at verbosity level 1 and above." >&3 +# printf "%s\n" "This message is seen at verbosity level 2 and above." >&4 +# printf "%s\n" "This message is seen at verbosity level 3 and above." >&5 + +# Debug variable dump at max verbosity +echo "BRANCH: $BRANCH +DIR: $DIR +GIT_REPO: $GIT_REPO +INTERACTIVE: $INTERACTIVE +PUBLIC: $PUBLIC +RUNPOD: $RUNPOD +SKIP_SPACE_CHECK: $SKIP_SPACE_CHECK +VERBOSITY: $VERBOSITY +Script directory is ${SCRIPT_DIR}." >&5 + +# This must be set after the getopts loop to account for $DIR changes. +PARENT_DIR="$(dirname "${DIR}")" +VENV_DIR="$DIR/venv" + +if [ -w "$PARENT_DIR" ] && [ ! -d "$DIR" ]; then + echo "Creating install folder ${DIR}." + mkdir "$DIR" +fi + +if [ ! -w "$DIR" ]; then + echo "We cannot write to ${DIR}." + echo "Please ensure the install directory is accurate and you have the correct permissions." + exit 1 +fi + +# Shared functions +# This checks for free space on the installation drive and returns that in Gb. +size_available() { + local folder + if [ -d "$DIR" ]; then + folder="$DIR" + elif [ -d "$PARENT_DIR" ]; then + folder="$PARENT_DIR" + elif [ -d "$(echo "$DIR" | cut -d "/" -f2)" ]; then + folder="$(echo "$DIR" | cut -d "/" -f2)" + else + echo "We are assuming a root drive install for space-checking purposes." + folder='/' + fi + + local FREESPACEINKB + FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" + echo "Detected available space in Kb: $FREESPACEINKB" >&5 + local FREESPACEINGB + FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) + echo "$FREESPACEINGB" +} + +# The expected usage is create_symlinks symlink target_file +create_symlinks() { + echo "Checking symlinks now." + # Next line checks for valid symlink + if [ -L "$1" ]; then + # Check if the linked file exists and points to the expected file + if [ -e "$1" ] && [ "$(readlink "$1")" == "$2" ]; then + echo "$(basename "$1") symlink looks fine. Skipping." + else + if [ -f "$2" ]; then + echo "Broken symlink detected. Recreating $(basename "$1")." + rm "$1" && + ln -s "$2" "$1" + else + echo "$2 does not exist. Nothing to link." + fi + fi + else + echo "Linking $(basename "$1")." + ln -s "$2" "$1" + fi +} + +install_python_dependencies() { + # Switch to local virtual env + echo "Switching to virtual Python environment." + if command -v python3 >/dev/null; then + python3 -m venv "$DIR/venv" + elif command -v python3.10 >/dev/null; then + python3.10 -m venv "$DIR/venv" + else + echo "Valid python3 or python3.10 binary not found." + echo "Cannot proceed with the python steps." + return 1 + fi + + # Activate the virtual environment + source "$DIR/venv/bin/activate" + + # Updating pip if there is one + echo "Checking for pip updates before Python operations." + pip install --upgrade pip >&3 + + echo "Installing python dependencies. This could take a few minutes as it downloads files." + echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." + case "$OSTYPE" in + "linux-gnu"*) pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 \ + --extra-index-url https://download.pytorch.org/whl/cu116 >&3 && + pip install -U -I --no-deps \ + https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 ;; + "darwin"*) pip install torch==2.0.0 torchvision==0.15.1 \ + -f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 ;; + "cygwin") + : + ;; + "msys") + : + ;; + esac + + if [ "$RUNPOD" = true ]; then + echo "Installing tenssort." + pip install tensorrt >&3 + fi + + # DEBUG ONLY (Update this version number to whatever PyCharm recommends) + # pip install pydevd-pycharm~=223.8836.43 + + #This will copy our requirements.txt file out and make the khoya_ss lib a dynamic location then cleanup. + local TEMP_REQUIREMENTS_FILE="$DIR/requirements_tmp_for_setup.txt" + echo "Copying $DIR/requirements.txt to $TEMP_REQUIREMENTS_FILE" >&3 + echo "Replacing the . for lib to our DIR variable in $TEMP_REQUIREMENTS_FILE." >&3 + awk -v dir="$DIR" '/#.*kohya_ss.*library/{print; getline; sub(/^\.$/, dir)}1' "$DIR/requirements.txt" >"$TEMP_REQUIREMENTS_FILE" + + # This will check if macOS is running then determine if M1+ or Intel CPU. + # It will append the appropriate packages to the requirements.txt file. + # Other OSs won't be affected and the version variables are at the top of this file. + if [[ "$(uname)" == "Darwin" ]]; then + # Check if the processor is Apple Silicon (arm64) + if [[ "$(uname -m)" == "arm64" ]]; then + echo "tensorflow-macos==$TENSORFLOW_MACOS_VERSION" >>"$TEMP_REQUIREMENTS_FILE" + echo "tensorflow-metal==$TENSORFLOW_METAL_VERSION" >>"$TEMP_REQUIREMENTS_FILE" + # Check if the processor is Intel (x86_64) + elif [[ "$(uname -m)" == "x86_64" ]]; then + echo "tensorflow==$TENSORFLOW_VERSION" >>"$TEMP_REQUIREMENTS_FILE" + fi + fi + + if [ $VERBOSITY == 2 ]; then + python -m pip install --quiet --use-pep517 --upgrade -r "$TEMP_REQUIREMENTS_FILE" >&3 + else + python -m pip install --use-pep517 --upgrade -r "$TEMP_REQUIREMENTS_FILE" >&3 + fi + + echo "Removing the temp requirements file." + if [ -f "$TEMP_REQUIREMENTS_FILE" ]; then + rm -f "$TEMP_REQUIREMENTS_FILE" + fi + + if [ -n "$VIRTUAL_ENV" ]; then + if command -v deactivate >/dev/null; then + echo "Exiting Python virtual environment." + deactivate + else + echo "deactivate command not found. Could still be in the Python virtual environment." + fi + fi +} + +# Attempt to non-interactively install a default accelerate config file unless specified otherwise. +# Documentation for order of precedence locations for configuration file for automated installation: +# https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations +configure_accelerate() { + echo "Source accelerate config location: $DIR/config_files/accelerate/default_config.yaml" >&3 + if [ "$INTERACTIVE" = true ]; then + accelerate config + else + if env_var_exists HF_HOME; then + if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then + mkdir -p "$HF_HOME/accelerate/" && + echo "Target accelerate config location: $HF_HOME/accelerate/default_config.yaml" >&3 + cp "$DIR/config_files/accelerate/default_config.yaml" "$HF_HOME/accelerate/default_config.yaml" && + echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml" + fi + elif env_var_exists XDG_CACHE_HOME; then + if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then + mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" && + echo "Target accelerate config location: $XDG_CACHE_HOME/accelerate/default_config.yaml" >&3 + cp "$DIR/config_files/accelerate/default_config.yaml" "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" && + echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" + fi + elif env_var_exists HOME; then + if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then + mkdir -p "$HOME/.cache/huggingface/accelerate" && + echo "Target accelerate config location: $HOME/accelerate/default_config.yaml" >&3 + cp "$DIR/config_files/accelerate/default_config.yaml" "$HOME/.cache/huggingface/accelerate/default_config.yaml" && + echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml" + fi + else + echo "Could not place the accelerate configuration file. Please configure manually." + sleep 2 + accelerate config + fi + fi +} + +# Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected +check_storage_space() { + if [ "$SKIP_SPACE_CHECK" = false ]; then + if [ "$(size_available)" -lt 10 ]; then + echo "You have less than 10Gb of free space. This installation may fail." + MSGTIMEOUT=10 # In seconds + MESSAGE="Continuing in..." + echo "Press control-c to cancel the installation." + for ((i = MSGTIMEOUT; i >= 0; i--)); do + printf "\r${MESSAGE} %ss. " "${i}" + sleep 1 + done + fi + fi +} + +# These are the git operations that will run to update or clone the repo +update_kohya_ss() { + if [ "$SKIP_GIT_UPDATE" = false ]; then + if command -v git >/dev/null; then + # First, we make sure there are no changes that need to be made in git, so no work is lost. + if [ "$(git -C "$DIR" status --porcelain=v1 2>/dev/null | wc -l)" -gt 0 ] && + echo "These files need to be committed or discarded: " >&4 && + git -C "$DIR" status >&4; then + echo "There are changes that need to be committed or discarded in the repo in $DIR." + echo "Commit those changes or run this script with -n to skip git operations entirely." + exit 1 + fi + + echo "Attempting to clone $GIT_REPO." + if [ ! -d "$DIR/.git" ]; then + echo "Cloning and switching to $GIT_REPO:$BRANCH" >&4 + git -C "$PARENT_DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3 + git -C "$DIR" switch "$BRANCH" >&4 + else + echo "git repo detected. Attempting to update repository instead." + echo "Updating: $GIT_REPO" + git -C "$DIR" pull "$GIT_REPO" "$BRANCH" >&3 + if ! git -C "$DIR" switch "$BRANCH" >&4; then + echo "Branch $BRANCH did not exist. Creating it." >&4 + git -C "$DIR" switch -c "$BRANCH" >&4 + fi + fi + else + echo "You need to install git." + echo "Rerun this after installing git or run this script with -n to skip the git operations." + fi + else + echo "Skipping git operations." + fi +} + +# Start OS-specific detection and work +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + # Check if root or sudo + root=false + if [ "$EUID" = 0 ]; then + root=true + elif command -v id >/dev/null && [ "$(id -u)" = 0 ]; then + root=true + elif [ "$UID" = 0 ]; then + root=true + fi + + get_distro_name() { + local line + if [ -f /etc/os-release ]; then + # We search for the line starting with ID= + # Then we remove the ID= prefix to get the name itself + line="$(grep -Ei '^ID=' /etc/os-release)" + echo "Raw detected os-release distro line: $line" >&5 + line=${line##*=} + echo "$line" + return 0 + elif command -v python >/dev/null; then + line="$(python -mplatform)" + echo "$line" + return 0 + elif command -v python3 >/dev/null; then + line="$(python3 -mplatform)" + echo "$line" + return 0 + else + line="None" + echo "$line" + return 1 + fi + } + + # We search for the line starting with ID_LIKE= + # Then we remove the ID_LIKE= prefix to get the name itself + # This is the "type" of distro. For example, Ubuntu returns "debian". + get_distro_family() { + local line + if [ -f /etc/os-release ]; then + if grep -Eiq '^ID_LIKE=' /etc/os-release >/dev/null; then + line="$(grep -Ei '^ID_LIKE=' /etc/os-release)" + echo "Raw detected os-release distro family line: $line" >&5 + line=${line##*=} + echo "$line" + return 0 + else + line="None" + echo "$line" + return 1 + fi + else + line="None" + echo "$line" + return 1 + fi + } + + check_storage_space + update_kohya_ss + + distro=get_distro_name + family=get_distro_family + echo "Raw detected distro string: $distro" >&4 + echo "Raw detected distro family string: $family" >&4 + + echo "Installing Python TK if not found on the system." + if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then + echo "Ubuntu detected." + if [ $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") = 0 ]; then + if [ "$root" = true ]; then + apt update -y >&3 && apt install -y python3-tk >&3 + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + else + echo "Python TK found! Skipping install!" + fi + elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then + echo "Redhat or Redhat base detected." + if ! rpm -qa | grep -qi python3-tkinter; then + if [ "$root" = true ]; then + dnf install python3-tkinter -y >&3 + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + fi + elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then + echo "Arch Linux or Arch base detected." + if ! pacman -Qi tk >/dev/null; then + if [ "$root" = true ]; then + pacman --noconfirm -S tk >&3 + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + fi + elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then + echo "OpenSUSE detected." + if ! rpm -qa | grep -qi python-tk; then + if [ "$root" = true ]; then + zypper install -y python-tk >&3 + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + fi + elif [ "$distro" = "None" ] || [ "$family" = "None" ]; then + if [ "$distro" = "None" ]; then + echo "We could not detect your distribution of Linux. Please file a bug report on github with the contents of your /etc/os-release file." + fi + + if [ "$family" = "None" ]; then + echo "We could not detect the family of your Linux distribution. Please file a bug report on github with the contents of your /etc/os-release file." + fi + fi + + install_python_dependencies + + # We need just a little bit more setup for non-interactive environments + if [ "$RUNPOD" = true ]; then + # Symlink paths + libnvinfer_plugin_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" + libnvinfer_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7" + libcudart_symlink="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0" + + #Target file paths + libnvinfer_plugin_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8" + libnvinfer_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.8" + libcudart_target="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12" + + echo "Checking symlinks now." + create_symlinks "$libnvinfer_plugin_symlink" "$libnvinfer_plugin_target" + create_symlinks "$libnvinfer_symlink" "$libnvinfer_target" + create_symlinks "$libcudart_symlink" "$libcudart_target" + + if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" + else + echo "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/ not found; not linking library." + fi + + if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/" + else + echo "${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ not found; not linking library." + fi + + configure_accelerate + + # This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete. + if command -v bash >/dev/null; then + if [ "$PUBLIC" = false ]; then + bash "$DIR"/gui.sh + else + bash "$DIR"/gui.sh --share + fi + else + # This shouldn't happen, but we're going to try to help. + if [ "$PUBLIC" = false ]; then + sh "$DIR"/gui.sh + else + sh "$DIR"/gui.sh --share + fi + fi + fi + + echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start." + echo "Please note if you'd like to expose your public server you need to run ./gui.sh --share" +elif [[ "$OSTYPE" == "darwin"* ]]; then + # The initial setup script to prep the environment on macOS + # xformers has been omitted as that is for Nvidia GPUs only + + if ! command -v brew >/dev/null; then + echo "Please install homebrew first. This is a requirement for the remaining setup." + echo "You can find that here: https://brew.sh" + #shellcheck disable=SC2016 + echo 'The "brew" command should be in $PATH to be detected.' + exit 1 + fi + + check_storage_space + + # Install base python packages + echo "Installing Python 3.10 if not found." + if ! brew ls --versions python@3.10 >/dev/null; then + echo "Installing Python 3.10." + brew install python@3.10 >&3 + else + echo "Python 3.10 found!" + fi + echo "Installing Python-TK 3.10 if not found." + if ! brew ls --versions python-tk@3.10 >/dev/null; then + echo "Installing Python TK 3.10." + brew install python-tk@3.10 >&3 + else + echo "Python Tkinter 3.10 found!" + fi + + update_kohya_ss + + if ! install_python_dependencies; then + echo "You may need to install Python. The command for this is brew install python@3.10." + fi + + configure_accelerate + echo -e "Setup finished! Run ./gui.sh to start." +elif [[ "$OSTYPE" == "cygwin" ]]; then + # Cygwin is a standalone suite of Linux utilities on Windows + echo "This hasn't been validated on cygwin yet." +elif [[ "$OSTYPE" == "msys" ]]; then + # MinGW has the msys environment which is a standalone suite of Linux utilities on Windows + # "git bash" on Windows may also be detected as msys. + echo "This hasn't been validated in msys (mingw) on Windows yet." +fi diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..754673f12e5ee658c432ddf64a335f4923e58925 --- /dev/null +++ b/style.css @@ -0,0 +1,21 @@ +#open_folder_small{ + height: auto; + min-width: auto; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; +} + +#open_folder{ + height: auto; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; +} + +#number_input{ + min-width: min-content; + flex-grow: 0.3; + padding-left: 0.75em; + padding-right: 0.75em; +} \ No newline at end of file diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..da5467d20261b01e0cc30c60aa45144ff794b973 --- /dev/null +++ b/textual_inversion_gui.py @@ -0,0 +1,1010 @@ +# v1: initial release +# v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation +# v3.1: Adding captionning of images to utilities + +import gradio as gr +import json +import math +import os +import subprocess +import pathlib +import argparse +from library.common_gui import ( + get_folder_path, + remove_doublequote, + get_file_path, + get_any_file_path, + get_saveasfile_path, + color_aug_changed, + save_inference_file, + gradio_advanced_training, + run_cmd_advanced_training, + run_cmd_training, + gradio_training, + gradio_config, + gradio_source_model, + # set_legacy_8bitadam, + update_my_data, + check_if_model_exist, +) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample +from easygui import msgbox + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💟 +document_symbol = '\U0001F4C4' # 📄 + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get('label') == 'True' else False + + if save_as_bool: + print('Save as...') + file_path = get_saveasfile_path(file_path) + else: + print('Save...') + if file_path == None or file_path == '': + file_path = get_saveasfile_path(file_path) + + # print(file_path) + + if file_path == None or file_path == '': + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Return the values of the variables as a dictionary + variables = { + name: value + for name, value in parameters # locals().items() + if name + not in [ + 'file_path', + 'save_as', + ] + } + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + # Save the data to the selected file + with open(file_path, 'w') as file: + json.dump(variables, file, indent=2) + + return file_path + + +def open_configuration( + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == '' and not file_path == None: + # load variables from JSON file + with open(file_path, 'r') as f: + my_data = json.load(f) + print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ['ask_for_file', 'file_path']: + values.append(my_data.get(key, value)) + return tuple(values) + + +def train_model( + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, +): + if pretrained_model_name_or_path == '': + msgbox('Source model information is missing') + return + + if train_data_dir == '': + msgbox('Image folder path is missing') + return + + if not os.path.exists(train_data_dir): + msgbox('Image folder does not exist') + return + + if reg_data_dir != '': + if not os.path.exists(reg_data_dir): + msgbox('Regularisation folder does not exist') + return + + if output_dir == '': + msgbox('Output folder path is missing') + return + + if token_string == '': + msgbox('Token string is missing') + return + + if init_word == '': + msgbox('Init word is missing') + return + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if check_if_model_exist(output_name, output_dir, save_model_as): + return + + # Get a list of all subfolders in train_data_dir + subfolders = [ + f + for f in os.listdir(train_data_dir) + if os.path.isdir(os.path.join(train_data_dir, f)) + ] + + total_steps = 0 + + # Loop through each subfolder and extract the number of repeats + for folder in subfolders: + # Extract the number of repeats from the folder name + repeats = int(folder.split('_')[0]) + + # Count the number of images in the folder + num_images = len( + [ + f + for f, lower_f in ( + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) + ) + if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) + ] + ) + + # Calculate the total number of steps for this folder + steps = repeats * num_images + total_steps += steps + + # Print the result + print(f'Folder {folder}: {steps} steps') + + # Print the result + # print(f"{total_steps} total steps") + + if reg_data_dir == '': + reg_factor = 1 + else: + print( + 'Regularisation images are used... Will double the number of steps required...' + ) + reg_factor = 2 + + # calculate max_train_steps + if max_train_steps == '': + max_train_steps = int( + math.ceil( + float(total_steps) + / int(train_batch_size) + * int(epoch) + * int(reg_factor) + ) + ) + else: + max_train_steps = int(max_train_steps) + + print(f'max_train_steps = {max_train_steps}') + + # calculate stop encoder training + if stop_text_encoder_training_pct == None: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + ) + print(f'stop_text_encoder_training = {stop_text_encoder_training}') + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + print(f'lr_warmup_steps = {lr_warmup_steps}') + + run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_textual_inversion.py"' + if v2: + run_cmd += ' --v2' + if v_parameterization: + run_cmd += ' --v_parameterization' + if enable_bucket: + run_cmd += ' --enable_bucket' + if no_token_padding: + run_cmd += ' --no_token_padding' + run_cmd += ( + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' + ) + run_cmd += f' --train_data_dir="{train_data_dir}"' + if len(reg_data_dir): + run_cmd += f' --reg_data_dir="{reg_data_dir}"' + run_cmd += f' --resolution={max_resolution}' + run_cmd += f' --output_dir="{output_dir}"' + run_cmd += f' --logging_dir="{logging_dir}"' + if not stop_text_encoder_training == 0: + run_cmd += ( + f' --stop_text_encoder_training={stop_text_encoder_training}' + ) + if not save_model_as == 'same as source model': + run_cmd += f' --save_model_as={save_model_as}' + # if not resume == '': + # run_cmd += f' --resume={resume}' + if not float(prior_loss_weight) == 1.0: + run_cmd += f' --prior_loss_weight={prior_loss_weight}' + if not vae == '': + run_cmd += f' --vae="{vae}"' + if not output_name == '': + run_cmd += f' --output_name="{output_name}"' + if int(max_token_length) > 75: + run_cmd += f' --max_token_length={max_token_length}' + if not max_train_epochs == '': + run_cmd += f' --max_train_epochs="{max_train_epochs}"' + if not max_data_loader_n_workers == '': + run_cmd += ( + f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' + ) + if int(gradient_accumulation_steps) > 1: + run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' + + run_cmd += run_cmd_training( + learning_rate=learning_rate, + lr_scheduler=lr_scheduler, + lr_warmup_steps=lr_warmup_steps, + train_batch_size=train_batch_size, + max_train_steps=max_train_steps, + save_every_n_epochs=save_every_n_epochs, + mixed_precision=mixed_precision, + save_precision=save_precision, + seed=seed, + caption_extension=caption_extension, + cache_latents=cache_latents, + optimizer=optimizer, + optimizer_args=optimizer_args, + ) + + run_cmd += run_cmd_advanced_training( + max_train_epochs=max_train_epochs, + max_data_loader_n_workers=max_data_loader_n_workers, + max_token_length=max_token_length, + resume=resume, + save_state=save_state, + mem_eff_attn=mem_eff_attn, + clip_skip=clip_skip, + flip_aug=flip_aug, + color_aug=color_aug, + shuffle_caption=shuffle_caption, + gradient_checkpointing=gradient_checkpointing, + full_fp16=full_fp16, + xformers=xformers, + # use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + noise_offset=noise_offset, + additional_parameters=additional_parameters, + vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, + ) + run_cmd += f' --token_string="{token_string}"' + run_cmd += f' --init_word="{init_word}"' + run_cmd += f' --num_vectors_per_token={num_vectors_per_token}' + if not weights == '': + run_cmd += f' --weights="{weights}"' + if template == 'object template': + run_cmd += f' --use_object_template' + elif template == 'style template': + run_cmd += f' --use_style_template' + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f'{output_dir}/{output_name}') + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file(output_dir, v2, v_parameterization, output_name) + + +def ti_tab( + train_data_dir=gr.Textbox(), + reg_data_dir=gr.Textbox(), + output_dir=gr.Textbox(), + logging_dir=gr.Textbox(), +): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + gr.Markdown('Train a TI using kohya textual inversion python code...') + ( + button_open_config, + button_save_config, + button_save_as_config, + config_file_name, + button_load_config, + ) = gradio_config() + + ( + pretrained_model_name_or_path, + v2, + v_parameterization, + save_model_as, + model_list, + ) = gradio_source_model( + save_model_as_choices=[ + 'ckpt', + 'safetensors', + ] + ) + + with gr.Tab('Folders'): + with gr.Row(): + train_data_dir = gr.Textbox( + label='Image folder', + placeholder='Folder where the training folders containing the images are located', + ) + train_data_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + train_data_dir_input_folder.click( + get_folder_path, + outputs=train_data_dir, + show_progress=False, + ) + reg_data_dir = gr.Textbox( + label='Regularisation folder', + placeholder='(Optional) Folder where where the regularization folders containing the images are located', + ) + reg_data_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + reg_data_dir_input_folder.click( + get_folder_path, + outputs=reg_data_dir, + show_progress=False, + ) + with gr.Row(): + output_dir = gr.Textbox( + label='Model output folder', + placeholder='Folder to output trained model', + ) + output_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + output_dir_input_folder.click( + get_folder_path, + outputs=output_dir, + show_progress=False, + ) + logging_dir = gr.Textbox( + label='Logging folder', + placeholder='Optional: enable logging and output TensorBoard log to this folder', + ) + logging_dir_input_folder = gr.Button( + '📂', elem_id='open_folder_small' + ) + logging_dir_input_folder.click( + get_folder_path, + outputs=logging_dir, + show_progress=False, + ) + with gr.Row(): + output_name = gr.Textbox( + label='Model output name', + placeholder='Name of the model to output', + value='last', + interactive=True, + ) + train_data_dir.change( + remove_doublequote, + inputs=[train_data_dir], + outputs=[train_data_dir], + ) + reg_data_dir.change( + remove_doublequote, + inputs=[reg_data_dir], + outputs=[reg_data_dir], + ) + output_dir.change( + remove_doublequote, + inputs=[output_dir], + outputs=[output_dir], + ) + logging_dir.change( + remove_doublequote, + inputs=[logging_dir], + outputs=[logging_dir], + ) + with gr.Tab('Training parameters'): + with gr.Row(): + weights = gr.Textbox( + label='Resume TI training', + placeholder='(Optional) Path to existing TI embeding file to keep training', + ) + weights_file_input = gr.Button('📂', elem_id='open_folder_small') + weights_file_input.click( + get_file_path, + outputs=weights, + show_progress=False, + ) + with gr.Row(): + token_string = gr.Textbox( + label='Token string', + placeholder='eg: cat', + ) + init_word = gr.Textbox( + label='Init word', + value='*', + ) + num_vectors_per_token = gr.Slider( + minimum=1, + maximum=75, + value=1, + step=1, + label='Vectors', + ) + max_train_steps = gr.Textbox( + label='Max train steps', + placeholder='(Optional) Maximum number of steps', + ) + template = gr.Dropdown( + label='Template', + choices=[ + 'caption', + 'object template', + 'style template', + ], + value='caption', + ) + ( + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + num_cpu_threads_per_process, + seed, + caption_extension, + cache_latents, + optimizer, + optimizer_args, + ) = gradio_training( + learning_rate_value='1e-5', + lr_scheduler_value='cosine', + lr_warmup_value='10', + ) + with gr.Row(): + max_resolution = gr.Textbox( + label='Max resolution', + value='512,512', + placeholder='512,512', + ) + stop_text_encoder_training = gr.Slider( + minimum=0, + maximum=100, + value=0, + step=1, + label='Stop text encoder training', + ) + enable_bucket = gr.Checkbox(label='Enable buckets', value=True) + with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + no_token_padding = gr.Checkbox( + label='No token padding', value=False + ) + gradient_accumulation_steps = gr.Number( + label='Gradient accumulate steps', value='1' + ) + with gr.Row(): + prior_loss_weight = gr.Number( + label='Prior loss weight', value=1.0 + ) + vae = gr.Textbox( + label='VAE', + placeholder='(Optiona) path to checkpoint of vae to replace for training', + ) + vae_button = gr.Button('📂', elem_id='open_folder_small') + vae_button.click( + get_any_file_path, + outputs=vae, + show_progress=False, + ) + ( + # use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + noise_offset, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ) = gradio_advanced_training() + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latents], + ) + + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() + + with gr.Tab('Tools'): + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=train_data_dir, + reg_data_dir_input=reg_data_dir, + output_dir_input=output_dir, + logging_dir_input=logging_dir, + ) + + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + show_progress=False, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + ] + + button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=settings_list, + show_progress=False, + ) + + return ( + train_data_dir, + reg_data_dir, + output_dir, + logging_dir, + ) + + +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('Dreambooth TI'): + ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) = ti_tab() + with gr.Tab('Utilities'): + utilities_tab( + train_data_dir_input=train_data_dir_input, + reg_data_dir_input=reg_data_dir_input, + output_dir_input=output_dir_input, + logging_dir_input=logging_dir_input, + enable_copy_info_button=True, + ) + + # Show the interface + launch_kwargs = {} + if not kwargs.get('username', None) == '': + launch_kwargs['auth'] = ( + kwargs.get('username', None), + kwargs.get('password', None), + ) + if kwargs.get('server_port', 0) > 0: + launch_kwargs['server_port'] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + + args = parser.parse_args() + + UI( + username=args.username, + password=args.password, + inbrowser=args.inbrowser, + server_port=args.server_port, + ) diff --git a/tools/canny.py b/tools/canny.py new file mode 100644 index 0000000000000000000000000000000000000000..5e0806898786e5251d2e715e33896bb4958a35e8 --- /dev/null +++ b/tools/canny.py @@ -0,0 +1,30 @@ +import argparse +import cv2 + + +def canny(args): + img = cv2.imread(args.input) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + canny_img = cv2.Canny(img, args.thres1, args.thres2) + # canny_img = 255 - canny_img + + cv2.imwrite(args.output, canny_img) + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, default=None, help="input path") + parser.add_argument("--output", type=str, default=None, help="output path") + parser.add_argument("--thres1", type=int, default=32, help="thres1") + parser.add_argument("--thres2", type=int, default=224, help="thres2") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + canny(args) diff --git a/tools/caption.py b/tools/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9dd53a966786cadc372a2577c4102aa231aa8b --- /dev/null +++ b/tools/caption.py @@ -0,0 +1,69 @@ +# This script will create the caption text files in the specified folder using the specified file pattern and caption text. +# +# eg: python caption.py D:\some\folder\location "*.png, *.jpg, *.webp" "some caption text" + +import argparse +# import glob +# import os +from pathlib import Path + +def create_caption_files(image_folder: str, file_pattern: str, caption_text: str, caption_file_ext: str, overwrite: bool): + # Split the file patterns string and strip whitespace from each pattern + patterns = [pattern.strip() for pattern in file_pattern.split(",")] + + # Create a Path object for the image folder + folder = Path(image_folder) + + # Iterate over the file patterns + for pattern in patterns: + # Use the glob method to match the file patterns + files = folder.glob(pattern) + + # Iterate over the matched files + for file in files: + # Check if a text file with the same name as the current file exists in the folder + txt_file = file.with_suffix(caption_file_ext) + if not txt_file.exists() or overwrite: + # Create a text file with the caption text in the folder, if it does not already exist + # or if the overwrite argument is True + with open(txt_file, "w") as f: + f.write(caption_text) + +def main(): + # Define command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("image_folder", type=str, help="the folder where the image files are located") + parser.add_argument("--file_pattern", type=str, default="*.png, *.jpg, *.jpeg, *.webp", help="the pattern to match the image file names") + parser.add_argument("--caption_file_ext", type=str, default=".caption", help="the caption file extension.") + parser.add_argument("--overwrite", action="store_true", default=False, help="whether to overwrite existing caption files") + + # Create a mutually exclusive group for the caption_text and caption_file arguments + group = parser.add_mutually_exclusive_group() + group.add_argument("--caption_text", type=str, help="the text to include in the caption files") + group.add_argument("--caption_file", type=argparse.FileType("r"), help="the file containing the text to include in the caption files") + + # Parse the command-line arguments + args = parser.parse_args() + image_folder = args.image_folder + file_pattern = args.file_pattern + caption_file_ext = args.caption_file_ext + overwrite = args.overwrite + + # Get the caption text from either the caption_text or caption_file argument + if args.caption_text: + caption_text = args.caption_text + elif args.caption_file: + caption_text = args.caption_file.read() + + # Create a Path object for the image folder + folder = Path(image_folder) + + # Check if the image folder exists and is a directory + if not folder.is_dir(): + raise ValueError(f"{image_folder} is not a valid directory.") + + # Create the caption files + create_caption_files(image_folder, file_pattern, caption_text, caption_file_ext, overwrite) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/convert_diffusers20_original_sd.md b/tools/convert_diffusers20_original_sd.md new file mode 100644 index 0000000000000000000000000000000000000000..4763e5fd563e603a31ac5b3b9e085d865be74d32 --- /dev/null +++ b/tools/convert_diffusers20_original_sd.md @@ -0,0 +1,46 @@ +# How to use + +##Diffusers to Stable Diffusion .ckpt conversion + +Specify the folder of the source model and the destination .ckpt file as follows (actually written on one line). The v1/v2 version is automatically determined. + +``` +python convert_diffusers20_original_sd.py ..\models\diffusers_model + ..\models\sd.ckpt +``` + +Note that v2 Diffusers' Text Encoder has only 22 layers, and if you convert it to Stable Diffusion as it is, the weights will be insufficient, so the weights of the 22 layers will be copied as the 23rd layer. The weight of the 23rd layer is not used during image generation, so it has no effect. Similarly, text_projection logit_scale also adds dummy weights (it doesn't seem to be used for image generation). + +## Stable Diffusion .ckpt to Diffusers + +Enter the following: + +``` +python convert_diffusers20_original_sd.py ..\models\sd.ckpt + ..\models\diffusers_model + --v2 --reference_model stabilityai/stable-diffusion-2 +``` + +Specify the .ckpt file and the destination folder as arguments. +Model judgment is not possible, so please use the `--v1` option or the `--v2` option depending on the model. + +Also, since `.ckpt` does not contain scheduler and tokenizer information, you need to copy them from some existing Diffusers model. Please specify with `--reference_model`. You can specify the HuggingFace id or a local model directory. + +If you don't have a local model, you can specify "stabilityai/stable-diffusion-2" or "stabilityai/stable-diffusion-2-base" for v2. +For v1.4/1.5, "CompVis/stable-diffusion-v1-4" is fine (v1.4 and v1.5 seem to be the same). + +## What can you do? + +`--fp16 / --bf16 / --float` + +You can specify the data format when saving checkpoint. --fp16 only, also valid when loading Diffusers models. + +`--epoch / --global_step` + +When saving checkpoint, write epoch and global_step with the specified values. If not specified, both will be 0. + +## Conclusion + +Some people may be troubled by the Diffusers model due to the poor inference environment. I hope it helps a little. + +(Note that converting the data format from checkpoint to checkpoint is also possible, although it has not been tested.)  \ No newline at end of file diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7cc1c58b5b0b78f69e2a719beeb839bd13d087 --- /dev/null +++ b/tools/convert_diffusers20_original_sd.py @@ -0,0 +1,94 @@ +# convert Diffusers v1.x/v2.0 model to original Stable Diffusion + +import argparse +import os +import torch +from diffusers import StableDiffusionPipeline + +import library.model_util as model_util + + +def convert(args): + # 匕数を確認する + load_dtype = torch.float16 if args.fp16 else None + + save_dtype = None + if args.fp16: + save_dtype = torch.float16 + elif args.bf16: + save_dtype = torch.bfloat16 + elif args.float: + save_dtype = torch.float + + is_load_ckpt = os.path.isfile(args.model_to_load) + is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 + + assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み蟌みにはv1/v2指定が必芁です" + assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers圢匏での保存には参照モデルが必芁です" + + # モデルを読み蟌む + msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) + print(f"loading {msg}: {args.model_to_load}") + + if is_load_ckpt: + v2_model = args.v2 + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) + else: + pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + + if args.v1 == args.v2: + # 自動刀定する + v2_model = unet.config.cross_attention_dim == 1024 + print("checking model version: model is " + ('v2' if v2_model else 'v1')) + else: + v2_model = not args.v1 + + # 倉換しお保存する + msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" + print(f"converting and saving as {msg}: {args.model_to_save}") + + if is_save_ckpt: + original_model = args.model_to_load if is_load_ckpt else None + key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, + original_model, args.epoch, args.global_step, save_dtype, vae) + print(f"model saved. total converted state_dict keys: {key_count}") + else: + print(f"copy scheduler/tokenizer config from: {args.reference_model}") + model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors) + print(f"model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--v1", action='store_true', + help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み蟌む') + parser.add_argument("--v2", action='store_true', + help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み蟌む') + parser.add_argument("--fp16", action='store_true', + help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16圢匏で読み蟌みDiffusers圢匏のみ察応、保存するcheckpointのみ察応') + parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16圢匏で保存するcheckpointのみ察応') + parser.add_argument("--float", action='store_true', + help='save as float (checkpoint only) / float(float32)圢匏で保存するcheckpointのみ察応') + parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに蚘録するepoch数の倀') + parser.add_argument("--global_step", type=int, default=0, + help='global_step to write to checkpoint / checkpointに蚘録するglobal_stepの倀') + parser.add_argument("--reference_model", type=str, default=None, + help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピヌ元のDiffusersモデル、Diffusers圢匏で保存するずきに必芁") + parser.add_argument("--use_safetensors", action='store_true', + help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors圢匏で保存するcheckpointは拡匵子で自動刀定") + + parser.add_argument("model_to_load", type=str, default=None, + help="model to load: checkpoint file or Diffusers model's directory / 読み蟌むモデル、checkpointかDiffusers圢匏モデルのディレクトリ") + parser.add_argument("model_to_save", type=str, default=None, + help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 倉換埌のモデル、拡匵子がある堎合はcheckpoint、ない堎合はDiffusesモデルずしお保存") + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + convert(args) diff --git a/tools/convert_images_to_hq_jpg.py b/tools/convert_images_to_hq_jpg.py new file mode 100644 index 0000000000000000000000000000000000000000..efc40477892194a49ae58f116c6a3db9c46a9cd2 --- /dev/null +++ b/tools/convert_images_to_hq_jpg.py @@ -0,0 +1,57 @@ +import argparse +import glob +import os +from pathlib import Path +from PIL import Image + + +def main(): + # Define the command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("directory", type=str, + help="the directory containing the images to be converted") + parser.add_argument("--in_ext", type=str, default="webp", + help="the input file extension") + parser.add_argument("--quality", type=int, default=95, + help="the JPEG quality (0-100)") + parser.add_argument("--delete_originals", action="store_true", + help="whether to delete the original files after conversion") + + # Parse the command-line arguments + args = parser.parse_args() + directory = args.directory + in_ext = args.in_ext + out_ext = "jpg" + quality = args.quality + delete_originals = args.delete_originals + + # Create the file pattern string using the input file extension + file_pattern = f"*.{in_ext}" + + # Get the list of files in the directory that match the file pattern + files = glob.glob(os.path.join(directory, file_pattern)) + + # Iterate over the list of files + for file in files: + # Open the image file + img = Image.open(file) + + # Create a new file path with the output file extension + new_path = Path(file).with_suffix(f".{out_ext}") + + # Check if the output file already exists + if new_path.exists(): + # Skip the conversion if the output file already exists + print(f"Skipping {file} because {new_path} already exists") + continue + + # Save the image to the new file as high-quality JPEG + img.save(new_path, quality=quality, optimize=True) + + # Optionally, delete the original file + if delete_originals: + os.remove(file) + + +if __name__ == "__main__": + main() diff --git a/tools/convert_images_to_webp.py b/tools/convert_images_to_webp.py new file mode 100644 index 0000000000000000000000000000000000000000..4833459e107d016db39092f63a9fb14a3b8935bf --- /dev/null +++ b/tools/convert_images_to_webp.py @@ -0,0 +1,57 @@ +import argparse +import glob +import os +from pathlib import Path +from PIL import Image + + +def main(): + # Define the command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("directory", type=str, + help="the directory containing the images to be converted") + parser.add_argument("--in_ext", type=str, default="webp", + help="the input file extension") + parser.add_argument("--delete_originals", action="store_true", + help="whether to delete the original files after conversion") + + # Parse the command-line arguments + args = parser.parse_args() + directory = args.directory + in_ext = args.in_ext + delete_originals = args.delete_originals + + # Set the output file extension to .webp + out_ext = "webp" + + # Create the file pattern string using the input file extension + file_pattern = f"*.{in_ext}" + + # Get the list of files in the directory that match the file pattern + files = glob.glob(os.path.join(directory, file_pattern)) + + # Iterate over the list of files + for file in files: + # Open the image file + img = Image.open(file) + + # Create a new file path with the output file extension + new_path = Path(file).with_suffix(f".{out_ext}") + print(new_path) + + # Check if the output file already exists + if new_path.exists(): + # Skip the conversion if the output file already exists + print(f"Skipping {file} because {new_path} already exists") + continue + + # Save the image to the new file as lossless + img.save(new_path, lossless=True) + + # Optionally, delete the original file + if delete_originals: + os.remove(file) + + +if __name__ == "__main__": + main() diff --git a/tools/create_user_files.py b/tools/create_user_files.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3d00302dbe407cd1bc5b681cce34253872cd8a --- /dev/null +++ b/tools/create_user_files.py @@ -0,0 +1,37 @@ +import os + +bat_content = r'''@echo off +REM Example of how to start the GUI with custom arguments. In this case how to auto launch the browser: +REM call gui.bat --inbrowser +REM +REM You can add many arguments on the same line +REM +call gui.bat --inbrowser +''' + +ps1_content = r'''# Example of how to start the GUI with custom arguments. In this case how to auto launch the browser: +# .\gui.ps1 --inbrowser +# +# You can add many arguments on the same line +# +# & .\gui.ps1 --inbrowser --server_port 2345 + +& .\gui.ps1 --inbrowser +''' + +bat_filename = 'gui-user.bat' +ps1_filename = 'gui-user.ps1' + +if not os.path.exists(bat_filename): + with open(bat_filename, 'w') as bat_file: + bat_file.write(bat_content) + print(f"File created: {bat_filename}") +else: + print(f"File already exists: {bat_filename}") + +if not os.path.exists(ps1_filename): + with open(ps1_filename, 'w') as ps1_file: + ps1_file.write(ps1_content) + print(f"File created: {ps1_filename}") +else: + print(f"File already exists: {ps1_filename}") diff --git a/tools/crop_images_to_n_buckets.py b/tools/crop_images_to_n_buckets.py new file mode 100644 index 0000000000000000000000000000000000000000..688b42b5940460cffacf30c669a639d34ba6eea9 --- /dev/null +++ b/tools/crop_images_to_n_buckets.py @@ -0,0 +1,208 @@ +# This code sorts a collection of images in a given directory by their aspect ratio, groups +# them into batches of a given size, crops each image in a batch to the average aspect ratio +# of that batch, and saves the cropped images in a specified directory. The user provides +# the paths to the input directory and the output directory, as well as the desired batch +# size. The program drops any images that do not fit exactly into the batches. + +import os +import cv2 +import argparse +import shutil + +def aspect_ratio(img_path): + """Return aspect ratio of an image""" + image = cv2.imread(img_path) + height, width = image.shape[:2] + aspect_ratio = float(width) / float(height) + return aspect_ratio + +def sort_images_by_aspect_ratio(path): + """Sort all images in a folder by aspect ratio""" + images = [] + for filename in os.listdir(path): + if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png") or filename.endswith(".webp"): + img_path = os.path.join(path, filename) + images.append((img_path, aspect_ratio(img_path))) + # sort the list of tuples based on the aspect ratio + sorted_images = sorted(images, key=lambda x: x[1]) + return sorted_images + +def create_groups(sorted_images, n_groups): + """Create n groups from sorted list of images""" + n = len(sorted_images) + size = n // n_groups + groups = [sorted_images[i * size : (i + 1) * size] for i in range(n_groups - 1)] + groups.append(sorted_images[(n_groups - 1) * size:]) + return groups + +def average_aspect_ratio(group): + """Calculate average aspect ratio for a group""" + aspect_ratios = [aspect_ratio for _, aspect_ratio in group] + avg_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios) + print(f"Average aspect ratio for group: {avg_aspect_ratio}") + return avg_aspect_ratio + +def center_crop_image(image, target_aspect_ratio): + """Crop the input image to the target aspect ratio. + + The function calculates the crop region for the input image based on its current aspect ratio and the target aspect ratio. + + Args: + image: A numpy array representing the input image. + target_aspect_ratio: A float representing the target aspect ratio. + + Returns: + A numpy array representing the cropped image. + + """ + height, width = image.shape[:2] + current_aspect_ratio = float(width) / float(height) + + if current_aspect_ratio == target_aspect_ratio: + return image + + if current_aspect_ratio > target_aspect_ratio: + new_width = int(target_aspect_ratio * height) + x_start = (width - new_width) // 2 + cropped_image = image[:, x_start:x_start+new_width] + else: + new_height = int(width / target_aspect_ratio) + y_start = (height - new_height) // 2 + cropped_image = image[y_start:y_start+new_height, :] + + return cropped_image + +def copy_related_files(img_path, save_path): + """ + Copy all files in the same directory as the input image that have the same base name as the input image to the + output directory with the corresponding new filename. + :param img_path: Path to the input image. + :param save_path: Path to the output image. + """ + # Get the base filename and directory + img_dir, img_basename = os.path.split(img_path) + img_base, img_ext = os.path.splitext(img_basename) + + save_dir, save_basename = os.path.split(save_path) + save_base, save_ext = os.path.splitext(save_basename) + + # Create the output directory if it does not exist + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + # Loop over all files in the same directory as the input image + try: + for filename in os.listdir(img_dir): + # Skip files with the same name as the input image + if filename == img_basename: + continue + + # Check if the file has the same base name as the input image + file_base, file_ext = os.path.splitext(filename) + if file_base == img_base: + # Build the new filename and copy the file + new_filename = os.path.join(save_dir, f"{save_base}{file_ext}") + shutil.copy2(os.path.join(img_dir, filename), new_filename) + except OSError as e: + print(f"Error: {e}") # Handle errors from os.listdir() + +def save_resized_cropped_images(group, folder_name, group_number, avg_aspect_ratio, use_original_name=False): + """Crop and resize all images in the input group to the smallest resolution, and save them to a folder. + + Args: + group: A list of tuples, where each tuple contains the path to an image and its aspect ratio. + folder_name: A string representing the name of the folder to save the images to. + group_number: An integer representing the group number. + avg_aspect_ratio: A float representing the average aspect ratio of the images in the group. + use_original_name: A boolean indicating whether to save the images with their original file names. + + """ + if not os.path.exists(folder_name): + os.makedirs(folder_name) + + # get the smallest size of the images + smallest_res = float("inf") + for img_path, _ in group: + image = cv2.imread(img_path) + cropped_image = center_crop_image(image, avg_aspect_ratio) + height, width = cropped_image.shape[:2] + image_res = height * width + if image_res < smallest_res: + smallest_res = image_res + small_height, small_width = height, width + + # resize all images to the smallest resolution of the images in the group + for i, (img_path, aspect_ratio) in enumerate(group): + image = cv2.imread(img_path) + cropped_image = center_crop_image(image, avg_aspect_ratio) + resized_image = cv2.resize(cropped_image, (small_width, small_height)) + if use_original_name: + save_name = os.path.basename(img_path) + else: + save_name = f"group_{group_number}_{i}.jpg" + save_path = os.path.join(folder_name, save_name) + cv2.imwrite(save_path, resized_image) + + # Copy matching files named the same as img_path to + copy_related_files(img_path, save_path) + + print(f"Saved {save_name} to {folder_name}") + + +def main(): + parser = argparse.ArgumentParser(description='Sort images and crop them based on aspect ratio') + parser.add_argument('input_dir', type=str, help='Path to the directory containing images') + parser.add_argument('output_dir', type=str, help='Path to the directory to save the cropped images') + parser.add_argument('batch_size', type=int, help='Size of the batches to create') + parser.add_argument('--use_original_name', action='store_true', help='Whether to use original file names for the saved images') + + args = parser.parse_args() + + print(f"Sorting images by aspect ratio in {args.input_dir}...") + if not os.path.exists(args.input_dir): + print(f"Error: Input directory does not exist: {args.input_dir}") + return + + if not os.path.exists(args.output_dir): + try: + os.makedirs(args.output_dir) + except OSError: + print(f"Error: Failed to create output directory: {args.output_dir}") + return + + sorted_images = sort_images_by_aspect_ratio(args.input_dir) + total_images = len(sorted_images) + print(f'Total images: {total_images}') + + if args.batch_size <= 0: + print("Error: Batch size must be greater than 0") + return + + group_size = total_images // args.batch_size + + print(f'Train batch size: {args.batch_size}, image group size: {group_size}') + remainder = total_images % args.batch_size + + if remainder != 0: + print(f'Dropping {remainder} images that do not fit in groups...') + sorted_images = sorted_images[:-remainder] + total_images = len(sorted_images) + group_size = total_images // args.batch_size + + print('Creating groups...') + groups = create_groups(sorted_images, group_size) + print(f"Created {len(groups)} groups") + + print('Saving cropped and resize images...') + for i, group in enumerate(groups): + avg_aspect_ratio = average_aspect_ratio(group) + print(f"Processing group {i+1} with {len(group)} images...") + try: + save_resized_cropped_images(group, args.output_dir, i+1, avg_aspect_ratio, args.use_original_name) + except Exception as e: + print(f"Error: Failed to save images in group {i+1}: {e}") + + print('Done') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tools/cudann_1.8_install.py b/tools/cudann_1.8_install.py new file mode 100644 index 0000000000000000000000000000000000000000..dec38a17e572fcd14cdf6f1fa1e6ef6e4be03df5 --- /dev/null +++ b/tools/cudann_1.8_install.py @@ -0,0 +1,106 @@ +import filecmp +import importlib.util +import os +import shutil +import sys +import sysconfig +import subprocess +from pathlib import Path +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../requirements.txt") + +def run(command, desc=None, errdesc=None, custom_env=None): + if desc is not None: + print(desc) + + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) + + if result.returncode != 0: + + message = f"""{errdesc or 'Error running command'}. +Command: {command} +Error code: {result.returncode} +stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''} +stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''} +""" + raise RuntimeError(message) + + return result.stdout.decode(encoding="utf8", errors="ignore") + +def check_versions(): + global req_file + reqs = open(req_file, 'r') + lines = reqs.readlines() + reqs_dict = {} + for line in lines: + splits = line.split("==") + if len(splits) == 2: + key = splits[0] + if "torch" not in key: + if "diffusers" in key: + key = "diffusers" + reqs_dict[key] = splits[1].replace("\n", "").strip() + if os.name == "nt": + reqs_dict["torch"] = "1.12.1+cu116" + reqs_dict["torchvision"] = "0.13.1+cu116" + + checks = ["xformers","bitsandbytes", "diffusers", "transformers", "torch", "torchvision"] + for check in checks: + check_ver = "N/A" + status = "[ ]" + try: + check_available = importlib.util.find_spec(check) is not None + if check_available: + check_ver = importlib_metadata.version(check) + if check in reqs_dict: + req_version = reqs_dict[check] + if str(check_ver) == str(req_version): + status = "[+]" + else: + status = "[!]" + except importlib_metadata.PackageNotFoundError: + check_available = False + if not check_available: + status = "[!]" + print(f"{status} {check} NOT installed.") + if check == 'xformers': + x_cmd = "-U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl" + print(f"Installing xformers with: pip install {x_cmd}") + run(f"pip install {x_cmd}", desc="Installing xformers") + + else: + print(f"{status} {check} version {check_ver} installed.") + +base_dir = os.path.dirname(os.path.realpath(__file__)) +#repo = git.Repo(base_dir) +#revision = repo.rev_parse("HEAD") +#print(f"Dreambooth revision is {revision}") +check_versions() +# Check for "different" B&B Files and copy only if necessary +if os.name == "nt": + python = sys.executable + bnb_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..\bitsandbytes_windows") + bnb_dest = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes") + cudnn_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..\cudnn_windows") + cudnn_dest = os.path.join(sysconfig.get_paths()["purelib"], "torch", "lib") + + print(f"Checking for CUDNN files in {cudnn_dest}") + if os.path.exists(cudnn_src): + if os.path.exists(cudnn_dest): + # check for different files + filecmp.clear_cache() + for file in os.listdir(cudnn_src): + src_file = os.path.join(cudnn_src, file) + dest_file = os.path.join(cudnn_dest, file) + #if dest file exists, check if it's different + if os.path.exists(dest_file): + shutil.copy2(src_file, cudnn_dest) + print("Copied CUDNN 8.6 files to destination") + else: + print(f"Installation Failed: \"{cudnn_src}\" could not be found. ") + + \ No newline at end of file diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py new file mode 100644 index 0000000000000000000000000000000000000000..68dec6cae932e827e79c49992238be7fd2edf21c --- /dev/null +++ b/tools/detect_face_rotate.py @@ -0,0 +1,246 @@ +# このスクリプトのラむセンスは、train_dreambooth.pyず同じくApache License 2.0ずしたす +# (c) 2022 Kohya S. @kohya_ss + +# 暪長の画像から顔怜出しお正立するように回転し、そこを䞭心に正方圢に切り出す + +# v2: extract max face if multiple faces are found +# v3: add crop_ratio option +# v4: add multiple faces extraction and min/max size + +import argparse +import math +import cv2 +import glob +import os +from anime_face_detector import create_detector +from tqdm import tqdm +import numpy as np + +KP_REYE = 11 +KP_LEYE = 19 + +SCORE_THRES = 0.90 + + +def detect_faces(detector, image, min_size): + preds = detector(image) # bgr + # print(len(preds)) + + faces = [] + for pred in preds: + bb = pred['bbox'] + score = bb[-1] + if score < SCORE_THRES: + continue + + left, top, right, bottom = bb[:4] + cx = int((left + right) / 2) + cy = int((top + bottom) / 2) + fw = int(right - left) + fh = int(bottom - top) + + lex, ley = pred['keypoints'][KP_LEYE, 0:2] + rex, rey = pred['keypoints'][KP_REYE, 0:2] + angle = math.atan2(ley - rey, lex - rex) + angle = angle / math.pi * 180 + + faces.append((cx, cy, fw, fh, angle)) + + faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 倧きい順 + return faces + + +def rotate_image(image, angle, cx, cy): + h, w = image.shape[0:2] + rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + + # # 回転する分、すこし画像サむズを倧きくする→ずりあえず無効化 + # nh = max(h, int(w * math.sin(angle))) + # nw = max(w, int(h * math.sin(angle))) + # if nh > h or nw > w: + # pad_y = nh - h + # pad_t = pad_y // 2 + # pad_x = nw - w + # pad_l = pad_x // 2 + # m = np.array([[0, 0, pad_l], + # [0, 0, pad_t]]) + # rot_mat = rot_mat + m + # h, w = nh, nw + # cx += pad_l + # cy += pad_t + + result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + return result, cx, cy + + +def process(args): + assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitずresize_face_sizeはどちらか片方しか指定できたせん" + assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できたせん" + + # アニメ顔怜出モデルを読み蟌む + print("loading face detector.") + detector = create_detector('yolov3') + + # cropの匕数を解析する + if args.crop_size is None: + crop_width = crop_height = None + else: + tokens = args.crop_size.split(',') + assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定しおください" + crop_width, crop_height = [int(t) for t in tokens] + + if args.crop_ratio is None: + crop_h_ratio = crop_v_ratio = None + else: + tokens = args.crop_ratio.split(',') + assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定しおください" + crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] + + # 画像を凊理する + print("processing.") + output_extension = ".png" + + os.makedirs(args.dst_dir, exist_ok=True) + paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ + glob.glob(os.path.join(args.src_dir, "*.webp")) + for path in tqdm(paths): + basename = os.path.splitext(os.path.basename(path))[0] + + # image = cv2.imread(path) # 日本語ファむル名で゚ラヌになる + image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + if image.shape[2] == 4: + print(f"image has alpha. ignore / 画像の透明床が蚭定されおいるため無芖したす: {path}") + image = image[:, :, :3].copy() # copyをしないず内郚的に透明床情報が付いたたたになるらしい + + h, w = image.shape[:2] + + faces = detect_faces(detector, image, args.multiple_faces) + for i, face in enumerate(faces): + cx, cy, fw, fh, angle = face + face_size = max(fw, fh) + if args.min_size is not None and face_size < args.min_size: + continue + if args.max_size is not None and face_size >= args.max_size: + continue + face_suffix = f"_{i+1:02d}" if args.multiple_faces else "" + + # オプション指定があれば回転する + face_img = image + if args.rotate: + face_img, cx, cy = rotate_image(face_img, angle, cx, cy) + + # オプション指定があれば顔を䞭心に切り出す + if crop_width is not None or crop_h_ratio is not None: + cur_crop_width, cur_crop_height = crop_width, crop_height + if crop_h_ratio is not None: + cur_crop_width = int(face_size * crop_h_ratio + .5) + cur_crop_height = int(face_size * crop_v_ratio + .5) + + # リサむズを必芁なら行う + scale = 1.0 + if args.resize_face_size is not None: + # 顔サむズを基準にリサむズする + scale = args.resize_face_size / face_size + if scale < cur_crop_width / w: + print( + f"image width too small in face size based resizing / 顔を基準にリサむズするず画像の幅がcrop sizeより小さい顔が盞察的に倧きすぎるので顔サむズが倉わりたす: {path}") + scale = cur_crop_width / w + if scale < cur_crop_height / h: + print( + f"image height too small in face size based resizing / 顔を基準にリサむズするず画像の高さがcrop sizeより小さい顔が盞察的に倧きすぎるので顔サむズが倉わりたす: {path}") + scale = cur_crop_height / h + elif crop_h_ratio is not None: + # 倍率指定の時にはリサむズしない + pass + else: + # 切り出しサむズ指定あり + if w < cur_crop_width: + print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化したす: {path}") + scale = cur_crop_width / w + if h < cur_crop_height: + print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化したす: {path}") + scale = cur_crop_height / h + if args.resize_fit: + scale = max(cur_crop_width / w, cur_crop_height / h) + + if scale != 1.0: + w = int(w * scale + .5) + h = int(h * scale + .5) + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + cx = int(cx * scale + .5) + cy = int(cy * scale + .5) + fw = int(fw * scale + .5) + fh = int(fh * scale + .5) + + cur_crop_width = min(cur_crop_width, face_img.shape[1]) + cur_crop_height = min(cur_crop_height, face_img.shape[0]) + + x = cx - cur_crop_width // 2 + cx = cur_crop_width // 2 + if x < 0: + cx = cx + x + x = 0 + elif x + cur_crop_width > w: + cx = cx + (x + cur_crop_width - w) + x = w - cur_crop_width + face_img = face_img[:, x:x+cur_crop_width] + + y = cy - cur_crop_height // 2 + cy = cur_crop_height // 2 + if y < 0: + cy = cy + y + y = 0 + elif y + cur_crop_height > h: + cy = cy + (y + cur_crop_height - h) + y = h - cur_crop_height + face_img = face_img[y:y + cur_crop_height] + + # # debug + # print(path, cx, cy, angle) + # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) + # cv2.imshow("image", crp) + # if cv2.waitKey() == 27: + # break + # cv2.destroyAllWindows() + + # debug + if args.debug: + cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) + + _, buf = cv2.imencode(output_extension, face_img) + with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: + buf.tofile(f) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み蟌むディレクトリ") + parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") + parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") + parser.add_argument("--resize_fit", action="store_true", + help="resize to fit smaller side after cropping / 切り出し埌の画像の短蟺がcrop_sizeにあうようにリサむズする") + parser.add_argument("--resize_face_size", type=int, default=None, + help="resize image before cropping by face size / 切り出し前に顔がこのサむズになるようにリサむズする") + parser.add_argument("--crop_size", type=str, default=None, + help="crop images with 'width,height' pixels, face centered / 顔を䞭心ずしお'幅,高さ'のサむズで切り出す") + parser.add_argument("--crop_ratio", type=str, default=None, + help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を䞭心ずしお顔サむズの'幅倍率,高さ倍率'のサむズで切り出す") + parser.add_argument("--min_size", type=int, default=None, + help="minimum face size to output (included) / 凊理察象ずする顔の最小サむズこの倀以䞊") + parser.add_argument("--max_size", type=int, default=None, + help="maximum face size to output (excluded) / 凊理察象ずする顔の最倧サむズこの倀未満") + parser.add_argument("--multiple_faces", action="store_true", + help="output each faces / 耇数の顔が芋぀かった堎合、それぞれを切り出す") + parser.add_argument("--debug", action="store_true", help="render rect for face / 凊理埌画像の顔䜍眮に矩圢を描画したす") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + + process(args) diff --git a/tools/extract_locon.py b/tools/extract_locon.py new file mode 100644 index 0000000000000000000000000000000000000000..ca72166518a74df0ecf3882e89ecf6b429ba5c0f --- /dev/null +++ b/tools/extract_locon.py @@ -0,0 +1,106 @@ +# +# From: https://raw.githubusercontent.com/KohakuBlueleaf/LoCon/main/extract_locon.py +# + +import argparse + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "base_model", help="The model which use it to train the dreambooth model", + default='', type=str + ) + parser.add_argument( + "db_model", help="the dreambooth model you want to extract the locon", + default='', type=str + ) + parser.add_argument( + "output_name", help="the output model", + default='./out.pt', type=str + ) + parser.add_argument( + "--is_v2", help="Your base/db model is sd v2 or not", + default=False, action="store_true" + ) + parser.add_argument( + "--device", help="Which device you want to use to extract the locon", + default='cpu', type=str + ) + parser.add_argument( + "--mode", + help=( + 'extraction mode, can be "fixed", "threshold", "ratio", "percentile". ' + 'If not "fixed", network_dim and conv_dim will be ignored' + ), + default='fixed', type=str + ) + parser.add_argument( + "--linear_dim", help="network dim for linear layer in fixed mode", + default=1, type=int + ) + parser.add_argument( + "--conv_dim", help="network dim for conv layer in fixed mode", + default=1, type=int + ) + parser.add_argument( + "--linear_threshold", help="singular value threshold for linear layer in threshold mode", + default=0., type=float + ) + parser.add_argument( + "--conv_threshold", help="singular value threshold for conv layer in threshold mode", + default=0., type=float + ) + parser.add_argument( + "--linear_ratio", help="singular ratio for linear layer in ratio mode", + default=0., type=float + ) + parser.add_argument( + "--conv_ratio", help="singular ratio for conv layer in ratio mode", + default=0., type=float + ) + parser.add_argument( + "--linear_percentile", help="singular value percentile for linear layer percentile mode", + default=1., type=float + ) + parser.add_argument( + "--conv_percentile", help="singular value percentile for conv layer percentile mode", + default=1., type=float + ) + return parser.parse_args() +ARGS = get_args() + +from locon.utils import extract_diff +from locon.kohya_model_utils import load_models_from_stable_diffusion_checkpoint + +import torch + + +def main(): + args = ARGS + base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) + db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model) + + linear_mode_param = { + 'fixed': args.linear_dim, + 'threshold': args.linear_threshold, + 'ratio': args.linear_ratio, + 'percentile': args.linear_percentile, + }[args.mode] + conv_mode_param = { + 'fixed': args.conv_dim, + 'threshold': args.conv_threshold, + 'ratio': args.conv_ratio, + 'percentile': args.conv_percentile, + }[args.mode] + + state_dict = extract_diff( + base, db, + args.mode, + linear_mode_param, conv_mode_param, + args.device + ) + torch.save(state_dict, args.output_name) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tools/lycoris_locon_extract.py b/tools/lycoris_locon_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..75b55490b16b684ff2abe76e989831188483b1f8 --- /dev/null +++ b/tools/lycoris_locon_extract.py @@ -0,0 +1,129 @@ +import os, sys +sys.path.insert(0, os.getcwd()) +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "base_model", help="The model which use it to train the dreambooth model", + default='', type=str + ) + parser.add_argument( + "db_model", help="the dreambooth model you want to extract the locon", + default='', type=str + ) + parser.add_argument( + "output_name", help="the output model", + default='./out.pt', type=str + ) + parser.add_argument( + "--is_v2", help="Your base/db model is sd v2 or not", + default=False, action="store_true" + ) + parser.add_argument( + "--device", help="Which device you want to use to extract the locon", + default='cpu', type=str + ) + parser.add_argument( + "--mode", + help=( + 'extraction mode, can be "fixed", "threshold", "ratio", "quantile". ' + 'If not "fixed", network_dim and conv_dim will be ignored' + ), + default='fixed', type=str + ) + parser.add_argument( + "--safetensors", help='use safetensors to save locon model', + default=False, action="store_true" + ) + parser.add_argument( + "--linear_dim", help="network dim for linear layer in fixed mode", + default=1, type=int + ) + parser.add_argument( + "--conv_dim", help="network dim for conv layer in fixed mode", + default=1, type=int + ) + parser.add_argument( + "--linear_threshold", help="singular value threshold for linear layer in threshold mode", + default=0., type=float + ) + parser.add_argument( + "--conv_threshold", help="singular value threshold for conv layer in threshold mode", + default=0., type=float + ) + parser.add_argument( + "--linear_ratio", help="singular ratio for linear layer in ratio mode", + default=0., type=float + ) + parser.add_argument( + "--conv_ratio", help="singular ratio for conv layer in ratio mode", + default=0., type=float + ) + parser.add_argument( + "--linear_quantile", help="singular value quantile for linear layer quantile mode", + default=1., type=float + ) + parser.add_argument( + "--conv_quantile", help="singular value quantile for conv layer quantile mode", + default=1., type=float + ) + parser.add_argument( + "--use_sparse_bias", help="enable sparse bias", + default=False, action="store_true" + ) + parser.add_argument( + "--sparsity", help="sparsity for sparse bias", + default=0.98, type=float + ) + parser.add_argument( + "--disable_cp", help="don't use cp decomposition", + default=False, action="store_true" + ) + return parser.parse_args() +ARGS = get_args() + + +from lycoris.utils import extract_diff +from lycoris.kohya_model_utils import load_models_from_stable_diffusion_checkpoint + +import torch +from safetensors.torch import save_file + + +def main(): + args = ARGS + base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) + db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model) + + linear_mode_param = { + 'fixed': args.linear_dim, + 'threshold': args.linear_threshold, + 'ratio': args.linear_ratio, + 'quantile': args.linear_quantile, + }[args.mode] + conv_mode_param = { + 'fixed': args.conv_dim, + 'threshold': args.conv_threshold, + 'ratio': args.conv_ratio, + 'quantile': args.conv_quantile, + }[args.mode] + + state_dict = extract_diff( + base, db, + args.mode, + linear_mode_param, conv_mode_param, + args.device, + args.use_sparse_bias, args.sparsity, + not args.disable_cp + ) + + if args.safetensors: + save_file(state_dict, args.output_name) + else: + torch.save(state_dict, args.output_name) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tools/merge_lycoris.py b/tools/merge_lycoris.py new file mode 100644 index 0000000000000000000000000000000000000000..570fa2b4b0987f5aee8bc5083867706ee37c3426 --- /dev/null +++ b/tools/merge_lycoris.py @@ -0,0 +1,80 @@ +import os +import sys +import argparse +import torch +from lycoris.utils import merge_loha, merge_locon +from lycoris.kohya_model_utils import ( + load_models_from_stable_diffusion_checkpoint, + save_stable_diffusion_checkpoint, + load_file +) +import gradio as gr + + +def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight): + base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model) + if lycoris_model.rsplit('.', 1)[-1] == 'safetensors': + lyco = load_file(lycoris_model) + else: + lyco = torch.load(lycoris_model) + + algo = None + for key in lyco: + if 'hada' in key: + algo = 'loha' + break + elif 'lora_up' in key: + algo = 'lora' + break + else: + raise NotImplementedError('Cannot find the algo for this lycoris model file.') + + dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat') + dtype = { + 'float': torch.float, + 'float16': torch.float16, + 'float32': torch.float32, + 'float64': torch.float64, + 'bfloat': torch.bfloat16, + 'bfloat16': torch.bfloat16, + }.get(dtype_str, None) + if dtype is None: + raise ValueError(f'Cannot Find the dtype "{dtype}"') + + if algo == 'loha': + merge_loha(base, lyco, weight, device) + elif algo == 'lora': + merge_locon(base, lyco, weight, device) + + save_stable_diffusion_checkpoint( + is_v2, output_name, + base[0], base[2], + None, 0, 0, dtype, + base[1] + ) + + return output_name + + +def main(): + iface = gr.Interface( + fn=merge_models, + inputs=[ + gr.inputs.Textbox(label="Base Model Path"), + gr.inputs.Textbox(label="Lycoris Model Path"), + gr.inputs.Textbox(label="Output Model Path", default='./out.pt'), + gr.inputs.Checkbox(label="Is base model SD V2?", default=False), + gr.inputs.Textbox(label="Device", default='cpu'), + gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'), + gr.inputs.Number(label="Weight", default=1.0) + ], + outputs=gr.outputs.Textbox(label="Merged Model Path"), + title="Model Merger", + description="Merge Lycoris and Stable Diffusion models", + ) + + iface.launch() + + +if __name__ == '__main__': + main() diff --git a/tools/original_control_net.py b/tools/original_control_net.py new file mode 100644 index 0000000000000000000000000000000000000000..4484ce9cd0484d855159a329f67714991f5f2b8f --- /dev/null +++ b/tools/original_control_net.py @@ -0,0 +1,320 @@ +from typing import List, NamedTuple, Any +import numpy as np +import cv2 +import torch +from safetensors.torch import load_file + +from diffusers import UNet2DConditionModel +from diffusers.models.unet_2d_condition import UNet2DConditionOutput + +import library.model_util as model_util + + +class ControlNetInfo(NamedTuple): + unet: Any + net: Any + prep: Any + weight: float + ratio: float + + +class ControlNet(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + # make control model + self.control_model = torch.nn.Module() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] + zero_convs = torch.nn.ModuleList() + for i, dim in enumerate(dims): + sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) + zero_convs.append(sub_list) + self.control_model.add_module("zero_convs", zero_convs) + + middle_block_out = torch.nn.Conv2d(1280, 1280, 1) + self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) + + dims = [16, 16, 32, 32, 96, 96, 256, 320] + strides = [1, 1, 2, 1, 2, 1, 2, 1] + prev_dim = 3 + input_hint_block = torch.nn.Sequential() + for i, (dim, stride) in enumerate(zip(dims, strides)): + input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) + if i < len(dims) - 1: + input_hint_block.append(torch.nn.SiLU()) + prev_dim = dim + self.control_model.add_module("input_hint_block", input_hint_block) + + +def load_control_net(v2, unet, model): + device = unet.device + + # control sdからキヌ倉換し぀぀U-Netに察応する郚分のみ取り出し、DiffusersのU-Netに読み蟌む + # state dictを読み蟌む + print(f"ControlNet: loading control SD model : {model}") + + if model_util.is_safetensors(model): + ctrl_sd_sd = load_file(model) + else: + ctrl_sd_sd = torch.load(model, map_location='cpu') + ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) + + # 重みをU-Netに読み蟌めるようにする。ControlNetはSD版のstate dictなので、それを読み蟌む + is_difference = "difference" in ctrl_sd_sd + print("ControlNet: loading difference") + + # ControlNetには存圚しないキヌがあるので、たず珟圚のU-NetでSD版の党keyを䜜っおおく + # たたTransfer Controlの元weightずなる + ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) + + # 元のU-Netに圱響しないようにコピヌする。たたprefixが付いおいないので付ける + for key in list(ctrl_unet_sd_sd.keys()): + ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() + + zero_conv_sd = {} + for key in list(ctrl_sd_sd.keys()): + if key.startswith("control_"): + unet_key = "model.diffusion_" + key[len("control_"):] + if unet_key not in ctrl_unet_sd_sd: # zero conv + zero_conv_sd[key] = ctrl_sd_sd[key] + continue + if is_difference: # Transfer Control + ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) + else: + ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) + + unet_config = model_util.create_unet_diffusers_config(v2) + ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict + + # ControlNetのU-Netを䜜成する + ctrl_unet = UNet2DConditionModel(**unet_config) + info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) + print("ControlNet: loading Control U-Net:", info) + + # U-Net以倖のControlNetを䜜成する + # TODO support middle only + ctrl_net = ControlNet() + info = ctrl_net.load_state_dict(zero_conv_sd) + print("ControlNet: loading ControlNet:", info) + + ctrl_unet.to(unet.device, dtype=unet.dtype) + ctrl_net.to(unet.device, dtype=unet.dtype) + return ctrl_unet, ctrl_net + + +def load_preprocess(prep_type: str): + if prep_type is None or prep_type.lower() == "none": + return None + + if prep_type.startswith("canny"): + args = prep_type.split("_") + th1 = int(args[1]) if len(args) >= 2 else 63 + th2 = int(args[2]) if len(args) >= 3 else 191 + + def canny(img): + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + return cv2.Canny(img, th1, th2) + return canny + + print("Unsupported prep type:", prep_type) + return None + + +def preprocess_ctrl_net_hint_image(image): + image = np.array(image).astype(np.float32) / 255.0 + image = image[:, :, ::-1].copy() # rgb to bgr + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 + + +def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): + guided_hints = [] + for i, cnet_info in enumerate(control_nets): + # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... ず䞊んでいるこず + b_hints = [] + if len(hints) == 1: # すべお同じ画像をhintずしお䜿う + hint = hints[0] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints = [hint for _ in range(b_size)] + else: + for bi in range(b_size): + hint = hints[(bi * len(control_nets) + i) % len(hints)] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints.append(hint) + b_hints = torch.cat(b_hints, dim=0) + b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) + + guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) + guided_hints.append(guided_hint) + return guided_hints + + +def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): + # ControlNet + # 耇数のControlNetの堎合は、出力をマヌゞするのではなく亀互に適甚する + cnet_cnt = len(control_nets) + cnet_idx = step % cnet_cnt + cnet_info = control_nets[cnet_idx] + + # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + if cnet_info.ratio < current_ratio: + return original_unet(sample, timestep, encoder_hidden_states) + + guided_hint = guided_hints[cnet_idx] + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) + outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) + outs = [o * cnet_info.weight for o in outs] + + # U-Net + return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) + + +""" + # これはmergeのバヌゞョン + # ControlNet + cnet_outs_list = [] + for i, cnet_info in enumerate(control_nets): + # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + if cnet_info.ratio < current_ratio: + continue + guided_hint = guided_hints[i] + outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) + for i in range(len(outs)): + outs[i] *= cnet_info.weight + + cnet_outs_list.append(outs) + + count = len(cnet_outs_list) + if count == 0: + return original_unet(sample, timestep, encoder_hidden_states) + + # sum of controlnets + for i in range(1, count): + cnet_outs_list[0] += cnet_outs_list[i] + + # U-Net + return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states) +""" + + +def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): + # copy from UNet2DConditionModel + default_overall_up_factor = 2**unet.num_upsamplers + + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + print("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if unet.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = unet.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=unet.dtype) + emb = unet.time_embedding(t_emb) + + outs = [] # output of ControlNet + zc_idx = 0 + + # 2. pre-process + sample = unet.conv_in(sample) + if is_control_net: + sample += guided_hint + outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in unet.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_control_net: + for rs in res_samples: + outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + down_block_res_samples += res_samples + + # 4. mid + sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + if is_control_net: + outs.append(control_net.control_model.middle_block_out[0](sample)) + return outs + + if not is_control_net: + sample += ctrl_outs.pop() + + # 5. up + for i, upsample_block in enumerate(unet.up_blocks): + is_final_block = i == len(unet.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if not is_control_net and len(ctrl_outs) > 0: + res_samples = list(res_samples) + apply_ctrl_outs = ctrl_outs[-len(res_samples):] + ctrl_outs = ctrl_outs[:-len(res_samples)] + for j in range(len(res_samples)): + res_samples[j] = res_samples[j] + apply_ctrl_outs[j] + res_samples = tuple(res_samples) + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = unet.conv_norm_out(sample) + sample = unet.conv_act(sample) + sample = unet.conv_out(sample) + + return UNet2DConditionOutput(sample=sample) diff --git a/tools/prune.py b/tools/prune.py new file mode 100644 index 0000000000000000000000000000000000000000..6493bb3d5138c770ba06439bb35257638f4281fa --- /dev/null +++ b/tools/prune.py @@ -0,0 +1,37 @@ +import argparse +import torch +from tqdm import tqdm + +parser = argparse.ArgumentParser(description="Prune a model") +parser.add_argument("model_prune", type=str, help="Path to model to prune") +parser.add_argument("prune_output", type=str, help="Path to pruned ckpt output") +parser.add_argument("--half", action="store_true", help="Save weights in half precision.") +args = parser.parse_args() + +print("Loading model...") +model_prune = torch.load(args.model_prune) +theta_prune = model_prune["state_dict"] +theta = {} + +print("Pruning model...") +for key in tqdm(theta_prune.keys(), desc="Pruning keys"): + if "model" in key: + theta.update({key: theta_prune[key]}) + +del theta_prune + +if args.half: + print("Halving model...") + state_dict = {k: v.half() for k, v in tqdm(theta.items(), desc="Halving weights")} +else: + state_dict = theta + +del theta + +print("Saving pruned model...") + +torch.save({"state_dict": state_dict}, args.prune_output) + +del state_dict + +print("Done pruning!") \ No newline at end of file diff --git a/tools/rename_depth_mask.py b/tools/rename_depth_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..97efdea411bf4524e7a124a0c945c94850c4c1d4 --- /dev/null +++ b/tools/rename_depth_mask.py @@ -0,0 +1,21 @@ +import os +import argparse + +# Define the command line arguments +parser = argparse.ArgumentParser(description='Rename files in a folder') +parser.add_argument('folder', metavar='folder', type=str, help='the folder containing the files to rename') + +# Parse the arguments +args = parser.parse_args() + +# Get the list of files in the folder +files = os.listdir(args.folder) + +# Loop through each file in the folder +for file in files: + # Check if the file has the expected format + if file.endswith('-0000.png'): + # Get the new file name + new_file_name = file[:-9] + '.mask' + # Rename the file + os.rename(os.path.join(args.folder, file), os.path.join(args.folder, new_file_name)) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3224c4e28aaad71113e3bab8140da78a69bc2b --- /dev/null +++ b/tools/resize_images_to_resolution.py @@ -0,0 +1,128 @@ +import glob +import os +import cv2 +import argparse +import shutil +import math +from PIL import Image +import numpy as np + + +def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): + # Split the max_resolution string by "," and strip any whitespaces + max_resolutions = [res.strip() for res in max_resolution.split(',')] + + # # Calculate max_pixels from max_resolution string + # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Create destination folder if it does not exist + if not os.path.exists(dst_img_folder): + os.makedirs(dst_img_folder) + + # Select interpolation method + if interpolation == 'lanczos4': + cv2_interpolation = cv2.INTER_LANCZOS4 + elif interpolation == 'cubic': + cv2_interpolation = cv2.INTER_CUBIC + else: + cv2_interpolation = cv2.INTER_AREA + + # Iterate through all files in src_img_folder + img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py + for filename in os.listdir(src_img_folder): + # Check if the image is png, jpg or webp etc... + if not filename.endswith(img_exts): + # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) + shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) + continue + + # Load image + # img = cv2.imread(os.path.join(src_img_folder, filename)) + image = Image.open(os.path.join(src_img_folder, filename)) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + + base, _ = os.path.splitext(filename) + for max_resolution in max_resolutions: + # Calculate max_pixels from max_resolution string + max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Calculate current number of pixels + current_pixels = img.shape[0] * img.shape[1] + + # Check if the image needs resizing + if current_pixels > max_pixels: + # Calculate scaling factor + scale_factor = max_pixels / current_pixels + + # Calculate new dimensions + new_height = int(img.shape[0] * math.sqrt(scale_factor)) + new_width = int(img.shape[1] * math.sqrt(scale_factor)) + + # Resize image + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + new_height, new_width = img.shape[0:2] + + # Calculate the new height and width that are divisible by divisible_by (with/without resizing) + new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by + new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by + + # Center crop the image to the calculated dimensions + y = int((img.shape[0] - new_height) / 2) + x = int((img.shape[1] - new_width) / 2) + img = img[y:y + new_height, x:x + new_width] + + # Split filename into base and extension + new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') + + # Save resized image in dst_img_folder + # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) + image = Image.fromarray(img) + image.save(os.path.join(dst_img_folder, new_filename), quality=100) + + proc = "Resized" if current_pixels > max_pixels else "Saved" + print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + + # If other files with same basename, copy them with resolution suffix + if copy_associated_files: + asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) + for asoc_file in asoc_files: + ext = os.path.splitext(asoc_file)[1] + if ext in img_exts: + continue + for max_resolution in max_resolutions: + new_asoc_file = base + '+' + max_resolution + ext + print(f"Copy {asoc_file} as {new_asoc_file}") + shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最倧画像サむズ面積以䞋にアスペクト比を維持したたたリサむズしたす') + parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') + parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサむズ埌の画像を保存するフォルダ') + parser.add_argument('--max_resolution', type=str, + help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最倧画像サむズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") + parser.add_argument('--divisible_by', type=int, + help='Ensure new dimensions are divisible by this value / リサむズ埌の画像のサむズをこの倀で割り切れるようにしたす', default=1) + parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], + default='area', help='Interpolation method for resizing / リサむズ時の補完方法') + parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png圢匏で保存') + parser.add_argument('--copy_associated_files', action='store_true', + help='Copy files with same base name to images (captions etc) / 画像ず同じファむル名拡匵子を陀くのファむルもコピヌする') + + return parser + + +def main(): + parser = setup_parser() + + args = parser.parse_args() + resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, + args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) + + +if __name__ == '__main__': + main() diff --git a/tools/resize_lora.py b/tools/resize_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..b99bb5bd1f485f233dea9d14889aadb3fdaa5e26 --- /dev/null +++ b/tools/resize_lora.py @@ -0,0 +1,339 @@ +# +# File from: https://raw.githubusercontent.com/mgz-dev/sd-scripts/main/networks/resize_lora.py +# + +# Convert LoRA to different rank approximation (should only be used to go to lower rank) +# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo and kohya + +import argparse +import torch +from safetensors.torch import load_file, save_file, safe_open +from tqdm import tqdm +from library import train_util, model_util +import numpy as np + +MIN_SV = 1e-6 + +def load_state_dict(file_name, dtype): + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location='cpu') + metadata = None + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, model, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if model_util.is_safetensors(file_name): + save_file(model, file_name, metadata) + else: + torch.save(model, file_name) + + +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0)/original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method=="sv_ratio": + # Calculate new dim and alpha based off ratio + max_sv = S[0] + min_sv = max_sv/dynamic_param + new_rank = max(torch.sum(S > min_sv).item(),1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + new_rank = min(max(new_rank, 1), len(S)-1) + new_alpha = float(scale*new_rank) + else: + new_rank = rank + new_alpha = float(scale*new_rank) + + + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + new_rank = 1 + new_alpha = float(scale*new_rank) + elif new_rank > rank: # cap max rank at rank + new_rank = rank + new_alpha = float(scale*new_rank) + + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro/s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0]/S[new_rank] + + return param_dict + + +def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha/network_dim + + if dynamic_method: + print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + if 'lora_down' in key: + block_down_name = key.split(".")[0] + lora_down_weight = value + if 'lora_up' in key: + block_up_name = key.split(".")[0] + lora_up_weight = value + + weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) + + if (block_down_name == block_up_name) and weights_loaded: + + conv2d = (len(lora_down_weight.size()) == 4) + + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + + if verbose: + max_ratio = param_dict['max_ratio'] + sum_retained = param_dict['sum_retained'] + fro_retained = param_dict['fro_retained'] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) + + verbose_str+=f"{block_down_name:75} | " + verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + + if verbose and dynamic_method: + verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str+=f"\n" + + new_alpha = param_dict['new_alpha'] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict + + if verbose: + print(verbose_str) + + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + print("resizing complete") + return o_lora_sd, network_dim, new_alpha + + +def resize(args): + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + print("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) + + print("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + + if not args.dynamic_method: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + metadata["ss_network_dim"] = 'Dynamic' + metadata["ss_network_alpha"] = 'Dynamic' + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の粟床、未指定時はfloat") + parser.add_argument("--new_rank", type=int, default=4, + help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファむル名、ckptたたはsafetensors") + parser.add_argument("--model", type=str, default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み蟌むLoRAモデル、ckptたたはsafetensors") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 蚈算を行うデバむス、cuda でGPUを䜿う") + parser.add_argument("--verbose", action="store_true", + help="Display verbose resizing information / rank倉曎時の詳现情報を出力する") + parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") + parser.add_argument("--dynamic_param", type=float, default=None, + help="Specify target for dynamic reduction") + + + args = parser.parse_args() + resize(args) \ No newline at end of file diff --git a/tools/validate_requirements.py b/tools/validate_requirements.py new file mode 100644 index 0000000000000000000000000000000000000000..86f09d55a74de387c1ca64501b40e34a62a88929 --- /dev/null +++ b/tools/validate_requirements.py @@ -0,0 +1,61 @@ +import os +import sys +import pkg_resources +import argparse + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Validate that requirements are satisfied.") +parser.add_argument('-r', '--requirements', type=str, default='requirements.txt', help="Path to the requirements file.") +args = parser.parse_args() + +print("Validating that requirements are satisfied.") + +# Load the requirements from the specified requirements file +with open(args.requirements) as f: + requirements = f.readlines() + +# Check each requirement against the installed packages +missing_requirements = [] +wrong_version_requirements = [] +for requirement in requirements: + requirement = requirement.strip() + if requirement == ".": + # Skip the current requirement if it is a dot (.) + continue + try: + pkg_resources.require(requirement) + except pkg_resources.DistributionNotFound: + # Check if the requirement contains a VCS URL + if "@" in requirement: + # If it does, split the requirement into two parts: the package name and the VCS URL + package_name, vcs_url = requirement.split("@", 1) + # Use pip to install the package from the VCS URL + os.system(f"pip install -e {vcs_url}") + # Try to require the package again + try: + pkg_resources.require(package_name) + except pkg_resources.DistributionNotFound: + missing_requirements.append(requirement) + else: + missing_requirements.append(requirement) + except pkg_resources.VersionConflict as e: + wrong_version_requirements.append((requirement, str(e.req), e.dist.version)) + +# If there are any missing or wrong version requirements, print an error message and exit with a non-zero exit code +if missing_requirements or wrong_version_requirements: + if missing_requirements: + print("Error: The following packages are missing:") + for requirement in missing_requirements: + print(f" - {requirement}") + if wrong_version_requirements: + print("Error: The following packages have the wrong version:") + for requirement, expected_version, actual_version in wrong_version_requirements: + print(f" - {requirement} (expected version {expected_version}, found version {actual_version})") + upgrade_script = "upgrade.ps1" if os.name == "nt" else "upgrade.sh" + print(f"\nRun \033[33m{upgrade_script}\033[0m or \033[33mpip install -U -r {args.requirements}\033[0m to resolve the missing requirements listed above...") + + sys.exit(1) + +# All requirements satisfied +print("All requirements satisfied.") +sys.exit(0) diff --git a/train_README-ja.md b/train_README-ja.md new file mode 100644 index 0000000000000000000000000000000000000000..d5f1b5fc86112638befd5fdf932dd98f87f9101d --- /dev/null +++ b/train_README-ja.md @@ -0,0 +1,936 @@ +__ドキュメント曎新䞭のため蚘述に誀りがあるかもしれたせん。__ + +# 孊習に぀いお、共通線 + +圓リポゞトリではモデルのfine tuning、DreamBooth、およびLoRAずTextual Inversionの孊習をサポヌトしたす。この文曞ではそれらに共通する、孊習デヌタの準備方法やオプション等に぀いお説明したす。 + +# 抂芁 + +あらかじめこのリポゞトリのREADMEを参照し、環境敎備を行っおください。 + + +以䞋に぀いお説明したす。 + +1. 孊習デヌタの準備に぀いお蚭定ファむルを甚いる新圢匏 +1. 孊習で䜿われる甚語のごく簡単な解説 +1. 以前の指定圢匏蚭定ファむルを甚いずコマンドラむンから指定 +1. 孊習途䞭のサンプル画像生成 +1. 各スクリプトで共通の、よく䜿われるオプション +1. fine tuning 方匏のメタデヌタ準備キャプションニングなど + +1.だけ実行すればずりあえず孊習は可胜です孊習に぀いおは各スクリプトのドキュメントを参照。2.以降は必芁に応じお参照しおください。 + + +# 孊習デヌタの準備に぀いお + +任意のフォルダ耇数でも可に孊習デヌタの画像ファむルを甚意しおおきたす。`.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` をサポヌトしたす。リサむズなどの前凊理は基本的に必芁ありたせん。 + +ただし孊習解像床埌述よりも極端に小さい画像は䜿わないか、あらかじめ超解像AIなどで拡倧しおおくこずをお勧めしたす。たた極端に倧きな画像3000x3000ピクセル皋床よりも倧きな画像ぱラヌになる堎合があるようですので事前に瞮小しおください。 + +孊習時には、モデルに孊ばせる画像デヌタを敎理し、スクリプトに察しお指定する必芁がありたす。孊習デヌタの数、孊習察象、キャプション画像の説明が甚意できるか吊かなどにより、いく぀かの方法で孊習デヌタを指定できたす。以䞋の方匏がありたすそれぞれの名前は䞀般的なものではなく、圓リポゞトリ独自の定矩です。正則化画像に぀いおは埌述したす。 + +1. DreamBooth、class+identifier方匏正則化画像䜿甚可 + + 特定の単語 (identifier) に孊習察象を玐づけるように孊習したす。キャプションを甚意する必芁はありたせん。たずえば特定のキャラを孊ばせる堎合に䜿うずキャプションを甚意する必芁がない分、手軜ですが、髪型や服装、背景など孊習デヌタの党芁玠が identifier に玐づけられお孊習されるため、生成時のプロンプトで服が倉えられない、ずいった事態も起こりえたす。 + +1. DreamBooth、キャプション方匏正則化画像䜿甚可 + + 画像ごずにキャプションが蚘録されたテキストファむルを甚意しお孊習したす。たずえば特定のキャラを孊ばせるず、画像の詳现をキャプションに蚘述するこずで癜い服を着たキャラA、赀い服を着たキャラA、などキャラずそれ以倖の芁玠が分離され、より厳密にモデルがキャラだけを孊ぶこずが期埅できたす。 + +1. fine tuning方匏正則化画像䜿甚䞍可 + + あらかじめキャプションをメタデヌタファむルにたずめたす。タグずキャプションを分けお管理したり、孊習を高速化するためlatentsを事前キャッシュしたりなどの機胜をサポヌトしたすいずれも別文曞で説明しおいたす。fine tuning方匏ずいう名前ですが fine tuning 以倖でも䜿えたす。 + +孊習したいものず䜿甚できる指定方法の組み合わせは以䞋の通りです。 + +| 孊習察象たたは方法 | スクリプト | DB / class+identifier | DB / キャプション | fine tuning | +| ----- | ----- | ----- | ----- | ----- | +| モデルをfine tuning | `fine_tune.py`| x | x | o | +| モデルをDreamBooth | `train_db.py`| o | o | x | +| LoRA | `train_network.py`| o | o | o | +| Textual Invesion | `train_textual_inversion.py`| o | o | o | + +## どれを遞ぶか + +LoRA、Textual Inversionに぀いおは、手軜にキャプションファむルを甚意せずに孊習したい堎合はDreamBooth class+identifier、甚意できるならDreamBooth キャプション方匏がよいでしょう。孊習デヌタの枚数が倚く、か぀正則化画像を䜿甚しない堎合はfine tuning方匏も怜蚎しおください。 + +DreamBoothに぀いおも同様ですが、fine tuning方匏は䜿えたせん。fine tuningの堎合はfine tuning方匏のみです。 + +# 各方匏の指定方法に぀いお + +ここではそれぞれの指定方法で兞型的なパタヌンに぀いおだけ説明したす。より詳现な指定方法に぀いおは [デヌタセット蚭定](./config_README-ja.md) をご芧ください。 + +# DreamBooth、class+identifier方匏正則化画像䜿甚可 + +この方匏では、各画像は `class identifier` ずいうキャプションで孊習されたのず同じこずになりたす`shs dog` など。 + +## step 1. identifierずclassを決める + +孊ばせたい察象を結び぀ける単語identifierず、察象の属するclassを決めたす。 + +instanceなどいろいろな呌び方がありたすが、ずりあえず元の論文に合わせたす。 + +以䞋ごく簡単に説明したす詳しくは調べおください。 + +classは孊習察象の䞀般的な皮別です。たずえば特定の犬皮を孊ばせる堎合には、classはdogになりたす。アニメキャラならモデルによりboyやgirl、1boyや1girlになるでしょう。 + +identifierは孊習察象を識別しお孊習するためのものです。任意の単語で構いたせんが、元論文によるず「tokinizerで1トヌクンになる3文字以䞋でレアな単語」が良いずのこずです。 + +identifierずclassを䜿い、たずえば「shs dog」などでモデルを孊習するこずで、孊習させたい察象をclassから識別しお孊習できたす。 + +画像生成時には「shs dog」ずすれば孊ばせた犬皮の画像が生成されたす。 + +identifierずしお私が最近䜿っおいるものを参考たでに挙げるず、``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny`` などです。本圓は Danbooru Tag に含たれないや぀がより望たしいです。 + +## step 2. 正則化画像を䜿うか吊かを決め、䜿う堎合には正則化画像を生成する + +正則化画像ずは、前述のclass党䜓が、孊習察象に匕っ匵られるこずを防ぐための画像ですlanguage drift。正則化画像を䜿わないず、たずえば `shs 1girl` で特定のキャラクタを孊ばせるず、単なる `1girl` ずいうプロンプトで生成しおもそのキャラに䌌おきたす。これは `1girl` が孊習時のキャプションに含たれおいるためです。 + +孊習察象の画像ず正則化画像を同時に孊ばせるこずで、class は class のたたで留たり、identifier をプロンプトに぀けた時だけ孊習察象が生成されるようになりたす。 + +LoRAやDreamBoothで特定のキャラだけ出おくればよい堎合は、正則化画像を甚いなくおも良いずいえたす。 + +Textual Inversionでは甚いなくおよいでしょう孊ばせる token string がキャプションに含たれない堎合はなにも孊習されないため。 + +正則化画像ずしおは、孊習察象のモデルで、class 名だけで生成した画像を甚いるのが䞀般的ですたずえば `1girl`。ただし生成画像の品質が悪い堎合には、プロンプトを工倫したり、ネットから別途ダりンロヌドした画像を甚いるこずもできたす。 + +正則化画像も孊習されるため、その品質はモデルに圱響したす。 + +䞀般的には数癟枚皋床、甚意するのが望たしいようです枚数が少ないず class 画像が䞀般化されずそれらの特城を孊んでしたいたす。 + +生成画像を䜿う堎合、通垞、生成画像のサむズは孊習解像床より正確にはbucketの解像床、埌述にあわせおください。 + +## step 2. 蚭定ファむルの蚘述 + +テキストファむルを䜜成し、拡匵子を `.toml` にしたす。たずえば以䞋のように蚘述したす。 + +`#` で始たっおいる郚分はコメントですので、このたたコピペしおそのたたでもよいですし、削陀しおも問題ありたせん。 + +```toml +[general] +enable_bucket = true # Aspect Ratio Bucketingを䜿うか吊か + +[[datasets]] +resolution = 512 # 孊習解像床 +batch_size = 4 # バッチサむズ + + [[datasets.subsets]] + image_dir = 'C:\hoge' # 孊習甚画像を入れたフォルダを指定 + class_tokens = 'hoge girl' # identifier class を指定 + num_repeats = 10 # 孊習甚画像の繰り返し回数 + + # 以䞋は正則化画像を甚いる堎合のみ蚘述する。甚いない堎合は削陀する + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定 + class_tokens = 'girl' # class を指定 + num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい +``` + +基本的には以䞋の堎所のみ曞き換えれば孊習できたす。 + +1. 孊習解像床 + + 数倀1぀を指定するず正方圢`512`なら512x512、鍵カッコカンマ区切りで2぀指定するず暪×瞊`[512,768]`なら512x768になりたす。SD1.x系ではもずもずの孊習解像床は512です。`[512,768]` 等の倧きめの解像床を指定するず瞊長、暪長画像生成時の砎綻を小さくできるかもしれたせん。SD2.x 768系では `768` です。 + +1. バッチサむズ + + 同時に䜕件のデヌタを孊習するかを指定したす。GPUのVRAMサむズ、孊習解像床によっお倉わっおきたす。詳しくは埌述したす。たたfine tuning/DreamBooth/LoRA等でも倉わっおきたすので各スクリプトの説明もご芧ください。 + +1. フォルダ指定 + + 孊習甚画像、正則化画像䜿甚する堎合のみのフォルダを指定したす。画像デヌタが含たれおいるフォルダそのものを指定したす。 + +1. identifier ず class の指定 + + 前述のサンプルの通りです。 + +1. 繰り返し回数 + + 埌述したす。 + +### 繰り返し回数に぀いお + +繰り返し回数は、正則化画像の枚数ず孊習甚画像の枚数を調敎するために甚いられたす。正則化画像の枚数は孊習甚画像よりも倚いため、孊習甚画像を繰り返しお枚数を合わせ、1察1の比率で孊習できるようにしたす。 + +繰り返し回数は「 __孊習甚画像の繰り返し回数×孊習甚画像の枚数≧正則化画像の繰り返し回数×正則化画像の枚数__ 」ずなるように指定しおください。 + +1 epochデヌタが䞀呚するず1 epochのデヌタ数が「孊習甚画像の繰り返し回数×孊習甚画像の枚数」ずなりたす。正則化画像の枚数がそれより倚いず、䜙った郚分の正則化画像は䜿甚されたせん。 + +## step 3. å­Šç¿’ + +それぞれのドキュメントを参考に孊習を行っおください。 + +# DreamBooth、キャプション方匏正則化画像䜿甚可 + +この方匏では各画像はキャプションで孊習されたす。 + +## step 1. キャプションファむルを準備する + +孊習甚画像のフォルダに、画像ず同じファむル名で、拡匵子 `.caption`蚭定で倉えられたすのファむルを眮いおください。それぞれのファむルは1行のみずしおください。゚ンコヌディングは `UTF-8` です。 + +## step 2. 正則化画像を䜿うか吊かを決め、䜿う堎合には正則化画像を生成する + +class+identifier圢匏ず同様です。なお正則化画像にもキャプションを付けるこずができたすが、通垞は䞍芁でしょう。 + +## step 2. 蚭定ファむルの蚘述 + +テキストファむルを䜜成し、拡匵子を `.toml` にしたす。たずえば以䞋のように蚘述したす。 + +```toml +[general] +enable_bucket = true # Aspect Ratio Bucketingを䜿うか吊か + +[[datasets]] +resolution = 512 # 孊習解像床 +batch_size = 4 # バッチサむズ + + [[datasets.subsets]] + image_dir = 'C:\hoge' # 孊習甚画像を入れたフォルダを指定 + caption_extension = '.caption' # キャプションファむルの拡匵子 .txt を䜿う堎合には曞き換える + num_repeats = 10 # 孊習甚画像の繰り返し回数 + + # 以䞋は正則化画像を甚いる堎合のみ蚘述する。甚いない堎合は削陀する + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定 + class_tokens = 'girl' # class を指定 + num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい +``` + +基本的には以䞋を堎所のみ曞き換えれば孊習できたす。特に蚘述がない郚分は class+identifier 方匏ず同じです。 + +1. 孊習解像床 +1. バッチサむズ +1. フォルダ指定 +1. キャプションファむルの拡匵子 + + 任意の拡匵子を指定できたす。 +1. 繰り返し回数 + +## step 3. å­Šç¿’ + +それぞれのドキュメントを参考に孊習を行っおください。 + +# fine tuning 方匏 + +## step 1. メタデヌタを準備する + +キャプションやタグをたずめた管理甚ファむルをメタデヌタず呌びたす。json圢匏で拡匵子は `.json` + です。䜜成方法は長くなりたすのでこの文曞の末尟に曞きたした。 + +## step 2. 蚭定ファむルの蚘述 + +テキストファむルを䜜成し、拡匵子を `.toml` にしたす。たずえば以䞋のように蚘述したす。 + +```toml +[general] +shuffle_caption = true +keep_tokens = 1 + +[[datasets]] +resolution = 512 # 孊習解像床 +batch_size = 4 # バッチサむズ + + [[datasets.subsets]] + image_dir = 'C:\piyo' # 孊習甚画像を入れたフォルダを指定 + metadata_file = 'C:\piyo\piyo_md.json' # メタデヌタファむル名 +``` + +基本的には以䞋を堎所のみ曞き換えれば孊習できたす。特に蚘述がない郚分は DreamBooth, class+identifier 方匏ず同じです。 + +1. 孊習解像床 +1. バッチサむズ +1. フォルダ指定 +1. メタデヌタファむル名 + + 埌述の方法で䜜成したメタデヌタファむルを指定したす。 + + +## step 3. å­Šç¿’ + +それぞれのドキュメントを参考に孊習を行っおください。 + +# 孊習で䜿われる甚語のごく簡単な解説 + +现かいこずは省略しおいたすし私も完党には理解しおいないため、詳しくは各自お調べください。 + +## fine tuningファむンチュヌニング + +モデルを孊習しお埮調敎するこずを指したす。䜿われ方によっお意味が異なっおきたすが、狭矩のfine tuningはStable Diffusionの堎合、モデルを画像ずキャプションで孊習するこずです。DreamBoothは狭矩のfine tuningのひず぀の特殊なやり方ず蚀えたす。広矩のfine tuningは、LoRAやTextual Inversion、Hypernetworksなどを含み、モデルを孊習するこずすべおを含みたす。 + +## ステップ + +ざっくりいうず孊習デヌタで1回蚈算するず1ステップです。「孊習デヌタのキャプションを今のモデルに流しおみお、出おくる画像を孊習デヌタの画像ず比范し、孊習デヌタに近づくようにモデルをわずかに倉曎する」のが1ステップです。 + +## バッチサむズ + +バッチサむズは1ステップで䜕件のデヌタをたずめお蚈算するかを指定する倀です。たずめお蚈算するため速床は盞察的に向䞊したす。たた䞀般的には粟床も高くなるずいわれおいたす。 + +`バッチサむズ×ステップ数` が孊習に䜿われるデヌタの件数になりたす。そのため、バッチサむズを増やした分だけステップ数を枛らすずよいでしょう。 + +ただし、たずえば「バッチサむズ1で1600ステップ」ず「バッチサむズ4で400ステップ」は同じ結果にはなりたせん。同じ孊習率の堎合、䞀般的には埌者のほうが孊習䞍足になりたす。孊習率を倚少倧きくするかたずえば `2e-6` など、ステップ数をたずえば500ステップにするなどしお工倫しおください。 + +バッチサむズを倧きくするずその分だけGPUメモリを消費したす。メモリが足りなくなるず゚ラヌになりたすし、゚ラヌにならないギリギリでは孊習速床が䜎䞋したす。タスクマネヌゞャヌや `nvidia-smi` コマンドで䜿甚メモリ量を確認しながら調敎するずよいでしょう。 + +なお、バッチは「䞀塊のデヌタ」䜍の意味です。 + +## 孊習率 + +ざっくりいうず1ステップごずにどのくらい倉化させるかを衚したす。倧きな倀を指定するずそれだけ速く孊習が進みたすが、倉化しすぎおモデルが壊れたり、最適な状態にたで至れない堎合がありたす。小さい倀を指定するず孊習速床は遅くなり、たた最適な状態にやはり至れない堎合がありたす。 + +fine tuning、DreamBoooth、LoRAそれぞれで倧きく異なり、たた孊習デヌタや孊習させたいモデル、バッチサむズやステップ数によっおも倉わっおきたす。䞀般的な倀から初めお孊習状態を芋ながら増枛しおください。 + +デフォルトでは孊習党䜓を通しお孊習率は固定です。スケゞュヌラの指定で孊習率をどう倉化させるか決められたすので、それらによっおも結果は倉わっおきたす。 + +## ゚ポックepoch + +孊習デヌタが䞀通り孊習されるずデヌタが䞀呚するず1 epochです。繰り返し回数を指定した堎合は、その繰り返し埌のデヌタが䞀呚するず1 epochです。 + +1 epochのステップ数は、基本的には `デヌタ件数÷バッチサむズ` ですが、Aspect Ratio Bucketing を䜿うず埮劙に増えたす異なるbucketのデヌタは同じバッチにできないため、ステップ数が増えたす。 + +## Aspect Ratio Bucketing + +Stable Diffusion のv1は512\*512で孊習されおいたすが、それに加えお256\*1024や384\*640ずいった解像床でも孊習したす。これによりトリミングされる郚分が枛り、より正しくキャプションず画像の関係が孊習されるこずが期埅されたす。 + +たた任意の解像床で孊習するため、事前に画像デヌタの瞊暪比を統䞀しおおく必芁がなくなりたす。 + +蚭定で有効、向こうが切り替えられたすが、ここたでの蚭定ファむルの蚘述䟋では有効になっおいたす`true` が蚭定されおいたす。 + +孊習解像床はパラメヌタずしお䞎えられた解像床の面積メモリ䜿甚量を超えない範囲で、64ピクセル単䜍デフォルト、倉曎可で瞊暪に調敎、䜜成されたす。 + +機械孊習では入力サむズをすべお統䞀するのが䞀般的ですが、特に制玄があるわけではなく、実際は同䞀のバッチ内で統䞀されおいれば倧䞈倫です。NovelAIの蚀うbucketingは、あらかじめ教垫デヌタを、アスペクト比に応じた孊習解像床ごずに分類しおおくこずを指しおいるようです。そしおバッチを各bucket内の画像で䜜成するこずで、バッチの画像サむズを統䞀したす。 + +# 以前の指定圢匏蚭定ファむルを甚いずコマンドラむンから指定 + +`.toml` ファむルを指定せずコマンドラむンオプションで指定する方法です。DreamBooth class+identifier方匏、DreamBooth キャプション方匏、fine tuning方匏がありたす。 + +## DreamBooth、class+identifier方匏 + +フォルダ名で繰り返し回数を指定したす。たた `train_data_dir` オプションず `reg_data_dir` オプションを甚いたす。 + +### step 1. 孊習甚画像の準備 + +孊習甚画像を栌玍するフォルダを䜜成したす。 __さらにその䞭に__ 、以䞋の名前でディレクトリを䜜成したす。 + +``` +<繰り返し回数>_ +``` + +間の``_``を忘れないでください。 + +たずえば「sls frog」ずいうプロンプトで、デヌタを20回繰り返す堎合、「20_sls frog」ずなりたす。以䞋のようになりたす。 + +![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png) + +### 耇数class、耇数察象identifierの孊習 + +方法は単玔で、孊習甚画像のフォルダ内に ``繰り返し回数_ `` のフォルダを耇数、正則化画像フォルダにも同様に ``繰り返し回数_`` のフォルダを耇数、甚意しおください。 + +たずえば「sls frog」ず「cpc rabbit」を同時に孊習する堎合、以䞋のようになりたす。 + +![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png) + +classがひず぀で察象が耇数の堎合、正則化画像フォルダはひず぀で構いたせん。たずえば1girlにキャラAずキャラBがいる堎合は次のようにしたす。 + +- train_girls + - 10_sls 1girl + - 10_cpc 1girl +- reg_girls + - 1_1girl + +### step 2. 正則化画像の準備 + +正則化画像を䜿う堎合の手順です。 + +正則化画像を栌玍するフォルダを䜜成したす。 __さらにその䞭に__ ``<繰り返し回数>_`` ずいう名前でディレクトリを䜜成したす。 + +たずえば「frog」ずいうプロンプトで、デヌタを繰り返さない1回だけ堎合、以䞋のようになりたす。 + +![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) + + +### step 3. 孊習の実行 + +各孊習スクリプトを実行したす。 `--train_data_dir` オプションで前述の孊習甚デヌタのフォルダを__画像を含むフォルダではなく、その芪フォルダ__、`--reg_data_dir` オプションで正則化画像のフォルダ__画像を含むフォルダではなく、その芪フォルダ__を指定しおください。 + +## DreamBooth、キャプション方匏 + +孊習甚画像、正則化画像のフォルダに、画像ず同じファむル名で、拡匵子.captionオプションで倉えられたすのファむルを眮くず、そのファむルからキャプションを読み蟌みプロンプトずしお孊習したす。 + +※それらの画像の孊習に、フォルダ名identifier classは䜿甚されなくなりたす。 + +キャプションファむルの拡匵子はデフォルトで.captionです。孊習スクリプトの `--caption_extension` オプションで倉曎できたす。`--shuffle_caption` オプションで孊習時のキャプションに぀いお、カンマ区切りの各郚分をシャッフルしながら孊習したす。 + +## fine tuning 方匏 + +メタデヌタを䜜るずころたでは蚭定ファむルを䜿う堎合ず同様です。`in_json` オプションでメタデヌタファむルを指定したす。 + +# 孊習途䞭でのサンプル出力 + +孊習䞭のモデルで詊しに画像生成するこずで孊習の進み方を確認できたす。孊習スクリプトに以䞋のオプションを指定したす。 + +- `--sample_every_n_steps` / `--sample_every_n_epochs` + + サンプル出力するステップ数たたぱポック数を指定したす。この数ごずにサンプル出力したす。䞡方指定するず゚ポック数が優先されたす。 + +- `--sample_prompts` + + サンプル出力甚プロンプトのファむルを指定したす。 + +- `--sample_sampler` + + サンプル出力に䜿うサンプラヌを指定したす。 + `'ddim', 'pndm', 'heun', 'dpmsolver', 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'`が遞べたす。 + +サンプル出力を行うにはあらかじめプロンプトを蚘述したテキストファむルを甚意しおおく必芁がありたす。1行に぀き1プロンプトで蚘述したす。 + +たずえば以䞋のようになりたす。 + +```txt +# prompt 1 +masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 + +# prompt 2 +masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 +``` + +先頭が `#` の行はコメントになりたす。`--n` のように 「`--` + 英小文字」で生成画像ぞのオプションを指定できたす。以䞋が䜿えたす。 + +- `--n` 次のオプションたでをネガティブプロンプトずしたす。 +- `--w` 生成画像の暪幅を指定したす。 +- `--h` 生成画像の高さを指定したす。 +- `--d` 生成画像のseedを指定したす。 +- `--l` 生成画像のCFG scaleを指定したす。 +- `--s` 生成時のステップ数を指定したす。 + + +# 各スクリプトで共通の、よく䜿われるオプション + +スクリプトの曎新埌、ドキュメントの曎新が远い付いおいない堎合がありたす。その堎合は `--help` オプションで䜿甚できるオプションを確認しおください。 + +## 孊習に䜿うモデル指定 + +- `--v2` / `--v_parameterization` + + 孊習察象モデルずしおHugging Faceのstable-diffusion-2-base、たたはそこからのfine tuningモデルを䜿う堎合掚論時に `v2-inference.yaml` を䜿うように指瀺されおいるモデルの堎合は `--v2` オプションを、stable-diffusion-2や768-v-ema.ckpt、およびそれらのfine tuningモデルを䜿う堎合掚論時に `v2-inference-v.yaml` を䜿うモデルの堎合は `--v2` ず `--v_parameterization` の䞡方のオプションを指定しおください。 + + Stable Diffusion 2.0では倧きく以䞋の点が倉わっおいたす。 + + 1. 䜿甚するTokenizer + 2. 䜿甚するText Encoderおよび䜿甚する出力局2.0は最埌から二番目の局を䜿う + 3. Text Encoderの出力次元数768->1024 + 4. U-Netの構造CrossAttentionのhead数など + 5. v-parameterizationサンプリング方法が倉曎されおいるらしい + + このうちbaseでは14が、baseの぀かない方768-vでは15が採甚されおいたす。14を有効にするのがv2オプション、5を有効にするのがv_parameterizationオプションです。 + +- `--pretrained_model_name_or_path` + + 远加孊習を行う元ずなるモデルを指定したす。Stable Diffusionのcheckpointファむル.ckptたたは.safetensors、Diffusersのロヌカルディスクにあるモデルディレクトリ、DiffusersのモデルID"stabilityai/stable-diffusion-2"などが指定できたす。 + +## 孊習に関する蚭定 + +- `--output_dir` + + 孊習埌のモデルを保存するフォルダを指定したす。 + +- `--output_name` + + モデルのファむル名を拡匵子を陀いお指定したす。 + +- `--dataset_config` + + デヌタセットの蚭定を蚘述した `.toml` ファむルを指定したす。 + +- `--max_train_steps` / `--max_train_epochs` + + 孊習するステップ数や゚ポック数を指定したす。䞡方指定するず゚ポック数のほうが優先されたす。 + +- `--mixed_precision` + + 省メモリ化のため mixed precision 混合粟床で孊習したす。`--mixed_precision="fp16"` のように指定したす。mixed precision なしデフォルトず比べお粟床が䜎くなる可胜性がありたすが、孊習に必芁なGPUメモリ量が倧きく枛りたす。 + + RTX30 シリヌズ以降では `bf16` も指定できたす。環境敎備時にaccelerateに行った蚭定ず合わせおください。 + +- `--gradient_checkpointing` + + 孊習時の重みの蚈算をたずめお行うのではなく少しず぀行うこずで、孊習に必芁なGPUメモリ量を枛らしたす。オンオフは粟床には圱響したせんが、オンにするずバッチサむズを倧きくできるため、そちらでの圱響はありたす。 + + たた䞀般的にはオンにするず速床は䜎䞋したすが、バッチサむズを倧きくできるので、トヌタルでの孊習時間はむしろ速くなるかもしれたせん。 + +- `--xformers` / `--mem_eff_attn` + + xformersオプションを指定するずxformersのCrossAttentionを甚いたす。xformersをむンストヌルしおいない堎合や゚ラヌずなる堎合環境にもよりたすが `mixed_precision="no"` の堎合など、代わりに `mem_eff_attn` オプションを指定するず省メモリ版CrossAttentionを䜿甚したすxformersよりも速床は遅くなりたす。 + +- `--save_precision` + + 保存時のデヌタ粟床を指定したす。save_precisionオプションにfloat、fp16、bf16のいずれかを指定するず、その圢匏でモデルを保存したすDreamBooth、fine tuningでDiffusers圢匏でモデルを保存する堎合は無効です。モデルのサむズを削枛したい堎合などにお䜿いください。 + +- `--save_every_n_epochs` / `--save_state` / `--resume` + save_every_n_epochsオプションに数倀を指定するず、その゚ポックごずに孊習途䞭のモデルを保存したす。 + + save_stateオプションを同時に指定するず、optimizer等の状態も含めた孊習状態を合わせお保存したす保存したモデルからも孊習再開できたすが、それに比べるず粟床の向䞊、孊習時間の短瞮が期埅できたす。保存先はフォルダになりたす。 + + 孊習状態は保存先フォルダに `-??????-state`??????ぱポック数ずいう名前のフォルダで出力されたす。長時間にわたる孊習時にご利甚ください。 + + 保存された孊習状態から孊習を再開するにはresumeオプションを䜿いたす。孊習状態のフォルダ`output_dir` ではなくその䞭のstateのフォルダを指定しおください。 + + なおAcceleratorの仕様により、゚ポック数、global stepは保存されおおらず、resumeしたずきにも1からになりたすがご容赊ください。 + +- `--save_model_as` DreamBooth, fine tuning のみ + + モデルの保存圢匏を`ckpt, safetensors, diffusers, diffusers_safetensors` から遞べたす。 + + `--save_model_as=safetensors` のように指定したす。Stable Diffusion圢匏ckptたたはsafetensorsを読み蟌み、Diffusers圢匏で保存する堎合、䞍足する情報はHugging Faceからv1.5たたはv2.1の情報を萜ずしおきお補完したす。 + +- `--clip_skip` + + `2` を指定するず、Text Encoder (CLIP) の埌ろから二番目の局の出力を甚いたす。1たたはオプション省略時は最埌の局を甚いたす。 + + ※SD2.0はデフォルトで埌ろから二番目の局を䜿うため、SD2.0の孊習では指定しないでください。 + + 孊習察象のモデルがもずもず二番目の局を䜿うように孊習されおいる堎合は、2を指定するずよいでしょう。 + + そうではなく最埌の局を䜿甚しおいた堎合はモデル党䜓がそれを前提に孊習されおいたす。そのため改めお二番目の局を䜿甚しお孊習するず、望たしい孊習結果を埗るにはある皋床の枚数の教垫デヌタ、長めの孊習が必芁になるかもしれたせん。 + +- `--max_token_length` + + デフォルトは75です。`150` たたは `225` を指定するこずでトヌクン長を拡匵しお孊習できたす。長いキャプションで孊習する堎合に指定しおください。 + + ただし孊習時のトヌクン拡匵の仕様は Automatic1111 氏のWeb UIずは埮劙に異なるため分割の仕様など、必芁なければ75で孊習するこずをお勧めしたす。 + + clip_skipず同様に、モデルの孊習状態ず異なる長さで孊習するには、ある皋床の教垫デヌタ枚数、長めの孊習時間が必芁になるず思われたす。 + +- `--persistent_data_loader_workers` + + Windows環境で指定するず゚ポック間の埅ち時間が倧幅に短瞮されたす。 + +- `--max_data_loader_n_workers` + + デヌタ読み蟌みのプロセス数を指定したす。プロセス数が倚いずデヌタ読み蟌みが速くなりGPUを効率的に利甚できたすが、メむンメモリを消費したす。デフォルトは「`8` たたは `CPU同時実行スレッド数-1` の小さいほう」なので、メむンメモリに䜙裕がない堎合や、GPU䜿甚率が90%皋床以䞊なら、それらの数倀を芋ながら `2` たたは `1` 皋床たで䞋げおください。 + +- `--logging_dir` / `--log_prefix` + + 孊習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定しおください。TensorBoard圢匏のログが保存されたす。 + + たずえば--logging_dir=logsず指定するず、䜜業フォルダにlogsフォルダが䜜成され、その䞭の日時フォルダにログが保存されたす。 + たた--log_prefixオプションを指定するず、日時の前に指定した文字列が远加されたす。「--logging_dir=logs --log_prefix=db_style1_」などずしお識別甚にお䜿いください。 + + TensorBoardでログを確認するには、別のコマンドプロンプトを開き、䜜業フォルダで以䞋のように入力したす。 + + ``` + tensorboard --logdir=logs + ``` + + tensorboardは環境敎備時にあわせおむンストヌルされるず思いたすが、もし入っおいないなら `pip install tensorboard` で入れおください。 + + その埌ブラりザを開き、http://localhost:6006/ ぞアクセスするず衚瀺されたす。 + +- `--noise_offset` + + こちらの蚘事の実装になりたす: https://www.crosslabs.org//blog/diffusion-with-offset-noise + + 党䜓的に暗い、明るい画像の生成結果が良くなる可胜性があるようです。LoRA孊習でも有効なようです。`0.1` 皋床の倀を指定するずよいようです。 + +- `--debug_dataset` + + このオプションを付けるこずで孊習を行う前に事前にどのような画像デヌタ、キャプションで孊習されるかを確認できたす。Escキヌを抌すず終了しおコマンドラむンに戻りたす。 + + ※Linux環境Colabを含むでは画像は衚瀺されたせん。 + +- `--vae` + + vaeオプションにStable Diffusionのcheckpoint、VAEのcheckpointファむル、DiffusesのモデルたたはVAEずもにロヌカルたたはHugging FaceのモデルIDが指定できたすのいずれかを指定するず、そのVAEを䜿っお孊習したすlatentsのキャッシュ時たたは孊習䞭のlatents取埗時。 + + DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み蟌んだものになりたす。 + + +## オプティマむザ関係 + +- `--optimizer_type` + --オプティマむザの皮類を指定したす。以䞋が指定できたす。 + - AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) + - 過去のバヌゞョンのオプション未指定時ず同じ + - AdamW8bit : 匕数は同䞊 + - 過去のバヌゞョンの--use_8bit_adam指定時ず同じ + - Lion : https://github.com/lucidrains/lion-pytorch + - 過去のバヌゞョンの--use_lion_optimizer指定時ず同じ + - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True + - SGDNesterov8bit : 匕数は同䞊 + - DAdaptation : https://github.com/facebookresearch/dadaptation + - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) + - 任意のオプティマむザ + +- `--learning_rate` + + 孊習率を指定したす。適切な孊習率は孊習スクリプトにより異なりたすので、それぞれの説明を参照しおください。 + +- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power` + + 孊習率のスケゞュヌラ関連の指定です。 + + lr_schedulerオプションで孊習率のスケゞュヌラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから遞べたす。デフォルトはconstantです。 + + lr_warmup_stepsでスケゞュヌラのりォヌムアップだんだん孊習率を倉えおいくステップ数を指定できたす。 + + lr_scheduler_num_cycles は cosine with restartsスケゞュヌラでのリスタヌト回数、lr_scheduler_power は polynomialスケゞュヌラでのpolynomial power です。 + + 詳现に぀いおは各自お調べください。 + +### オプティマむザの指定に぀いお + +オプティマむザのオプション匕数は--optimizer_argsオプションで指定しおください。key=valueの圢匏で、耇数の倀が指定できたす。たた、valueはカンマ区切りで耇数の倀が指定できたす。たずえばAdamWオプティマむザに匕数を指定する堎合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになりたす。 + +オプション匕数を指定する堎合は、それぞれのオプティマむザの仕様をご確認ください。 + +䞀郚のオプティマむザでは必須の匕数があり、省略するず自動的に远加されたすSGDNesterovのmomentumなど。コン゜ヌルの出力を確認しおください。 + +D-Adaptationオプティマむザは孊習率を自動調敎したす。孊習率のオプションに指定した倀は孊習率そのものではなくD-Adaptationが決定した孊習率の適甚率になりたすので、通垞は1.0を指定しおください。Text EncoderにU-Netの半分の孊習率を指定したい堎合は、``--text_encoder_lr=0.5 --unet_lr=1.0``ず指定したす。 + +AdaFactorオプティマむザはrelative_step=Trueを指定するず孊習率を自動調敎できたす省略時はデフォルトで远加されたす。自動調敎する堎合は孊習率のスケゞュヌラにはadafactor_schedulerが匷制的に䜿甚されたす。たたscale_parameterずwarmup_initを指定するずよいようです。 + +自動調敎する堎合のオプション指定はたずえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになりたす。 + +孊習率を自動調敎しない堎合はオプション匕数 ``relative_step=False`` を远加しおください。その堎合、孊習率のスケゞュヌラにはconstant_with_warmupが、たた募配のclip normをしないこずが掚奚されおいるようです。そのため匕数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになりたす。 + +### 任意のオプティマむザを䜿う + +``torch.optim`` のオプティマむザを䜿う堎合にはクラス名のみを``--optimizer_type=RMSprop``など、他のモゞュヌルのオプティマむザを䜿う時は「モゞュヌル名.クラス名」を指定しおください``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など。 + +内郚でimportlibしおいるだけで動䜜は未確認です。必芁ならパッケヌゞをむンストヌルしおください。 + + + + +# メタデヌタファむルの䜜成 + +## 教垫デヌタの甚意 + +前述のように孊習させたい画像デヌタを甚意し、任意のフォルダに入れおください。 + +たずえば以䞋のように画像を栌玍したす。 + +![教垫デヌタフォルダのスクショ](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png) + +## 自動キャプショニング + +キャプションを䜿わずタグだけで孊習する堎合はスキップしおください。 + +たた手動でキャプションを甚意する堎合、キャプションは教垫デヌタ画像ず同じディレクトリに、同じファむル名、拡匵子.caption等で甚意しおください。各ファむルは1行のみのテキストファむルずしたす。 + +### BLIPによるキャプショニング + +最新版ではBLIPのダりンロヌド、重みのダりンロヌド、仮想環境の远加は䞍芁になりたした。そのたたで動䜜したす。 + +finetuneフォルダ内のmake_captions.pyを実行したす。 + +``` +python finetune\make_captions.py --batch_size <バッチサむズ> <教垫デヌタフォルダ> +``` + +バッチサむズ8、教垫デヌタを芪フォルダのtrain_dataに眮いた堎合、以䞋のようになりたす。 + +``` +python finetune\make_captions.py --batch_size 8 ..\train_data +``` + +キャプションファむルが教垫デヌタ画像ず同じディレクトリに、同じファむル名、拡匵子.captionで䜜成されたす。 + +batch_sizeはGPUのVRAM容量に応じお増枛しおください。倧きいほうが速くなりたすVRAM 12GBでももう少し増やせるず思いたす。 +max_lengthオプションでキャプションの最倧長を指定できたす。デフォルトは75です。モデルをトヌクン長225で孊習する堎合には長くしおも良いかもしれたせん。 +caption_extensionオプションでキャプションの拡匵子を倉曎できたす。デフォルトは.captionです.txtにするず埌述のDeepDanbooruず競合したす。 + +耇数の教垫デヌタフォルダがある堎合には、それぞれのフォルダに察しお実行しおください。 + +なお、掚論にランダム性があるため、実行するたびに結果が倉わりたす。固定する堎合には--seedオプションで `--seed 42` のように乱数seedを指定しおください。 + +その他のオプションは `--help` でヘルプをご参照くださいパラメヌタの意味に぀いおはドキュメントがたずたっおいないようで、゜ヌスを芋るしかないようです。 + +デフォルトでは拡匵子.captionでキャプションファむルが生成されたす。 + +![captionが生成されたフォルダ](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png) + +たずえば以䞋のようなキャプションが付きたす。 + +![キャプションず画像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png) + +## DeepDanbooruによるタグ付け + +danbooruタグのタグ付け自䜓を行わない堎合は「キャプションずタグ情報の前凊理」に進んでください。 + +タグ付けはDeepDanbooruたたはWD14Taggerで行いたす。WD14Taggerのほうが粟床が良いようです。WD14Taggerでタグ付けする堎合は、次の章ぞ進んでください。 + +### 環境敎備 + +DeepDanbooru https://github.com/KichangKim/DeepDanbooru を䜜業フォルダにcloneしおくるか、zipをダりンロヌドしお展開したす。私はzipで展開したした。 +たたDeepDanbooruのReleasesのペヌゞ https://github.com/KichangKim/DeepDanbooru/releases の「DeepDanbooru Pretrained Model v3-20211112-sgd-e28」のAssetsから、deepdanbooru-v3-20211112-sgd-e28.zipをダりンロヌドしおきおDeepDanbooruのフォルダに展開したす。 + +以䞋からダりンロヌドしたす。Assetsをクリックしお開き、そこからダりンロヌドしたす。 + +![DeepDanbooruダりンロヌドペヌゞ](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png) + +以䞋のようなこういうディレクトリ構造にしおください + +![DeepDanbooruのディレクトリ構造](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png) + +Diffusersの環境に必芁なラむブラリをむンストヌルしたす。DeepDanbooruのフォルダに移動しおむンストヌルしたす実質的にはtensorflow-ioが远加されるだけだず思いたす。 + +``` +pip install -r requirements.txt +``` + +続いおDeepDanbooru自䜓をむンストヌルしたす。 + +``` +pip install . +``` + +以䞊でタグ付けの環境敎備は完了です。 + +### タグ付けの実斜 +DeepDanbooruのフォルダに移動し、deepdanbooruを実行しおタグ付けを行いたす。 + +``` +deepdanbooru evaluate <教垫デヌタフォルダ> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +教垫デヌタを芪フォルダのtrain_dataに眮いた堎合、以䞋のようになりたす。 + +``` +deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +タグファむルが教垫デヌタ画像ず同じディレクトリに、同じファむル名、拡匵子.txtで䜜成されたす。1件ず぀凊理されるためわりず遅いです。 + +耇数の教垫デヌタフォルダがある堎合には、それぞれのフォルダに察しお実行しおください。 + +以䞋のように生成されたす。 + +![DeepDanbooruの生成ファむル](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png) + +こんな感じにタグが付きたすすごい情報量  。 + +![DeepDanbooruタグず画像](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png) + +## WD14Taggerによるタグ付け + +DeepDanbooruの代わりにWD14Taggerを甚いる手順です。 + +Automatic1111氏のWebUIで䜿甚しおいるtaggerを利甚したす。こちらのgithubペヌゞhttps://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger の情報を参考にさせおいただきたした。 + +最初の環境敎備で必芁なモゞュヌルはむンストヌル枈みです。たた重みはHugging Faceから自動的にダりンロヌドしおきたす。 + +### タグ付けの実斜 + +スクリプトを実行しおタグ付けを行いたす。 +``` +python tag_images_by_wd14_tagger.py --batch_size <バッチサむズ> <教垫デヌタフォルダ> +``` + +教垫デヌタを芪フォルダのtrain_dataに眮いた堎合、以䞋のようになりたす。 +``` +python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data +``` + +初回起動時にはモデルファむルがwd14_tagger_modelフォルダに自動的にダりンロヌドされたすフォルダはオプションで倉えられたす。以䞋のようになりたす。 + +![ダりンロヌドされたファむル](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png) + +タグファむルが教垫デヌタ画像ず同じディレクトリに、同じファむル名、拡匵子.txtで䜜成されたす。 + +![生成されたタグファむル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) + +![タグず画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) + +threshオプションで、刀定されたタグのconfidence確信床がいく぀以䞊でタグを぀けるかが指定できたす。デフォルトはWD14Taggerのサンプルず同じ0.35です。倀を䞋げるずより倚くのタグが付䞎されたすが、粟床は䞋がりたす。 + +batch_sizeはGPUのVRAM容量に応じお増枛しおください。倧きいほうが速くなりたすVRAM 12GBでももう少し増やせるず思いたす。caption_extensionオプションでタグファむルの拡匵子を倉曎できたす。デフォルトは.txtです。 + +model_dirオプションでモデルの保存先フォルダを指定できたす。 + +たたforce_downloadオプションを指定するず保存先フォルダがあっおもモデルを再ダりンロヌドしたす。 + +耇数の教垫デヌタフォルダがある堎合には、それぞれのフォルダに察しお実行しおください。 + +## キャプションずタグ情報の前凊理 + +スクリプトから凊理しやすいようにキャプションずタグをメタデヌタずしおひず぀のファむルにたずめたす。 + +### キャプションの前凊理 + +キャプションをメタデヌタに入れるには、䜜業フォルダ内で以䞋を実行しおくださいキャプションを孊習に䜿わない堎合は実行䞍芁です実際は1行で蚘述したす、以䞋同様。`--full_path` オプションを指定しおメタデヌタに画像ファむルの堎所をフルパスで栌玍したす。このオプションを省略するず盞察パスで蚘録されたすが、フォルダ指定が `.toml` ファむル内で別途必芁になりたす。 + +``` +python merge_captions_to_metadata.py --full_apth <教垫デヌタフォルダ> +  --in_json <読み蟌むメタデヌタファむル名> <メタデヌタファむル名> +``` + +メタデヌタファむル名は任意の名前です。 +教垫デヌタがtrain_data、読み蟌むメタデヌタファむルなし、メタデヌタファむルがmeta_cap.jsonの堎合、以䞋のようになりたす。 + +``` +python merge_captions_to_metadata.py --full_path train_data meta_cap.json +``` + +caption_extensionオプションでキャプションの拡匵子を指定できたす。 + +耇数の教垫デヌタフォルダがある堎合には、full_path匕数を指定し぀぀、それぞれのフォルダに察しお実行しおください。 + +``` +python merge_captions_to_metadata.py --full_path + train_data1 meta_cap1.json +python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json + train_data2 meta_cap2.json +``` + +in_jsonを省略するず曞き蟌み先メタデヌタファむルがあるずそこから読み蟌み、そこに䞊曞きしたす。 + +__※in_jsonオプションず曞き蟌み先を郜床曞き換えお、別のメタデヌタファむルぞ曞き出すようにするず安党です。__ + +### タグの前凊理 + +同様にタグもメタデヌタにたずめたすタグを孊習に䜿わない堎合は実行䞍芁です。 +``` +python merge_dd_tags_to_metadata.py --full_path <教垫デヌタフォルダ> + --in_json <読み蟌むメタデヌタファむル名> <曞き蟌むメタデヌタファむル名> +``` + +先ず同じディレクトリ構成で、meta_cap.jsonを読み、meta_cap_dd.jsonに曞きだす堎合、以䞋ずなりたす。 +``` +python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json +``` + +耇数の教垫デヌタフォルダがある堎合には、full_path匕数を指定し぀぀、それぞれのフォルダに察しお実行しおください。 + +``` +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json + train_data1 meta_cap_dd1.json +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json + train_data2 meta_cap_dd2.json +``` + +in_jsonを省略するず曞き蟌み先メタデヌタファむルがあるずそこから読み蟌み、そこに䞊曞きしたす。 + +__※in_jsonオプションず曞き蟌み先を郜床曞き換えお、別のメタデヌタファむルぞ曞き出すようにするず安党です。__ + +### キャプションずタグのクリヌニング + +ここたででメタデヌタファむルにキャプションずDeepDanbooruのタグがたずめられおいたす。ただ自動キャプショニングにしたキャプションは衚蚘ゆれなどがあり埮劙※ですし、タグにはアンダヌスコアが含たれおいたりratingが付いおいたりしたすのでDeepDanbooruの堎合、゚ディタの眮換機胜などを甚いおキャプションずタグのクリヌニングをしたほうがいいでしょう。 + +※たずえばアニメ絵の少女を孊習する堎合、キャプションにはgirl/girls/woman/womenなどのばら぀きがありたす。たた「anime girl」なども単に「girl」ずしたほうが適切かもしれたせん。 + +クリヌニング甚のスクリプトが甚意しおありたすので、スクリプトの内容を状況に応じお線集しおお䜿いください。 + +教垫デヌタフォルダの指定は䞍芁になりたした。メタデヌタ内の党デヌタをクリヌニングしたす。 + +``` +python clean_captions_and_tags.py <読み蟌むメタデヌタファむル名> <曞き蟌むメタデヌタファむル名> +``` + +--in_jsonは付きたせんのでご泚意ください。たずえば次のようになりたす。 + +``` +python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json +``` + +以䞊でキャプションずタグの前凊理は完了です。 + +## latentsの事前取埗 + +※ このステップは必須ではありたせん。省略しおも孊習時にlatentsを取埗しながら孊習できたす。 +たた孊習時に `random_crop` や `color_aug` などを行う堎合にはlatentsの事前取埗はできたせん画像を毎回倉えながら孊習するため。事前取埗をしない堎合、ここたでのメタデヌタで孊習できたす。 + +あらかじめ画像の朜圚衚珟を取埗しディスクに保存しおおきたす。それにより、孊習を高速に進めるこずができたす。あわせおbucketing教垫デヌタをアスペクト比に応じお分類するを行いたす。 + +䜜業フォルダで以䞋のように入力しおください。 +``` +python prepare_buckets_latents.py --full_path <教垫デヌタフォルダ> + <読み蟌むメタデヌタファむル名> <曞き蟌むメタデヌタファむル名> + + --batch_size <バッチサむズ> + --max_resolution <解像床 幅,高さ> + --mixed_precision <粟床> +``` + +モデルがmodel.ckpt、バッチサむズ4、孊習解像床は512\*512、粟床nofloat32で、meta_clean.jsonからメタデヌタを読み蟌み、meta_lat.jsonに曞き蟌む堎合、以䞋のようになりたす。 + +``` +python prepare_buckets_latents.py --full_path + train_data meta_clean.json meta_lat.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no +``` + +教垫デヌタフォルダにnumpyのnpz圢匏でlatentsが保存されたす。 + +解像床の最小サむズを--min_bucket_resoオプションで、最倧サむズを--max_bucket_resoで指定できたす。デフォルトはそれぞれ256、1024です。たずえば最小サむズに384を指定するず、256\*1024や320\*768などの解像床は䜿わなくなりたす。 +解像床を768\*768のように倧きくした堎合、最倧サむズに1280などを指定するず良いでしょう。 + +--flip_augオプションを指定するず巊右反転のaugmentationデヌタ拡匵を行いたす。疑䌌的にデヌタ量を二倍に増やすこずができたすが、デヌタが巊右察称でない堎合に指定するず䟋えばキャラクタの倖芋、髪型など孊習がうたく行かなくなりたす。 + + +反転した画像に぀いおもlatentsを取埗し、\*\_flip.npzファむルを保存する単玔な実装です。fline_tune.pyには特にオプション指定は必芁ありたせん。\_flip付きのファむルがある堎合、flip付き・なしのファむルを、ランダムに読み蟌みたす。 + +バッチサむズはVRAM 12GBでももう少し増やせるかもしれたせん。 +解像床は64で割り切れる数字で、"幅,高さ"で指定したす。解像床はfine tuning時のメモリサむズに盎結したす。VRAM 12GBでは512,512が限界ず思われたす※。16GBなら512,704や512,768たで䞊げられるかもしれたせん。なお256,256等にしおもVRAM 8GBでは厳しいようですパラメヌタやoptimizerなどは解像床に関係せず䞀定のメモリが必芁なため。 + +※batch size 1の孊習で12GB VRAM、640,640で動いたずの報告もありたした。 + +以䞋のようにbucketingの結果が衚瀺されたす。 + +![bucketingの結果](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png) + +耇数の教垫デヌタフォルダがある堎合には、full_path匕数を指定し぀぀、それぞれのフォルダに察しお実行しおください。 +``` +python prepare_buckets_latents.py --full_path + train_data1 meta_clean.json meta_lat1.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +python prepare_buckets_latents.py --full_path + train_data2 meta_lat1.json meta_lat2.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +``` +読み蟌み元ず曞き蟌み先を同じにするこずも可胜ですが別々の方が安党です。 + +__※匕数を郜床曞き換えお、別のメタデヌタファむルに曞き蟌むず安党です。__ + diff --git a/train_db.py b/train_db.py new file mode 100644 index 0000000000000000000000000000000000000000..b3eead94150361db27af572eaada24421e12cf90 --- /dev/null +++ b/train_db.py @@ -0,0 +1,429 @@ +# DreamBooth training +# XXX dropped option: fine_tune + +import gc +import time +import argparse +import itertools +import math +import os +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer = train_util.load_tokenizer(args) + + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 蚭定ファむルが利甚されるため以䞋のオプションは無芖されたす: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + if args.no_token_padding: + train_dataset_group.disable_token_padding() + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするずきはcolor_augずrandom_cropは䜿えたせん" + + # acceleratorを準備する + print("prepare accelerator") + + if args.gradient_accumulation_steps > 1: + print( + f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" + ) + print( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に蚭定されおいたす。accelerateは耇数モデルU-NetおよびText Encoderの孊習時にgradient_accumulation_stepsをサポヌトしおいないため結果は未知数です" + ) + + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに察応した型を甚意しおおき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み蟌む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # モデルに xformers ずか memory efficient attention を組み蟌む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 孊習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # 孊習を準備するモデルを適切な状態にする + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 + unet.requires_grad_(True) # 念のため远加 + text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + print("Text Encoder is not trained.") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 孊習に必芁なクラスを準備する + print("prepare optimizer, data loader etc.") + if train_text_encoder: + trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + trainable_params = unet.parameters() + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数0はメむンプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最倧で指定された数たで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 孊習ステップ数を蚈算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定゚ポックたでのステップ数: {args.max_train_steps}") + + # デヌタセット偎にも孊習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + # lr schedulerを甚意する TODO gradient_accumulation_stepsの扱いが䜕かおかしいかもしれない。埌で確認する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実隓的機胜募配も含めたfp16孊習を行う モデル党䜓をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を䜿う堎合はmixed_precision='fp16'を指定しおください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + # acceleratorがなんかよろしくやっおくれるらしい + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error + + # 実隓的機胜募配も含めたfp16孊習を行う PyTorchにパッチを圓おおfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を蚈算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 孊習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / å­Šç¿’é–‹å§‹") + print(f" num train images * repeats / 孊習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサむズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサむズ䞊列孊習、募配合蚈含む: {total_batch_size}") + print(f" gradient ccumulation steps / 募配を合蚈するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 孊習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth") + + loss_list = [] + loss_total = 0.0 + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + # 指定したステップ数たでText Encoderを孊習するepoch最初の状態 + unet.train() + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: + text_encoder.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + # 指定したステップ数でText Encoderの孊習を止める + if global_step == args.stop_text_encoder_training: + print(f"stop text encoder training at step {global_step}") + if not args.gradient_checkpointing: + text_encoder.train(False) + text_encoder.requires_grad_(False) + + with accelerator.accumulate(unet): + with torch.no_grad(): + # latentに倉換 + if cache_latents: + latents = batch["latents"].to(accelerator.device) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごずのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必芁なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + if train_text_encoder: + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この埌メモリを䜿うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False, True) + train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--no_token_padding", + action="store_true", + help="disable token padding (same as Diffuser's DreamBooth) / トヌクンのpaddingを無効にするDiffusers版DreamBoothず同じ動䜜", + ) + parser.add_argument( + "--stop_text_encoder_training", + type=int, + default=None, + help="steps to stop text encoder training, -1 for no training / Text Encoderの孊習を止めるステップ数、-1で最初から孊習しない", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_db_README-ja.md b/train_db_README-ja.md new file mode 100644 index 0000000000000000000000000000000000000000..0d0747bb41223a52a4f609f58eb1314639924913 --- /dev/null +++ b/train_db_README-ja.md @@ -0,0 +1,167 @@ +DreamBoothのガむドです。 + +[孊習に぀いおの共通ドキュメント](./train_README-ja.md) もあわせおご芧ください。 + +# 抂芁 + +DreamBoothずは、画像生成モデルに特定の䞻題を远加孊習し、それを特定の識別子で生成する技術です。[論文はこちら](https://arxiv.org/abs/2208.12242)。 + +具䜓的には、Stable Diffusionのモデルにキャラや画颚などを孊ばせ、それを `shs` のような特定の単語で呌び出せる生成画像に出珟させるこずができたす。 + +スクリプトは[DiffusersのDreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)を元にしおいたすが、以䞋のような機胜远加を行っおいたすいく぀かの機胜は元のスクリプト偎もその埌察応しおいたす。 + +スクリプトの䞻な機胜は以䞋の通りです。 + +- 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化[Shivam Shrirao氏版](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)ず同様。 +- xformersによる省メモリ化。 +- 512x512だけではなく任意サむズでの孊習。 +- augmentationによる品質の向䞊。 +- DreamBoothだけではなくText Encoder+U-Netのfine tuningに察応。 +- Stable Diffusion圢匏でのモデルの読み曞き。 +- Aspect Ratio Bucketing。 +- Stable Diffusion v2.0察応。 + +# 孊習の手順 + +あらかじめこのリポゞトリのREADMEを参照し、環境敎備を行っおください。 + +## デヌタの準備 + +[孊習デヌタの準備に぀いお](./train_README-ja.md) を参照しおください。 + +## 孊習の実行 + +スクリプトを実行したす。最倧限、メモリを節玄したコマンドは以䞋のようになりたす実際には1行で入力したす。それぞれの行を必芁に応じお曞き換えおください。12GB皋床のVRAMで動䜜するようです。 + +``` +accelerate launch --num_cpu_threads_per_process 1 train_db.py + --pretrained_model_name_or_path=<.ckptたたは.safetensordたたはDiffusers版モデルのディレクトリ> + --dataset_config=<デヌタ準備で䜜成した.tomlファむル> + --output_dir=<孊習したモデルの出力先フォルダ> + --output_name=<孊習したモデル出力時のファむル名> + --save_model_as=safetensors + --prior_loss_weight=1.0 + --max_train_steps=1600 + --learning_rate=1e-6 + --optimizer_type="AdamW8bit" + --xformers + --mixed_precision="fp16" + --cache_latents + --gradient_checkpointing +``` + +`num_cpu_threads_per_process` には通垞は1を指定するずよいようです。 + +`pretrained_model_name_or_path` に远加孊習を行う元ずなるモデルを指定したす。Stable Diffusionのcheckpointファむル.ckptたたは.safetensors、Diffusersのロヌカルディスクにあるモデルディレクトリ、DiffusersのモデルID"stabilityai/stable-diffusion-2"などが指定できたす。 + +`output_dir` に孊習埌のモデルを保存するフォルダを指定したす。`output_name` にモデルのファむル名を拡匵子を陀いお指定したす。`save_model_as` でsafetensors圢匏での保存を指定しおいたす。 + +`dataset_config` に `.toml` ファむルを指定したす。ファむル内でのバッチサむズ指定は、圓初はメモリ消費を抑えるために `1` ずしおください。 + +`prior_loss_weight` は正則化画像のlossの重みです。通垞は1.0を指定したす。 + +孊習させるステップ数 `max_train_steps` を1600ずしたす。孊習率 `learning_rate` はここでは1e-6を指定しおいたす。 + +省メモリ化のため `mixed_precision="fp16"` を指定したすRTX30 シリヌズ以降では `bf16` も指定できたす。環境敎備時にaccelerateに行った蚭定ず合わせおください。たた `gradient_checkpointing` を指定したす。 + +オプティマむザモデルを孊習デヌタにあうように最適化孊習させるクラスにメモリ消費の少ない 8bit AdamW を䜿うため、 `optimizer_type="AdamW8bit"` を指定したす。 + +`xformers` オプションを指定し、xformersのCrossAttentionを甚いたす。xformersをむンストヌルしおいない堎合や゚ラヌずなる堎合環境にもよりたすが `mixed_precision="no"` の堎合など、代わりに `mem_eff_attn` オプションを指定するず省メモリ版CrossAttentionを䜿甚したす速床は遅くなりたす。 + +省メモリ化のため `cache_latents` オプションを指定しおVAEの出力をキャッシュしたす。 + +ある皋床メモリがある堎合は、`.toml` ファむルを線集しおバッチサむズをたずえば `4` くらいに増やしおください高速化ず粟床向䞊の可胜性がありたす。たた `cache_latents` を倖すこずで augmentation が可胜になりたす。 + +### よく䜿われるオプションに぀いお + +以䞋の堎合には [孊習の共通ドキュメント](./train_README-ja.md) の「よく䜿われるオプション」を参照しおください。 + +- Stable Diffusion 2.xたたはそこからの掟生モデルを孊習する +- clip skipを2以䞊を前提ずしたモデルを孊習する +- 75トヌクンを超えたキャプションで孊習する + +### DreamBoothでのステップ数に぀いお + +圓スクリプトでは省メモリ化のため、ステップ圓たりの孊習回数が元のスクリプトの半分になっおいたす察象の画像ず正則化画像を同䞀のバッチではなく別のバッチに分割しお孊習するため。 + +元のDiffusers版やXavierXiao氏のStable Diffusion版ずほが同じ孊習を行うには、ステップ数を倍にしおください。 + +孊習画像ず正則化画像をたずめおから shuffle するため厳密にはデヌタの順番が倉わっおしたいたすが、孊習には倧きな圱響はないず思いたす。 + +### DreamBoothでのバッチサむズに぀いお + +モデル党䜓を孊習するためLoRA等の孊習に比べるずメモリ消費量は倚くなりたすfine tuningず同じ。 + +### 孊習率に぀いお + +Diffusers版では5e-6ですがStable Diffusion版は1e-6ですので、䞊のサンプルでは1e-6を指定しおいたす。 + +### 以前の圢匏のデヌタセット指定をした堎合のコマンドラむン + +解像床やバッチサむズをオプションで指定したす。コマンドラむンの䟋は以䞋の通りです。 + +``` +accelerate launch --num_cpu_threads_per_process 1 train_db.py + --pretrained_model_name_or_path=<.ckptたたは.safetensordたたはDiffusers版モデルのディレクトリ> + --train_data_dir=<孊習甚デヌタのディレクトリ> + --reg_data_dir=<正則化画像のディレクトリ> + --output_dir=<孊習したモデルの出力先ディレクトリ> + --output_name=<孊習したモデル出力時のファむル名> + --prior_loss_weight=1.0 + --resolution=512 + --train_batch_size=1 + --learning_rate=1e-6 + --max_train_steps=1600 + --use_8bit_adam + --xformers + --mixed_precision="bf16" + --cache_latents + --gradient_checkpointing +``` + +## 孊習したモデルで画像生成する + +孊習が終わるず指定したフォルダに指定した名前でsafetensorsファむルが出力されたす。 + +v1.4/1.5およびその他の掟生モデルの堎合、このモデルでAutomatic1111氏のWebUIなどで掚論できたす。models\Stable-diffusionフォルダに眮いおください。 + +v2.xモデルでWebUIで画像生成する堎合、モデルの仕様が蚘述された.yamlファむルが別途必芁になりたす。v2.x baseの堎合はv2-inference.yamlを、768/vの堎合はv2-inference-v.yamlを、同じフォルダに眮き、拡匵子の前の郚分をモデルず同じ名前にしおください。 + +![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) + +各yamlファむルは[Stability AIのSD2.0のリポゞトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にありたす。 + +# DreamBooth特有のその他の䞻なオプション + +すべおのオプションに぀いおは別文曞を参照しおください。 + +## Text Encoderの孊習を途䞭から行わない --stop_text_encoder_training + +stop_text_encoder_trainingオプションに数倀を指定するず、そのステップ数以降はText Encoderの孊習を行わずU-Netだけ孊習したす。堎合によっおは粟床の向䞊が期埅できるかもしれたせん。 + +恐らくText Encoderだけ先に過孊習するこずがあり、それを防げるのではないかず掚枬しおいたすが、詳现な圱響は䞍明です。 + +## Tokenizerのパディングをしない --no_token_padding +no_token_paddingオプションを指定するずTokenizerの出力をpaddingしたせんDiffusers版の旧DreamBoothず同じ動きになりたす。 + + + diff --git a/train_db_README.md b/train_db_README.md new file mode 100644 index 0000000000000000000000000000000000000000..2367d29ae3180e4e92d10943baa9aeafd9ad4e8b --- /dev/null +++ b/train_db_README.md @@ -0,0 +1,295 @@ +A guide to DreamBooth. The same procedure is used for training additional networks such as LoRA. + +# overview + +The main functions of the script are as follows. + +- Memory saving by 8bit Adam optimizer and latent cache (similar to ShivamShirao's version). +- Saved memory by xformers. +- Study in any size, not just 512x512. +- Quality improvement with augmentation. +- Supports fine tuning of Text Encoder+U-Net as well as DreamBooth. +- Read and write models in StableDiffusion format. +- Aspect Ratio Bucketing. +- Supports Stable Diffusion v2.0. + +# learning procedure + +## step 1. Environment improvement + +See the README in this repository. + + +## step 2. Determine identifier and class + +Decide the word identifier that connects the target you want to learn and the class to which the target belongs. + +(There are various names such as instance, but for the time being I will stick to the original paper.) + +Here's a very brief explanation (look it up for more details). + +class is the general type to learn. For example, if you want to learn a specific breed of dog, the class will be dog. Anime characters will be boy, girl, 1boy or 1girl depending on the model. + +The identifier is for identifying and learning the learning target. Any word is fine, but according to the original paper, ``a rare word with 3 letters or less that becomes one token with tokinizer'' is good. + +By using the identifier and class to train the model, for example, "shs dog", you can learn by identifying the object you want to learn from the class. + +When generating an image, if you say "shs dog", an image of the learned dog breed will be generated. + +(For reference, the identifier I use these days is ``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny``.) + +## step 3. Prepare images for training +Create a folder to store training images. __In addition, create a directory with the following name: + +``` +_ +``` + +Don't forget the ``_`` between them. + +The number of repetitions is specified to match the number of regularized images (described later). + +For example, at the prompt "sls frog", to repeat the data 20 times, it would be "20_sls frog". It will be as follows. + +![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png) + +## step 4. Preparing regularized images +This is the procedure when using a regularized image. It is also possible to learn without using the regularization image (the whole target class is affected because it is impossible to distinguish without using the regularization image). + +Create a folder to store the regularized images. __In addition, __ create a directory named ``_``. + +For example, with the prompt "frog" and without repeating the data (just once): + +![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) + +Specify the number of iterations so that " __ number of iterations of training images x number of training images ≥ number of iterations of regularization images x number of regularization images __". + +(The number of data in one epoch is "number of repetitions of training images x number of training images". If the number of regularization images is more than that, the remaining regularization images will not be used.) + +## step 5. Run training +Run the script. The maximally memory-saving command looks like this (actually typed on one line): + +*The command for learning additional networks such as LoRA is ``train_network.py`` instead of ``train_db.py``. You will also need additional network_\* options, so please refer to LoRA's guide. + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path= + --train_data_dir= + --reg_data_dir= + --output_dir= + --prior_loss_weight=1.0 + --resolution=512 + --train_batch_size=1 + --learning_rate=1e-6 + --max_train_steps=1600 + --use_8bit_adam + --xformers + --mixed_precision="bf16" + --cache_latents + --gradient_checkpointing +``` + +It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process. + +Specify the model to perform additional training in pretrained_model_name_or_path. You can specify a Stable Diffusion checkpoint file (.ckpt or .safetensors), a model directory on the Diffusers local disk, or a Diffusers model ID (such as "stabilityai/stable-diffusion-2"). The saved model after training will be saved in the same format as the original model by default (can be changed with the save_model_as option). + +prior_loss_weight is the loss weight of the regularized image. Normally, specify 1.0. + +resolution will be the size of the image (resolution, width and height). If bucketing (described later) is not used, use this size for training images and regularization images. + +train_batch_size is the training batch size. Set max_train_steps to 1600. The learning rate learning_rate is 5e-6 in the diffusers version and 1e-6 in the StableDiffusion version, so 1e-6 is specified here. + +Specify mixed_precision="bf16" (or "fp16") and gradient_checkpointing for memory saving. + +Specify the xformers option and use xformers' CrossAttention. If you don't have xformers installed, if you get an error (without mixed_precision, it was an error in my environment), specify the mem_eff_attn option instead to use the memory-saving version of CrossAttention (speed will be slower) . + +Cache VAE output with cache_latents option to save memory. + +If you have a certain amount of memory, specify it as follows, for example. + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path= + --train_data_dir= + --reg_data_dir= + --output_dir= + --prior_loss_weight=1.0 + --resolution=512 + --train_batch_size=4 + --learning_rate=1e-6 + --max_train_steps=400 + --use_8bit_adam + --xformers + --mixed_precision="bf16" + --cache_latents +``` + +Remove gradient_checkpointing to speed up (memory usage will increase). Increase the batch size to improve speed and accuracy. + +An example of using bucketing (see below) and using augmentation (see below) looks like this: + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path= + --train_data_dir= + --reg_data_dir= + --output_dir= + --resolution=768,512 + --train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800 + --use_8bit_adam --xformers --mixed_precision="bf16" + --save_every_n_epochs=1 --save_state --save_precision="bf16" + --logging_dir=logs + --enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280 + --color_aug --flip_aug --gradient_checkpointing --seed 42 +``` + +### About the number of steps +To save memory, the number of training steps per step is half that of train_drebooth.py (because the target image and the regularization image are divided into different batches instead of the same batch). +Double the number of steps to get almost the same training as the original Diffusers version and XavierXiao's StableDiffusion version. + +(Strictly speaking, the order of the data changes due to shuffle=True, but I don't think it has a big impact on learning.) + +## Generate an image with the trained model + +Name last.ckpt in the specified folder when learning is completed will output the checkpoint (if you learned the DiffUsers version model, it will be the last folder). + +For v1.4/1.5 and other derived models, this model can be inferred by Automatic1111's WebUI, etc. Place it in the models\Stable-diffusion folder. + +When generating images with WebUI with the v2.x model, a separate .yaml file that describes the model specifications is required. Place v2-inference.yaml for v2.x base and v2-inference-v.yaml for 768/v in the same folder and make the part before the extension the same name as the model. + +![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) + +Each yaml file can be found at [https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion] (Stability AI SD2.0 repository). + +# Other study options + +## Supports Stable Diffusion 2.0 --v2 / --v_parameterization +Specify the v2 option when using Hugging Face's stable-diffusion-2-base, and specify both the v2 and v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt. + +In addition, learning SD 2.0 seems to be difficult with VRAM 12GB because the Text Encoder is getting bigger. + +The following points have changed significantly in Stable Diffusion 2.0. + +1. Tokenizer to use +2. Which Text Encoder to use and which output layer to use (2.0 uses the penultimate layer) +3. Output dimensionality of Text Encoder (768->1024) +4. Structure of U-Net (number of heads of CrossAttention, etc.) +5. v-parameterization (the sampling method seems to have changed) + +Among these, 1 to 4 are adopted for base, and 1 to 5 are adopted for the one without base (768-v). Enabling 1-4 is the v2 option, and enabling 5 is the v_parameterization option. + +## check training data --debug_dataset +By adding this option, you can check what kind of image data and captions will be learned in advance before learning. Press Esc to exit and return to the command line. + +*Please note that it seems to hang when executed in an environment where there is no screen such as Colab. + +## Stop training Text Encoder --stop_text_encoder_training +If you specify a numerical value for the stop_text_encoder_training option, after that number of steps, only the U-Net will be trained without training the Text Encoder. In some cases, the accuracy may be improved. + +(Probably only the Text Encoder may overfit first, and I guess that it can be prevented, but the detailed impact is unknown.) + +## Load and learn VAE separately --vae +If you specify either a Stable Diffusion checkpoint, a VAE checkpoint file, a Diffuses model, or a VAE (both of which can specify a local or Hugging Face model ID) in the vae option, that VAE is used for learning (latents when caching or getting latents during learning). +The saved model will incorporate this VAE. + +## save during learning --save_every_n_epochs / --save_state / --resume +Specifying a number for the save_every_n_epochs option saves the model during training every epoch. + +If you specify the save_state option at the same time, the learning state including the state of the optimizer etc. will be saved together (compared to restarting learning from the checkpoint, you can expect to improve accuracy and shorten the learning time). The learning state is output in a folder named "epoch-??????-state" (?????? is the number of epochs) in the destination folder. Please use it when studying for a long time. + +Use the resume option to resume training from a saved training state. Please specify the learning state folder. + +Please note that due to the specifications of Accelerator (?), the number of epochs and global step are not saved, and it will start from 1 even when you resume. + +## No tokenizer padding --no_token_padding +The no_token_padding option does not pad the output of the Tokenizer (same behavior as Diffusers version of old DreamBooth). + +## Training with arbitrary size images --resolution +You can study outside the square. Please specify "width, height" like "448,640" in resolution. Width and height must be divisible by 64. Match the size of the training image and the regularization image. + +Personally, I often generate vertically long images, so I sometimes learn with "448, 640". + +## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso +It is enabled by specifying the enable_bucket option. Stable Diffusion is trained at 512x512, but also at resolutions such as 256x768 and 384x640. + +If you specify this option, you do not need to unify the training images and regularization images to a specific resolution. Choose from several resolutions (aspect ratios) and learn at that resolution. +Since the resolution is 64 pixels, the aspect ratio may not be exactly the same as the original image. + +You can specify the minimum size of the resolution with the min_bucket_reso option and the maximum size with the max_bucket_reso. The defaults are 256 and 1024 respectively. +For example, specifying a minimum size of 384 will not use resolutions such as 256x1024 or 320x768. +If you increase the resolution to 768x768, you may want to specify 1280 as the maximum size. + +When Aspect Ratio Bucketing is enabled, it may be better to prepare regularization images with various resolutions that are similar to the training images. + +(Because the images in one batch are not biased toward training images and regularization images. + +## augmentation --color_aug / --flip_aug +Augmentation is a method of improving model performance by dynamically changing data during learning. Learn while subtly changing the hue with color_aug and flipping left and right with flip_aug. + +Since the data changes dynamically, it cannot be specified together with the cache_latents option. + +## Specify data precision when saving --save_precision +Specifying float, fp16, or bf16 as the save_precision option will save the checkpoint in that format (only when saving in Stable Diffusion format). Please use it when you want to reduce the size of checkpoint. + +## save in any format --save_model_as +Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors. + +When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face. + +## Save learning log --logging_dir / --log_prefix +Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved. + +For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder. +Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=db_style1_" for identification. + +To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard). + +``` +tensorboard --logdir=logs +``` + +Then open your browser and go to http://localhost:6006/ to see it. + +## scheduler related specification of learning rate --lr_scheduler / --lr_warmup_steps +You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant. With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details. + +## Training with fp16 gradient (experimental feature) --full_fp16 +The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). +As a result, it seems that the SD1.x 512x512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512x512 size can be learned with a VRAM usage of less than 12GB. + +Specify fp16 in the accelerate config beforehand and optionally set ``mixed_precision="fp16"`` (bf16 does not work). + +To minimize memory usage, use xformers, use_8bit_adam, cache_latents, gradient_checkpointing options and set train_batch_size to 1. + +(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.) + +It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). Accuracy will drop considerably, and the probability of learning failure on the way will also increase. +The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk. + +# Other learning methods + +## Learning multiple classes, multiple identifiers +The method is simple, multiple folders with ``Repetition count_ `` in the training image folder, and a folder with ``Repetition count_`` in the regularization image folder. Please prepare multiple + +For example, learning "sls frog" and "cpc rabbit" at the same time would look like this: + +![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png) + +If you have one class and multiple targets, you can have only one regularized image folder. For example, if 1girl has character A and character B, do as follows. + +- train_girls + - 10_sls 1girl + - 10_cpc 1girl +- reg_girls + -1_1girl + +If the number of data varies, it seems that good results can be obtained by adjusting the number of repetitions to unify the number of sheets for each class and identifier. + +## Use captions in DreamBooth +If you put a file with the same file name as the image and the extension .caption (you can change it in the option) in the training image and regularization image folders, the caption will be read from that file and learned as a prompt. + +* The folder name (identifier class) will no longer be used for training those images. + +Adding captions to each image (you can use BLIP, etc.) may help clarify the attributes you want to learn. + +Caption files have a .caption extension by default. You can change it with --caption_extension. With the --shuffle_caption option, study captions during learning while shuffling each part separated by commas. \ No newline at end of file diff --git a/train_network.py b/train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..476f76dfc01635a011a85b26a09e9feca18fd443 --- /dev/null +++ b/train_network.py @@ -0,0 +1,724 @@ +from torch.nn.parallel import DistributedDataParallel as DDP +import importlib +import argparse +import gc +import math +import os +import random +import time +import json +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +import library.train_util as train_util +from library.train_util import ( + DreamBoothDataset, +) +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + + +# TODO 他のスクリプトず共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if args.network_train_unet_only: + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + else: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder + + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # デヌタセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 蚭定ファむルが利甚されるため以䞋のオプションは無芖されたす: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value('i',0) + current_step = Value('i',0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありたせん。匕数指定を確認しおくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの芪フォルダを指定する必芁がありたす" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするずきはcolor_augずrandom_cropは䜿えたせん" + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに察応した型を甚意しおおき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み蟌む + for pi in range(accelerator.state.num_processes): + # TODO: modify other training scripts as well + if pi == accelerator.state.local_process_index: + print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator.device if args.lowram else "cpu" + ) + + # work on low-ram device + if args.lowram: + text_encoder.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + + gc.collect() + torch.cuda.empty_cache() + accelerator.wait_for_everyone() + + + # モデルに xformers ずか memory efficient attention を組み蟌む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 孊習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # prepare network + import sys + + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + if network is None: + return + + if args.network_weights is not None: + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) + + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + network.enable_gradient_checkpointing() # may have no effect + + # 孊習に必芁なクラスを準備する + print("prepare optimizer, data loader etc.") + + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数0はメむンプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最倧で指定された数たで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 孊習ステップ数を蚈算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + if is_main_process: + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定゚ポックたでのステップ数: {args.max_train_steps}") + + # デヌタセット偎にも孊習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを甚意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実隓的機胜募配も含めたfp16孊習を行う モデル党䜓をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を䜿う堎合はmixed_precision='fp16'を指定しおください。" + print("enable full fp16 training.") + network.to(weight_dtype) + + # acceleratorがなんかよろしくやっおくれるらしい + if train_unet and train_text_encoder: + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_text_encoder: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + text_encoder.train() + + # set top parameter requires_grad = True for gradient checkpointing works + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) + else: + unet.eval() + text_encoder.eval() + + # support DistributedDataParallel + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + + network.prepare_grad_etc(text_encoder, unet) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実隓的機胜募配も含めたfp16孊習を行う PyTorchにパッチを圓おおfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を蚈算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 孊習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + if is_main_process: + print("running training / å­Šç¿’é–‹å§‹") + print(f" num train images * repeats / 孊習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサむズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサむズ䞊列孊習、募配合蚈含む: {total_batch_size}") + print(f" gradient accumulation steps / 募配を合蚈するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 孊習ステップ数: {args.max_train_steps}") + + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not use this value + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + } + + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } + + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } + + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir + + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # あるディレクトリが耇数のdatasetで䜿甚されおいる堎合、䞀床だけ数える + # もずもず繰り返し回数を指定しおいるので、キャプション内でのタグの出珟回数ず、それが孊習で䜕床䜿われるかは䞀臎しない + # なので、ここで耇数datasetの回数を合算しおもあたり意味はない + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / デヌタセットは1個だけ存圚するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれたせん。" + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) + + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] + minimum_metadata = {} + for key in minimum_keys: + if key in metadata: + minimum_metadata[key] = metadata[key] + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + if accelerator.is_main_process: + accelerator.init_trackers("network_train") + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + for epoch in range(num_train_epochs): + if is_main_process: + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch+1 + + metadata["ss_epoch"] = str(epoch + 1) + + network.on_epoch_start(text_encoder, unet) + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(network): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに倉換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごずのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必芁なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + metadata["ss_training_finished_at"] = str(time.time()) + print(f"saving checkpoint: {ckpt_file}") + unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + if is_main_process: + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) + + if is_main_process: + network = unwrap_model(network) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この埌メモリを䜿うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデヌタを出力先モデルに保存しない") + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の圢匏デフォルトはsafetensors", + ) + + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの孊習率") + parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの孊習率") + + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 孊習するネットワヌクの初期重み") + parser.add_argument("--network_module", type=str, default=None, help="network module to train / 孊習察象のネットワヌクのモゞュヌル") + parser.add_argument( + "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モゞュヌルの次元数ネットワヌクにより定矩は異なりたす" + ) + parser.add_argument( + "--network_alpha", + type=float, + default=1, + help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調敎のalpha倀、デフォルト1旧バヌゞョンず同じ動䜜をするにはnetwork_dimず同じ倀を指定", + ) + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワヌクぞの远加の匕数" + ) + parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連郚分のみ孊習する") + parser.add_argument( + "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連郚分のみ孊習する" + ) + parser.add_argument( + "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデヌタに蚘録する任意のコメント文字列" + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_network_README-ja.md b/train_network_README-ja.md new file mode 100644 index 0000000000000000000000000000000000000000..79d1709f40b0724db3d8030cda7487c208f5e6d2 --- /dev/null +++ b/train_network_README-ja.md @@ -0,0 +1,269 @@ +# LoRAの孊習に぀いお + +[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)arxiv、[LoRA](https://github.com/microsoft/LoRA)githubをStable Diffusionに適甚したものです。 + +[cloneofsimo氏のリポゞトリ](https://github.com/cloneofsimo/lora)を倧いに参考にさせおいただきたした。ありがずうございたす。 + +通垞のLoRAは Linear およぎカヌネルサむズ 1x1 の Conv2d にのみ適甚されたすが、カヌネルサむズ 3x3 のConv2dに適甚を拡倧するこずもできたす。 + +Conv2d 3x3ぞの拡倧は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリヌスし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝したす。 + +8GB VRAMでもぎりぎり動䜜するようです。 + +[孊習に぀いおの共通ドキュメント](./train_README-ja.md) もあわせおご芧ください。 + +## 孊習したモデルに関する泚意 + +cloneofsimo氏のリポゞトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)ずは、珟時点では互換性がありたせん。いく぀かの機胜拡匵を行っおいるためです埌述。 + +WebUI等で画像生成する堎合には、孊習したLoRAのモデルを孊習元のStable Diffusionのモデルにこのリポゞトリ内のスクリプトであらかじめマヌゞしおおくか、こちらの[WebUI甹extension](https://github.com/kohya-ss/sd-webui-additional-networks)を䜿っおください。 + +# 孊習の手順 + +あらかじめこのリポゞトリのREADMEを参照し、環境敎備を行っおください。 + +## デヌタの準備 + +[孊習デヌタの準備に぀いお](./train_README-ja.md) を参照しおください。 + + +## 孊習の実行 + +`train_network.py`を甚いたす。 + +`train_network.py`では `--network_module` オプションに、孊習察象のモゞュヌル名を指定したす。LoRAに察応するのはnetwork.loraずなりたすので、それを指定しおください。 + +なお孊習率は通垞のDreamBoothやfine tuningよりも高めの、1e-4皋床を指定するずよいようです。 + +以䞋はコマンドラむンの䟋です。 + +``` +accelerate launch --num_cpu_threads_per_process 1 train_network.py + --pretrained_model_name_or_path=<.ckptたたは.safetensordたたはDiffusers版モデルのディレクトリ> + --dataset_config=<デヌタ準備で䜜成した.tomlファむル> + --output_dir=<孊習したモデルの出力先フォルダ> + --output_name=<孊習したモデル出力時のファむル名> + --save_model_as=safetensors + --prior_loss_weight=1.0 + --max_train_steps=400 + --learning_rate=1e-4 + --optimizer_type="AdamW8bit" + --xformers + --mixed_precision="fp16" + --cache_latents + --gradient_checkpointing + --save_every_n_epochs=1 + --network_module=networks.lora +``` + +`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されたす。他のオプション、オプティマむザ等に぀いおは [孊習の共通ドキュメント](./train_README-ja.md) の「よく䜿われるオプション」も参照しおください。 + +その他、以䞋のオプションが指定できたす。 + +* `--network_dim` + * LoRAのRANKを指定したす``--networkdim=4``など。省略時は4になりたす。数が倚いほど衚珟力は増したすが、孊習に必芁なメモリ、時間は増えたす。たた闇雲に増やしおも良くないようです。 +* `--network_alpha` + * アンダヌフロヌを防ぎ安定しお孊習するための ``alpha`` 倀を指定したす。デフォルトは1です。``network_dim``ず同じ倀を指定するず以前のバヌゞョンず同じ動䜜になりたす。 +* `--persistent_data_loader_workers` + * Windows環境で指定するず゚ポック間の埅ち時間が倧幅に短瞮されたす。 +* `--max_data_loader_n_workers` + * デヌタ読み蟌みのプロセス数を指定したす。プロセス数が倚いずデヌタ読み蟌みが速くなりGPUを効率的に利甚できたすが、メむンメモリを消費したす。デフォルトは「`8` たたは `CPU同時実行スレッド数-1` の小さいほう」なので、メむンメモリに䜙裕がない堎合や、GPU䜿甚率が90%皋床以䞊なら、それらの数倀を芋ながら `2` たたは `1` 皋床たで䞋げおください。 +* `--network_weights` + * 孊習前に孊習枈みのLoRAの重みを読み蟌み、そこから远加で孊習したす。 +* `--network_train_unet_only` + * U-Netに関連するLoRAモゞュヌルのみ有効ずしたす。fine tuning的な孊習で指定するずよいかもしれたせん。 +* `--network_train_text_encoder_only` + * Text Encoderに関連するLoRAモゞュヌルのみ有効ずしたす。Textual Inversion的な効果が期埅できるかもしれたせん。 +* `--unet_lr` + * U-Netに関連するLoRAモゞュヌルに、通垞の孊習率--learning_rateオプションで指定ずは異なる孊習率を䜿う時に指定したす。 +* `--text_encoder_lr` + * Text Encoderに関連するLoRAモゞュヌルに、通垞の孊習率--learning_rateオプションで指定ずは異なる孊習率を䜿う時に指定したす。Text Encoderのほうを若干䜎めの孊習率5e-5などにしたほうが良い、ずいう話もあるようです。 +* `--network_args` + * 耇数の匕数を指定できたす。埌述したす。 + +`--network_train_unet_only` ず `--network_train_text_encoder_only` の䞡方ずも未指定時デフォルトはText EncoderずU-Netの䞡方のLoRAモゞュヌルを有効にしたす。 + +## LoRA を Conv2d に拡倧しお適甚する + +通垞のLoRAは Linear およぎカヌネルサむズ 1x1 の Conv2d にのみ適甚されたすが、カヌネルサむズ 3x3 のConv2dに適甚を拡倧するこずもできたす。 + +`--network_args` に以䞋のように指定しおください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定しおください。 + +``` +--network_args "conv_dim=1" "conv_alpha=1" +``` + +以䞋のように alpha 省略時は1になりたす。 + +``` +--network_args "conv_dim=1" +``` + +## マヌゞスクリプトに぀いお + +merge_lora.pyでStable DiffusionのモデルにLoRAの孊習結果をマヌゞしたり、耇数のLoRAモデルをマヌゞしたりできたす。 + +### Stable DiffusionのモデルにLoRAのモデルをマヌゞする + +マヌゞ埌のモデルは通垞のStable Diffusionのckptず同様に扱えたす。たずえば以䞋のようなコマンドラむンになりたす。 + +``` +python networks\merge_lora.py --sd_model ..\model\model.ckpt + --save_to ..\lora_train1\model-char1-merged.safetensors + --models ..\lora_train1\last.safetensors --ratios 0.8 +``` + +Stable Diffusion v2.xのモデルで孊習し、それにマヌゞする堎合は、--v2オプションを指定しおください。 + +--sd_modelオプションにマヌゞの元ずなるStable Diffusionのモデルファむルを指定したす.ckptたたは.safetensorsのみ察応で、Diffusersは今のずころ察応しおいたせん。 + +--save_toオプションにマヌゞ埌のモデルの保存先を指定したす.ckptたたは.safetensors、拡匵子で自動刀定。 + +--modelsに孊習したLoRAのモデルファむルを指定したす。耇数指定も可胜で、その時は順にマヌゞしたす。 + +--ratiosにそれぞれのモデルの適甚率どのくらい重みを元モデルに反映するかを0~1.0の数倀で指定したす。䟋えば過孊習に近いような堎合は、適甚率を䞋げるずマシになるかもしれたせん。モデルの数ず同じだけ指定しおください。 + +耇数指定時は以䞋のようになりたす。 + +``` +python networks\merge_lora.py --sd_model ..\model\model.ckpt + --save_to ..\lora_train1\model-char1-merged.safetensors + --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5 +``` + +### 耇数のLoRAのモデルをマヌゞする + +耇数のLoRAモデルをひず぀ず぀SDモデルに適甚する堎合ず、耇数のLoRAモデルをマヌゞしおからSDモデルにマヌゞする堎合ずは、蚈算順序の関連で埮劙に異なる結果になりたす。 + +たずえば以䞋のようなコマンドラむンになりたす。 + +``` +python networks\merge_lora.py + --save_to ..\lora_train1\model-char1-style1-merged.safetensors + --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4 +``` + +--sd_modelオプションは指定䞍芁です。 + +--save_toオプションにマヌゞ埌のLoRAモデルの保存先を指定したす.ckptたたは.safetensors、拡匵子で自動刀定。 + +--modelsに孊習したLoRAのモデルファむルを指定したす。䞉぀以䞊も指定可胜です。 + +--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数倀で指定したす。二぀のモデルを䞀察䞀でマヌゞす堎合は、「0.5 0.5」になりたす。「1.0 1.0」では合蚈の重みが倧きくなりすぎお、恐らく結果はあたり望たしくないものになるず思われたす。 + +v1で孊習したLoRAずv2で孊習したLoRA、rank次元数や``alpha``の異なるLoRAはマヌゞできたせん。U-NetだけのLoRAずU-Net+Text EncoderのLoRAはマヌゞできるはずですが、結果は未知数です。 + + +### その他のオプション + +* precision + * マヌゞ蚈算時の粟床をfloat、fp16、bf16から指定できたす。省略時は粟床を確保するためfloatになりたす。メモリ䜿甚量を枛らしたい堎合はfp16/bf16を指定しおください。 +* save_precision + * モデル保存時の粟床をfloat、fp16、bf16から指定できたす。省略時はprecisionず同じ粟床になりたす。 + + +## 耇数のrankが異なるLoRAのモデルをマヌゞする + +耇数のLoRAをひず぀のLoRAで近䌌したす完党な再珟はできたせん。`svd_merge_lora.py`を甚いたす。たずえば以䞋のようなコマンドラむンになりたす。 + +``` +python networks\svd_merge_lora.py + --save_to ..\lora_train1\model-char1-style1-merged.safetensors + --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors + --ratios 0.6 0.4 --new_rank 32 --device cuda +``` + +`merge_lora.py` ず䞻なオプションは同䞀です。以䞋のオプションが远加されおいたす。 + +- `--new_rank` + - 䜜成するLoRAのrankを指定したす。 +- `--new_conv_rank` + - 䜜成する Conv2d 3x3 LoRA の rank を指定したす。省略時は `new_rank` ず同じになりたす。 +- `--device` + - `--device cuda`ずしおcudaを指定するず蚈算をGPU䞊で行いたす。凊理が速くなりたす。 + +## 圓リポゞトリ内の画像生成スクリプトで生成する + +gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを远加しおください。意味は孊習時ず同様です。 + +--network_mulオプションで0~1.0の数倀を指定するず、LoRAの適甚率を倉えられたす。 + +## 二぀のモデルの差分からLoRAモデルを䜜成する + +[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数匏はそのたた䜿わせおいただきたしたよく理解しおいたせんが近䌌には特異倀分解を甚いるようです。 + +二぀のモデルたずえばfine tuningの元モデルずfine tuning埌のモデルの差分を、LoRAで近䌌したす。 + +### スクリプトの実行方法 + +以䞋のように指定しおください。 +``` +python networks\extract_lora_from_models.py --model_org base-model.ckpt + --model_tuned fine-tuned-model.ckpt + --save_to lora-weights.safetensors --dim 4 +``` + +--model_orgオプションに元のStable Diffusionモデルを指定したす。䜜成したLoRAモデルを適甚する堎合は、このモデルを指定しお適甚するこずになりたす。.ckptたたは.safetensorsが指定できたす。 + +--model_tunedオプションに差分を抜出する察象のStable Diffusionモデルを指定したす。たずえばfine tuningやDreamBooth埌のモデルを指定したす。.ckptたたは.safetensorsが指定できたす。 + +--save_toにLoRAモデルの保存先を指定したす。--dimにLoRAの次元数を指定したす。 + +生成されたLoRAモデルは、孊習したLoRAモデルず同様に䜿甚できたす。 + +Text Encoderが二぀のモデルで同じ堎合にはLoRAはU-NetのみのLoRAずなりたす。 + +### その他のオプション + +- `--v2` + - v2.xのStable Diffusionモデルを䜿う堎合に指定しおください。 +- `--device` + - ``--device cuda``ずしおcudaを指定するず蚈算をGPU䞊で行いたす。凊理が速くなりたすCPUでもそこたで遅くないため、せいぜい倍数倍皋床のようです。 +- `--save_precision` + - LoRAの保存圢匏を"float", "fp16", "bf16"から指定したす。省略時はfloatになりたす。 +- `--conv_dim` + - 指定するずLoRAの適甚範囲を Conv2d 3x3 ぞ拡倧したす。Conv2d 3x3 の rank を指定したす。 + +## 画像リサむズスクリプト + +のちほどドキュメントを敎理したすがずりあえずここに説明を曞いおおきたす。 + +Aspect Ratio Bucketingの機胜拡匵で、小さな画像に぀いおは拡倧しないでそのたた教垫デヌタずするこずが可胜になりたした。元の教垫画像を瞮小した画像を、教垫デヌタに加えるず粟床が向䞊したずいう報告ずずもに前凊理甚のスクリプトをいただきたしたので敎備しお远加したした。bmaltais氏に感謝したす。 + +### スクリプトの実行方法 + +以䞋のように指定しおください。元の画像そのたた、およびリサむズ埌の画像が倉換先フォルダに保存されたす。リサむズ埌の画像には、ファむル名に ``+512x512`` のようにリサむズ先の解像床が付け加えられたす画像サむズずは異なりたす。リサむズ先の解像床より小さい画像は拡倧されるこずはありたせん。 + +``` +python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png + --copy_associated_files 元画像フォルダ 倉換先フォルダ +``` + +元画像フォルダ内の画像ファむルが、指定した解像床耇数指定可ず同じ面積になるようにリサむズされ、倉換先フォルダに保存されたす。画像以倖のファむルはそのたたコピヌされたす。 + +``--max_resolution`` オプションにリサむズ先のサむズを䟋のように指定しおください。面積がそのサむズになるようにリサむズしたす。耇数指定するず、それぞれの解像床でリサむズされたす。``512x512,384x384,256x256``なら、倉換先フォルダの画像は、元サむズずリサむズ埌サむズ×3の蚈4枚になりたす。 + +``--save_as_png`` オプションを指定するずpng圢匏で保存したす。省略するずjpeg圢匏quality=100で保存されたす。 + +``--copy_associated_files`` オプションを指定するず、拡匵子を陀き画像ず同じファむル名たずえばキャプションなどのファむルが、リサむズ埌の画像のファむル名ず同じ名前でコピヌされたす。 + + +### その他のオプション + +- divisible_by + - リサむズ埌の画像のサむズ瞊、暪のそれぞれがこの倀で割り切れるように、画像䞭心を切り出したす。 +- interpolation + - 瞮小時の補完方法を指定したす。``area, cubic, lanczos4``から遞択可胜で、デフォルトは``area``です。 + + +## 远加情報 + +### cloneofsimo氏のリポゞトリずの違い + +2022/12/25時点では、圓リポゞトリはLoRAの適甚個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡倧し、衚珟力が増しおいたす。ただその代わりメモリ䜿甚量は増え、8GBぎりぎりになりたした。 + +たたモゞュヌル入れ替え機構は党く異なりたす。 + +### 将来拡匵に぀いお + +LoRAだけでなく他の拡匵にも察応可胜ですので、それらも远加予定です。 diff --git a/train_network_README.md b/train_network_README.md new file mode 100644 index 0000000000000000000000000000000000000000..b0363a68b66f4fb8ddd7efc5ccaecc730e1666e9 --- /dev/null +++ b/train_network_README.md @@ -0,0 +1,189 @@ +## About learning LoRA + +[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) (arxiv), [LoRA](https://github.com/microsoft/LoRA) (github) to Stable Applied to Diffusion. + +[cloneofsimo's repository](https://github.com/cloneofsimo/lora) was a great reference. Thank you very much. + +8GB VRAM seems to work just fine. + +## A Note about Trained Models + +Cloneofsimo's repository and d8ahazard's [Drebooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_drebooth_extension) are currently incompatible. Because we are doing some enhancements (see below). + +When generating images with WebUI, etc., merge the learned LoRA model with the learning source Stable Diffusion model in advance with the script in this repository, or click here [Extention for WebUI] (https://github .com/kohya-ss/sd-webui-additional-networks). + +## Learning method + +Use train_network.py. + +You can learn both the DreamBooth method (using identifiers (sks, etc.) and classes, optionally regularized images) and the fine tuning method using captions. + +Both methods can be learned in much the same way as existing scripts. We will discuss the differences later. + +### Using the DreamBooth Method + +Please refer to [DreamBooth guide](./train_db_README-en.md) and prepare the data. + +Specify train_network.py instead of train_db.py when training. + +Almost all options are available (except Stable Diffusion model save related), but stop_text_encoder_training is not supported. + +### When to use captions + +Please refer to [fine-tuning guide](./fine_tune_README_en.md) and perform each step. + +Specify train_network.py instead of fine_tune.py when training. Almost all options (except for model saving) can be used as is. + +In addition, it will work even if you do not perform "Pre-obtain latents". Since the latent is acquired from the VAE when learning (or caching), the learning speed will be slower, but color_aug can be used instead. + +### Options for Learning LoRA + +In train_network.py, specify the name of the module to be trained in the --network_module option. LoRA is compatible with network.lora, so please specify it. + +The learning rate should be set to about 1e-4, which is higher than normal DreamBooth and fine tuning. + +Below is an example command line (DreamBooth technique). + +``` +accelerate launch --num_cpu_threads_per_process 12 train_network.py + --pretrained_model_name_or_path=..\models\model.ckpt + --train_data_dir=..\data\db\char1 --output_dir=..\lora_train1 + --reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0 + --resolution=448,640 --train_batch_size=1 --learning_rate=1e-4 + --max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16 + --save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug + --network_module=networks.lora +``` + +The LoRA model will be saved in the directory specified by the --output_dir option. + +In addition, the following options can be specified. + +* --network_dim + * Specify the number of dimensions of LoRA (such as ``--networkdim=4``). Default is 4. The greater the number, the greater the expressive power, but the memory and time required for learning also increase. In addition, it seems that it is not good to increase it blindly. +* --network_weights + * Load pretrained LoRA weights before training and additionally learn from them. +* --network_train_unet_only + * Valid only for LoRA modules related to U-Net. It may be better to specify it in fine-tuning study. +* --network_train_text_encoder_only + * Only LoRA modules related to Text Encoder are enabled. You may be able to expect a textual inversion effect. +* --unet_lr + * Specify when using a learning rate different from the normal learning rate (specified with the --learning_rate option) for the LoRA module related to U-Net. +* --text_encoder_lr + * Specify when using a learning rate different from the normal learning rate (specified with the --learning_rate option) for the LoRA module associated with the Text Encoder. Some people say that it is better to set the Text Encoder to a slightly lower learning rate (such as 5e-5). + +When neither --network_train_unet_only nor --network_train_text_encoder_only is specified (default), both Text Encoder and U-Net LoRA modules are enabled. + +## About the merge script + +merge_lora.py allows you to merge LoRA training results into a Stable Diffusion model, or merge multiple LoRA models. + +### Merge LoRA model into Stable Diffusion model + +The model after merging can be handled in the same way as normal Stable Diffusion ckpt. For example, a command line like: + +``` +python networks\merge_lora.py --sd_model ..\model\model.ckpt + --save_to ..\lora_train1\model-char1-merged.safetensors + --models ..\lora_train1\last.safetensors --ratios 0.8 +``` + +Specify the --v2 option if you want to train with a Stable Diffusion v2.x model and merge with it. + +Specify the Stable Diffusion model file to be merged in the --sd_model option (only .ckpt or .safetensors are supported, Diffusers is not currently supported). + +Specify the save destination of the model after merging in the --save_to option (.ckpt or .safetensors, automatically determined by extension). + +Specify the LoRA model file learned in --models. It is possible to specify more than one, in which case they will be merged in order. + +For --ratios, specify the application rate of each model (how much weight is reflected in the original model) with a numerical value from 0 to 1.0. For example, if it is close to overfitting, it may be better if the application rate is lowered. Specify as many as the number of models. + +When specifying multiple, it will be as follows. + +``` +python networks\merge_lora.py --sd_model ..\model\model.ckpt + --save_to ..\lora_train1\model-char1-merged.safetensors + --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5 +``` + +### Merge multiple LoRA models + +Applying multiple LoRA models one by one to the SD model and merging multiple LoRA models and then merging them into the SD model yield slightly different results in relation to the calculation order. + +For example, a command line like: + +``` +python networks\merge_lora.py + --save_to ..\lora_train1\model-char1-style1-merged.safetensors + --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4 +``` + +The --sd_model option does not need to be specified. + +Specify the save destination of the merged LoRA model in the --save_to option (.ckpt or .safetensors, automatically determined by extension). + +Specify the LoRA model file learned in --models. Three or more can be specified. + +For --ratios, specify the ratio of each model (how much weight is reflected in the original model) with a numerical value from 0 to 1.0. If you merge two models one-to-one, it will be "0.5 0.5". "1.0 1.0" would give too much weight to the sum, and the result would probably be less desirable. + +LoRA trained with v1 and LoRA trained with v2, and LoRA with different number of dimensions cannot be merged. U-Net only LoRA and U-Net+Text Encoder LoRA should be able to merge, but the result is unknown. + + +### Other Options + +* precision + * The precision for merge calculation can be specified from float, fp16, and bf16. If omitted, it will be float to ensure accuracy. Specify fp16/bf16 if you want to reduce memory usage. +* save_precision + * You can specify the precision when saving the model from float, fp16, bf16. If omitted, the precision is the same as precision. + +## Generate with the image generation script in this repository + +Add options --network_module, --network_weights, --network_dim (optional) to gen_img_diffusers.py. The meaning is the same as when learning. + +You can change the LoRA application rate by specifying a value between 0 and 1.0 with the --network_mul option. + +## Create a LoRA model from the difference between two models + +It was implemented with reference to [this discussion](https://github.com/cloneofsimo/lora/discussions/56). I used the formula as it is (I don't understand it well, but it seems that singular value decomposition is used for approximation). + +LoRA approximates the difference between two models (for example, the original model after fine tuning and the model after fine tuning). + +### How to run scripts + +Please specify as follows. +``` +python networks\extract_lora_from_models.py --model_org base-model.ckpt + --model_tuned fine-tuned-model.ckpt + --save_to lora-weights.safetensors --dim 4 +``` + +Specify the original Stable Diffusion model for the --model_org option. When applying the created LoRA model, this model will be specified and applied. .ckpt or .safetensors can be specified. + +Specify the Stable Diffusion model to extract the difference in the --model_tuned option. For example, specify a model after fine tuning or DreamBooth. .ckpt or .safetensors can be specified. + +Specify the save destination of the LoRA model in --save_to. Specify the number of dimensions of LoRA in --dim. + +A generated LoRA model can be used in the same way as a trained LoRA model. + +If the Text Encoder is the same for both models, LoRA will be U-Net only LoRA. + +### Other Options + +--v2 + - Please specify when using the v2.x Stable Diffusion model. +--device + - If cuda is specified as ``--device cuda``, the calculation will be performed on the GPU. Processing will be faster (because even the CPU is not that slow, it seems to be at most twice or several times faster). +--save_precision + - Specify the LoRA save format from "float", "fp16", "bf16". Default is float. + +## Additional Information + +### Differences from cloneofsimo's repository + +As of 12/25, this repository has expanded LoRA application points to Text Encoder's MLP, U-Net's FFN, and Transformer's in/out projection, increasing its expressiveness. However, the amount of memory used increased instead, and it became the last minute of 8GB. + +Also, the module replacement mechanism is completely different. + +### About Future Expansion + +It is possible to support not only LoRA but also other expansions, so we plan to add them as well. \ No newline at end of file diff --git a/train_textual_inversion.py b/train_textual_inversion.py new file mode 100644 index 0000000000000000000000000000000000000000..f279370a9635891b6b5c4dcc82ecaa1bd6431c50 --- /dev/null +++ b/train_textual_inversion.py @@ -0,0 +1,590 @@ +import importlib +import argparse +import gc +import math +import os +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +def train(args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template + + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに察応した型を甚意しおおき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み蟌む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + + # Convert the init_word to token_id + if args.init_word is not None: + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトヌクン長がnum_vectors_per_tokenず合わないため、繰り返したたは切り捚おが発生したす: length {len(init_token_ids)}" + ) + else: + init_token_ids = None + + # add new word to tokenizer, count is num_vectors_per_token + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存圚したす。別の単語を䜿っおください: {args.token_string}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + + # load weights + if args.weights is not None: + embeddings = load_weights(args.weights) + assert len(token_ids) == len( + embeddings + ), f"num_vectors_per_token is mismatch for weights / 指定した重みずnum_vectors_per_tokenの倀が異なりたす: {len(embeddings)}" + # print(token_ids, embeddings.size()) + for token_id, embedding in zip(token_ids, embeddings): + token_embeds[token_id] = embedding + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + print(f"weighs loaded") + + print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # デヌタセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 蚭定ファむルが利甚されるため以䞋のオプションは無芖されたす: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value('i',0) + current_step = Value('i',0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn ずいう文字列に曞き換える超乱暎な実装 + if use_template: + print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) + + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + print("No data found. Please verify arguments / 画像がありたせん。匕数指定を確認しおください") + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするずきはcolor_augずrandom_cropは䜿えたせん" + + # モデルに xformers ずか memory efficient attention を組み蟌む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 孊習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + # 孊習に必芁なクラスを準備する + print("prepare optimizer, data loader etc.") + trainable_params = text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数0はメむンプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最倧で指定された数たで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 孊習ステップ数を蚈算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定゚ポックたでのステップ数: {args.max_train_steps}") + + # デヌタセット偎にも孊習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを甚意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # acceleratorがなんかよろしくやっおくれるらしい + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + # print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + else: + unet.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実隓的機胜募配も含めたfp16孊習を行う PyTorchにパッチを圓おおfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + text_encoder.to(weight_dtype) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を蚈算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 孊習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / å­Šç¿’é–‹å§‹") + print(f" num train images * repeats / 孊習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサむズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサむズ䞊列孊習、募配合蚈含む: {total_batch_size}") + print(f" gradient ccumulation steps / 募配を合蚈するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 孊習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch+1 + + text_encoder.train() + + loss_total = 0 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(text_encoder): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに倉換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + # weight_dtype) use float instead of fp16/bf16 because text encoder is float + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss_weights = batch["loss_weights"] # 各sampleごずのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必芁なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + index_no_updates + ] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + # end of epoch + + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() + + del accelerator # この埌メモリを䜿うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + print("model saved.") + + +def save_weights(file, updated_embs, save_dtype): + state_dict = {"emb_params": updated_embs} + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI + + +def load_weights(file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + # compatible to Web UI's file format + data = torch.load(file, map_location="cpu") + if type(data) != dict: + raise ValueError(f"weight file is not dict / 重みファむルがdict圢匏ではありたせん: {file}") + + if "string_to_param" in data: # textual inversion embeddings + data = data["string_to_param"] + if hasattr(data, "_parameters"): # support old PyTorch? + data = getattr(data, "_parameters") + + emb = next(iter(data.values())) + if type(emb) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファむルのデヌタがTensorではありたせん: {file}") + + if len(emb.size()) == 1: + emb = emb.unsqueeze(0) + + return emb + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="pt", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の圢匏デフォルトはpt", + ) + + parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 孊習するネットワヌクの初期重み") + parser.add_argument( + "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トヌクンに割り圓おるembeddingsの芁玠数" + ) + parser.add_argument( + "--token_string", + type=str, + default=None, + help="token string used in training, must not exist in tokenizer / 孊習時に䜿甚されるトヌクン文字列、tokenizerに存圚しない文字であるこず", + ) + parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に䜿甚する単語、耇数可") + parser.add_argument( + "--use_object_template", + action="store_true", + help="ignore caption and use default templates for object / キャプションは䜿わずデフォルトの物䜓甚テンプレヌトで孊習する", + ) + parser.add_argument( + "--use_style_template", + action="store_true", + help="ignore caption and use default templates for stype / キャプションは䜿わずデフォルトのスタむル甚テンプレヌトで孊習する", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py new file mode 100644 index 0000000000000000000000000000000000000000..74e9bc2e320ebe13385323aa2994dcc2166a317a --- /dev/null +++ b/train_textual_inversion_XTI.py @@ -0,0 +1,644 @@ +import importlib +import argparse +import gc +import math +import os +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight +from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +def train(args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template + + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: + print( + "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsずsample_every_n_epochsは珟圚このスクリプトではサポヌトされおいたせん" + ) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに察応した型を甚意しおおき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み蟌む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + + # Convert the init_word to token_id + if args.init_word is not None: + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトヌクン長がnum_vectors_per_tokenず合わないため、繰り返したたは切り捚おが発生したす: length {len(init_token_ids)}" + ) + else: + init_token_ids = None + + # add new word to tokenizer, count is num_vectors_per_token + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存圚したす。別の単語を䜿っおください: {args.token_string}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + token_strings_XTI = [] + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] + for layer_name in XTI_layers: + token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] + + tokenizer.add_tokens(token_strings_XTI) + token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) + print(f"tokens are added (XTI): {token_ids_XTI}") + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids_XTI): + token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + + # load weights + if args.weights is not None: + embeddings = load_weights(args.weights) + assert len(token_ids) == len( + embeddings + ), f"num_vectors_per_token is mismatch for weights / 指定した重みずnum_vectors_per_tokenの倀が異なりたす: {len(embeddings)}" + # print(token_ids, embeddings.size()) + for token_id, embedding in zip(token_ids_XTI, embeddings): + token_embeds[token_id] = embedding + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + print(f"weighs loaded") + + print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # デヌタセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 蚭定ファむルが利甚されるため以䞋のオプションは無芖されたす: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn ずいう文字列に曞き換える超乱暎な実装 + if use_template: + print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) + + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + print("No data found. Please verify arguments / 画像がありたせん。匕数指定を確認しおください") + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするずきはcolor_augずrandom_cropは䜿えたせん" + + # モデルに xformers ずか memory efficient attention を組み蟌む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + + # 孊習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + # 孊習に必芁なクラスを準備する + print("prepare optimizer, data loader etc.") + trainable_params = text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数0はメむンプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最倧で指定された数たで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 孊習ステップ数を蚈算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定゚ポックたでのステップ数: {args.max_train_steps}") + + # デヌタセット偎にも孊習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを甚意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # acceleratorがなんかよろしくやっおくれるらしい + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] + # print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + else: + unet.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実隓的機胜募配も含めたfp16孊習を行う PyTorchにパッチを圓おおfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + text_encoder.to(weight_dtype) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を蚈算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 孊習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / å­Šç¿’é–‹å§‹") + print(f" num train images * repeats / 孊習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサむズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサむズ䞊列孊習、募配合蚈含む: {total_batch_size}") + print(f" gradient ccumulation steps / 募配を合蚈するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 孊習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + text_encoder.train() + + loss_total = 0 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(text_encoder): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに倉換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + # weight_dtype) use float instead of fp16/bf16 because text encoder is float + encoder_hidden_states = torch.stack( + [ + train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) + for s in torch.split(input_ids, 1, dim=1) + ] + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss_weights = batch["loss_weights"] # 各sampleごずのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必芁なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + index_no_updates + ] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) + + # end of epoch + + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + + del accelerator # この埌メモリを䜿うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + print("model saved.") + + +def save_weights(file, updated_embs, save_dtype): + updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1]) + updated_embs = updated_embs.chunk(16) + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] + state_dict = {} + for i, layer_name in enumerate(XTI_layers): + state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype) + + # if save_dtype is not None: + # for key in list(state_dict.keys()): + # v = state_dict[key] + # v = v.detach().clone().to("cpu").to(save_dtype) + # state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI + + +def load_weights(file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + raise ValueError(f"NOT XTI: {file}") + + if len(data.values()) != 16: + raise ValueError(f"NOT XTI: {file}") + + emb = torch.concat([x for x in data.values()]) + + return emb + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="pt", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の圢匏デフォルトはpt", + ) + + parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 孊習するネットワヌクの初期重み") + parser.add_argument( + "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トヌクンに割り圓おるembeddingsの芁玠数" + ) + parser.add_argument( + "--token_string", + type=str, + default=None, + help="token string used in training, must not exist in tokenizer / 孊習時に䜿甚されるトヌクン文字列、tokenizerに存圚しない文字であるこず", + ) + parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に䜿甚する単語、耇数可") + parser.add_argument( + "--use_object_template", + action="store_true", + help="ignore caption and use default templates for object / キャプションは䜿わずデフォルトの物䜓甚テンプレヌトで孊習する", + ) + parser.add_argument( + "--use_style_template", + action="store_true", + help="ignore caption and use default templates for stype / キャプションは䜿わずデフォルトのスタむル甚テンプレヌトで孊習する", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_ti_README-ja.md b/train_ti_README-ja.md new file mode 100644 index 0000000000000000000000000000000000000000..908736961202a4bbf14f92be9efd27193541c186 --- /dev/null +++ b/train_ti_README-ja.md @@ -0,0 +1,105 @@ +[Textual Inversion](https://textual-inversion.github.io/) の孊習に぀いおの説明です。 + +[孊習に぀いおの共通ドキュメント](./train_README-ja.md) もあわせおご芧ください。 + +実装に圓たっおは https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を倧いに参考にしたした。 + +孊習したモデルはWeb UIでもそのたた䜿えたす。なお恐らくSD2.xにも察応しおいたすが珟時点では未テストです。 + +# 孊習の手順 + +あらかじめこのリポゞトリのREADMEを参照し、環境敎備を行っおください。 + +## デヌタの準備 + +[孊習デヌタの準備に぀いお](./train_README-ja.md) を参照しおください。 + +## 孊習の実行 + +``train_textual_inversion.py`` を甚いたす。以䞋はコマンドラむンの䟋ですDreamBooth手法。 + +``` +accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py + --dataset_config=<デヌタ準備で䜜成した.tomlファむル> + --output_dir=<孊習したモデルの出力先フォルダ> + --output_name=<孊習したモデル出力時のファむル名> + --save_model_as=safetensors + --prior_loss_weight=1.0 + --max_train_steps=1600 + --learning_rate=1e-6 + --optimizer_type="AdamW8bit" + --xformers + --mixed_precision="fp16" + --cache_latents + --gradient_checkpointing + --token_string=mychar4 --init_word=cute --num_vectors_per_token=4 +``` + +``--token_string`` に孊習時のトヌクン文字列を指定したす。__孊習時のプロンプトは、この文字列を含むようにしおくださいtoken_stringがmychar4なら、``mychar4 1girl`` など__。プロンプトのこの文字列の郚分が、Textual Inversionの新しいtokenに眮換されお孊習されたす。DreamBooth, class+identifier圢匏のデヌタセットずしお、`token_string` をトヌクン文字列にするのが最も簡単で確実です。 + +プロンプトにトヌクン文字列が含たれおいるかどうかは、``--debug_dataset`` で眮換埌のtoken idが衚瀺されたすので、以䞋のように ``49408`` 以降のtokenが存圚するかどうかで確認できたす。 + +``` +input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407]]) +``` + +tokenizerがすでに持っおいる単語䞀般的な単語は䜿甚できたせん。 + +``--init_word`` にembeddingsを初期化するずきのコピヌ元トヌクンの文字列を指定したす。孊ばせたい抂念が近いものを遞ぶずよいようです。二぀以䞊のトヌクンになる文字列は指定できたせん。 + +``--num_vectors_per_token`` にいく぀のトヌクンをこの孊習で䜿うかを指定したす。倚いほうが衚珟力が増したすが、その分倚くのトヌクンを消費したす。たずえばnum_vectors_per_token=8の堎合、指定したトヌクン文字列は䞀般的なプロンプトの77トヌクン制限のうち8トヌクンを消費したす。 + +以䞊がTextual Inversionのための䞻なオプションです。以降は他の孊習スクリプトず同様です。 + +`num_cpu_threads_per_process` には通垞は1を指定するずよいようです。 + +`pretrained_model_name_or_path` に远加孊習を行う元ずなるモデルを指定したす。Stable Diffusionのcheckpointファむル.ckptたたは.safetensors、Diffusersのロヌカルディスクにあるモデルディレクトリ、DiffusersのモデルID"stabilityai/stable-diffusion-2"などが指定できたす。 + +`output_dir` に孊習埌のモデルを保存するフォルダを指定したす。`output_name` にモデルのファむル名を拡匵子を陀いお指定したす。`save_model_as` でsafetensors圢匏での保存を指定しおいたす。 + +`dataset_config` に `.toml` ファむルを指定したす。ファむル内でのバッチサむズ指定は、圓初はメモリ消費を抑えるために `1` ずしおください。 + +孊習させるステップ数 `max_train_steps` を10000ずしたす。孊習率 `learning_rate` はここでは5e-6を指定しおいたす。 + +省メモリ化のため `mixed_precision="fp16"` を指定したすRTX30 シリヌズ以降では `bf16` も指定できたす。環境敎備時にaccelerateに行った蚭定ず合わせおください。たた `gradient_checkpointing` を指定したす。 + +オプティマむザモデルを孊習デヌタにあうように最適化孊習させるクラスにメモリ消費の少ない 8bit AdamW を䜿うため、 `optimizer_type="AdamW8bit"` を指定したす。 + +`xformers` オプションを指定し、xformersのCrossAttentionを甚いたす。xformersをむンストヌルしおいない堎合や゚ラヌずなる堎合環境にもよりたすが `mixed_precision="no"` の堎合など、代わりに `mem_eff_attn` オプションを指定するず省メモリ版CrossAttentionを䜿甚したす速床は遅くなりたす。 + +ある皋床メモリがある堎合は、`.toml` ファむルを線集しおバッチサむズをたずえば `8` くらいに増やしおください高速化ず粟床向䞊の可胜性がありたす。 + +### よく䜿われるオプションに぀いお + +以䞋の堎合にはオプションに関するドキュメントを参照しおください。 + +- Stable Diffusion 2.xたたはそこからの掟生モデルを孊習する +- clip skipを2以䞊を前提ずしたモデルを孊習する +- 75トヌクンを超えたキャプションで孊習する + +### Textual Inversionでのバッチサむズに぀いお + +モデル党䜓を孊習するDreamBoothやfine tuningに比べおメモリ䜿甚量が少ないため、バッチサむズは倧きめにできたす。 + +# Textual Inversionのその他の䞻なオプション + +すべおのオプションに぀いおは別文曞を参照しおください。 + +* `--weights` + * 孊習前に孊習枈みのembeddingsを読み蟌み、そこから远加で孊習したす。 +* `--use_object_template` + * キャプションではなく既定の物䜓甚テンプレヌト文字列``a photo of a {}``などで孊習したす。公匏実装ず同じになりたす。キャプションは無芖されたす。 +* `--use_style_template` + * キャプションではなく既定のスタむル甚テンプレヌト文字列で孊習したす``a painting in the style of {}``など。公匏実装ず同じになりたす。キャプションは無芖されたす。 + +## 圓リポゞトリ内の画像生成スクリプトで生成する + +gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで孊習したembeddingsファむルを指定しおください耇数可。プロンプトでembeddingsファむルのファむル名拡匵子を陀くを䜿うず、そのembeddingsが適甚されたす。 + diff --git a/train_ti_README.md b/train_ti_README.md new file mode 100644 index 0000000000000000000000000000000000000000..ba03d555870fa1a4093a1d105cf01c3febfec6b9 --- /dev/null +++ b/train_ti_README.md @@ -0,0 +1,62 @@ +## About learning Textual Inversion + +[Textual Inversion](https://textual-inversion.github.io/). I heavily referenced https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion for the implementation. + +The trained model can be used as is on the Web UI. + +In addition, it is probably compatible with SD2.x, but it has not been tested at this time. + +## Learning method + +Use ``train_textual_inversion.py``. + +Data preparation is exactly the same as ``train_network.py``, so please refer to [their document](./train_network_README-en.md). + +## options + +Below is an example command line (DreamBooth technique). + +``` +accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py + --pretrained_model_name_or_path=..\models\model.ckpt + --train_data_dir=..\data\db\char1 --output_dir=..\ti_train1 + --resolution=448,640 --train_batch_size=1 --learning_rate=1e-4 + --max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16 + --save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug + --token_string=mychar4 --init_word=cute --num_vectors_per_token=4 +``` + +``--token_string`` specifies the token string for learning. __The learning prompt should contain this string (eg ``mychar4 1girl`` if token_string is mychar4)__. This string part of the prompt is replaced with a new token for Textual Inversion and learned. + +``--debug_dataset`` will display the token id after substitution, so you can check if the token string after ``49408`` exists as shown below. I can confirm. + +``` +input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, + 49407, 49407, 49407, 49407, 49407, 49407, 49407]]) +``` + +Words that the tokenizer already has (common words) cannot be used. + +In ``--init_word``, specify the string of the copy source token when initializing embeddings. It seems to be a good idea to choose something that has a similar concept to what you want to learn. You cannot specify a character string that becomes two or more tokens. + +``--num_vectors_per_token`` specifies how many tokens to use for this training. The higher the number, the more expressive it is, but it consumes more tokens. For example, if num_vectors_per_token=8, then the specified token string will consume 8 tokens (out of the 77 token limit for a typical prompt). + + +In addition, the following options can be specified. + +* --weights + * Load learned embeddings before learning and learn additionally from there. +* --use_object_template + * Learn with default object template strings (such as ``a photo of a {}``) instead of captions. It will be the same as the official implementation. Captions are ignored. +* --use_style_template + * Learn with default style template strings instead of captions (such as ``a painting in the style of {}``). It will be the same as the official implementation. Captions are ignored. + +## Generate with the image generation script in this repository + +In gen_img_diffusers.py, specify the learned embeddings file with the ``--textual_inversion_embeddings`` option. Using the filename (without the extension) of the embeddings file at the prompt will apply the embeddings. \ No newline at end of file diff --git a/v2_inference/v2-inference-v.yaml b/v2_inference/v2-inference-v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ec8dfbfefe94ae8522c93017668fea78d580acf --- /dev/null +++ b/v2_inference/v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/v2_inference/v2-inference.yaml b/v2_inference/v2-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..152c4f3c2b36c3b246a9cb10eb8166134b0d2e1c --- /dev/null +++ b/v2_inference/v2-inference.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/v2_inference/v2-inpainting-inference.yaml b/v2_inference/v2-inpainting-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32a9471d71b828c51bcbbabfe34c5f6c8282c803 --- /dev/null +++ b/v2_inference/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/v2_inference/v2-midas-inference.yaml b/v2_inference/v2-midas-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f20c30f618b81091e31c2c4cf15325fa38638af4 --- /dev/null +++ b/v2_inference/v2-midas-inference.yaml @@ -0,0 +1,74 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + target: ldm.modules.midas.api.MiDaSInference + params: + model_type: "dpt_hybrid" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + diff --git a/v2_inference/x4-upscaling.yaml b/v2_inference/x4-upscaling.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2db0964af699f86d1891c761710a2d53f59b842c --- /dev/null +++ b/v2_inference/x4-upscaling.yaml @@ -0,0 +1,76 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + parameterization: "v" + low_scale_key: "lr" + linear_start: 0.0001 + linear_end: 0.02 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 128 + channels: 4 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.08333 + use_ema: False + + low_scale_config: + target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation + params: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" +