Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +1 -0
- .github/workflows/typos.yaml +21 -0
- .gitignore +11 -0
- .gradio/certificate.pem +31 -0
- LICENSE.md +201 -0
- README.md +17 -8
- XTI_hijack.py +209 -0
- _typos.toml +15 -0
- cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 +3 -0
- config_README-ja.md +279 -0
- config_files/accelerate/default_config.yaml +22 -0
- dreambooth_gui.py +944 -0
- fine_tune.py +430 -0
- fine_tune_README.md +465 -0
- fine_tune_README_ja.md +140 -0
- finetune/blip/blip.py +240 -0
- finetune/blip/med.py +955 -0
- finetune/blip/med_config.json +22 -0
- finetune/blip/vit.py +305 -0
- finetune/clean_captions_and_tags.py +190 -0
- finetune/hypernetwork_nai.py +96 -0
- finetune/make_captions.py +168 -0
- finetune/make_captions_by_git.py +151 -0
- finetune/merge_captions_to_metadata.py +76 -0
- finetune/merge_dd_tags_to_metadata.py +71 -0
- finetune/prepare_buckets_latents.py +267 -0
- finetune/tag_images_by_wd14_tagger.py +206 -0
- finetune_gui.py +888 -0
- gen_img_diffusers.py +0 -0
- gui.sh +9 -0
- kohya_gui.py +110 -0
- kohya_ss_colab.ipynb +448 -0
- library/__init__.py +0 -0
- library/basic_caption_gui.py +140 -0
- library/blip_caption_gui.py +149 -0
- library/common_gui.py +978 -0
- library/config_util.py +536 -0
- library/convert_model_gui.py +247 -0
- library/custom_train_functions.py +18 -0
- library/dataset_balancing_gui.py +146 -0
- library/dreambooth_folder_creation_gui.py +210 -0
- library/extract_lora_gui.py +178 -0
- library/extract_lycoris_locon_gui.py +309 -0
- library/git_caption_gui.py +136 -0
- library/lpw_stable_diffusion.py +1179 -0
- library/merge_lora_gui.py +156 -0
- library/model_util.py +1165 -0
- library/resize_lora_gui.py +173 -0
- library/sampler_gui.py +102 -0
- library/svd_merge_lora_gui.py +190 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3 filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .github/workflows/typos.yaml
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            # yamllint disable rule:line-length
         | 
| 3 | 
            +
            name: Typos
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            on:  # yamllint disable-line rule:truthy
         | 
| 6 | 
            +
              push:
         | 
| 7 | 
            +
              pull_request:
         | 
| 8 | 
            +
                types:
         | 
| 9 | 
            +
                  - opened
         | 
| 10 | 
            +
                  - synchronize
         | 
| 11 | 
            +
                  - reopened
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            jobs:
         | 
| 14 | 
            +
              build:
         | 
| 15 | 
            +
                runs-on: ubuntu-latest
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                steps:
         | 
| 18 | 
            +
                  - uses: actions/checkout@v3
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                  - name: typos-action
         | 
| 21 | 
            +
                    uses: crate-ci/typos@v1.13.10
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            venv
         | 
| 2 | 
            +
            __pycache__
         | 
| 3 | 
            +
            cudnn_windows
         | 
| 4 | 
            +
            .vscode
         | 
| 5 | 
            +
            *.egg-info
         | 
| 6 | 
            +
            build
         | 
| 7 | 
            +
            wd14_tagger_model
         | 
| 8 | 
            +
            .DS_Store
         | 
| 9 | 
            +
            locon
         | 
| 10 | 
            +
            gui-user.bat
         | 
| 11 | 
            +
            gui-user.ps1
         | 
    	
        .gradio/certificate.pem
    ADDED
    
    | @@ -0,0 +1,31 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            -----BEGIN CERTIFICATE-----
         | 
| 2 | 
            +
            MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
         | 
| 3 | 
            +
            TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
         | 
| 4 | 
            +
            cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
         | 
| 5 | 
            +
            WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
         | 
| 6 | 
            +
            ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
         | 
| 7 | 
            +
            MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
         | 
| 8 | 
            +
            h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
         | 
| 9 | 
            +
            0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
         | 
| 10 | 
            +
            A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
         | 
| 11 | 
            +
            T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
         | 
| 12 | 
            +
            B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
         | 
| 13 | 
            +
            B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
         | 
| 14 | 
            +
            KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
         | 
| 15 | 
            +
            OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
         | 
| 16 | 
            +
            jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
         | 
| 17 | 
            +
            qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
         | 
| 18 | 
            +
            rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
         | 
| 19 | 
            +
            HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
         | 
| 20 | 
            +
            hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
         | 
| 21 | 
            +
            ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
         | 
| 22 | 
            +
            3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
         | 
| 23 | 
            +
            NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
         | 
| 24 | 
            +
            ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
         | 
| 25 | 
            +
            TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
         | 
| 26 | 
            +
            jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
         | 
| 27 | 
            +
            oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
         | 
| 28 | 
            +
            4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
         | 
| 29 | 
            +
            mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
         | 
| 30 | 
            +
            emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
         | 
| 31 | 
            +
            -----END CERTIFICATE-----
         | 
    	
        LICENSE.md
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
                                             Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright [2022] [kohya-ss]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,21 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
             | 
| 4 | 
            -
            colorFrom: indigo
         | 
| 5 | 
            -
            colorTo: gray
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 5. | 
| 8 | 
            -
            app_file: app.py
         | 
| 9 | 
            -
            pinned: false
         | 
| 10 | 
             
            ---
         | 
|  | |
| 11 |  | 
| 12 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: kohya_ss_colab
         | 
| 3 | 
            +
            app_file: dreambooth_gui.py
         | 
|  | |
|  | |
| 4 | 
             
            sdk: gradio
         | 
| 5 | 
            +
            sdk_version: 5.47.2
         | 
|  | |
|  | |
| 6 | 
             
            ---
         | 
| 7 | 
            +
            [](https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb)
         | 
| 8 |  | 
| 9 | 
            +
            # Kohya SS WebUI Colab Setup
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            This Colab workbook sets up a Kohya SS instance on Colab and provides a link to access the Kohya WebUI on Gradio Live. Kohya SS is a Python library that provides Stable Diffusion-based models for image, text, and audio generation tasks. This Colab workbook provides a convenient way for users to run Kohya SS without needing to install anything on their local machine.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            This workbook was inspired by the work of [Spaceginner](https://github.com/Spaceginner)'s original Colab workbook and the [Kohya SS project](https://github.com/bmaltais/kohya_ss) by [bmaltais](https://github.com/bmaltais). The Colab workbook was coded by [panguin6010](https://github.com/panguin6010) 
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            ## Tutorials
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            Before running this code, make sure you are familiar with using Colab workbooks and have a basic understanding of Kohya SS and its usage. You can find tutorials for these online. If you encounter any issues or have suggestions for improvement, feel free to contribute to the project.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ## Link
         | 
| 21 | 
            +
            ```https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb```
         | 
    	
        XTI_hijack.py
    ADDED
    
    | @@ -0,0 +1,209 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from typing import Union, List, Optional, Dict, Any, Tuple
         | 
| 3 | 
            +
            from diffusers.models.unet_2d_condition import UNet2DConditionOutput
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def unet_forward_XTI(self,
         | 
| 6 | 
            +
                    sample: torch.FloatTensor,
         | 
| 7 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 8 | 
            +
                    encoder_hidden_states: torch.Tensor,
         | 
| 9 | 
            +
                    class_labels: Optional[torch.Tensor] = None,
         | 
| 10 | 
            +
                    return_dict: bool = True,
         | 
| 11 | 
            +
                ) -> Union[UNet2DConditionOutput, Tuple]:
         | 
| 12 | 
            +
                    r"""
         | 
| 13 | 
            +
                    Args:
         | 
| 14 | 
            +
                        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
         | 
| 15 | 
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
         | 
| 16 | 
            +
                        encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
         | 
| 17 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 18 | 
            +
                            Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    Returns:
         | 
| 21 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
         | 
| 22 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
         | 
| 23 | 
            +
                        returning a tuple, the first element is the sample tensor.
         | 
| 24 | 
            +
                    """
         | 
| 25 | 
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         | 
| 26 | 
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
         | 
| 27 | 
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         | 
| 28 | 
            +
                    # on the fly if necessary.
         | 
| 29 | 
            +
                    default_overall_up_factor = 2**self.num_upsamplers
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         | 
| 32 | 
            +
                    forward_upsample_size = False
         | 
| 33 | 
            +
                    upsample_size = None
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         | 
| 36 | 
            +
                        logger.info("Forward upsample size to force interpolation output size.")
         | 
| 37 | 
            +
                        forward_upsample_size = True
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    # 0. center input if necessary
         | 
| 40 | 
            +
                    if self.config.center_input_sample:
         | 
| 41 | 
            +
                        sample = 2 * sample - 1.0
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # 1. time
         | 
| 44 | 
            +
                    timesteps = timestep
         | 
| 45 | 
            +
                    if not torch.is_tensor(timesteps):
         | 
| 46 | 
            +
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         | 
| 47 | 
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         | 
| 48 | 
            +
                        is_mps = sample.device.type == "mps"
         | 
| 49 | 
            +
                        if isinstance(timestep, float):
         | 
| 50 | 
            +
                            dtype = torch.float32 if is_mps else torch.float64
         | 
| 51 | 
            +
                        else:
         | 
| 52 | 
            +
                            dtype = torch.int32 if is_mps else torch.int64
         | 
| 53 | 
            +
                        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
         | 
| 54 | 
            +
                    elif len(timesteps.shape) == 0:
         | 
| 55 | 
            +
                        timesteps = timesteps[None].to(sample.device)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 58 | 
            +
                    timesteps = timesteps.expand(sample.shape[0])
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    t_emb = self.time_proj(timesteps)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         | 
| 63 | 
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         | 
| 64 | 
            +
                    # there might be better ways to encapsulate this.
         | 
| 65 | 
            +
                    t_emb = t_emb.to(dtype=self.dtype)
         | 
| 66 | 
            +
                    emb = self.time_embedding(t_emb)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    if self.config.num_class_embeds is not None:
         | 
| 69 | 
            +
                        if class_labels is None:
         | 
| 70 | 
            +
                            raise ValueError("class_labels should be provided when num_class_embeds > 0")
         | 
| 71 | 
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
         | 
| 72 | 
            +
                        emb = emb + class_emb
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # 2. pre-process
         | 
| 75 | 
            +
                    sample = self.conv_in(sample)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    # 3. down
         | 
| 78 | 
            +
                    down_block_res_samples = (sample,)
         | 
| 79 | 
            +
                    down_i = 0
         | 
| 80 | 
            +
                    for downsample_block in self.down_blocks:
         | 
| 81 | 
            +
                        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
         | 
| 82 | 
            +
                            sample, res_samples = downsample_block(
         | 
| 83 | 
            +
                                hidden_states=sample,
         | 
| 84 | 
            +
                                temb=emb,
         | 
| 85 | 
            +
                                encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
         | 
| 86 | 
            +
                            )
         | 
| 87 | 
            +
                            down_i += 2
         | 
| 88 | 
            +
                        else:
         | 
| 89 | 
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        down_block_res_samples += res_samples
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # 4. mid
         | 
| 94 | 
            +
                    sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    # 5. up
         | 
| 97 | 
            +
                    up_i = 7
         | 
| 98 | 
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         | 
| 99 | 
            +
                        is_final_block = i == len(self.up_blocks) - 1
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         | 
| 102 | 
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        # if we have not reached the final block and need to forward the
         | 
| 105 | 
            +
                        # upsample size, we do it here
         | 
| 106 | 
            +
                        if not is_final_block and forward_upsample_size:
         | 
| 107 | 
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
         | 
| 110 | 
            +
                            sample = upsample_block(
         | 
| 111 | 
            +
                                hidden_states=sample,
         | 
| 112 | 
            +
                                temb=emb,
         | 
| 113 | 
            +
                                res_hidden_states_tuple=res_samples,
         | 
| 114 | 
            +
                                encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
         | 
| 115 | 
            +
                                upsample_size=upsample_size,
         | 
| 116 | 
            +
                            )
         | 
| 117 | 
            +
                            up_i += 3
         | 
| 118 | 
            +
                        else:
         | 
| 119 | 
            +
                            sample = upsample_block(
         | 
| 120 | 
            +
                                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
         | 
| 121 | 
            +
                            )
         | 
| 122 | 
            +
                    # 6. post-process
         | 
| 123 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 124 | 
            +
                    sample = self.conv_act(sample)
         | 
| 125 | 
            +
                    sample = self.conv_out(sample)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    if not return_dict:
         | 
| 128 | 
            +
                        return (sample,)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    return UNet2DConditionOutput(sample=sample)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            def downblock_forward_XTI(
         | 
| 133 | 
            +
                self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
         | 
| 134 | 
            +
            ):
         | 
| 135 | 
            +
                output_states = ()
         | 
| 136 | 
            +
                i = 0
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                for resnet, attn in zip(self.resnets, self.attentions):
         | 
| 139 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        def create_custom_forward(module, return_dict=None):
         | 
| 142 | 
            +
                            def custom_forward(*inputs):
         | 
| 143 | 
            +
                                if return_dict is not None:
         | 
| 144 | 
            +
                                    return module(*inputs, return_dict=return_dict)
         | 
| 145 | 
            +
                                else:
         | 
| 146 | 
            +
                                    return module(*inputs)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                            return custom_forward
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
         | 
| 151 | 
            +
                        hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 152 | 
            +
                            create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
         | 
| 153 | 
            +
                        )[0]
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        hidden_states = resnet(hidden_states, temb)
         | 
| 156 | 
            +
                        hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    output_states += (hidden_states,)
         | 
| 159 | 
            +
                    i += 1
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                if self.downsamplers is not None:
         | 
| 162 | 
            +
                    for downsampler in self.downsamplers:
         | 
| 163 | 
            +
                        hidden_states = downsampler(hidden_states)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    output_states += (hidden_states,)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                return hidden_states, output_states
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            def upblock_forward_XTI(
         | 
| 170 | 
            +
                self,
         | 
| 171 | 
            +
                hidden_states,
         | 
| 172 | 
            +
                res_hidden_states_tuple,
         | 
| 173 | 
            +
                temb=None,
         | 
| 174 | 
            +
                encoder_hidden_states=None,
         | 
| 175 | 
            +
                upsample_size=None,
         | 
| 176 | 
            +
            ):
         | 
| 177 | 
            +
                i = 0
         | 
| 178 | 
            +
                for resnet, attn in zip(self.resnets, self.attentions):
         | 
| 179 | 
            +
                    # pop res hidden states
         | 
| 180 | 
            +
                    res_hidden_states = res_hidden_states_tuple[-1]
         | 
| 181 | 
            +
                    res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         | 
| 182 | 
            +
                    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                        def create_custom_forward(module, return_dict=None):
         | 
| 187 | 
            +
                            def custom_forward(*inputs):
         | 
| 188 | 
            +
                                if return_dict is not None:
         | 
| 189 | 
            +
                                    return module(*inputs, return_dict=return_dict)
         | 
| 190 | 
            +
                                else:
         | 
| 191 | 
            +
                                    return module(*inputs)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                            return custom_forward
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                        hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
         | 
| 196 | 
            +
                        hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 197 | 
            +
                            create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
         | 
| 198 | 
            +
                        )[0]
         | 
| 199 | 
            +
                    else:
         | 
| 200 | 
            +
                        hidden_states = resnet(hidden_states, temb)
         | 
| 201 | 
            +
                        hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
         | 
| 202 | 
            +
                    
         | 
| 203 | 
            +
                    i += 1
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                if self.upsamplers is not None:
         | 
| 206 | 
            +
                    for upsampler in self.upsamplers:
         | 
| 207 | 
            +
                        hidden_states = upsampler(hidden_states, upsample_size)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                return hidden_states
         | 
    	
        _typos.toml
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Files for typos
         | 
| 2 | 
            +
            # Instruction:  https://github.com/marketplace/actions/typos-action#getting-started
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            [default.extend-identifiers]
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            [default.extend-words]
         | 
| 7 | 
            +
            NIN="NIN"
         | 
| 8 | 
            +
            parms="parms"
         | 
| 9 | 
            +
            nin="nin"
         | 
| 10 | 
            +
            extention="extention" # Intentionally left
         | 
| 11 | 
            +
            nd="nd"
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            [files]
         | 
| 15 | 
            +
            extend-exclude = ["_typos.toml"]
         | 
    	
        cache/huggingface/gradio/frpc/frpc_linux_amd64_v0.3
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c791d1f047b41ff5885772fc4bf20b797c6059bbd82abb9e31de15e55d6a57c4
         | 
| 3 | 
            +
            size 11907224
         | 
    	
        config_README-ja.md
    ADDED
    
    | @@ -0,0 +1,279 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            `--dataset_config` で渡すことができる設定ファイルに関する説明です。
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            ## 概要
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            * 複数のデータセットが設定可能になります
         | 
| 10 | 
            +
                * 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。
         | 
| 11 | 
            +
                * DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。
         | 
| 12 | 
            +
            * サブセットごとに設定を変更することが可能になります
         | 
| 13 | 
            +
                * データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。
         | 
| 14 | 
            +
                * `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            TOML で記述した設定ファイルの例です。
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ```toml
         | 
| 21 | 
            +
            [general]
         | 
| 22 | 
            +
            shuffle_caption = true
         | 
| 23 | 
            +
            caption_extension = '.txt'
         | 
| 24 | 
            +
            keep_tokens = 1
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # これは DreamBooth 方式のデータセット
         | 
| 27 | 
            +
            [[datasets]]
         | 
| 28 | 
            +
            resolution = 512
         | 
| 29 | 
            +
            batch_size = 4
         | 
| 30 | 
            +
            keep_tokens = 2
         | 
| 31 | 
            +
             | 
| 32 | 
            +
              [[datasets.subsets]]
         | 
| 33 | 
            +
              image_dir = 'C:\hoge'
         | 
| 34 | 
            +
              class_tokens = 'hoge girl'
         | 
| 35 | 
            +
              # このサブセットは keep_tokens = 2 (所属する datasets の値が使われる)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
              [[datasets.subsets]]
         | 
| 38 | 
            +
              image_dir = 'C:\fuga'
         | 
| 39 | 
            +
              class_tokens = 'fuga boy'
         | 
| 40 | 
            +
              keep_tokens = 3
         | 
| 41 | 
            +
             | 
| 42 | 
            +
              [[datasets.subsets]]
         | 
| 43 | 
            +
              is_reg = true
         | 
| 44 | 
            +
              image_dir = 'C:\reg'
         | 
| 45 | 
            +
              class_tokens = 'human'
         | 
| 46 | 
            +
              keep_tokens = 1
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            # これは fine tuning 方式のデータセット
         | 
| 49 | 
            +
            [[datasets]]
         | 
| 50 | 
            +
            resolution = [768, 768]
         | 
| 51 | 
            +
            batch_size = 2
         | 
| 52 | 
            +
             | 
| 53 | 
            +
              [[datasets.subsets]]
         | 
| 54 | 
            +
              image_dir = 'C:\piyo'
         | 
| 55 | 
            +
              metadata_file = 'C:\piyo\piyo_md.json'
         | 
| 56 | 
            +
              # このサブセットは keep_tokens = 1 (general の値が使われる)
         | 
| 57 | 
            +
            ```
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            ## データセット・サブセットに関する設定
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            * `[general]`
         | 
| 66 | 
            +
                * 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。
         | 
| 67 | 
            +
                * データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。
         | 
| 68 | 
            +
            * `[[datasets]]`
         | 
| 69 | 
            +
                * `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。
         | 
| 70 | 
            +
                * サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。
         | 
| 71 | 
            +
            * `[[datasets.subsets]]`
         | 
| 72 | 
            +
                * `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            ```
         | 
| 77 | 
            +
            C:\
         | 
| 78 | 
            +
            ├─ hoge  ->  [[datasets.subsets]] No.1  ┐                        ┐
         | 
| 79 | 
            +
            ├─ fuga  ->  [[datasets.subsets]] No.2  |->  [[datasets]] No.1   |->  [general]
         | 
| 80 | 
            +
            ├─ reg   ->  [[datasets.subsets]] No.3  ┘                        |
         | 
| 81 | 
            +
            └─ piyo  ->  [[datasets.subsets]] No.4  -->  [[datasets]] No.2   ┘
         | 
| 82 | 
            +
            ```
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            * DreamBooth 方式専用のオプション
         | 
| 91 | 
            +
            * fine tuning 方式専用のオプション
         | 
| 92 | 
            +
            * caption dropout の手法が使える場合のオプション
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。
         | 
| 95 | 
            +
            併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。
         | 
| 96 | 
            +
            つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。
         | 
| 99 | 
            +
            そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            ### 全学習方法で共通のオプション
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            学習方法によらずに指定可能なオプションです。
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            #### データセット向けオプション
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            | オプション名 | 設定例 | `[general]` | `[[datasets]]` |
         | 
| 112 | 
            +
            | ---- | ---- | ---- | ---- |
         | 
| 113 | 
            +
            | `batch_size` | `1` | o | o |
         | 
| 114 | 
            +
            | `bucket_no_upscale` | `true` | o | o |
         | 
| 115 | 
            +
            | `bucket_reso_steps` | `64` | o | o |
         | 
| 116 | 
            +
            | `enable_bucket` | `true` | o | o |
         | 
| 117 | 
            +
            | `max_bucket_reso` | `1024` | o | o |
         | 
| 118 | 
            +
            | `min_bucket_reso` | `128` | o | o |
         | 
| 119 | 
            +
            | `resolution` | `256`, `[512, 512]` | o | o |
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            * `batch_size`
         | 
| 122 | 
            +
                * コマンドライン引数の `--train_batch_size` と同等です。
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            これらの設定はデータセットごとに固定です。
         | 
| 125 | 
            +
            つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
         | 
| 126 | 
            +
            例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            #### サブセット向けオプション
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            サブセットの設定に関わるオプションです。
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
         | 
| 133 | 
            +
            | ---- | ---- | ---- | ---- | ---- |
         | 
| 134 | 
            +
            | `color_aug` | `false` | o | o | o |
         | 
| 135 | 
            +
            | `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
         | 
| 136 | 
            +
            | `flip_aug` | `true` | o | o | o |
         | 
| 137 | 
            +
            | `keep_tokens` | `2` | o | o | o |
         | 
| 138 | 
            +
            | `num_repeats` | `10` | o | o | o |
         | 
| 139 | 
            +
            | `random_crop` | `false` | o | o | o |
         | 
| 140 | 
            +
            | `shuffle_caption` | `true` | o | o | o |
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            * `num_repeats`
         | 
| 143 | 
            +
                * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            ### DreamBooth 方式専用のオプション
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            #### サブセット向けオプション
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            DreamBooth 方式のサブセットの設定に関わるオプションです。
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
         | 
| 154 | 
            +
            | ---- | ---- | ---- | ---- | ---- |
         | 
| 155 | 
            +
            | `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
         | 
| 156 | 
            +
            | `caption_extension` | `".txt"` | o | o | o |
         | 
| 157 | 
            +
            | `class_tokens` | `“sks girl”` | - | - | o |
         | 
| 158 | 
            +
            | `is_reg` | `false` | - | - | o |
         | 
| 159 | 
            +
             | 
| 160 | 
            +
            まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            * `image_dir`
         | 
| 163 | 
            +
                * 画像ディレクトリのパスを指定します。指定必須オプションです。
         | 
| 164 | 
            +
                * 画像はディレクトリ直下に置かれている必要があります。
         | 
| 165 | 
            +
            * `class_tokens`
         | 
| 166 | 
            +
                * クラストークンを設定します。
         | 
| 167 | 
            +
                * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイル���見つからなかった場合にはエラーになります。
         | 
| 168 | 
            +
            * `is_reg`
         | 
| 169 | 
            +
                * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            ### fine tuning 方式専用のオプション
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。
         | 
| 174 | 
            +
             | 
| 175 | 
            +
            #### サブセット向けオプション
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            fine tuning 方式のサブセットの設定に関わるオプションです。
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
         | 
| 180 | 
            +
            | ---- | ---- | ---- | ---- | ---- |
         | 
| 181 | 
            +
            | `image_dir` | `‘C:\hoge’` | - | - | o |
         | 
| 182 | 
            +
            | `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) |
         | 
| 183 | 
            +
             | 
| 184 | 
            +
            * `image_dir`
         | 
| 185 | 
            +
                * 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。
         | 
| 186 | 
            +
                    * 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。
         | 
| 187 | 
            +
                * 画像はディレクトリ直下に置かれている必要があります。
         | 
| 188 | 
            +
            * `metadata_file`
         | 
| 189 | 
            +
                * サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。
         | 
| 190 | 
            +
                    * コマンドライン引数の `--in_json` と同等です。
         | 
| 191 | 
            +
                * サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            ### caption dropout の手法が使える場合に指定可能なオプション
         | 
| 194 | 
            +
             | 
| 195 | 
            +
            caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。
         | 
| 196 | 
            +
            DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。
         | 
| 197 | 
            +
             | 
| 198 | 
            +
            #### サブセット向けオプション
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            caption dropout が使えるサブセットの設定に関わるオプションです。
         | 
| 201 | 
            +
             | 
| 202 | 
            +
            | オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
         | 
| 203 | 
            +
            | ---- | ---- | ---- | ---- |
         | 
| 204 | 
            +
            | `caption_dropout_every_n_epochs` | o | o | o |
         | 
| 205 | 
            +
            | `caption_dropout_rate` | o | o | o |
         | 
| 206 | 
            +
            | `caption_tag_dropout_rate` | o | o | o |
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            ## 重複したサブセットが存在する時の挙動
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            DreamBooth 方式のデータセットの場合、その中にある `image_dir` が同一のサブセットは重複していると見なされます。
         | 
| 211 | 
            +
            fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。
         | 
| 212 | 
            +
            データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            一方、異なるデータセットに所属している場合は、重複しているとは見なされません。
         | 
| 215 | 
            +
            例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。
         | 
| 216 | 
            +
            これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。
         | 
| 217 | 
            +
             | 
| 218 | 
            +
            ```toml
         | 
| 219 | 
            +
            # 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            [[datasets]]
         | 
| 222 | 
            +
            resolution = 512
         | 
| 223 | 
            +
             | 
| 224 | 
            +
              [[datasets.subsets]]
         | 
| 225 | 
            +
              image_dir = 'C:\hoge'
         | 
| 226 | 
            +
             | 
| 227 | 
            +
            [[datasets]]
         | 
| 228 | 
            +
            resolution = 768
         | 
| 229 | 
            +
             | 
| 230 | 
            +
              [[datasets.subsets]]
         | 
| 231 | 
            +
              image_dir = 'C:\hoge'
         | 
| 232 | 
            +
            ```
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            ## コマンドライン引数との併用
         | 
| 235 | 
            +
             | 
| 236 | 
            +
            設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
         | 
| 237 | 
            +
             | 
| 238 | 
            +
            以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。
         | 
| 239 | 
            +
             | 
| 240 | 
            +
            * `--train_data_dir`
         | 
| 241 | 
            +
            * `--reg_data_dir`
         | 
| 242 | 
            +
            * `--in_json`
         | 
| 243 | 
            +
             | 
| 244 | 
            +
            以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。
         | 
| 245 | 
            +
             | 
| 246 | 
            +
            | コマンドライン引数のオプション     | 優先される設定ファイルのオプション |
         | 
| 247 | 
            +
            | ---------------------------------- | ---------------------------------- |
         | 
| 248 | 
            +
            | `--bucket_no_upscale`              |                                    |
         | 
| 249 | 
            +
            | `--bucket_reso_steps`              |                                    |
         | 
| 250 | 
            +
            | `--caption_dropout_every_n_epochs` |                                    |
         | 
| 251 | 
            +
            | `--caption_dropout_rate`           |                                    |
         | 
| 252 | 
            +
            | `--caption_extension`              |                                    |
         | 
| 253 | 
            +
            | `--caption_tag_dropout_rate`       |                                    |
         | 
| 254 | 
            +
            | `--color_aug`                      |                                    |
         | 
| 255 | 
            +
            | `--dataset_repeats`                | `num_repeats`                      |
         | 
| 256 | 
            +
            | `--enable_bucket`                  |                                    |
         | 
| 257 | 
            +
            | `--face_crop_aug_range`            |                                    |
         | 
| 258 | 
            +
            | `--flip_aug`                       |                                    |
         | 
| 259 | 
            +
            | `--keep_tokens`                    |                                    |
         | 
| 260 | 
            +
            | `--min_bucket_reso`                |                                    |
         | 
| 261 | 
            +
            | `--random_crop`                    |                                    |
         | 
| 262 | 
            +
            | `--resolution`                     |                                    |
         | 
| 263 | 
            +
            | `--shuffle_caption`                |                                    |
         | 
| 264 | 
            +
            | `--train_batch_size`               | `batch_size`                       |
         | 
| 265 | 
            +
             | 
| 266 | 
            +
            ## エラーの手引き
         | 
| 267 | 
            +
             | 
| 268 | 
            +
            現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。
         | 
| 269 | 
            +
            将来的にはこの問題の改善に取り組む予定です。
         | 
| 270 | 
            +
             | 
| 271 | 
            +
            次善策として、頻出のエラーとその対処法について載せておきます。
         | 
| 272 | 
            +
            正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。
         | 
| 273 | 
            +
             | 
| 274 | 
            +
            * `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。
         | 
| 275 | 
            +
              * `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。
         | 
| 276 | 
            +
            * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
         | 
| 277 | 
            +
            * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
    	
        config_files/accelerate/default_config.yaml
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            command_file: null
         | 
| 2 | 
            +
            commands: null
         | 
| 3 | 
            +
            compute_environment: LOCAL_MACHINE
         | 
| 4 | 
            +
            deepspeed_config: {}
         | 
| 5 | 
            +
            distributed_type: 'NO'
         | 
| 6 | 
            +
            downcast_bf16: 'no'
         | 
| 7 | 
            +
            dynamo_backend: 'NO'
         | 
| 8 | 
            +
            fsdp_config: {}
         | 
| 9 | 
            +
            gpu_ids: all
         | 
| 10 | 
            +
            machine_rank: 0
         | 
| 11 | 
            +
            main_process_ip: null
         | 
| 12 | 
            +
            main_process_port: null
         | 
| 13 | 
            +
            main_training_function: main
         | 
| 14 | 
            +
            megatron_lm_config: {}
         | 
| 15 | 
            +
            mixed_precision: 'no'
         | 
| 16 | 
            +
            num_machines: 1
         | 
| 17 | 
            +
            num_processes: 1
         | 
| 18 | 
            +
            rdzv_backend: static
         | 
| 19 | 
            +
            same_network: true
         | 
| 20 | 
            +
            tpu_name: null
         | 
| 21 | 
            +
            tpu_zone: null
         | 
| 22 | 
            +
            use_cpu: false
         | 
    	
        dreambooth_gui.py
    ADDED
    
    | @@ -0,0 +1,944 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # v1: initial release
         | 
| 2 | 
            +
            # v2: add open and save folder icons
         | 
| 3 | 
            +
            # v3: Add new Utilities tab for Dreambooth folder preparation
         | 
| 4 | 
            +
            # v3.1: Adding captionning of images to utilities
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import gradio as gr
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            import math
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            import subprocess
         | 
| 11 | 
            +
            import pathlib
         | 
| 12 | 
            +
            import argparse
         | 
| 13 | 
            +
            from library.common_gui import (
         | 
| 14 | 
            +
                get_folder_path,
         | 
| 15 | 
            +
                remove_doublequote,
         | 
| 16 | 
            +
                get_file_path,
         | 
| 17 | 
            +
                get_any_file_path,
         | 
| 18 | 
            +
                get_saveasfile_path,
         | 
| 19 | 
            +
                color_aug_changed,
         | 
| 20 | 
            +
                save_inference_file,
         | 
| 21 | 
            +
                gradio_advanced_training,
         | 
| 22 | 
            +
                run_cmd_advanced_training,
         | 
| 23 | 
            +
                run_cmd_training,
         | 
| 24 | 
            +
                gradio_training,
         | 
| 25 | 
            +
                gradio_config,
         | 
| 26 | 
            +
                gradio_source_model,
         | 
| 27 | 
            +
                # set_legacy_8bitadam,
         | 
| 28 | 
            +
                update_my_data,
         | 
| 29 | 
            +
                check_if_model_exist,
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
            from library.tensorboard_gui import (
         | 
| 32 | 
            +
                gradio_tensorboard,
         | 
| 33 | 
            +
                start_tensorboard,
         | 
| 34 | 
            +
                stop_tensorboard,
         | 
| 35 | 
            +
            )
         | 
| 36 | 
            +
            from library.dreambooth_folder_creation_gui import (
         | 
| 37 | 
            +
                gradio_dreambooth_folder_creation_tab,
         | 
| 38 | 
            +
            )
         | 
| 39 | 
            +
            from library.utilities import utilities_tab
         | 
| 40 | 
            +
            from library.sampler_gui import sample_gradio_config, run_cmd_sample
         | 
| 41 | 
            +
            from easygui import msgbox
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 44 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 45 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 46 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def save_configuration(
         | 
| 50 | 
            +
                save_as,
         | 
| 51 | 
            +
                file_path,
         | 
| 52 | 
            +
                pretrained_model_name_or_path,
         | 
| 53 | 
            +
                v2,
         | 
| 54 | 
            +
                v_parameterization,
         | 
| 55 | 
            +
                logging_dir,
         | 
| 56 | 
            +
                train_data_dir,
         | 
| 57 | 
            +
                reg_data_dir,
         | 
| 58 | 
            +
                output_dir,
         | 
| 59 | 
            +
                max_resolution,
         | 
| 60 | 
            +
                learning_rate,
         | 
| 61 | 
            +
                lr_scheduler,
         | 
| 62 | 
            +
                lr_warmup,
         | 
| 63 | 
            +
                train_batch_size,
         | 
| 64 | 
            +
                epoch,
         | 
| 65 | 
            +
                save_every_n_epochs,
         | 
| 66 | 
            +
                mixed_precision,
         | 
| 67 | 
            +
                save_precision,
         | 
| 68 | 
            +
                seed,
         | 
| 69 | 
            +
                num_cpu_threads_per_process,
         | 
| 70 | 
            +
                cache_latents,
         | 
| 71 | 
            +
                caption_extension,
         | 
| 72 | 
            +
                enable_bucket,
         | 
| 73 | 
            +
                gradient_checkpointing,
         | 
| 74 | 
            +
                full_fp16,
         | 
| 75 | 
            +
                no_token_padding,
         | 
| 76 | 
            +
                stop_text_encoder_training,
         | 
| 77 | 
            +
                # use_8bit_adam,
         | 
| 78 | 
            +
                xformers,
         | 
| 79 | 
            +
                save_model_as,
         | 
| 80 | 
            +
                shuffle_caption,
         | 
| 81 | 
            +
                save_state,
         | 
| 82 | 
            +
                resume,
         | 
| 83 | 
            +
                prior_loss_weight,
         | 
| 84 | 
            +
                color_aug,
         | 
| 85 | 
            +
                flip_aug,
         | 
| 86 | 
            +
                clip_skip,
         | 
| 87 | 
            +
                vae,
         | 
| 88 | 
            +
                output_name,
         | 
| 89 | 
            +
                max_token_length,
         | 
| 90 | 
            +
                max_train_epochs,
         | 
| 91 | 
            +
                max_data_loader_n_workers,
         | 
| 92 | 
            +
                mem_eff_attn,
         | 
| 93 | 
            +
                gradient_accumulation_steps,
         | 
| 94 | 
            +
                model_list,
         | 
| 95 | 
            +
                keep_tokens,
         | 
| 96 | 
            +
                persistent_data_loader_workers,
         | 
| 97 | 
            +
                bucket_no_upscale,
         | 
| 98 | 
            +
                random_crop,
         | 
| 99 | 
            +
                bucket_reso_steps,
         | 
| 100 | 
            +
                caption_dropout_every_n_epochs,
         | 
| 101 | 
            +
                caption_dropout_rate,
         | 
| 102 | 
            +
                optimizer,
         | 
| 103 | 
            +
                optimizer_args,
         | 
| 104 | 
            +
                noise_offset,
         | 
| 105 | 
            +
                sample_every_n_steps,
         | 
| 106 | 
            +
                sample_every_n_epochs,
         | 
| 107 | 
            +
                sample_sampler,
         | 
| 108 | 
            +
                sample_prompts,
         | 
| 109 | 
            +
                additional_parameters,
         | 
| 110 | 
            +
                vae_batch_size,
         | 
| 111 | 
            +
                min_snr_gamma,
         | 
| 112 | 
            +
            ):
         | 
| 113 | 
            +
                # Get list of function parameters and values
         | 
| 114 | 
            +
                parameters = list(locals().items())
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                original_file_path = file_path
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                save_as_bool = True if save_as.get('label') == 'True' else False
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                if save_as_bool:
         | 
| 121 | 
            +
                    print('Save as...')
         | 
| 122 | 
            +
                    file_path = get_saveasfile_path(file_path)
         | 
| 123 | 
            +
                else:
         | 
| 124 | 
            +
                    print('Save...')
         | 
| 125 | 
            +
                    if file_path == None or file_path == '':
         | 
| 126 | 
            +
                        file_path = get_saveasfile_path(file_path)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                # print(file_path)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                if file_path == None or file_path == '':
         | 
| 131 | 
            +
                    return original_file_path  # In case a file_path was provided and the user decide to cancel the open action
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # Return the values of the variables as a dictionary
         | 
| 134 | 
            +
                variables = {
         | 
| 135 | 
            +
                    name: value
         | 
| 136 | 
            +
                    for name, value in parameters  # locals().items()
         | 
| 137 | 
            +
                    if name
         | 
| 138 | 
            +
                    not in [
         | 
| 139 | 
            +
                        'file_path',
         | 
| 140 | 
            +
                        'save_as',
         | 
| 141 | 
            +
                    ]
         | 
| 142 | 
            +
                }
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                # Extract the destination directory from the file path
         | 
| 145 | 
            +
                destination_directory = os.path.dirname(file_path)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                # Create the destination directory if it doesn't exist
         | 
| 148 | 
            +
                if not os.path.exists(destination_directory):
         | 
| 149 | 
            +
                    os.makedirs(destination_directory)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                # Save the data to the selected file
         | 
| 152 | 
            +
                with open(file_path, 'w') as file:
         | 
| 153 | 
            +
                    json.dump(variables, file, indent=2)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                return file_path
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def open_configuration(
         | 
| 159 | 
            +
                ask_for_file,
         | 
| 160 | 
            +
                file_path,
         | 
| 161 | 
            +
                pretrained_model_name_or_path,
         | 
| 162 | 
            +
                v2,
         | 
| 163 | 
            +
                v_parameterization,
         | 
| 164 | 
            +
                logging_dir,
         | 
| 165 | 
            +
                train_data_dir,
         | 
| 166 | 
            +
                reg_data_dir,
         | 
| 167 | 
            +
                output_dir,
         | 
| 168 | 
            +
                max_resolution,
         | 
| 169 | 
            +
                learning_rate,
         | 
| 170 | 
            +
                lr_scheduler,
         | 
| 171 | 
            +
                lr_warmup,
         | 
| 172 | 
            +
                train_batch_size,
         | 
| 173 | 
            +
                epoch,
         | 
| 174 | 
            +
                save_every_n_epochs,
         | 
| 175 | 
            +
                mixed_precision,
         | 
| 176 | 
            +
                save_precision,
         | 
| 177 | 
            +
                seed,
         | 
| 178 | 
            +
                num_cpu_threads_per_process,
         | 
| 179 | 
            +
                cache_latents,
         | 
| 180 | 
            +
                caption_extension,
         | 
| 181 | 
            +
                enable_bucket,
         | 
| 182 | 
            +
                gradient_checkpointing,
         | 
| 183 | 
            +
                full_fp16,
         | 
| 184 | 
            +
                no_token_padding,
         | 
| 185 | 
            +
                stop_text_encoder_training,
         | 
| 186 | 
            +
                # use_8bit_adam,
         | 
| 187 | 
            +
                xformers,
         | 
| 188 | 
            +
                save_model_as,
         | 
| 189 | 
            +
                shuffle_caption,
         | 
| 190 | 
            +
                save_state,
         | 
| 191 | 
            +
                resume,
         | 
| 192 | 
            +
                prior_loss_weight,
         | 
| 193 | 
            +
                color_aug,
         | 
| 194 | 
            +
                flip_aug,
         | 
| 195 | 
            +
                clip_skip,
         | 
| 196 | 
            +
                vae,
         | 
| 197 | 
            +
                output_name,
         | 
| 198 | 
            +
                max_token_length,
         | 
| 199 | 
            +
                max_train_epochs,
         | 
| 200 | 
            +
                max_data_loader_n_workers,
         | 
| 201 | 
            +
                mem_eff_attn,
         | 
| 202 | 
            +
                gradient_accumulation_steps,
         | 
| 203 | 
            +
                model_list,
         | 
| 204 | 
            +
                keep_tokens,
         | 
| 205 | 
            +
                persistent_data_loader_workers,
         | 
| 206 | 
            +
                bucket_no_upscale,
         | 
| 207 | 
            +
                random_crop,
         | 
| 208 | 
            +
                bucket_reso_steps,
         | 
| 209 | 
            +
                caption_dropout_every_n_epochs,
         | 
| 210 | 
            +
                caption_dropout_rate,
         | 
| 211 | 
            +
                optimizer,
         | 
| 212 | 
            +
                optimizer_args,
         | 
| 213 | 
            +
                noise_offset,
         | 
| 214 | 
            +
                sample_every_n_steps,
         | 
| 215 | 
            +
                sample_every_n_epochs,
         | 
| 216 | 
            +
                sample_sampler,
         | 
| 217 | 
            +
                sample_prompts,
         | 
| 218 | 
            +
                additional_parameters,
         | 
| 219 | 
            +
                vae_batch_size,
         | 
| 220 | 
            +
                min_snr_gamma,
         | 
| 221 | 
            +
            ):
         | 
| 222 | 
            +
                # Get list of function parameters and values
         | 
| 223 | 
            +
                parameters = list(locals().items())
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                ask_for_file = True if ask_for_file.get('label') == 'True' else False
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                original_file_path = file_path
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                if ask_for_file:
         | 
| 230 | 
            +
                    file_path = get_file_path(file_path)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                if not file_path == '' and not file_path == None:
         | 
| 233 | 
            +
                    # load variables from JSON file
         | 
| 234 | 
            +
                    with open(file_path, 'r') as f:
         | 
| 235 | 
            +
                        my_data = json.load(f)
         | 
| 236 | 
            +
                        print('Loading config...')
         | 
| 237 | 
            +
                        # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
         | 
| 238 | 
            +
                        my_data = update_my_data(my_data)
         | 
| 239 | 
            +
                else:
         | 
| 240 | 
            +
                    file_path = original_file_path  # In case a file_path was provided and the user decide to cancel the open action
         | 
| 241 | 
            +
                    my_data = {}
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                values = [file_path]
         | 
| 244 | 
            +
                for key, value in parameters:
         | 
| 245 | 
            +
                    # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
         | 
| 246 | 
            +
                    if not key in ['ask_for_file', 'file_path']:
         | 
| 247 | 
            +
                        values.append(my_data.get(key, value))
         | 
| 248 | 
            +
                return tuple(values)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            def train_model(
         | 
| 252 | 
            +
                pretrained_model_name_or_path,
         | 
| 253 | 
            +
                v2,
         | 
| 254 | 
            +
                v_parameterization,
         | 
| 255 | 
            +
                logging_dir,
         | 
| 256 | 
            +
                train_data_dir,
         | 
| 257 | 
            +
                reg_data_dir,
         | 
| 258 | 
            +
                output_dir,
         | 
| 259 | 
            +
                max_resolution,
         | 
| 260 | 
            +
                learning_rate,
         | 
| 261 | 
            +
                lr_scheduler,
         | 
| 262 | 
            +
                lr_warmup,
         | 
| 263 | 
            +
                train_batch_size,
         | 
| 264 | 
            +
                epoch,
         | 
| 265 | 
            +
                save_every_n_epochs,
         | 
| 266 | 
            +
                mixed_precision,
         | 
| 267 | 
            +
                save_precision,
         | 
| 268 | 
            +
                seed,
         | 
| 269 | 
            +
                num_cpu_threads_per_process,
         | 
| 270 | 
            +
                cache_latents,
         | 
| 271 | 
            +
                caption_extension,
         | 
| 272 | 
            +
                enable_bucket,
         | 
| 273 | 
            +
                gradient_checkpointing,
         | 
| 274 | 
            +
                full_fp16,
         | 
| 275 | 
            +
                no_token_padding,
         | 
| 276 | 
            +
                stop_text_encoder_training_pct,
         | 
| 277 | 
            +
                # use_8bit_adam,
         | 
| 278 | 
            +
                xformers,
         | 
| 279 | 
            +
                save_model_as,
         | 
| 280 | 
            +
                shuffle_caption,
         | 
| 281 | 
            +
                save_state,
         | 
| 282 | 
            +
                resume,
         | 
| 283 | 
            +
                prior_loss_weight,
         | 
| 284 | 
            +
                color_aug,
         | 
| 285 | 
            +
                flip_aug,
         | 
| 286 | 
            +
                clip_skip,
         | 
| 287 | 
            +
                vae,
         | 
| 288 | 
            +
                output_name,
         | 
| 289 | 
            +
                max_token_length,
         | 
| 290 | 
            +
                max_train_epochs,
         | 
| 291 | 
            +
                max_data_loader_n_workers,
         | 
| 292 | 
            +
                mem_eff_attn,
         | 
| 293 | 
            +
                gradient_accumulation_steps,
         | 
| 294 | 
            +
                model_list,  # Keep this. Yes, it is unused here but required given the common list used
         | 
| 295 | 
            +
                keep_tokens,
         | 
| 296 | 
            +
                persistent_data_loader_workers,
         | 
| 297 | 
            +
                bucket_no_upscale,
         | 
| 298 | 
            +
                random_crop,
         | 
| 299 | 
            +
                bucket_reso_steps,
         | 
| 300 | 
            +
                caption_dropout_every_n_epochs,
         | 
| 301 | 
            +
                caption_dropout_rate,
         | 
| 302 | 
            +
                optimizer,
         | 
| 303 | 
            +
                optimizer_args,
         | 
| 304 | 
            +
                noise_offset,
         | 
| 305 | 
            +
                sample_every_n_steps,
         | 
| 306 | 
            +
                sample_every_n_epochs,
         | 
| 307 | 
            +
                sample_sampler,
         | 
| 308 | 
            +
                sample_prompts,
         | 
| 309 | 
            +
                additional_parameters,
         | 
| 310 | 
            +
                vae_batch_size,
         | 
| 311 | 
            +
                min_snr_gamma,
         | 
| 312 | 
            +
            ):
         | 
| 313 | 
            +
                if pretrained_model_name_or_path == '':
         | 
| 314 | 
            +
                    msgbox('Source model information is missing')
         | 
| 315 | 
            +
                    return
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                if train_data_dir == '':
         | 
| 318 | 
            +
                    msgbox('Image folder path is missing')
         | 
| 319 | 
            +
                    return
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                if not os.path.exists(train_data_dir):
         | 
| 322 | 
            +
                    msgbox('Image folder does not exist')
         | 
| 323 | 
            +
                    return
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                if reg_data_dir != '':
         | 
| 326 | 
            +
                    if not os.path.exists(reg_data_dir):
         | 
| 327 | 
            +
                        msgbox('Regularisation folder does not exist')
         | 
| 328 | 
            +
                        return
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                if output_dir == '':
         | 
| 331 | 
            +
                    msgbox('Output folder path is missing')
         | 
| 332 | 
            +
                    return
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                if check_if_model_exist(output_name, output_dir, save_model_as):
         | 
| 335 | 
            +
                    return
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                # Get a list of all subfolders in train_data_dir, excluding hidden folders
         | 
| 338 | 
            +
                subfolders = [
         | 
| 339 | 
            +
                    f
         | 
| 340 | 
            +
                    for f in os.listdir(train_data_dir)
         | 
| 341 | 
            +
                    if os.path.isdir(os.path.join(train_data_dir, f))
         | 
| 342 | 
            +
                    and not f.startswith('.')
         | 
| 343 | 
            +
                ]
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                # Check if subfolders are present. If not let the user know and return
         | 
| 346 | 
            +
                if not subfolders:
         | 
| 347 | 
            +
                    print(
         | 
| 348 | 
            +
                        '\033[33mNo subfolders were found in',
         | 
| 349 | 
            +
                        train_data_dir,
         | 
| 350 | 
            +
                        " can't train\...033[0m",
         | 
| 351 | 
            +
                    )
         | 
| 352 | 
            +
                    return
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                total_steps = 0
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                # Loop through each subfolder and extract the number of repeats
         | 
| 357 | 
            +
                for folder in subfolders:
         | 
| 358 | 
            +
                    # Extract the number of repeats from the folder name
         | 
| 359 | 
            +
                    try:
         | 
| 360 | 
            +
                        repeats = int(folder.split('_')[0])
         | 
| 361 | 
            +
                    except ValueError:
         | 
| 362 | 
            +
                        print(
         | 
| 363 | 
            +
                            '\033[33mSubfolder',
         | 
| 364 | 
            +
                            folder,
         | 
| 365 | 
            +
                            "does not have a proper repeat value, please correct the name or remove it... can't train...\033[0m",
         | 
| 366 | 
            +
                        )
         | 
| 367 | 
            +
                        continue
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    # Count the number of images in the folder
         | 
| 370 | 
            +
                    num_images = len(
         | 
| 371 | 
            +
                        [
         | 
| 372 | 
            +
                            f
         | 
| 373 | 
            +
                            for f, lower_f in (
         | 
| 374 | 
            +
                                (file, file.lower())
         | 
| 375 | 
            +
                                for file in os.listdir(
         | 
| 376 | 
            +
                                    os.path.join(train_data_dir, folder)
         | 
| 377 | 
            +
                                )
         | 
| 378 | 
            +
                            )
         | 
| 379 | 
            +
                            if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
         | 
| 380 | 
            +
                        ]
         | 
| 381 | 
            +
                    )
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    if num_images == 0:
         | 
| 384 | 
            +
                        print(f'{folder} folder contain no images, skipping...')
         | 
| 385 | 
            +
                    else:
         | 
| 386 | 
            +
                        # Calculate the total number of steps for this folder
         | 
| 387 | 
            +
                        steps = repeats * num_images
         | 
| 388 | 
            +
                        total_steps += steps
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                        # Print the result
         | 
| 391 | 
            +
                        print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                if total_steps == 0:
         | 
| 394 | 
            +
                    print(
         | 
| 395 | 
            +
                        '\033[33mNo images were found in folder',
         | 
| 396 | 
            +
                        train_data_dir,
         | 
| 397 | 
            +
                        '... please rectify!\033[0m',
         | 
| 398 | 
            +
                    )
         | 
| 399 | 
            +
                    return
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                # Print the result
         | 
| 402 | 
            +
                # print(f"{total_steps} total steps")
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                if reg_data_dir == '':
         | 
| 405 | 
            +
                    reg_factor = 1
         | 
| 406 | 
            +
                else:
         | 
| 407 | 
            +
                    print(
         | 
| 408 | 
            +
                        '\033[94mRegularisation images are used... Will double the number of steps required...\033[0m'
         | 
| 409 | 
            +
                    )
         | 
| 410 | 
            +
                    reg_factor = 2
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                # calculate max_train_steps
         | 
| 413 | 
            +
                max_train_steps = int(
         | 
| 414 | 
            +
                    math.ceil(
         | 
| 415 | 
            +
                        float(total_steps)
         | 
| 416 | 
            +
                        / int(train_batch_size)
         | 
| 417 | 
            +
                        * int(epoch)
         | 
| 418 | 
            +
                        * int(reg_factor)
         | 
| 419 | 
            +
                    )
         | 
| 420 | 
            +
                )
         | 
| 421 | 
            +
                print(f'max_train_steps = {max_train_steps}')
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                # calculate stop encoder training
         | 
| 424 | 
            +
                if int(stop_text_encoder_training_pct) == -1:
         | 
| 425 | 
            +
                    stop_text_encoder_training = -1
         | 
| 426 | 
            +
                elif stop_text_encoder_training_pct == None:
         | 
| 427 | 
            +
                    stop_text_encoder_training = 0
         | 
| 428 | 
            +
                else:
         | 
| 429 | 
            +
                    stop_text_encoder_training = math.ceil(
         | 
| 430 | 
            +
                        float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
         | 
| 431 | 
            +
                    )
         | 
| 432 | 
            +
                print(f'stop_text_encoder_training = {stop_text_encoder_training}')
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
         | 
| 435 | 
            +
                print(f'lr_warmup_steps = {lr_warmup_steps}')
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
         | 
| 438 | 
            +
                if v2:
         | 
| 439 | 
            +
                    run_cmd += ' --v2'
         | 
| 440 | 
            +
                if v_parameterization:
         | 
| 441 | 
            +
                    run_cmd += ' --v_parameterization'
         | 
| 442 | 
            +
                if enable_bucket:
         | 
| 443 | 
            +
                    run_cmd += ' --enable_bucket'
         | 
| 444 | 
            +
                if no_token_padding:
         | 
| 445 | 
            +
                    run_cmd += ' --no_token_padding'
         | 
| 446 | 
            +
                run_cmd += (
         | 
| 447 | 
            +
                    f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
         | 
| 448 | 
            +
                )
         | 
| 449 | 
            +
                run_cmd += f' --train_data_dir="{train_data_dir}"'
         | 
| 450 | 
            +
                if len(reg_data_dir):
         | 
| 451 | 
            +
                    run_cmd += f' --reg_data_dir="{reg_data_dir}"'
         | 
| 452 | 
            +
                run_cmd += f' --resolution={max_resolution}'
         | 
| 453 | 
            +
                run_cmd += f' --output_dir="{output_dir}"'
         | 
| 454 | 
            +
                run_cmd += f' --logging_dir="{logging_dir}"'
         | 
| 455 | 
            +
                if not stop_text_encoder_training == 0:
         | 
| 456 | 
            +
                    run_cmd += (
         | 
| 457 | 
            +
                        f' --stop_text_encoder_training={stop_text_encoder_training}'
         | 
| 458 | 
            +
                    )
         | 
| 459 | 
            +
                if not save_model_as == 'same as source model':
         | 
| 460 | 
            +
                    run_cmd += f' --save_model_as={save_model_as}'
         | 
| 461 | 
            +
                # if not resume == '':
         | 
| 462 | 
            +
                #     run_cmd += f' --resume={resume}'
         | 
| 463 | 
            +
                if not float(prior_loss_weight) == 1.0:
         | 
| 464 | 
            +
                    run_cmd += f' --prior_loss_weight={prior_loss_weight}'
         | 
| 465 | 
            +
                if not vae == '':
         | 
| 466 | 
            +
                    run_cmd += f' --vae="{vae}"'
         | 
| 467 | 
            +
                if not output_name == '':
         | 
| 468 | 
            +
                    run_cmd += f' --output_name="{output_name}"'
         | 
| 469 | 
            +
                if int(max_token_length) > 75:
         | 
| 470 | 
            +
                    run_cmd += f' --max_token_length={max_token_length}'
         | 
| 471 | 
            +
                if not max_train_epochs == '':
         | 
| 472 | 
            +
                    run_cmd += f' --max_train_epochs="{max_train_epochs}"'
         | 
| 473 | 
            +
                if not max_data_loader_n_workers == '':
         | 
| 474 | 
            +
                    run_cmd += (
         | 
| 475 | 
            +
                        f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
         | 
| 476 | 
            +
                    )
         | 
| 477 | 
            +
                if int(gradient_accumulation_steps) > 1:
         | 
| 478 | 
            +
                    run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                run_cmd += run_cmd_training(
         | 
| 481 | 
            +
                    learning_rate=learning_rate,
         | 
| 482 | 
            +
                    lr_scheduler=lr_scheduler,
         | 
| 483 | 
            +
                    lr_warmup_steps=lr_warmup_steps,
         | 
| 484 | 
            +
                    train_batch_size=train_batch_size,
         | 
| 485 | 
            +
                    max_train_steps=max_train_steps,
         | 
| 486 | 
            +
                    save_every_n_epochs=save_every_n_epochs,
         | 
| 487 | 
            +
                    mixed_precision=mixed_precision,
         | 
| 488 | 
            +
                    save_precision=save_precision,
         | 
| 489 | 
            +
                    seed=seed,
         | 
| 490 | 
            +
                    caption_extension=caption_extension,
         | 
| 491 | 
            +
                    cache_latents=cache_latents,
         | 
| 492 | 
            +
                    optimizer=optimizer,
         | 
| 493 | 
            +
                    optimizer_args=optimizer_args,
         | 
| 494 | 
            +
                )
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                run_cmd += run_cmd_advanced_training(
         | 
| 497 | 
            +
                    max_train_epochs=max_train_epochs,
         | 
| 498 | 
            +
                    max_data_loader_n_workers=max_data_loader_n_workers,
         | 
| 499 | 
            +
                    max_token_length=max_token_length,
         | 
| 500 | 
            +
                    resume=resume,
         | 
| 501 | 
            +
                    save_state=save_state,
         | 
| 502 | 
            +
                    mem_eff_attn=mem_eff_attn,
         | 
| 503 | 
            +
                    clip_skip=clip_skip,
         | 
| 504 | 
            +
                    flip_aug=flip_aug,
         | 
| 505 | 
            +
                    color_aug=color_aug,
         | 
| 506 | 
            +
                    shuffle_caption=shuffle_caption,
         | 
| 507 | 
            +
                    gradient_checkpointing=gradient_checkpointing,
         | 
| 508 | 
            +
                    full_fp16=full_fp16,
         | 
| 509 | 
            +
                    xformers=xformers,
         | 
| 510 | 
            +
                    # use_8bit_adam=use_8bit_adam,
         | 
| 511 | 
            +
                    keep_tokens=keep_tokens,
         | 
| 512 | 
            +
                    persistent_data_loader_workers=persistent_data_loader_workers,
         | 
| 513 | 
            +
                    bucket_no_upscale=bucket_no_upscale,
         | 
| 514 | 
            +
                    random_crop=random_crop,
         | 
| 515 | 
            +
                    bucket_reso_steps=bucket_reso_steps,
         | 
| 516 | 
            +
                    caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
         | 
| 517 | 
            +
                    caption_dropout_rate=caption_dropout_rate,
         | 
| 518 | 
            +
                    noise_offset=noise_offset,
         | 
| 519 | 
            +
                    additional_parameters=additional_parameters,
         | 
| 520 | 
            +
                    vae_batch_size=vae_batch_size,
         | 
| 521 | 
            +
                    min_snr_gamma=min_snr_gamma,
         | 
| 522 | 
            +
                )
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                run_cmd += run_cmd_sample(
         | 
| 525 | 
            +
                    sample_every_n_steps,
         | 
| 526 | 
            +
                    sample_every_n_epochs,
         | 
| 527 | 
            +
                    sample_sampler,
         | 
| 528 | 
            +
                    sample_prompts,
         | 
| 529 | 
            +
                    output_dir,
         | 
| 530 | 
            +
                )
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                print(run_cmd)
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                # Run the command
         | 
| 535 | 
            +
                if os.name == 'posix':
         | 
| 536 | 
            +
                    os.system(run_cmd)
         | 
| 537 | 
            +
                else:
         | 
| 538 | 
            +
                    subprocess.run(run_cmd)
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                # check if output_dir/last is a folder... therefore it is a diffuser model
         | 
| 541 | 
            +
                last_dir = pathlib.Path(f'{output_dir}/{output_name}')
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                if not last_dir.is_dir():
         | 
| 544 | 
            +
                    # Copy inference model for v2 if required
         | 
| 545 | 
            +
                    save_inference_file(output_dir, v2, v_parameterization, output_name)
         | 
| 546 | 
            +
             | 
| 547 | 
            +
             | 
| 548 | 
            +
            def dreambooth_tab(
         | 
| 549 | 
            +
                train_data_dir=gr.Textbox(),
         | 
| 550 | 
            +
                reg_data_dir=gr.Textbox(),
         | 
| 551 | 
            +
                output_dir=gr.Textbox(),
         | 
| 552 | 
            +
                logging_dir=gr.Textbox(),
         | 
| 553 | 
            +
            ):
         | 
| 554 | 
            +
                dummy_db_true = gr.Label(value=True, visible=False)
         | 
| 555 | 
            +
                dummy_db_false = gr.Label(value=False, visible=False)
         | 
| 556 | 
            +
                gr.Markdown('Train a custom model using kohya dreambooth python code...')
         | 
| 557 | 
            +
                (
         | 
| 558 | 
            +
                    button_open_config,
         | 
| 559 | 
            +
                    button_save_config,
         | 
| 560 | 
            +
                    button_save_as_config,
         | 
| 561 | 
            +
                    config_file_name,
         | 
| 562 | 
            +
                    button_load_config,
         | 
| 563 | 
            +
                ) = gradio_config()
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                (
         | 
| 566 | 
            +
                    pretrained_model_name_or_path,
         | 
| 567 | 
            +
                    v2,
         | 
| 568 | 
            +
                    v_parameterization,
         | 
| 569 | 
            +
                    save_model_as,
         | 
| 570 | 
            +
                    model_list,
         | 
| 571 | 
            +
                ) = gradio_source_model()
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                with gr.Tab('Folders'):
         | 
| 574 | 
            +
                    with gr.Row():
         | 
| 575 | 
            +
                        train_data_dir = gr.Textbox(
         | 
| 576 | 
            +
                            label='Image folder',
         | 
| 577 | 
            +
                            placeholder='Folder where the training folders containing the images are located',
         | 
| 578 | 
            +
                        )
         | 
| 579 | 
            +
                        train_data_dir_input_folder = gr.Button(
         | 
| 580 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 581 | 
            +
                        )
         | 
| 582 | 
            +
                        train_data_dir_input_folder.click(
         | 
| 583 | 
            +
                            get_folder_path,
         | 
| 584 | 
            +
                            outputs=train_data_dir,
         | 
| 585 | 
            +
                            show_progress=False,
         | 
| 586 | 
            +
                        )
         | 
| 587 | 
            +
                        reg_data_dir = gr.Textbox(
         | 
| 588 | 
            +
                            label='Regularisation folder',
         | 
| 589 | 
            +
                            placeholder='(Optional) Folder where where the regularization folders containing the images are located',
         | 
| 590 | 
            +
                        )
         | 
| 591 | 
            +
                        reg_data_dir_input_folder = gr.Button(
         | 
| 592 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 593 | 
            +
                        )
         | 
| 594 | 
            +
                        reg_data_dir_input_folder.click(
         | 
| 595 | 
            +
                            get_folder_path,
         | 
| 596 | 
            +
                            outputs=reg_data_dir,
         | 
| 597 | 
            +
                            show_progress=False,
         | 
| 598 | 
            +
                        )
         | 
| 599 | 
            +
                    with gr.Row():
         | 
| 600 | 
            +
                        output_dir = gr.Textbox(
         | 
| 601 | 
            +
                            label='Model output folder',
         | 
| 602 | 
            +
                            placeholder='Folder to output trained model',
         | 
| 603 | 
            +
                        )
         | 
| 604 | 
            +
                        output_dir_input_folder = gr.Button(
         | 
| 605 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 606 | 
            +
                        )
         | 
| 607 | 
            +
                        output_dir_input_folder.click(get_folder_path, outputs=output_dir)
         | 
| 608 | 
            +
                        logging_dir = gr.Textbox(
         | 
| 609 | 
            +
                            label='Logging folder',
         | 
| 610 | 
            +
                            placeholder='Optional: enable logging and output TensorBoard log to this folder',
         | 
| 611 | 
            +
                        )
         | 
| 612 | 
            +
                        logging_dir_input_folder = gr.Button(
         | 
| 613 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 614 | 
            +
                        )
         | 
| 615 | 
            +
                        logging_dir_input_folder.click(
         | 
| 616 | 
            +
                            get_folder_path,
         | 
| 617 | 
            +
                            outputs=logging_dir,
         | 
| 618 | 
            +
                            show_progress=False,
         | 
| 619 | 
            +
                        )
         | 
| 620 | 
            +
                    with gr.Row():
         | 
| 621 | 
            +
                        output_name = gr.Textbox(
         | 
| 622 | 
            +
                            label='Model output name',
         | 
| 623 | 
            +
                            placeholder='Name of the model to output',
         | 
| 624 | 
            +
                            value='last',
         | 
| 625 | 
            +
                            interactive=True,
         | 
| 626 | 
            +
                        )
         | 
| 627 | 
            +
                    train_data_dir.change(
         | 
| 628 | 
            +
                        remove_doublequote,
         | 
| 629 | 
            +
                        inputs=[train_data_dir],
         | 
| 630 | 
            +
                        outputs=[train_data_dir],
         | 
| 631 | 
            +
                    )
         | 
| 632 | 
            +
                    reg_data_dir.change(
         | 
| 633 | 
            +
                        remove_doublequote,
         | 
| 634 | 
            +
                        inputs=[reg_data_dir],
         | 
| 635 | 
            +
                        outputs=[reg_data_dir],
         | 
| 636 | 
            +
                    )
         | 
| 637 | 
            +
                    output_dir.change(
         | 
| 638 | 
            +
                        remove_doublequote,
         | 
| 639 | 
            +
                        inputs=[output_dir],
         | 
| 640 | 
            +
                        outputs=[output_dir],
         | 
| 641 | 
            +
                    )
         | 
| 642 | 
            +
                    logging_dir.change(
         | 
| 643 | 
            +
                        remove_doublequote,
         | 
| 644 | 
            +
                        inputs=[logging_dir],
         | 
| 645 | 
            +
                        outputs=[logging_dir],
         | 
| 646 | 
            +
                    )
         | 
| 647 | 
            +
                with gr.Tab('Training parameters'):
         | 
| 648 | 
            +
                    (
         | 
| 649 | 
            +
                        learning_rate,
         | 
| 650 | 
            +
                        lr_scheduler,
         | 
| 651 | 
            +
                        lr_warmup,
         | 
| 652 | 
            +
                        train_batch_size,
         | 
| 653 | 
            +
                        epoch,
         | 
| 654 | 
            +
                        save_every_n_epochs,
         | 
| 655 | 
            +
                        mixed_precision,
         | 
| 656 | 
            +
                        save_precision,
         | 
| 657 | 
            +
                        num_cpu_threads_per_process,
         | 
| 658 | 
            +
                        seed,
         | 
| 659 | 
            +
                        caption_extension,
         | 
| 660 | 
            +
                        cache_latents,
         | 
| 661 | 
            +
                        optimizer,
         | 
| 662 | 
            +
                        optimizer_args,
         | 
| 663 | 
            +
                    ) = gradio_training(
         | 
| 664 | 
            +
                        learning_rate_value='1e-5',
         | 
| 665 | 
            +
                        lr_scheduler_value='cosine',
         | 
| 666 | 
            +
                        lr_warmup_value='10',
         | 
| 667 | 
            +
                    )
         | 
| 668 | 
            +
                    with gr.Row():
         | 
| 669 | 
            +
                        max_resolution = gr.Textbox(
         | 
| 670 | 
            +
                            label='Max resolution',
         | 
| 671 | 
            +
                            value='512,512',
         | 
| 672 | 
            +
                            placeholder='512,512',
         | 
| 673 | 
            +
                        )
         | 
| 674 | 
            +
                        stop_text_encoder_training = gr.Slider(
         | 
| 675 | 
            +
                            minimum=-1,
         | 
| 676 | 
            +
                            maximum=100,
         | 
| 677 | 
            +
                            value=0,
         | 
| 678 | 
            +
                            step=1,
         | 
| 679 | 
            +
                            label='Stop text encoder training',
         | 
| 680 | 
            +
                        )
         | 
| 681 | 
            +
                        enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
         | 
| 682 | 
            +
                    with gr.Accordion('Advanced Configuration', open=False):
         | 
| 683 | 
            +
                        with gr.Row():
         | 
| 684 | 
            +
                            no_token_padding = gr.Checkbox(
         | 
| 685 | 
            +
                                label='No token padding', value=False
         | 
| 686 | 
            +
                            )
         | 
| 687 | 
            +
                            gradient_accumulation_steps = gr.Number(
         | 
| 688 | 
            +
                                label='Gradient accumulate steps', value='1'
         | 
| 689 | 
            +
                            )
         | 
| 690 | 
            +
                        with gr.Row():
         | 
| 691 | 
            +
                            prior_loss_weight = gr.Number(
         | 
| 692 | 
            +
                                label='Prior loss weight', value=1.0
         | 
| 693 | 
            +
                            )
         | 
| 694 | 
            +
                            vae = gr.Textbox(
         | 
| 695 | 
            +
                                label='VAE',
         | 
| 696 | 
            +
                                placeholder='(Optiona) path to checkpoint of vae to replace for training',
         | 
| 697 | 
            +
                            )
         | 
| 698 | 
            +
                            vae_button = gr.Button('📂', elem_id='open_folder_small')
         | 
| 699 | 
            +
                            vae_button.click(
         | 
| 700 | 
            +
                                get_any_file_path,
         | 
| 701 | 
            +
                                outputs=vae,
         | 
| 702 | 
            +
                                show_progress=False,
         | 
| 703 | 
            +
                            )
         | 
| 704 | 
            +
                        (
         | 
| 705 | 
            +
                            # use_8bit_adam,
         | 
| 706 | 
            +
                            xformers,
         | 
| 707 | 
            +
                            full_fp16,
         | 
| 708 | 
            +
                            gradient_checkpointing,
         | 
| 709 | 
            +
                            shuffle_caption,
         | 
| 710 | 
            +
                            color_aug,
         | 
| 711 | 
            +
                            flip_aug,
         | 
| 712 | 
            +
                            clip_skip,
         | 
| 713 | 
            +
                            mem_eff_attn,
         | 
| 714 | 
            +
                            save_state,
         | 
| 715 | 
            +
                            resume,
         | 
| 716 | 
            +
                            max_token_length,
         | 
| 717 | 
            +
                            max_train_epochs,
         | 
| 718 | 
            +
                            max_data_loader_n_workers,
         | 
| 719 | 
            +
                            keep_tokens,
         | 
| 720 | 
            +
                            persistent_data_loader_workers,
         | 
| 721 | 
            +
                            bucket_no_upscale,
         | 
| 722 | 
            +
                            random_crop,
         | 
| 723 | 
            +
                            bucket_reso_steps,
         | 
| 724 | 
            +
                            caption_dropout_every_n_epochs,
         | 
| 725 | 
            +
                            caption_dropout_rate,
         | 
| 726 | 
            +
                            noise_offset,
         | 
| 727 | 
            +
                            additional_parameters,
         | 
| 728 | 
            +
                            vae_batch_size,
         | 
| 729 | 
            +
                            min_snr_gamma,
         | 
| 730 | 
            +
                        ) = gradio_advanced_training()
         | 
| 731 | 
            +
                        color_aug.change(
         | 
| 732 | 
            +
                            color_aug_changed,
         | 
| 733 | 
            +
                            inputs=[color_aug],
         | 
| 734 | 
            +
                            outputs=[cache_latents],
         | 
| 735 | 
            +
                        )
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                    (
         | 
| 738 | 
            +
                        sample_every_n_steps,
         | 
| 739 | 
            +
                        sample_every_n_epochs,
         | 
| 740 | 
            +
                        sample_sampler,
         | 
| 741 | 
            +
                        sample_prompts,
         | 
| 742 | 
            +
                    ) = sample_gradio_config()
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                with gr.Tab('Tools'):
         | 
| 745 | 
            +
                    gr.Markdown(
         | 
| 746 | 
            +
                        'This section provide Dreambooth tools to help setup your dataset...'
         | 
| 747 | 
            +
                    )
         | 
| 748 | 
            +
                    gradio_dreambooth_folder_creation_tab(
         | 
| 749 | 
            +
                        train_data_dir_input=train_data_dir,
         | 
| 750 | 
            +
                        reg_data_dir_input=reg_data_dir,
         | 
| 751 | 
            +
                        output_dir_input=output_dir,
         | 
| 752 | 
            +
                        logging_dir_input=logging_dir,
         | 
| 753 | 
            +
                    )
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                button_run = gr.Button('Train model', variant='primary')
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                # Setup gradio tensorboard buttons
         | 
| 758 | 
            +
                button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                button_start_tensorboard.click(
         | 
| 761 | 
            +
                    start_tensorboard,
         | 
| 762 | 
            +
                    inputs=logging_dir,
         | 
| 763 | 
            +
                    show_progress=False,
         | 
| 764 | 
            +
                )
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                button_stop_tensorboard.click(
         | 
| 767 | 
            +
                    stop_tensorboard,
         | 
| 768 | 
            +
                    show_progress=False,
         | 
| 769 | 
            +
                )
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                settings_list = [
         | 
| 772 | 
            +
                    pretrained_model_name_or_path,
         | 
| 773 | 
            +
                    v2,
         | 
| 774 | 
            +
                    v_parameterization,
         | 
| 775 | 
            +
                    logging_dir,
         | 
| 776 | 
            +
                    train_data_dir,
         | 
| 777 | 
            +
                    reg_data_dir,
         | 
| 778 | 
            +
                    output_dir,
         | 
| 779 | 
            +
                    max_resolution,
         | 
| 780 | 
            +
                    learning_rate,
         | 
| 781 | 
            +
                    lr_scheduler,
         | 
| 782 | 
            +
                    lr_warmup,
         | 
| 783 | 
            +
                    train_batch_size,
         | 
| 784 | 
            +
                    epoch,
         | 
| 785 | 
            +
                    save_every_n_epochs,
         | 
| 786 | 
            +
                    mixed_precision,
         | 
| 787 | 
            +
                    save_precision,
         | 
| 788 | 
            +
                    seed,
         | 
| 789 | 
            +
                    num_cpu_threads_per_process,
         | 
| 790 | 
            +
                    cache_latents,
         | 
| 791 | 
            +
                    caption_extension,
         | 
| 792 | 
            +
                    enable_bucket,
         | 
| 793 | 
            +
                    gradient_checkpointing,
         | 
| 794 | 
            +
                    full_fp16,
         | 
| 795 | 
            +
                    no_token_padding,
         | 
| 796 | 
            +
                    stop_text_encoder_training,
         | 
| 797 | 
            +
                    # use_8bit_adam,
         | 
| 798 | 
            +
                    xformers,
         | 
| 799 | 
            +
                    save_model_as,
         | 
| 800 | 
            +
                    shuffle_caption,
         | 
| 801 | 
            +
                    save_state,
         | 
| 802 | 
            +
                    resume,
         | 
| 803 | 
            +
                    prior_loss_weight,
         | 
| 804 | 
            +
                    color_aug,
         | 
| 805 | 
            +
                    flip_aug,
         | 
| 806 | 
            +
                    clip_skip,
         | 
| 807 | 
            +
                    vae,
         | 
| 808 | 
            +
                    output_name,
         | 
| 809 | 
            +
                    max_token_length,
         | 
| 810 | 
            +
                    max_train_epochs,
         | 
| 811 | 
            +
                    max_data_loader_n_workers,
         | 
| 812 | 
            +
                    mem_eff_attn,
         | 
| 813 | 
            +
                    gradient_accumulation_steps,
         | 
| 814 | 
            +
                    model_list,
         | 
| 815 | 
            +
                    keep_tokens,
         | 
| 816 | 
            +
                    persistent_data_loader_workers,
         | 
| 817 | 
            +
                    bucket_no_upscale,
         | 
| 818 | 
            +
                    random_crop,
         | 
| 819 | 
            +
                    bucket_reso_steps,
         | 
| 820 | 
            +
                    caption_dropout_every_n_epochs,
         | 
| 821 | 
            +
                    caption_dropout_rate,
         | 
| 822 | 
            +
                    optimizer,
         | 
| 823 | 
            +
                    optimizer_args,
         | 
| 824 | 
            +
                    noise_offset,
         | 
| 825 | 
            +
                    sample_every_n_steps,
         | 
| 826 | 
            +
                    sample_every_n_epochs,
         | 
| 827 | 
            +
                    sample_sampler,
         | 
| 828 | 
            +
                    sample_prompts,
         | 
| 829 | 
            +
                    additional_parameters,
         | 
| 830 | 
            +
                    vae_batch_size,
         | 
| 831 | 
            +
                    min_snr_gamma,
         | 
| 832 | 
            +
                ]
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                button_open_config.click(
         | 
| 835 | 
            +
                    open_configuration,
         | 
| 836 | 
            +
                    inputs=[dummy_db_true, config_file_name] + settings_list,
         | 
| 837 | 
            +
                    outputs=[config_file_name] + settings_list,
         | 
| 838 | 
            +
                    show_progress=False,
         | 
| 839 | 
            +
                )
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                button_load_config.click(
         | 
| 842 | 
            +
                    open_configuration,
         | 
| 843 | 
            +
                    inputs=[dummy_db_false, config_file_name] + settings_list,
         | 
| 844 | 
            +
                    outputs=[config_file_name] + settings_list,
         | 
| 845 | 
            +
                    show_progress=False,
         | 
| 846 | 
            +
                )
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                button_save_config.click(
         | 
| 849 | 
            +
                    save_configuration,
         | 
| 850 | 
            +
                    inputs=[dummy_db_false, config_file_name] + settings_list,
         | 
| 851 | 
            +
                    outputs=[config_file_name],
         | 
| 852 | 
            +
                    show_progress=False,
         | 
| 853 | 
            +
                )
         | 
| 854 | 
            +
             | 
| 855 | 
            +
                button_save_as_config.click(
         | 
| 856 | 
            +
                    save_configuration,
         | 
| 857 | 
            +
                    inputs=[dummy_db_true, config_file_name] + settings_list,
         | 
| 858 | 
            +
                    outputs=[config_file_name],
         | 
| 859 | 
            +
                    show_progress=False,
         | 
| 860 | 
            +
                )
         | 
| 861 | 
            +
             | 
| 862 | 
            +
                button_run.click(
         | 
| 863 | 
            +
                    train_model,
         | 
| 864 | 
            +
                    inputs=settings_list,
         | 
| 865 | 
            +
                    show_progress=False,
         | 
| 866 | 
            +
                )
         | 
| 867 | 
            +
             | 
| 868 | 
            +
                return (
         | 
| 869 | 
            +
                    train_data_dir,
         | 
| 870 | 
            +
                    reg_data_dir,
         | 
| 871 | 
            +
                    output_dir,
         | 
| 872 | 
            +
                    logging_dir,
         | 
| 873 | 
            +
                )
         | 
| 874 | 
            +
             | 
| 875 | 
            +
             | 
| 876 | 
            +
            def UI(**kwargs):
         | 
| 877 | 
            +
                css = ''
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                if os.path.exists('./style.css'):
         | 
| 880 | 
            +
                    with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
         | 
| 881 | 
            +
                        print('Load CSS...')
         | 
| 882 | 
            +
                        css += file.read() + '\n'
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                interface = gr.Blocks(css=css)
         | 
| 885 | 
            +
             | 
| 886 | 
            +
                with interface:
         | 
| 887 | 
            +
                    with gr.Tab('Dreambooth'):
         | 
| 888 | 
            +
                        (
         | 
| 889 | 
            +
                            train_data_dir_input,
         | 
| 890 | 
            +
                            reg_data_dir_input,
         | 
| 891 | 
            +
                            output_dir_input,
         | 
| 892 | 
            +
                            logging_dir_input,
         | 
| 893 | 
            +
                        ) = dreambooth_tab()
         | 
| 894 | 
            +
                    with gr.Tab('Utilities'):
         | 
| 895 | 
            +
                        utilities_tab(
         | 
| 896 | 
            +
                            train_data_dir_input=train_data_dir_input,
         | 
| 897 | 
            +
                            reg_data_dir_input=reg_data_dir_input,
         | 
| 898 | 
            +
                            output_dir_input=output_dir_input,
         | 
| 899 | 
            +
                            logging_dir_input=logging_dir_input,
         | 
| 900 | 
            +
                            enable_copy_info_button=True,
         | 
| 901 | 
            +
                        )
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                # Show the interface
         | 
| 904 | 
            +
                launch_kwargs = {}
         | 
| 905 | 
            +
                if not kwargs.get('username', None) == '':
         | 
| 906 | 
            +
                    launch_kwargs['auth'] = (
         | 
| 907 | 
            +
                        kwargs.get('username', None),
         | 
| 908 | 
            +
                        kwargs.get('password', None),
         | 
| 909 | 
            +
                    )
         | 
| 910 | 
            +
                if kwargs.get('server_port', 0) > 0:
         | 
| 911 | 
            +
                    launch_kwargs['server_port'] = kwargs.get('server_port', 0)
         | 
| 912 | 
            +
                if kwargs.get('inbrowser', False):
         | 
| 913 | 
            +
                    launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
         | 
| 914 | 
            +
                print(launch_kwargs)
         | 
| 915 | 
            +
                interface.launch(**launch_kwargs)
         | 
| 916 | 
            +
             | 
| 917 | 
            +
             | 
| 918 | 
            +
            if __name__ == '__main__':
         | 
| 919 | 
            +
                # torch.cuda.set_per_process_memory_fraction(0.48)
         | 
| 920 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 921 | 
            +
                parser.add_argument(
         | 
| 922 | 
            +
                    '--username', type=str, default='', help='Username for authentication'
         | 
| 923 | 
            +
                )
         | 
| 924 | 
            +
                parser.add_argument(
         | 
| 925 | 
            +
                    '--password', type=str, default='', help='Password for authentication'
         | 
| 926 | 
            +
                )
         | 
| 927 | 
            +
                parser.add_argument(
         | 
| 928 | 
            +
                    '--server_port',
         | 
| 929 | 
            +
                    type=int,
         | 
| 930 | 
            +
                    default=0,
         | 
| 931 | 
            +
                    help='Port to run the server listener on',
         | 
| 932 | 
            +
                )
         | 
| 933 | 
            +
                parser.add_argument(
         | 
| 934 | 
            +
                    '--inbrowser', action='store_true', help='Open in browser'
         | 
| 935 | 
            +
                )
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                args = parser.parse_args()
         | 
| 938 | 
            +
             | 
| 939 | 
            +
                UI(
         | 
| 940 | 
            +
                    username=args.username,
         | 
| 941 | 
            +
                    password=args.password,
         | 
| 942 | 
            +
                    inbrowser=args.inbrowser,
         | 
| 943 | 
            +
                    server_port=args.server_port,
         | 
| 944 | 
            +
                )
         | 
    	
        fine_tune.py
    ADDED
    
    | @@ -0,0 +1,430 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # training with captions
         | 
| 2 | 
            +
            # XXX dropped option: hypernetwork training
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            import gc
         | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            import toml
         | 
| 9 | 
            +
            from multiprocessing import Value
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from tqdm import tqdm
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from accelerate.utils import set_seed
         | 
| 14 | 
            +
            import diffusers
         | 
| 15 | 
            +
            from diffusers import DDPMScheduler
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import library.train_util as train_util
         | 
| 18 | 
            +
            import library.config_util as config_util
         | 
| 19 | 
            +
            from library.config_util import (
         | 
| 20 | 
            +
                ConfigSanitizer,
         | 
| 21 | 
            +
                BlueprintGenerator,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
            import library.custom_train_functions as custom_train_functions
         | 
| 24 | 
            +
            from library.custom_train_functions import apply_snr_weight
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def train(args):
         | 
| 28 | 
            +
                train_util.verify_training_args(args)
         | 
| 29 | 
            +
                train_util.prepare_dataset_args(args, True)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                cache_latents = args.cache_latents
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                if args.seed is not None:
         | 
| 34 | 
            +
                    set_seed(args.seed)  # 乱数系列を初期化する
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                tokenizer = train_util.load_tokenizer(args)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
         | 
| 39 | 
            +
                if args.dataset_config is not None:
         | 
| 40 | 
            +
                    print(f"Load dataset config from {args.dataset_config}")
         | 
| 41 | 
            +
                    user_config = config_util.load_user_config(args.dataset_config)
         | 
| 42 | 
            +
                    ignored = ["train_data_dir", "in_json"]
         | 
| 43 | 
            +
                    if any(getattr(args, attr) is not None for attr in ignored):
         | 
| 44 | 
            +
                        print(
         | 
| 45 | 
            +
                            "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
         | 
| 46 | 
            +
                                ", ".join(ignored)
         | 
| 47 | 
            +
                            )
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                else:
         | 
| 50 | 
            +
                    user_config = {
         | 
| 51 | 
            +
                        "datasets": [
         | 
| 52 | 
            +
                            {
         | 
| 53 | 
            +
                                "subsets": [
         | 
| 54 | 
            +
                                    {
         | 
| 55 | 
            +
                                        "image_dir": args.train_data_dir,
         | 
| 56 | 
            +
                                        "metadata_file": args.in_json,
         | 
| 57 | 
            +
                                    }
         | 
| 58 | 
            +
                                ]
         | 
| 59 | 
            +
                            }
         | 
| 60 | 
            +
                        ]
         | 
| 61 | 
            +
                    }
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
         | 
| 64 | 
            +
                train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                current_epoch = Value("i", 0)
         | 
| 67 | 
            +
                current_step = Value("i", 0)
         | 
| 68 | 
            +
                ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
         | 
| 69 | 
            +
                collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                if args.debug_dataset:
         | 
| 72 | 
            +
                    train_util.debug_dataset(train_dataset_group)
         | 
| 73 | 
            +
                    return
         | 
| 74 | 
            +
                if len(train_dataset_group) == 0:
         | 
| 75 | 
            +
                    print(
         | 
| 76 | 
            +
                        "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
                    return
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                if cache_latents:
         | 
| 81 | 
            +
                    assert (
         | 
| 82 | 
            +
                        train_dataset_group.is_latent_cacheable()
         | 
| 83 | 
            +
                    ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # acceleratorを準備する
         | 
| 86 | 
            +
                print("prepare accelerator")
         | 
| 87 | 
            +
                accelerator, unwrap_model = train_util.prepare_accelerator(args)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                # mixed precisionに対応した型を用意しておき適宜castする
         | 
| 90 | 
            +
                weight_dtype, save_dtype = train_util.prepare_dtype(args)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                # モデルを読み込む
         | 
| 93 | 
            +
                text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                # verify load/save model formats
         | 
| 96 | 
            +
                if load_stable_diffusion_format:
         | 
| 97 | 
            +
                    src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
         | 
| 98 | 
            +
                    src_diffusers_model_path = None
         | 
| 99 | 
            +
                else:
         | 
| 100 | 
            +
                    src_stable_diffusion_ckpt = None
         | 
| 101 | 
            +
                    src_diffusers_model_path = args.pretrained_model_name_or_path
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                if args.save_model_as is None:
         | 
| 104 | 
            +
                    save_stable_diffusion_format = load_stable_diffusion_format
         | 
| 105 | 
            +
                    use_safetensors = args.use_safetensors
         | 
| 106 | 
            +
                else:
         | 
| 107 | 
            +
                    save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
         | 
| 108 | 
            +
                    use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # Diffusers版のxformers使用フラグを設定する関数
         | 
| 111 | 
            +
                def set_diffusers_xformers_flag(model, valid):
         | 
| 112 | 
            +
                    #   model.set_use_memory_efficient_attention_xformers(valid)            # 次のリリースでなくなりそう
         | 
| 113 | 
            +
                    # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
         | 
| 114 | 
            +
                    # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
         | 
| 115 | 
            +
                    # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    # Recursively walk through all the children.
         | 
| 118 | 
            +
                    # Any children which exposes the set_use_memory_efficient_attention_xformers method
         | 
| 119 | 
            +
                    # gets the message
         | 
| 120 | 
            +
                    def fn_recursive_set_mem_eff(module: torch.nn.Module):
         | 
| 121 | 
            +
                        if hasattr(module, "set_use_memory_efficient_attention_xformers"):
         | 
| 122 | 
            +
                            module.set_use_memory_efficient_attention_xformers(valid)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        for child in module.children():
         | 
| 125 | 
            +
                            fn_recursive_set_mem_eff(child)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    fn_recursive_set_mem_eff(model)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                # モデルに xformers とか memory efficient attention を組み込む
         | 
| 130 | 
            +
                if args.diffusers_xformers:
         | 
| 131 | 
            +
                    print("Use xformers by Diffusers")
         | 
| 132 | 
            +
                    set_diffusers_xformers_flag(unet, True)
         | 
| 133 | 
            +
                else:
         | 
| 134 | 
            +
                    # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
         | 
| 135 | 
            +
                    print("Disable Diffusers' xformers")
         | 
| 136 | 
            +
                    set_diffusers_xformers_flag(unet, False)
         | 
| 137 | 
            +
                    train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                # 学習を準備する
         | 
| 140 | 
            +
                if cache_latents:
         | 
| 141 | 
            +
                    vae.to(accelerator.device, dtype=weight_dtype)
         | 
| 142 | 
            +
                    vae.requires_grad_(False)
         | 
| 143 | 
            +
                    vae.eval()
         | 
| 144 | 
            +
                    with torch.no_grad():
         | 
| 145 | 
            +
                        train_dataset_group.cache_latents(vae, args.vae_batch_size)
         | 
| 146 | 
            +
                    vae.to("cpu")
         | 
| 147 | 
            +
                    if torch.cuda.is_available():
         | 
| 148 | 
            +
                        torch.cuda.empty_cache()
         | 
| 149 | 
            +
                    gc.collect()
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                # 学習を準備する:モデルを適切な状態にする
         | 
| 152 | 
            +
                training_models = []
         | 
| 153 | 
            +
                if args.gradient_checkpointing:
         | 
| 154 | 
            +
                    unet.enable_gradient_checkpointing()
         | 
| 155 | 
            +
                training_models.append(unet)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                if args.train_text_encoder:
         | 
| 158 | 
            +
                    print("enable text encoder training")
         | 
| 159 | 
            +
                    if args.gradient_checkpointing:
         | 
| 160 | 
            +
                        text_encoder.gradient_checkpointing_enable()
         | 
| 161 | 
            +
                    training_models.append(text_encoder)
         | 
| 162 | 
            +
                else:
         | 
| 163 | 
            +
                    text_encoder.to(accelerator.device, dtype=weight_dtype)
         | 
| 164 | 
            +
                    text_encoder.requires_grad_(False)  # text encoderは学習しない
         | 
| 165 | 
            +
                    if args.gradient_checkpointing:
         | 
| 166 | 
            +
                        text_encoder.gradient_checkpointing_enable()
         | 
| 167 | 
            +
                        text_encoder.train()  # required for gradient_checkpointing
         | 
| 168 | 
            +
                    else:
         | 
| 169 | 
            +
                        text_encoder.eval()
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                if not cache_latents:
         | 
| 172 | 
            +
                    vae.requires_grad_(False)
         | 
| 173 | 
            +
                    vae.eval()
         | 
| 174 | 
            +
                    vae.to(accelerator.device, dtype=weight_dtype)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                for m in training_models:
         | 
| 177 | 
            +
                    m.requires_grad_(True)
         | 
| 178 | 
            +
                params = []
         | 
| 179 | 
            +
                for m in training_models:
         | 
| 180 | 
            +
                    params.extend(m.parameters())
         | 
| 181 | 
            +
                params_to_optimize = params
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                # 学習に必要なクラスを準備する
         | 
| 184 | 
            +
                print("prepare optimizer, data loader etc.")
         | 
| 185 | 
            +
                _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                # dataloaderを準備する
         | 
| 188 | 
            +
                # DataLoaderのプロセス数:0はメインプロセスになる
         | 
| 189 | 
            +
                n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1)  # cpu_count-1 ただし最大で指定された数まで
         | 
| 190 | 
            +
                train_dataloader = torch.utils.data.DataLoader(
         | 
| 191 | 
            +
                    train_dataset_group,
         | 
| 192 | 
            +
                    batch_size=1,
         | 
| 193 | 
            +
                    shuffle=True,
         | 
| 194 | 
            +
                    collate_fn=collater,
         | 
| 195 | 
            +
                    num_workers=n_workers,
         | 
| 196 | 
            +
                    persistent_workers=args.persistent_data_loader_workers,
         | 
| 197 | 
            +
                )
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                # 学習ステップ数を計算する
         | 
| 200 | 
            +
                if args.max_train_epochs is not None:
         | 
| 201 | 
            +
                    args.max_train_steps = args.max_train_epochs * math.ceil(
         | 
| 202 | 
            +
                        len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                    print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                # データセット側にも学習ステップを送信
         | 
| 207 | 
            +
                train_dataset_group.set_max_train_steps(args.max_train_steps)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                # lr schedulerを用意する
         | 
| 210 | 
            +
                lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
         | 
| 213 | 
            +
                if args.full_fp16:
         | 
| 214 | 
            +
                    assert (
         | 
| 215 | 
            +
                        args.mixed_precision == "fp16"
         | 
| 216 | 
            +
                    ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
         | 
| 217 | 
            +
                    print("enable full fp16 training.")
         | 
| 218 | 
            +
                    unet.to(weight_dtype)
         | 
| 219 | 
            +
                    text_encoder.to(weight_dtype)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # acceleratorがなんかよろしくやってくれるらしい
         | 
| 222 | 
            +
                if args.train_text_encoder:
         | 
| 223 | 
            +
                    unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
         | 
| 224 | 
            +
                        unet, text_encoder, optimizer, train_dataloader, lr_scheduler
         | 
| 225 | 
            +
                    )
         | 
| 226 | 
            +
                else:
         | 
| 227 | 
            +
                    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
         | 
| 230 | 
            +
                if args.full_fp16:
         | 
| 231 | 
            +
                    train_util.patch_accelerator_for_fp16_training(accelerator)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                # resumeする
         | 
| 234 | 
            +
                if args.resume is not None:
         | 
| 235 | 
            +
                    print(f"resume training from state: {args.resume}")
         | 
| 236 | 
            +
                    accelerator.load_state(args.resume)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                # epoch数を計算する
         | 
| 239 | 
            +
                num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
         | 
| 240 | 
            +
                num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
         | 
| 241 | 
            +
                if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
         | 
| 242 | 
            +
                    args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                # 学習する
         | 
| 245 | 
            +
                total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
         | 
| 246 | 
            +
                print("running training / 学習開始")
         | 
| 247 | 
            +
                print(f"  num examples / サンプル数: {train_dataset_group.num_train_images}")
         | 
| 248 | 
            +
                print(f"  num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
         | 
| 249 | 
            +
                print(f"  num epochs / epoch数: {num_train_epochs}")
         | 
| 250 | 
            +
                print(f"  batch size per device / バッチサイズ: {args.train_batch_size}")
         | 
| 251 | 
            +
                print(f"  total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
         | 
| 252 | 
            +
                print(f"  gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
         | 
| 253 | 
            +
                print(f"  total optimization steps / 学習ステップ数: {args.max_train_steps}")
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
         | 
| 256 | 
            +
                global_step = 0
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                noise_scheduler = DDPMScheduler(
         | 
| 259 | 
            +
                    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
         | 
| 260 | 
            +
                )
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                if accelerator.is_main_process:
         | 
| 263 | 
            +
                    accelerator.init_trackers("finetuning")
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                for epoch in range(num_train_epochs):
         | 
| 266 | 
            +
                    print(f"epoch {epoch+1}/{num_train_epochs}")
         | 
| 267 | 
            +
                    current_epoch.value = epoch + 1
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    for m in training_models:
         | 
| 270 | 
            +
                        m.train()
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    loss_total = 0
         | 
| 273 | 
            +
                    for step, batch in enumerate(train_dataloader):
         | 
| 274 | 
            +
                        current_step.value = global_step
         | 
| 275 | 
            +
                        with accelerator.accumulate(training_models[0]):  # 複数モデルに対応していない模様だがとりあえずこうしておく
         | 
| 276 | 
            +
                            with torch.no_grad():
         | 
| 277 | 
            +
                                if "latents" in batch and batch["latents"] is not None:
         | 
| 278 | 
            +
                                    latents = batch["latents"].to(accelerator.device)
         | 
| 279 | 
            +
                                else:
         | 
| 280 | 
            +
                                    # latentに変換
         | 
| 281 | 
            +
                                    latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
         | 
| 282 | 
            +
                                latents = latents * 0.18215
         | 
| 283 | 
            +
                            b_size = latents.shape[0]
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                            with torch.set_grad_enabled(args.train_text_encoder):
         | 
| 286 | 
            +
                                # Get the text embedding for conditioning
         | 
| 287 | 
            +
                                input_ids = batch["input_ids"].to(accelerator.device)
         | 
| 288 | 
            +
                                encoder_hidden_states = train_util.get_hidden_states(
         | 
| 289 | 
            +
                                    args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
         | 
| 290 | 
            +
                                )
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                            # Sample noise that we'll add to the latents
         | 
| 293 | 
            +
                            noise = torch.randn_like(latents, device=latents.device)
         | 
| 294 | 
            +
                            if args.noise_offset:
         | 
| 295 | 
            +
                                # https://www.crosslabs.org//blog/diffusion-with-offset-noise
         | 
| 296 | 
            +
                                noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                            # Sample a random timestep for each image
         | 
| 299 | 
            +
                            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
         | 
| 300 | 
            +
                            timesteps = timesteps.long()
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                            # Add noise to the latents according to the noise magnitude at each timestep
         | 
| 303 | 
            +
                            # (this is the forward diffusion process)
         | 
| 304 | 
            +
                            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                            # Predict the noise residual
         | 
| 307 | 
            +
                            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                            if args.v_parameterization:
         | 
| 310 | 
            +
                                # v-parameterization training
         | 
| 311 | 
            +
                                target = noise_scheduler.get_velocity(latents, noise, timesteps)
         | 
| 312 | 
            +
                            else:
         | 
| 313 | 
            +
                                target = noise
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                            if args.min_snr_gamma:
         | 
| 316 | 
            +
                                # do not mean over batch dimension for snr weight
         | 
| 317 | 
            +
                                loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
         | 
| 318 | 
            +
                                loss = loss.mean([1, 2, 3])
         | 
| 319 | 
            +
                                loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
         | 
| 320 | 
            +
                                loss = loss.mean()  # mean over batch dimension
         | 
| 321 | 
            +
                            else:
         | 
| 322 | 
            +
                                loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                            accelerator.backward(loss)
         | 
| 325 | 
            +
                            if accelerator.sync_gradients and args.max_grad_norm != 0.0:
         | 
| 326 | 
            +
                                params_to_clip = []
         | 
| 327 | 
            +
                                for m in training_models:
         | 
| 328 | 
            +
                                    params_to_clip.extend(m.parameters())
         | 
| 329 | 
            +
                                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                            optimizer.step()
         | 
| 332 | 
            +
                            lr_scheduler.step()
         | 
| 333 | 
            +
                            optimizer.zero_grad(set_to_none=True)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                        # Checks if the accelerator has performed an optimization step behind the scenes
         | 
| 336 | 
            +
                        if accelerator.sync_gradients:
         | 
| 337 | 
            +
                            progress_bar.update(1)
         | 
| 338 | 
            +
                            global_step += 1
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                            train_util.sample_images(
         | 
| 341 | 
            +
                                accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
         | 
| 342 | 
            +
                            )
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                        current_loss = loss.detach().item()  # 平均なのでbatch sizeは関係ないはず
         | 
| 345 | 
            +
                        if args.logging_dir is not None:
         | 
| 346 | 
            +
                            logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
         | 
| 347 | 
            +
                            if args.optimizer_type.lower() == "DAdaptation".lower():  # tracking d*lr value
         | 
| 348 | 
            +
                                logs["lr/d*lr"] = (
         | 
| 349 | 
            +
                                    lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
         | 
| 350 | 
            +
                                )
         | 
| 351 | 
            +
                            accelerator.log(logs, step=global_step)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                        # TODO moving averageにする
         | 
| 354 | 
            +
                        loss_total += current_loss
         | 
| 355 | 
            +
                        avr_loss = loss_total / (step + 1)
         | 
| 356 | 
            +
                        logs = {"loss": avr_loss}  # , "lr": lr_scheduler.get_last_lr()[0]}
         | 
| 357 | 
            +
                        progress_bar.set_postfix(**logs)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                        if global_step >= args.max_train_steps:
         | 
| 360 | 
            +
                            break
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    if args.logging_dir is not None:
         | 
| 363 | 
            +
                        logs = {"loss/epoch": loss_total / len(train_dataloader)}
         | 
| 364 | 
            +
                        accelerator.log(logs, step=epoch + 1)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    accelerator.wait_for_everyone()
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    if args.save_every_n_epochs is not None:
         | 
| 369 | 
            +
                        src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
         | 
| 370 | 
            +
                        train_util.save_sd_model_on_epoch_end(
         | 
| 371 | 
            +
                            args,
         | 
| 372 | 
            +
                            accelerator,
         | 
| 373 | 
            +
                            src_path,
         | 
| 374 | 
            +
                            save_stable_diffusion_format,
         | 
| 375 | 
            +
                            use_safetensors,
         | 
| 376 | 
            +
                            save_dtype,
         | 
| 377 | 
            +
                            epoch,
         | 
| 378 | 
            +
                            num_train_epochs,
         | 
| 379 | 
            +
                            global_step,
         | 
| 380 | 
            +
                            unwrap_model(text_encoder),
         | 
| 381 | 
            +
                            unwrap_model(unet),
         | 
| 382 | 
            +
                            vae,
         | 
| 383 | 
            +
                        )
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                is_main_process = accelerator.is_main_process
         | 
| 388 | 
            +
                if is_main_process:
         | 
| 389 | 
            +
                    unet = unwrap_model(unet)
         | 
| 390 | 
            +
                    text_encoder = unwrap_model(text_encoder)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                accelerator.end_training()
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                if args.save_state:
         | 
| 395 | 
            +
                    train_util.save_state_on_train_end(args, accelerator)
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                del accelerator  # この後メモリを使うのでこれは消す
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                if is_main_process:
         | 
| 400 | 
            +
                    src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
         | 
| 401 | 
            +
                    train_util.save_sd_model_on_train_end(
         | 
| 402 | 
            +
                        args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
         | 
| 403 | 
            +
                    )
         | 
| 404 | 
            +
                    print("model saved.")
         | 
| 405 | 
            +
             | 
| 406 | 
            +
             | 
| 407 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 408 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                train_util.add_sd_models_arguments(parser)
         | 
| 411 | 
            +
                train_util.add_dataset_arguments(parser, False, True, True)
         | 
| 412 | 
            +
                train_util.add_training_arguments(parser, False)
         | 
| 413 | 
            +
                train_util.add_sd_saving_arguments(parser)
         | 
| 414 | 
            +
                train_util.add_optimizer_arguments(parser)
         | 
| 415 | 
            +
                config_util.add_config_arguments(parser)
         | 
| 416 | 
            +
                custom_train_functions.add_custom_train_arguments(parser)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
         | 
| 419 | 
            +
                parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                return parser
         | 
| 422 | 
            +
             | 
| 423 | 
            +
             | 
| 424 | 
            +
            if __name__ == "__main__":
         | 
| 425 | 
            +
                parser = setup_parser()
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                args = parser.parse_args()
         | 
| 428 | 
            +
                args = train_util.read_config_from_file(args, parser)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                train(args)
         | 
    	
        fine_tune_README.md
    ADDED
    
    | @@ -0,0 +1,465 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            It is a fine tuning that corresponds to NovelAI's proposed learning method, automatic captioning, tagging, Windows + VRAM 12GB (for v1.4/1.5) environment, etc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            ## overview
         | 
| 4 | 
            +
            Fine tuning of U-Net of Stable Diffusion using Diffusers. It corresponds to the following improvements in NovelAI's article (For Aspect Ratio Bucketing, I referred to NovelAI's code, but the final code is all original).
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            * Use the output of the penultimate layer instead of the last layer of CLIP (Text Encoder).
         | 
| 7 | 
            +
            * Learning at non-square resolutions (Aspect Ratio Bucketing).
         | 
| 8 | 
            +
            * Extend token length from 75 to 225.
         | 
| 9 | 
            +
            * Captioning with BLIP (automatic creation of captions), automatic tagging with DeepDanbooru or WD14Tagger.
         | 
| 10 | 
            +
            * Also supports Hypernetwork learning.
         | 
| 11 | 
            +
            * Supports Stable Diffusion v2.0 (base and 768/v).
         | 
| 12 | 
            +
            * By acquiring the output of VAE in advance and saving it to disk, we aim to save memory and speed up learning.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Text Encoder is not trained by default. For fine tuning of the whole model, it seems common to learn only U-Net (NovelAI seems to be the same). Text Encoder can also be learned as an option.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            ## Additional features
         | 
| 17 | 
            +
            ### Change CLIP output
         | 
| 18 | 
            +
            CLIP (Text Encoder) converts the text into features in order to reflect the prompt in the image. Stable diffusion uses the output of the last layer of CLIP, but you can change it to use the output of the penultimate layer. According to NovelAI, this will reflect prompts more accurately.
         | 
| 19 | 
            +
            It is also possible to use the output of the last layer as is.
         | 
| 20 | 
            +
            *Stable Diffusion 2.0 uses the penultimate layer by default. Do not specify the clip_skip option.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            ### Training in non-square resolutions
         | 
| 23 | 
            +
            Stable Diffusion is trained at 512\*512, but also at resolutions such as 256\*1024 and 384\*640. It is expected that this will reduce the cropped portion and learn the relationship between prompts and images more correctly.
         | 
| 24 | 
            +
            The learning resolution is adjusted vertically and horizontally in units of 64 pixels within a range that does not exceed the resolution area (= memory usage) given as a parameter.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            In machine learning, it is common to unify all input sizes, but there are no particular restrictions, and in fact it is okay as long as they are unified within the same batch. NovelAI's bucketing seems to refer to classifying training data in advance for each learning resolution according to the aspect ratio. And by creating a batch with the images in each bucket, the image size of the batch is unified.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ### Extending token length from 75 to 225
         | 
| 29 | 
            +
            Stable diffusion has a maximum of 75 tokens (77 tokens including the start and end), but we will extend it to 225 tokens.
         | 
| 30 | 
            +
            However, the maximum length that CLIP accepts is 75 tokens, so in the case of 225 tokens, we simply divide it into thirds, call CLIP, and then concatenate the results.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            *I'm not sure if this is the preferred implementation. It seems to be working for now. Especially in 2.0, there is no implementation that can be used as a reference, so I have implemented it independently.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            *Automatic1111's Web UI seems to divide the text with commas in mind, but in my case, it's a simple division.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            ## Environmental arrangement
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            See the [README](./README-en.md) in this repository.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            ## Preparing teacher data
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            Prepare the image data you want to learn and put it in any folder. No prior preparation such as resizing is required.
         | 
| 43 | 
            +
            However, for images that are smaller than the training resolution, it is recommended to enlarge them while maintaining the quality using super-resolution.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            It also supports multiple teacher data folders. Preprocessing will be executed for each folder.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            For example, store an image like this:
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ## Automatic captioning
         | 
| 52 | 
            +
            Skip if you just want to learn tags without captions.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            Also, when preparing captions manually, prepare them in the same directory as the teacher data image, with the same file name, extension .caption, etc. Each file should be a text file with only one line.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            ### Captioning with BLIP
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            The latest version no longer requires BLIP downloads, weight downloads, and additional virtual environments. Works as-is.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            Run make_captions.py in the finetune folder.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            ```
         | 
| 63 | 
            +
            python finetune\make_captions.py --batch_size <batch size> <teacher data folder>
         | 
| 64 | 
            +
            ```
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            If the batch size is 8 and the training data is placed in the parent folder train_data, it will be as follows.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            ```
         | 
| 69 | 
            +
            python finetune\make_captions.py --batch_size 8 ..\train_data
         | 
| 70 | 
            +
            ```
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            A caption file is created in the same directory as the teacher data image with the same file name and extension .caption.
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more).
         | 
| 75 | 
            +
            You can specify the maximum length of the caption with the max_length option. Default is 75. It may be longer if the model is trained with a token length of 225.
         | 
| 76 | 
            +
            You can change the caption extension with the caption_extension option. Default is .caption (.txt conflicts with DeepDanbooru described later).
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            If there are multiple teacher data folders, execute for each folder.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            Note that the inference is random, so the results will change each time you run it. If you want to fix it, specify a random number seed like "--seed 42" with the --seed option.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            For other options, please refer to the help with --help (there seems to be no documentation for the meaning of the parameters, so you have to look at the source).
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            A caption file is generated with the extension .caption by default.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            For example, with captions like:
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            ## Tagged by DeepDanbooru
         | 
| 93 | 
            +
            If you do not want to tag the danbooru tag itself, please proceed to "Preprocessing of caption and tag information".
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            Tagging is done with DeepDanbooru or WD14Tagger. WD14Tagger seems to be more accurate. If you want to tag with WD14Tagger, skip to the next chapter.
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            ### Environmental arrangement
         | 
| 98 | 
            +
            Clone DeepDanbooru https://github.com/KichangKim/DeepDanbooru into your working folder, or download the zip and extract it. I unzipped it.
         | 
| 99 | 
            +
            Also, download deepdanbooru-v3-20211112-sgd-e28.zip from Assets of "DeepDanbooru Pretrained Model v3-20211112-sgd-e28" on the DeepDanbooru Releases page https://github.com/KichangKim/DeepDanbooru/releases and extract it to the DeepDanbooru folder.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            Download from below. Click to open Assets and download from there.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            Make a directory structure like this
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            Install the necessary libraries for the Diffusers environment. Go to the DeepDanbooru folder and install it (I think it's actually just adding tensorflow-io).
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            ```
         | 
| 112 | 
            +
            pip install -r requirements.txt
         | 
| 113 | 
            +
            ```
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            Next, install DeepDanbooru itself.
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            ```
         | 
| 118 | 
            +
            pip install .
         | 
| 119 | 
            +
            ```
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            This completes the preparation of the environment for tagging.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            ### Implementing tagging
         | 
| 124 | 
            +
            Go to DeepDanbooru's folder and run deepdanbooru to tag.
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            ```
         | 
| 127 | 
            +
            deepdanbooru evaluate <teacher data folder> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
         | 
| 128 | 
            +
            ```
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            If you put the training data in the parent folder train_data, it will be as follows.
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            ```
         | 
| 133 | 
            +
            deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
         | 
| 134 | 
            +
            ```
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. It is slow because it is processed one by one.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            If there are multiple teacher data folders, execute for each folder.
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            It is generated as follows.
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            A tag is attached like this (great amount of information...).
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            ## Tagging with WD14Tagger
         | 
| 149 | 
            +
            This procedure uses WD14Tagger instead of DeepDanbooru.
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            Use the tagger used in Mr. Automatic1111's WebUI. I referred to the information on this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            The modules required for the initial environment maintenance have already been installed. Weights are automatically downloaded from Hugging Face.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            ### Implementing tagging
         | 
| 156 | 
            +
            Run the script to do the tagging.
         | 
| 157 | 
            +
            ```
         | 
| 158 | 
            +
            python tag_images_by_wd14_tagger.py --batch_size <batch size> <teacher data folder>
         | 
| 159 | 
            +
            ```
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            If you put the training data in the parent folder train_data, it will be as follows.
         | 
| 162 | 
            +
            ```
         | 
| 163 | 
            +
            python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
         | 
| 164 | 
            +
            ```
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            The model file will be automatically downloaded to the wd14_tagger_model folder on first launch (folder can be changed in options). It will be as follows.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            A tag file is created in the same directory as the teacher data image with the same file name and extension .txt.
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            With the thresh option, you can specify the number of confidences of the determined tag to attach the tag. The default is 0.35, same as the WD14Tagger sample. Lower values give more tags, but less accuracy.
         | 
| 177 | 
            +
            Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). You can change the tag file extension with the caption_extension option. Default is .txt.
         | 
| 178 | 
            +
            You can specify the folder where the model is saved with the model_dir option.
         | 
| 179 | 
            +
            Also, if you specify the force_download option, the model will be re-downloaded even if there is a save destination folder.
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            If there are multiple teacher data folders, execute for each folder.
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            ## Preprocessing caption and tag information
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            Combine captions and tags into a single file as metadata for easy processing from scripts.
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            ### Caption preprocessing
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            To put captions into the metadata, run the following in your working folder (if you don't use captions for learning, you don't need to run this) (it's actually a single line, and so on).
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            ```
         | 
| 192 | 
            +
            python merge_captions_to_metadata.py <teacher data folder>
         | 
| 193 | 
            +
            --in_json <metadata file name to read>
         | 
| 194 | 
            +
                 <metadata file name>
         | 
| 195 | 
            +
            ```
         | 
| 196 | 
            +
             | 
| 197 | 
            +
            The metadata file name is an arbitrary name.
         | 
| 198 | 
            +
            If the training data is train_data, there is no metadata file to read, and the metadata file is meta_cap.json, it will be as follows.
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            ```
         | 
| 201 | 
            +
            python merge_captions_to_metadata.py train_data meta_cap.json
         | 
| 202 | 
            +
            ```
         | 
| 203 | 
            +
             | 
| 204 | 
            +
            You can specify the caption extension with the caption_extension option.
         | 
| 205 | 
            +
             | 
| 206 | 
            +
            If there are multiple teacher data folders, please specify the full_path argument (metadata will have full path information). Then run it for each folder.
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            ```
         | 
| 209 | 
            +
            python merge_captions_to_metadata.py --full_path
         | 
| 210 | 
            +
                 train_data1 meta_cap1.json
         | 
| 211 | 
            +
            python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
         | 
| 212 | 
            +
                 train_data2 meta_cap2.json
         | 
| 213 | 
            +
            ```
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there.
         | 
| 216 | 
            +
             | 
| 217 | 
            +
            __*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __
         | 
| 218 | 
            +
             | 
| 219 | 
            +
            ### Tag preprocessing
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            Similarly, tags are also collected in metadata (no need to do this if tags are not used for learning).
         | 
| 222 | 
            +
            ```
         | 
| 223 | 
            +
            python merge_dd_tags_to_metadata.py <teacher data folder>
         | 
| 224 | 
            +
                 --in_json <metadata file name to load>
         | 
| 225 | 
            +
                 <metadata file name to write>
         | 
| 226 | 
            +
            ```
         | 
| 227 | 
            +
             | 
| 228 | 
            +
            With the same directory structure as above, when reading meta_cap.json and writing to meta_cap_dd.json, it will be as follows.
         | 
| 229 | 
            +
            ```
         | 
| 230 | 
            +
            python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json
         | 
| 231 | 
            +
            ```
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder.
         | 
| 234 | 
            +
             | 
| 235 | 
            +
            ```
         | 
| 236 | 
            +
            python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
         | 
| 237 | 
            +
                 train_data1 meta_cap_dd1.json
         | 
| 238 | 
            +
            python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
         | 
| 239 | 
            +
                 train_data2 meta_cap_dd2.json
         | 
| 240 | 
            +
            ```
         | 
| 241 | 
            +
             | 
| 242 | 
            +
            If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there.
         | 
| 243 | 
            +
             | 
| 244 | 
            +
            __*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __
         | 
| 245 | 
            +
             | 
| 246 | 
            +
            ### Cleaning captions and tags
         | 
| 247 | 
            +
            Up to this point, captions and DeepDanbooru tags have been put together in the metadata file. However, captions with automatic captioning are subtle due to spelling variations (*), and tags include underscores and ratings (in the case of DeepDanbooru), so the editor's replacement function etc. You should use it to clean your captions and tags.
         | 
| 248 | 
            +
             | 
| 249 | 
            +
            *For example, when learning a girl in an anime picture, there are variations in captions such as girl/girls/woman/women. Also, it may be more appropriate to simply use "girl" for things like "anime girl".
         | 
| 250 | 
            +
             | 
| 251 | 
            +
            A script for cleaning is provided, so please edit the contents of the script according to the situation and use it.
         | 
| 252 | 
            +
             | 
| 253 | 
            +
            (It is no longer necessary to specify the teacher data folder. All data in the metadata will be cleaned.)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
            ```
         | 
| 256 | 
            +
            python clean_captions_and_tags.py <metadata file name to read> <metadata file name to write>
         | 
| 257 | 
            +
            ```
         | 
| 258 | 
            +
             | 
| 259 | 
            +
            Please note that --in_json is not included. For example:
         | 
| 260 | 
            +
             | 
| 261 | 
            +
            ```
         | 
| 262 | 
            +
            python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
         | 
| 263 | 
            +
            ```
         | 
| 264 | 
            +
             | 
| 265 | 
            +
            Preprocessing of captions and tags is now complete.
         | 
| 266 | 
            +
             | 
| 267 | 
            +
            ## Get latents in advance
         | 
| 268 | 
            +
             | 
| 269 | 
            +
            In order to speed up the learning, we acquire the latent representation of the image in advance and save it to disk. At the same time, bucketing (classifying the training data according to the aspect ratio) is performed.
         | 
| 270 | 
            +
             | 
| 271 | 
            +
            In your working folder, type:
         | 
| 272 | 
            +
            ```
         | 
| 273 | 
            +
            python prepare_buckets_latents.py <teacher data folder>
         | 
| 274 | 
            +
                 <metadata file name to read> <metadata file name to write>
         | 
| 275 | 
            +
                 <model name or checkpoint for fine tuning>
         | 
| 276 | 
            +
                 --batch_size <batch size>
         | 
| 277 | 
            +
                 --max_resolution <resolution width, height>
         | 
| 278 | 
            +
                 --mixed_precision <precision>
         | 
| 279 | 
            +
            ```
         | 
| 280 | 
            +
             | 
| 281 | 
            +
            If the model is model.ckpt, batch size 4, training resolution is 512\*512, precision is no (float32), read metadata from meta_clean.json and write to meta_lat.json:
         | 
| 282 | 
            +
             | 
| 283 | 
            +
            ```
         | 
| 284 | 
            +
            python prepare_buckets_latents.py
         | 
| 285 | 
            +
                 train_data meta_clean.json meta_lat.json model.ckpt
         | 
| 286 | 
            +
                 --batch_size 4 --max_resolution 512,512 --mixed_precision no
         | 
| 287 | 
            +
            ```
         | 
| 288 | 
            +
             | 
| 289 | 
            +
            Latents are saved in numpy npz format in the teacher data folder.
         | 
| 290 | 
            +
             | 
| 291 | 
            +
            Specify the --v2 option when loading a Stable Diffusion 2.0 model (--v_parameterization is not required).
         | 
| 292 | 
            +
             | 
| 293 | 
            +
            You can specify the minimum resolution size with the --min_bucket_reso option and the maximum size with the --max_bucket_reso option. The defaults are 256 and 1024 respectively. For example, specifying a minimum size of 384 will not use resolutions such as 256\*1024 or 320\*768.
         | 
| 294 | 
            +
            If you increase the resolution to something like 768\*768, you should specify something like 1280 for the maximum size.
         | 
| 295 | 
            +
             | 
| 296 | 
            +
            If you specify the --flip_aug option, it will perform horizontal flip augmentation (data augmentation). You can artificially double the amount of data, but if you specify it when the data is not left-right symmetrical (for example, character appearance, hairstyle, etc.), learning will not go well.
         | 
| 297 | 
            +
            (This is a simple implementation that acquires the latents for the flipped image and saves the \*\_flip.npz file. No options are required for fline_tune.py. If there is a file with \_flip, Randomly load a file without
         | 
| 298 | 
            +
             | 
| 299 | 
            +
            The batch size may be increased a little more even with 12GB of VRAM.
         | 
| 300 | 
            +
            The resolution is a number divisible by 64, and is specified by "width, height". The resolution is directly linked to the memory size during fine tuning. 512,512 seems to be the limit with VRAM 12GB (*). 16GB may be raised to 512,704 or 512,768. Even with 256, 256, etc., it seems to be difficult with 8GB of VRAM (because parameters and optimizers require a certain amount of memory regardless of resolution).
         | 
| 301 | 
            +
             | 
| 302 | 
            +
            *There was also a report that learning batch size 1 worked with 12GB VRAM and 640,640.
         | 
| 303 | 
            +
             | 
| 304 | 
            +
            The result of bucketing is displayed as follows.
         | 
| 305 | 
            +
             | 
| 306 | 
            +
            
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder.
         | 
| 309 | 
            +
            ```
         | 
| 310 | 
            +
            python prepare_buckets_latents.py --full_path
         | 
| 311 | 
            +
                 train_data1 meta_clean.json meta_lat1.json model.ckpt
         | 
| 312 | 
            +
                 --batch_size 4 --max_resolution 512,512 --mixed_precision no
         | 
| 313 | 
            +
             | 
| 314 | 
            +
            python prepare_buckets_latents.py --full_path
         | 
| 315 | 
            +
                 train_data2 meta_lat1.json meta_lat2.json model.ckpt
         | 
| 316 | 
            +
                 --batch_size 4 --max_resolution 512,512 --mixed_precision no
         | 
| 317 | 
            +
             | 
| 318 | 
            +
            ```
         | 
| 319 | 
            +
            It is possible to make the read source and write destination the same, but separate is safer.
         | 
| 320 | 
            +
             | 
| 321 | 
            +
            __*It is safe to rewrite the argument each time and write it to a separate metadata file. __
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            ## Run training
         | 
| 325 | 
            +
            For example: Below are the settings for saving memory.
         | 
| 326 | 
            +
            ```
         | 
| 327 | 
            +
            accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
         | 
| 328 | 
            +
                 --pretrained_model_name_or_path=model.ckpt
         | 
| 329 | 
            +
                 --in_json meta_lat.json
         | 
| 330 | 
            +
                 --train_data_dir=train_data
         | 
| 331 | 
            +
                 --output_dir=fine_tuned
         | 
| 332 | 
            +
                 --shuffle_caption
         | 
| 333 | 
            +
                 --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
         | 
| 334 | 
            +
                 --use_8bit_adam --xformers --gradient_checkpointing
         | 
| 335 | 
            +
                 --mixed_precision=bf16
         | 
| 336 | 
            +
                 --save_every_n_epochs=4
         | 
| 337 | 
            +
            ```
         | 
| 338 | 
            +
             | 
| 339 | 
            +
            It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process of accelerate.
         | 
| 340 | 
            +
             | 
| 341 | 
            +
            Specify the model to be trained in pretrained_model_name_or_path (Stable Diffusion checkpoint or Diffusers model). Stable Diffusion checkpoint supports .ckpt and .safetensors (automatically determined by extension).
         | 
| 342 | 
            +
             | 
| 343 | 
            +
            Specifies the metadata file when caching latent to in_json.
         | 
| 344 | 
            +
             | 
| 345 | 
            +
            Specify the training data folder for train_data_dir and the output destination folder for the trained model for output_dir.
         | 
| 346 | 
            +
             | 
| 347 | 
            +
            If shuffle_caption is specified, captions and tags are shuffled and learned in units separated by commas (this is the method used in Waifu Diffusion v1.3).
         | 
| 348 | 
            +
            (You can keep some of the leading tokens fixed without shuffling. See keep_tokens for other options.)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
            Specify the batch size in train_batch_size. Specify 1 or 2 for VRAM 12GB. The number that can be specified also changes depending on the resolution.
         | 
| 351 | 
            +
            The actual amount of data used for training is "batch size x number of steps". When increasing the batch size, the number of steps can be decreased accordingly.
         | 
| 352 | 
            +
             | 
| 353 | 
            +
            Specify the learning rate in learning_rate. For example Waifu Diffusion v1.3 seems to be 5e-6.
         | 
| 354 | 
            +
            Specify the number of steps in max_train_steps.
         | 
| 355 | 
            +
             | 
| 356 | 
            +
            Specify use_8bit_adam to use the 8-bit Adam Optimizer. It saves memory and speeds up, but accuracy may decrease.
         | 
| 357 | 
            +
             | 
| 358 | 
            +
            Specifying xformers replaces CrossAttention to save memory and speed up.
         | 
| 359 | 
            +
            * As of 11/9, xformers will cause an error in float32 learning, so please use bf16/fp16 or use memory-saving CrossAttention with mem_eff_attn instead (speed is inferior to xformers).
         | 
| 360 | 
            +
             | 
| 361 | 
            +
            Enable intermediate saving of gradients in gradient_checkpointing. It's slower, but uses less memory.
         | 
| 362 | 
            +
             | 
| 363 | 
            +
            Specifies whether to use mixed precision with mixed_precision. Specifying "fp16" or "bf16" saves memory, but accuracy is inferior.
         | 
| 364 | 
            +
            "fp16" and "bf16" use almost the same amount of memory, and it is said that bf16 has better learning results (I didn't feel much difference in the range I tried).
         | 
| 365 | 
            +
            If "no" is specified, it will not be used (it will be float32).
         | 
| 366 | 
            +
             | 
| 367 | 
            +
            * It seems that an error will occur when reading checkpoints learned with bf16 with Mr. AUTOMATIC1111's Web UI. This seems to be because the data type bfloat16 causes an error in the Web UI model safety checker. Save in fp16 or float32 format with the save_precision option. Or it seems to be good to store it in safetytensors format.
         | 
| 368 | 
            +
             | 
| 369 | 
            +
            Specifying save_every_n_epochs will save the model being trained every time that many epochs have passed.
         | 
| 370 | 
            +
             | 
| 371 | 
            +
            ### Supports Stable Diffusion 2.0
         | 
| 372 | 
            +
            Specify the --v2 option when using Hugging Face's stable-diffusion-2-base, and specify both --v2 and --v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt please.
         | 
| 373 | 
            +
             | 
| 374 | 
            +
            ### Increase accuracy and speed when memory is available
         | 
| 375 | 
            +
            First, removing gradient_checkpointing will speed it up. However, the batch size that can be set is reduced, so please set while looking at the balance between accuracy and speed.
         | 
| 376 | 
            +
             | 
| 377 | 
            +
            Increasing the batch size increases speed and accuracy. Increase the speed while checking the speed per data within the range where the memory is sufficient (the speed may actually decrease when the memory is at the limit).
         | 
| 378 | 
            +
             | 
| 379 | 
            +
            ### Change CLIP output used
         | 
| 380 | 
            +
            Specifying 2 for the clip_skip option uses the output of the next-to-last layer. If 1 or option is omitted, the last layer is used.
         | 
| 381 | 
            +
            The learned model should be able to be inferred by Automatic1111's web UI.
         | 
| 382 | 
            +
             | 
| 383 | 
            +
            *SD2.0 uses the second layer from the back by default, so please do not specify it when learning SD2.0.
         | 
| 384 | 
            +
             | 
| 385 | 
            +
            If the model being trained was originally trained to use the second layer, 2 is a good value.
         | 
| 386 | 
            +
             | 
| 387 | 
            +
            If you were using the last layer instead, the entire model would have been trained on that assumption. Therefore, if you train again using the second layer, you may need a certain number of teacher data and longer learning to obtain the desired learning result.
         | 
| 388 | 
            +
             | 
| 389 | 
            +
            ### Extending Token Length
         | 
| 390 | 
            +
            You can learn by extending the token length by specifying 150 or 225 for max_token_length.
         | 
| 391 | 
            +
            The learned model should be able to be inferred by Automatic1111's web UI.
         | 
| 392 | 
            +
             | 
| 393 | 
            +
            As with clip_skip, learning with a length different from the learning state of the model may require a certain amount of teacher data and a longer learning time.
         | 
| 394 | 
            +
             | 
| 395 | 
            +
            ### Save learning log
         | 
| 396 | 
            +
            Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved.
         | 
| 397 | 
            +
             | 
| 398 | 
            +
            For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder.
         | 
| 399 | 
            +
            Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=fine_tune_style1" for identification.
         | 
| 400 | 
            +
             | 
| 401 | 
            +
            To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard).
         | 
| 402 | 
            +
            ```
         | 
| 403 | 
            +
            tensorboard --logdir=logs
         | 
| 404 | 
            +
            ```
         | 
| 405 | 
            +
             | 
| 406 | 
            +
            ### Learning Hypernetworks
         | 
| 407 | 
            +
            It will be explained in another article.
         | 
| 408 | 
            +
             | 
| 409 | 
            +
            ### Learning with fp16 gradient (experimental feature)
         | 
| 410 | 
            +
            The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). As a result, it seems that the SD1.x 512*512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512*512 size can be learned with a VRAM usage of less than 12GB.
         | 
| 411 | 
            +
             | 
| 412 | 
            +
            Specify fp16 in advance in accelerate config and optionally set mixed_precision="fp16" (does not work with bf16).
         | 
| 413 | 
            +
             | 
| 414 | 
            +
            To minimize memory usage, use the xformers, use_8bit_adam, gradient_checkpointing options and set train_batch_size to 1.
         | 
| 415 | 
            +
            (If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
            It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). The accuracy will drop considerably, and the probability of learning failure on the way will also increase. The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk.
         | 
| 418 | 
            +
             | 
| 419 | 
            +
            ### Other Options
         | 
| 420 | 
            +
             | 
| 421 | 
            +
            #### keep_tokens
         | 
| 422 | 
            +
            If a number is specified, the specified number of tokens (comma-separated strings) from the beginning of the caption are fixed without being shuffled.
         | 
| 423 | 
            +
             | 
| 424 | 
            +
            If there are both captions and tags, the prompts during learning will be concatenated like "caption, tag 1, tag 2...", so if you set "--keep_tokens=1", the caption will always be at the beginning during learning. will come.
         | 
| 425 | 
            +
             | 
| 426 | 
            +
            #### dataset_repeats
         | 
| 427 | 
            +
            If the number of data sets is extremely small, the epoch will end soon (it will take some time at the epoch break), so please specify a numerical value and multiply the data by some to make the epoch longer.
         | 
| 428 | 
            +
             | 
| 429 | 
            +
            #### train_text_encoder
         | 
| 430 | 
            +
            Text Encoder is also a learning target. Slightly increased memory usage.
         | 
| 431 | 
            +
             | 
| 432 | 
            +
            In normal fine tuning, the Text Encoder is not targeted for training (probably because U-Net is trained to follow the output of the Text Encoder), but if the number of training data is small, the Text Encoder is trained like DreamBooth. also seems to be valid.
         | 
| 433 | 
            +
             | 
| 434 | 
            +
            #### save_precision
         | 
| 435 | 
            +
            The data format when saving checkpoints can be specified from float, fp16, and bf16 (if not specified, it is the same as the data format during learning). It saves disk space, but the model produces different results. Also, if you specify float or fp16, you should be able to read it on Mr. 1111's Web UI.
         | 
| 436 | 
            +
             | 
| 437 | 
            +
            *For VAE, the data format of the original checkpoint will remain, so the model size may not be reduced to a little over 2GB even with fp16.
         | 
| 438 | 
            +
             | 
| 439 | 
            +
            #### save_model_as
         | 
| 440 | 
            +
            Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors.
         | 
| 441 | 
            +
             | 
| 442 | 
            +
            When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face.
         | 
| 443 | 
            +
             | 
| 444 | 
            +
            #### use_safetensors
         | 
| 445 | 
            +
            This option saves checkpoints in safetyensors format. The save format will be the default (same format as loaded).
         | 
| 446 | 
            +
             | 
| 447 | 
            +
            #### save_state and resume
         | 
| 448 | 
            +
            The save_state option saves the learning state of the optimizer, etc. in addition to the checkpoint in the folder when saving midway and at the final save. This avoids a decrease in accuracy when learning is resumed after being interrupted (since the optimizer optimizes while having a state, if the state is reset, the optimization must be performed again from the initial state. not). Note that the number of steps is not saved due to Accelerate specifications.
         | 
| 449 | 
            +
             | 
| 450 | 
            +
            When starting the script, you can resume by specifying the folder where the state is saved with the resume option.
         | 
| 451 | 
            +
             | 
| 452 | 
            +
            Please note that the learning state will be about 5 GB per save, so please be careful of the disk capacity.
         | 
| 453 | 
            +
             | 
| 454 | 
            +
            #### gradient_accumulation_steps
         | 
| 455 | 
            +
            Updates the gradient in batches for the specified number of steps. Has a similar effect to increasing the batch size, but consumes slightly more memory.
         | 
| 456 | 
            +
             | 
| 457 | 
            +
            *The Accelerate specification does not support multiple learning models, so if you set Text Encoder as the learning target and specify a value of 2 or more for this option, an error may occur.
         | 
| 458 | 
            +
             | 
| 459 | 
            +
            #### lr_scheduler / lr_warmup_steps
         | 
| 460 | 
            +
            You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant.
         | 
| 461 | 
            +
             | 
| 462 | 
            +
            With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details.
         | 
| 463 | 
            +
             | 
| 464 | 
            +
            #### diffusers_xformers
         | 
| 465 | 
            +
            Uses Diffusers' xformers feature rather than the script's own xformers replacement feature. Hypernetwork learning is no longer possible.
         | 
    	
        fine_tune_README_ja.md
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # 概要
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            * CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。
         | 
| 10 | 
            +
            * 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。
         | 
| 11 | 
            +
            * トークン長を75から225に拡張する。
         | 
| 12 | 
            +
            * BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。
         | 
| 13 | 
            +
            * Hypernetworkの学習にも対応する。
         | 
| 14 | 
            +
            * Stable Diffusion v2.0(baseおよび768/v)に対応。
         | 
| 15 | 
            +
            * VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # 追加機能について
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            ## CLIPの出力の変更
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。
         | 
| 24 | 
            +
            元のまま、最後の層の出力を用いることも可能です。
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ## 正方形以外の解像度での学習
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。
         | 
| 31 | 
            +
            学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            ## トークン長の75から225への拡張
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。
         | 
| 38 | 
            +
            ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            ※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            ※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # 学習の手順
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            ## データの準備
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            [学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            ## 学習の実行
         | 
| 53 | 
            +
            たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            ```
         | 
| 56 | 
            +
            accelerate launch --num_cpu_threads_per_process 1 fine_tune.py 
         | 
| 57 | 
            +
                --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> 
         | 
| 58 | 
            +
                --output_dir=<学習したモデルの出力先フォルダ>  
         | 
| 59 | 
            +
                --output_name=<学習したモデル出力時のファイル名> 
         | 
| 60 | 
            +
                --dataset_config=<データ準備で作成した.tomlファイル> 
         | 
| 61 | 
            +
                --save_model_as=safetensors 
         | 
| 62 | 
            +
                --learning_rate=5e-6 --max_train_steps=10000 
         | 
| 63 | 
            +
                --use_8bit_adam --xformers --gradient_checkpointing
         | 
| 64 | 
            +
                --mixed_precision=fp16
         | 
| 65 | 
            +
            ```
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            `num_cpu_threads_per_process` には通常は1を指定するとよいようです。
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            ### よく使われるオプションについて
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            以下の場合にはオプションに関するドキュメントを参照してください。
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            - Stable Diffusion 2.xまたはそこからの派生モデルを学習する
         | 
| 90 | 
            +
            - clip skipを2以上を前提としたモデルを学習する
         | 
| 91 | 
            +
            - 75トークンを超えたキャプションで学習する
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            ### バッチサイズについて
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            ### 学習率について
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            ### 以前の形式のデータセット指定をした場合のコマンドライン
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            ```
         | 
| 106 | 
            +
            accelerate launch --num_cpu_threads_per_process 1 fine_tune.py 
         | 
| 107 | 
            +
                --pretrained_model_name_or_path=model.ckpt 
         | 
| 108 | 
            +
                --in_json meta_lat.json 
         | 
| 109 | 
            +
                --train_data_dir=train_data 
         | 
| 110 | 
            +
                --output_dir=fine_tuned 
         | 
| 111 | 
            +
                --shuffle_caption 
         | 
| 112 | 
            +
                --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 
         | 
| 113 | 
            +
                --use_8bit_adam --xformers --gradient_checkpointing
         | 
| 114 | 
            +
                --mixed_precision=bf16
         | 
| 115 | 
            +
                --save_every_n_epochs=4
         | 
| 116 | 
            +
            ```
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            <!-- 
         | 
| 119 | 
            +
            ### 勾配をfp16とした学習(実験的機能)
         | 
| 120 | 
            +
            full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
         | 
| 125 | 
            +
            (余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
         | 
| 128 | 
            +
            -->
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            # fine tuning特有のその他の主なオプション
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            すべてのオプションについては別文書を参照してください。
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            ## `train_text_encoder`
         | 
| 135 | 
            +
            Text Encoderも学習対象とします。メモリ使用量が若干増加します。
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            ## `diffusers_xformers`
         | 
| 140 | 
            +
            スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。
         | 
    	
        finetune/blip/blip.py
    ADDED
    
    | @@ -0,0 +1,240 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            '''
         | 
| 2 | 
            +
             * Copyright (c) 2022, salesforce.com, inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             * SPDX-License-Identifier: BSD-3-Clause
         | 
| 5 | 
            +
             * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
         | 
| 6 | 
            +
             * By Junnan Li
         | 
| 7 | 
            +
            '''
         | 
| 8 | 
            +
            import warnings
         | 
| 9 | 
            +
            warnings.filterwarnings("ignore")
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # from models.vit import VisionTransformer, interpolate_pos_embed
         | 
| 12 | 
            +
            # from models.med import BertConfig, BertModel, BertLMHeadModel
         | 
| 13 | 
            +
            from blip.vit import VisionTransformer, interpolate_pos_embed
         | 
| 14 | 
            +
            from blip.med import BertConfig, BertModel, BertLMHeadModel
         | 
| 15 | 
            +
            from transformers import BertTokenizer
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            from torch import nn
         | 
| 19 | 
            +
            import torch.nn.functional as F
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import os
         | 
| 22 | 
            +
            from urllib.parse import urlparse
         | 
| 23 | 
            +
            from timm.models.hub import download_cached_file
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            class BLIP_Base(nn.Module):
         | 
| 26 | 
            +
                def __init__(self,                 
         | 
| 27 | 
            +
                             med_config = 'configs/med_config.json',  
         | 
| 28 | 
            +
                             image_size = 224,
         | 
| 29 | 
            +
                             vit = 'base',
         | 
| 30 | 
            +
                             vit_grad_ckpt = False,
         | 
| 31 | 
            +
                             vit_ckpt_layer = 0,                 
         | 
| 32 | 
            +
                             ):
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    Args:
         | 
| 35 | 
            +
                        med_config (str): path for the mixture of encoder-decoder model's configuration file
         | 
| 36 | 
            +
                        image_size (int): input image size
         | 
| 37 | 
            +
                        vit (str): model size of vision transformer
         | 
| 38 | 
            +
                    """               
         | 
| 39 | 
            +
                    super().__init__()
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                    self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
         | 
| 42 | 
            +
                    self.tokenizer = init_tokenizer()   
         | 
| 43 | 
            +
                    med_config = BertConfig.from_json_file(med_config)
         | 
| 44 | 
            +
                    med_config.encoder_width = vision_width
         | 
| 45 | 
            +
                    self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)  
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    
         | 
| 48 | 
            +
                def forward(self, image, caption, mode):
         | 
| 49 | 
            +
                    
         | 
| 50 | 
            +
                    assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
         | 
| 51 | 
            +
                    text = self.tokenizer(caption, return_tensors="pt").to(image.device) 
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                    if mode=='image':    
         | 
| 54 | 
            +
                        # return image features
         | 
| 55 | 
            +
                        image_embeds = self.visual_encoder(image)             
         | 
| 56 | 
            +
                        return image_embeds
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    elif mode=='text':
         | 
| 59 | 
            +
                        # return text features
         | 
| 60 | 
            +
                        text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,                      
         | 
| 61 | 
            +
                                                        return_dict = True, mode = 'text')  
         | 
| 62 | 
            +
                        return text_output.last_hidden_state
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    elif mode=='multimodal':
         | 
| 65 | 
            +
                        # return multimodel features
         | 
| 66 | 
            +
                        image_embeds = self.visual_encoder(image)    
         | 
| 67 | 
            +
                        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)      
         | 
| 68 | 
            +
                        
         | 
| 69 | 
            +
                        text.input_ids[:,0] = self.tokenizer.enc_token_id
         | 
| 70 | 
            +
                        output = self.text_encoder(text.input_ids,
         | 
| 71 | 
            +
                                                   attention_mask = text.attention_mask,
         | 
| 72 | 
            +
                                                   encoder_hidden_states = image_embeds,
         | 
| 73 | 
            +
                                                   encoder_attention_mask = image_atts,      
         | 
| 74 | 
            +
                                                   return_dict = True,
         | 
| 75 | 
            +
                                                  )              
         | 
| 76 | 
            +
                        return output.last_hidden_state
         | 
| 77 | 
            +
                    
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    
         | 
| 80 | 
            +
            class BLIP_Decoder(nn.Module):
         | 
| 81 | 
            +
                def __init__(self,                 
         | 
| 82 | 
            +
                             med_config = 'configs/med_config.json',  
         | 
| 83 | 
            +
                             image_size = 384,
         | 
| 84 | 
            +
                             vit = 'base',
         | 
| 85 | 
            +
                             vit_grad_ckpt = False,
         | 
| 86 | 
            +
                             vit_ckpt_layer = 0,
         | 
| 87 | 
            +
                             prompt = 'a picture of ',
         | 
| 88 | 
            +
                             ):
         | 
| 89 | 
            +
                    """
         | 
| 90 | 
            +
                    Args:
         | 
| 91 | 
            +
                        med_config (str): path for the mixture of encoder-decoder model's configuration file
         | 
| 92 | 
            +
                        image_size (int): input image size
         | 
| 93 | 
            +
                        vit (str): model size of vision transformer
         | 
| 94 | 
            +
                    """            
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    
         | 
| 97 | 
            +
                    self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
         | 
| 98 | 
            +
                    self.tokenizer = init_tokenizer()   
         | 
| 99 | 
            +
                    med_config = BertConfig.from_json_file(med_config)
         | 
| 100 | 
            +
                    med_config.encoder_width = vision_width
         | 
| 101 | 
            +
                    self.text_decoder = BertLMHeadModel(config=med_config)    
         | 
| 102 | 
            +
                    
         | 
| 103 | 
            +
                    self.prompt = prompt
         | 
| 104 | 
            +
                    self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    
         | 
| 107 | 
            +
                def forward(self, image, caption):
         | 
| 108 | 
            +
                    
         | 
| 109 | 
            +
                    image_embeds = self.visual_encoder(image) 
         | 
| 110 | 
            +
                    image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                    text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 
         | 
| 113 | 
            +
                    
         | 
| 114 | 
            +
                    text.input_ids[:,0] = self.tokenizer.bos_token_id
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)         
         | 
| 117 | 
            +
                    decoder_targets[:,:self.prompt_length] = -100
         | 
| 118 | 
            +
                 
         | 
| 119 | 
            +
                    decoder_output = self.text_decoder(text.input_ids, 
         | 
| 120 | 
            +
                                                       attention_mask = text.attention_mask, 
         | 
| 121 | 
            +
                                                       encoder_hidden_states = image_embeds,
         | 
| 122 | 
            +
                                                       encoder_attention_mask = image_atts,                  
         | 
| 123 | 
            +
                                                       labels = decoder_targets,
         | 
| 124 | 
            +
                                                       return_dict = True,   
         | 
| 125 | 
            +
                                                      )   
         | 
| 126 | 
            +
                    loss_lm = decoder_output.loss
         | 
| 127 | 
            +
                    
         | 
| 128 | 
            +
                    return loss_lm
         | 
| 129 | 
            +
                    
         | 
| 130 | 
            +
                def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
         | 
| 131 | 
            +
                    image_embeds = self.visual_encoder(image)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    if not sample:
         | 
| 134 | 
            +
                        image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
         | 
| 135 | 
            +
                        
         | 
| 136 | 
            +
                    image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
         | 
| 137 | 
            +
                    model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    prompt = [self.prompt] * image.size(0)
         | 
| 140 | 
            +
                    input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 
         | 
| 141 | 
            +
                    input_ids[:,0] = self.tokenizer.bos_token_id
         | 
| 142 | 
            +
                    input_ids = input_ids[:, :-1] 
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if sample:
         | 
| 145 | 
            +
                        #nucleus sampling
         | 
| 146 | 
            +
                        outputs = self.text_decoder.generate(input_ids=input_ids,
         | 
| 147 | 
            +
                                                              max_length=max_length,
         | 
| 148 | 
            +
                                                              min_length=min_length,
         | 
| 149 | 
            +
                                                              do_sample=True,
         | 
| 150 | 
            +
                                                              top_p=top_p,
         | 
| 151 | 
            +
                                                              num_return_sequences=1,
         | 
| 152 | 
            +
                                                              eos_token_id=self.tokenizer.sep_token_id,
         | 
| 153 | 
            +
                                                              pad_token_id=self.tokenizer.pad_token_id, 
         | 
| 154 | 
            +
                                                              repetition_penalty=1.1,                                            
         | 
| 155 | 
            +
                                                              **model_kwargs)
         | 
| 156 | 
            +
                    else:
         | 
| 157 | 
            +
                        #beam search
         | 
| 158 | 
            +
                        outputs = self.text_decoder.generate(input_ids=input_ids,
         | 
| 159 | 
            +
                                                              max_length=max_length,
         | 
| 160 | 
            +
                                                              min_length=min_length,
         | 
| 161 | 
            +
                                                              num_beams=num_beams,
         | 
| 162 | 
            +
                                                              eos_token_id=self.tokenizer.sep_token_id,
         | 
| 163 | 
            +
                                                              pad_token_id=self.tokenizer.pad_token_id,     
         | 
| 164 | 
            +
                                                              repetition_penalty=repetition_penalty,
         | 
| 165 | 
            +
                                                              **model_kwargs)            
         | 
| 166 | 
            +
                        
         | 
| 167 | 
            +
                    captions = []    
         | 
| 168 | 
            +
                    for output in outputs:
         | 
| 169 | 
            +
                        caption = self.tokenizer.decode(output, skip_special_tokens=True)    
         | 
| 170 | 
            +
                        captions.append(caption[len(self.prompt):])
         | 
| 171 | 
            +
                    return captions
         | 
| 172 | 
            +
                
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            def blip_decoder(pretrained='',**kwargs):
         | 
| 175 | 
            +
                model = BLIP_Decoder(**kwargs)
         | 
| 176 | 
            +
                if pretrained:
         | 
| 177 | 
            +
                    model,msg = load_checkpoint(model,pretrained)
         | 
| 178 | 
            +
                    assert(len(msg.missing_keys)==0)
         | 
| 179 | 
            +
                return model    
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
            def blip_feature_extractor(pretrained='',**kwargs):
         | 
| 182 | 
            +
                model = BLIP_Base(**kwargs)
         | 
| 183 | 
            +
                if pretrained:
         | 
| 184 | 
            +
                    model,msg = load_checkpoint(model,pretrained)
         | 
| 185 | 
            +
                    assert(len(msg.missing_keys)==0)
         | 
| 186 | 
            +
                return model        
         | 
| 187 | 
            +
             | 
| 188 | 
            +
            def init_tokenizer():
         | 
| 189 | 
            +
                tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
         | 
| 190 | 
            +
                tokenizer.add_special_tokens({'bos_token':'[DEC]'})
         | 
| 191 | 
            +
                tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})       
         | 
| 192 | 
            +
                tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]  
         | 
| 193 | 
            +
                return tokenizer
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
         | 
| 197 | 
            +
                    
         | 
| 198 | 
            +
                assert vit in ['base', 'large'], "vit parameter must be base or large"
         | 
| 199 | 
            +
                if vit=='base':
         | 
| 200 | 
            +
                    vision_width = 768
         | 
| 201 | 
            +
                    visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 
         | 
| 202 | 
            +
                                                       num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
         | 
| 203 | 
            +
                                                       drop_path_rate=0 or drop_path_rate
         | 
| 204 | 
            +
                                                      )   
         | 
| 205 | 
            +
                elif vit=='large':
         | 
| 206 | 
            +
                    vision_width = 1024
         | 
| 207 | 
            +
                    visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 
         | 
| 208 | 
            +
                                                       num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
         | 
| 209 | 
            +
                                                       drop_path_rate=0.1 or drop_path_rate
         | 
| 210 | 
            +
                                                      )   
         | 
| 211 | 
            +
                return visual_encoder, vision_width
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            def is_url(url_or_filename):
         | 
| 214 | 
            +
                parsed = urlparse(url_or_filename)
         | 
| 215 | 
            +
                return parsed.scheme in ("http", "https")
         | 
| 216 | 
            +
             | 
| 217 | 
            +
            def load_checkpoint(model,url_or_filename):
         | 
| 218 | 
            +
                if is_url(url_or_filename):
         | 
| 219 | 
            +
                    cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
         | 
| 220 | 
            +
                    checkpoint = torch.load(cached_file, map_location='cpu') 
         | 
| 221 | 
            +
                elif os.path.isfile(url_or_filename):        
         | 
| 222 | 
            +
                    checkpoint = torch.load(url_or_filename, map_location='cpu') 
         | 
| 223 | 
            +
                else:
         | 
| 224 | 
            +
                    raise RuntimeError('checkpoint url or path is invalid')
         | 
| 225 | 
            +
                    
         | 
| 226 | 
            +
                state_dict = checkpoint['model']
         | 
| 227 | 
            +
                
         | 
| 228 | 
            +
                state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 
         | 
| 229 | 
            +
                if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
         | 
| 230 | 
            +
                    state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
         | 
| 231 | 
            +
                                                                                     model.visual_encoder_m)    
         | 
| 232 | 
            +
                for key in model.state_dict().keys():
         | 
| 233 | 
            +
                    if key in state_dict.keys():
         | 
| 234 | 
            +
                        if state_dict[key].shape!=model.state_dict()[key].shape:
         | 
| 235 | 
            +
                            del state_dict[key]
         | 
| 236 | 
            +
                
         | 
| 237 | 
            +
                msg = model.load_state_dict(state_dict,strict=False)
         | 
| 238 | 
            +
                print('load checkpoint from %s'%url_or_filename)  
         | 
| 239 | 
            +
                return model,msg
         | 
| 240 | 
            +
                
         | 
    	
        finetune/blip/med.py
    ADDED
    
    | @@ -0,0 +1,955 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            '''
         | 
| 2 | 
            +
             * Copyright (c) 2022, salesforce.com, inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             * SPDX-License-Identifier: BSD-3-Clause
         | 
| 5 | 
            +
             * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
         | 
| 6 | 
            +
             * By Junnan Li
         | 
| 7 | 
            +
             * Based on huggingface code base
         | 
| 8 | 
            +
             * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
         | 
| 9 | 
            +
            '''
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import math
         | 
| 12 | 
            +
            import os
         | 
| 13 | 
            +
            import warnings
         | 
| 14 | 
            +
            from dataclasses import dataclass
         | 
| 15 | 
            +
            from typing import Optional, Tuple
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            from torch import Tensor, device, dtype, nn
         | 
| 19 | 
            +
            import torch.utils.checkpoint
         | 
| 20 | 
            +
            from torch import nn
         | 
| 21 | 
            +
            from torch.nn import CrossEntropyLoss
         | 
| 22 | 
            +
            import torch.nn.functional as F
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from transformers.activations import ACT2FN
         | 
| 25 | 
            +
            from transformers.file_utils import (
         | 
| 26 | 
            +
                ModelOutput,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            from transformers.modeling_outputs import (
         | 
| 29 | 
            +
                BaseModelOutputWithPastAndCrossAttentions,
         | 
| 30 | 
            +
                BaseModelOutputWithPoolingAndCrossAttentions,
         | 
| 31 | 
            +
                CausalLMOutputWithCrossAttentions,
         | 
| 32 | 
            +
                MaskedLMOutput,
         | 
| 33 | 
            +
                MultipleChoiceModelOutput,
         | 
| 34 | 
            +
                NextSentencePredictorOutput,
         | 
| 35 | 
            +
                QuestionAnsweringModelOutput,
         | 
| 36 | 
            +
                SequenceClassifierOutput,
         | 
| 37 | 
            +
                TokenClassifierOutput,
         | 
| 38 | 
            +
            )
         | 
| 39 | 
            +
            from transformers.modeling_utils import (
         | 
| 40 | 
            +
                PreTrainedModel,
         | 
| 41 | 
            +
                apply_chunking_to_forward,
         | 
| 42 | 
            +
                find_pruneable_heads_and_indices,
         | 
| 43 | 
            +
                prune_linear_layer,
         | 
| 44 | 
            +
            )
         | 
| 45 | 
            +
            from transformers.utils import logging
         | 
| 46 | 
            +
            from transformers.models.bert.configuration_bert import BertConfig
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class BertEmbeddings(nn.Module):
         | 
| 53 | 
            +
                """Construct the embeddings from word and position embeddings."""
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self, config):
         | 
| 56 | 
            +
                    super().__init__()
         | 
| 57 | 
            +
                    self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
         | 
| 58 | 
            +
                    self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
         | 
| 61 | 
            +
                    # any TensorFlow checkpoint file
         | 
| 62 | 
            +
                    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 63 | 
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         | 
| 66 | 
            +
                    self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
         | 
| 67 | 
            +
                    self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    self.config = config
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def forward(
         | 
| 72 | 
            +
                    self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
         | 
| 73 | 
            +
                ):
         | 
| 74 | 
            +
                    if input_ids is not None:
         | 
| 75 | 
            +
                        input_shape = input_ids.size()
         | 
| 76 | 
            +
                    else:
         | 
| 77 | 
            +
                        input_shape = inputs_embeds.size()[:-1]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    seq_length = input_shape[1]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    if position_ids is None:
         | 
| 82 | 
            +
                        position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    if inputs_embeds is None:
         | 
| 85 | 
            +
                        inputs_embeds = self.word_embeddings(input_ids)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    embeddings = inputs_embeds
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if self.position_embedding_type == "absolute":
         | 
| 90 | 
            +
                        position_embeddings = self.position_embeddings(position_ids)
         | 
| 91 | 
            +
                        embeddings += position_embeddings
         | 
| 92 | 
            +
                    embeddings = self.LayerNorm(embeddings)
         | 
| 93 | 
            +
                    embeddings = self.dropout(embeddings)
         | 
| 94 | 
            +
                    return embeddings
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            class BertSelfAttention(nn.Module):
         | 
| 98 | 
            +
                def __init__(self, config, is_cross_attention):
         | 
| 99 | 
            +
                    super().__init__()
         | 
| 100 | 
            +
                    self.config = config
         | 
| 101 | 
            +
                    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
         | 
| 102 | 
            +
                        raise ValueError(
         | 
| 103 | 
            +
                            "The hidden size (%d) is not a multiple of the number of attention "
         | 
| 104 | 
            +
                            "heads (%d)" % (config.hidden_size, config.num_attention_heads)
         | 
| 105 | 
            +
                        )
         | 
| 106 | 
            +
                    
         | 
| 107 | 
            +
                    self.num_attention_heads = config.num_attention_heads
         | 
| 108 | 
            +
                    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
         | 
| 109 | 
            +
                    self.all_head_size = self.num_attention_heads * self.attention_head_size
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.query = nn.Linear(config.hidden_size, self.all_head_size)
         | 
| 112 | 
            +
                    if is_cross_attention:
         | 
| 113 | 
            +
                        self.key = nn.Linear(config.encoder_width, self.all_head_size)
         | 
| 114 | 
            +
                        self.value = nn.Linear(config.encoder_width, self.all_head_size)
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        self.key = nn.Linear(config.hidden_size, self.all_head_size)
         | 
| 117 | 
            +
                        self.value = nn.Linear(config.hidden_size, self.all_head_size)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
         | 
| 120 | 
            +
                    self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         | 
| 121 | 
            +
                    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
         | 
| 122 | 
            +
                        self.max_position_embeddings = config.max_position_embeddings
         | 
| 123 | 
            +
                        self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
         | 
| 124 | 
            +
                    self.save_attention = False   
         | 
| 125 | 
            +
                        
         | 
| 126 | 
            +
                def save_attn_gradients(self, attn_gradients):
         | 
| 127 | 
            +
                    self.attn_gradients = attn_gradients
         | 
| 128 | 
            +
                    
         | 
| 129 | 
            +
                def get_attn_gradients(self):
         | 
| 130 | 
            +
                    return self.attn_gradients
         | 
| 131 | 
            +
                
         | 
| 132 | 
            +
                def save_attention_map(self, attention_map):
         | 
| 133 | 
            +
                    self.attention_map = attention_map
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                def get_attention_map(self):
         | 
| 136 | 
            +
                    return self.attention_map
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
                def transpose_for_scores(self, x):
         | 
| 139 | 
            +
                    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
         | 
| 140 | 
            +
                    x = x.view(*new_x_shape)
         | 
| 141 | 
            +
                    return x.permute(0, 2, 1, 3)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                def forward(
         | 
| 144 | 
            +
                    self,
         | 
| 145 | 
            +
                    hidden_states,
         | 
| 146 | 
            +
                    attention_mask=None,
         | 
| 147 | 
            +
                    head_mask=None,
         | 
| 148 | 
            +
                    encoder_hidden_states=None,
         | 
| 149 | 
            +
                    encoder_attention_mask=None,
         | 
| 150 | 
            +
                    past_key_value=None,
         | 
| 151 | 
            +
                    output_attentions=False,
         | 
| 152 | 
            +
                ):
         | 
| 153 | 
            +
                    mixed_query_layer = self.query(hidden_states)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # If this is instantiated as a cross-attention module, the keys
         | 
| 156 | 
            +
                    # and values come from an encoder; the attention mask needs to be
         | 
| 157 | 
            +
                    # such that the encoder's padding tokens are not attended to.
         | 
| 158 | 
            +
                    is_cross_attention = encoder_hidden_states is not None
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    if is_cross_attention:
         | 
| 161 | 
            +
                        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
         | 
| 162 | 
            +
                        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
         | 
| 163 | 
            +
                        attention_mask = encoder_attention_mask
         | 
| 164 | 
            +
                    elif past_key_value is not None:
         | 
| 165 | 
            +
                        key_layer = self.transpose_for_scores(self.key(hidden_states))
         | 
| 166 | 
            +
                        value_layer = self.transpose_for_scores(self.value(hidden_states))
         | 
| 167 | 
            +
                        key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
         | 
| 168 | 
            +
                        value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
         | 
| 169 | 
            +
                    else:
         | 
| 170 | 
            +
                        key_layer = self.transpose_for_scores(self.key(hidden_states))
         | 
| 171 | 
            +
                        value_layer = self.transpose_for_scores(self.value(hidden_states))
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    query_layer = self.transpose_for_scores(mixed_query_layer)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    past_key_value = (key_layer, value_layer)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # Take the dot product between "query" and "key" to get the raw attention scores.
         | 
| 178 | 
            +
                    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
         | 
| 181 | 
            +
                        seq_length = hidden_states.size()[1]
         | 
| 182 | 
            +
                        position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
         | 
| 183 | 
            +
                        position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
         | 
| 184 | 
            +
                        distance = position_ids_l - position_ids_r
         | 
| 185 | 
            +
                        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
         | 
| 186 | 
            +
                        positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                        if self.position_embedding_type == "relative_key":
         | 
| 189 | 
            +
                            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
         | 
| 190 | 
            +
                            attention_scores = attention_scores + relative_position_scores
         | 
| 191 | 
            +
                        elif self.position_embedding_type == "relative_key_query":
         | 
| 192 | 
            +
                            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
         | 
| 193 | 
            +
                            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
         | 
| 194 | 
            +
                            attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
         | 
| 197 | 
            +
                    if attention_mask is not None:
         | 
| 198 | 
            +
                        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
         | 
| 199 | 
            +
                        attention_scores = attention_scores + attention_mask
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # Normalize the attention scores to probabilities.
         | 
| 202 | 
            +
                    attention_probs = nn.Softmax(dim=-1)(attention_scores)
         | 
| 203 | 
            +
                    
         | 
| 204 | 
            +
                    if is_cross_attention and self.save_attention:
         | 
| 205 | 
            +
                        self.save_attention_map(attention_probs)
         | 
| 206 | 
            +
                        attention_probs.register_hook(self.save_attn_gradients)         
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # This is actually dropping out entire tokens to attend to, which might
         | 
| 209 | 
            +
                    # seem a bit unusual, but is taken from the original Transformer paper.
         | 
| 210 | 
            +
                    attention_probs_dropped = self.dropout(attention_probs)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    # Mask heads if we want to
         | 
| 213 | 
            +
                    if head_mask is not None:
         | 
| 214 | 
            +
                        attention_probs_dropped = attention_probs_dropped * head_mask
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    context_layer = torch.matmul(attention_probs_dropped, value_layer)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         | 
| 219 | 
            +
                    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
         | 
| 220 | 
            +
                    context_layer = context_layer.view(*new_context_layer_shape)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    outputs = outputs + (past_key_value,)
         | 
| 225 | 
            +
                    return outputs
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            class BertSelfOutput(nn.Module):
         | 
| 229 | 
            +
                def __init__(self, config):
         | 
| 230 | 
            +
                    super().__init__()
         | 
| 231 | 
            +
                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 232 | 
            +
                    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 233 | 
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                def forward(self, hidden_states, input_tensor):
         | 
| 236 | 
            +
                    hidden_states = self.dense(hidden_states)
         | 
| 237 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 238 | 
            +
                    hidden_states = self.LayerNorm(hidden_states + input_tensor)
         | 
| 239 | 
            +
                    return hidden_states
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            class BertAttention(nn.Module):
         | 
| 243 | 
            +
                def __init__(self, config, is_cross_attention=False):
         | 
| 244 | 
            +
                    super().__init__()
         | 
| 245 | 
            +
                    self.self = BertSelfAttention(config, is_cross_attention)
         | 
| 246 | 
            +
                    self.output = BertSelfOutput(config)
         | 
| 247 | 
            +
                    self.pruned_heads = set()
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                def prune_heads(self, heads):
         | 
| 250 | 
            +
                    if len(heads) == 0:
         | 
| 251 | 
            +
                        return
         | 
| 252 | 
            +
                    heads, index = find_pruneable_heads_and_indices(
         | 
| 253 | 
            +
                        heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
         | 
| 254 | 
            +
                    )
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    # Prune linear layers
         | 
| 257 | 
            +
                    self.self.query = prune_linear_layer(self.self.query, index)
         | 
| 258 | 
            +
                    self.self.key = prune_linear_layer(self.self.key, index)
         | 
| 259 | 
            +
                    self.self.value = prune_linear_layer(self.self.value, index)
         | 
| 260 | 
            +
                    self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    # Update hyper params and store pruned heads
         | 
| 263 | 
            +
                    self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
         | 
| 264 | 
            +
                    self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
         | 
| 265 | 
            +
                    self.pruned_heads = self.pruned_heads.union(heads)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def forward(
         | 
| 268 | 
            +
                    self,
         | 
| 269 | 
            +
                    hidden_states,
         | 
| 270 | 
            +
                    attention_mask=None,
         | 
| 271 | 
            +
                    head_mask=None,
         | 
| 272 | 
            +
                    encoder_hidden_states=None,
         | 
| 273 | 
            +
                    encoder_attention_mask=None,
         | 
| 274 | 
            +
                    past_key_value=None,
         | 
| 275 | 
            +
                    output_attentions=False,
         | 
| 276 | 
            +
                ):
         | 
| 277 | 
            +
                    self_outputs = self.self(
         | 
| 278 | 
            +
                        hidden_states,
         | 
| 279 | 
            +
                        attention_mask,
         | 
| 280 | 
            +
                        head_mask,
         | 
| 281 | 
            +
                        encoder_hidden_states,
         | 
| 282 | 
            +
                        encoder_attention_mask,
         | 
| 283 | 
            +
                        past_key_value,
         | 
| 284 | 
            +
                        output_attentions,
         | 
| 285 | 
            +
                    )
         | 
| 286 | 
            +
                    attention_output = self.output(self_outputs[0], hidden_states)
         | 
| 287 | 
            +
                    outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
         | 
| 288 | 
            +
                    return outputs
         | 
| 289 | 
            +
             | 
| 290 | 
            +
             | 
| 291 | 
            +
            class BertIntermediate(nn.Module):
         | 
| 292 | 
            +
                def __init__(self, config):
         | 
| 293 | 
            +
                    super().__init__()
         | 
| 294 | 
            +
                    self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
         | 
| 295 | 
            +
                    if isinstance(config.hidden_act, str):
         | 
| 296 | 
            +
                        self.intermediate_act_fn = ACT2FN[config.hidden_act]
         | 
| 297 | 
            +
                    else:
         | 
| 298 | 
            +
                        self.intermediate_act_fn = config.hidden_act
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def forward(self, hidden_states):
         | 
| 301 | 
            +
                    hidden_states = self.dense(hidden_states)
         | 
| 302 | 
            +
                    hidden_states = self.intermediate_act_fn(hidden_states)
         | 
| 303 | 
            +
                    return hidden_states
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
            class BertOutput(nn.Module):
         | 
| 307 | 
            +
                def __init__(self, config):
         | 
| 308 | 
            +
                    super().__init__()
         | 
| 309 | 
            +
                    self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
         | 
| 310 | 
            +
                    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 311 | 
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def forward(self, hidden_states, input_tensor):
         | 
| 314 | 
            +
                    hidden_states = self.dense(hidden_states)
         | 
| 315 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 316 | 
            +
                    hidden_states = self.LayerNorm(hidden_states + input_tensor)
         | 
| 317 | 
            +
                    return hidden_states
         | 
| 318 | 
            +
             | 
| 319 | 
            +
             | 
| 320 | 
            +
            class BertLayer(nn.Module):
         | 
| 321 | 
            +
                def __init__(self, config, layer_num):
         | 
| 322 | 
            +
                    super().__init__()
         | 
| 323 | 
            +
                    self.config = config
         | 
| 324 | 
            +
                    self.chunk_size_feed_forward = config.chunk_size_feed_forward
         | 
| 325 | 
            +
                    self.seq_len_dim = 1
         | 
| 326 | 
            +
                    self.attention = BertAttention(config)      
         | 
| 327 | 
            +
                    self.layer_num = layer_num          
         | 
| 328 | 
            +
                    if self.config.add_cross_attention:
         | 
| 329 | 
            +
                        self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
         | 
| 330 | 
            +
                    self.intermediate = BertIntermediate(config)
         | 
| 331 | 
            +
                    self.output = BertOutput(config)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def forward(
         | 
| 334 | 
            +
                    self,
         | 
| 335 | 
            +
                    hidden_states,
         | 
| 336 | 
            +
                    attention_mask=None,
         | 
| 337 | 
            +
                    head_mask=None,
         | 
| 338 | 
            +
                    encoder_hidden_states=None,
         | 
| 339 | 
            +
                    encoder_attention_mask=None,
         | 
| 340 | 
            +
                    past_key_value=None,
         | 
| 341 | 
            +
                    output_attentions=False,
         | 
| 342 | 
            +
                    mode=None,
         | 
| 343 | 
            +
                ):
         | 
| 344 | 
            +
                    # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
         | 
| 345 | 
            +
                    self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
         | 
| 346 | 
            +
                    self_attention_outputs = self.attention(
         | 
| 347 | 
            +
                        hidden_states,
         | 
| 348 | 
            +
                        attention_mask,
         | 
| 349 | 
            +
                        head_mask,
         | 
| 350 | 
            +
                        output_attentions=output_attentions,
         | 
| 351 | 
            +
                        past_key_value=self_attn_past_key_value,
         | 
| 352 | 
            +
                    )
         | 
| 353 | 
            +
                    attention_output = self_attention_outputs[0]
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    outputs = self_attention_outputs[1:-1]
         | 
| 356 | 
            +
                    present_key_value = self_attention_outputs[-1]
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    if mode=='multimodal':
         | 
| 359 | 
            +
                        assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                        cross_attention_outputs = self.crossattention(
         | 
| 362 | 
            +
                            attention_output,
         | 
| 363 | 
            +
                            attention_mask,
         | 
| 364 | 
            +
                            head_mask,
         | 
| 365 | 
            +
                            encoder_hidden_states,
         | 
| 366 | 
            +
                            encoder_attention_mask,
         | 
| 367 | 
            +
                            output_attentions=output_attentions,
         | 
| 368 | 
            +
                        )
         | 
| 369 | 
            +
                        attention_output = cross_attention_outputs[0]
         | 
| 370 | 
            +
                        outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights                               
         | 
| 371 | 
            +
                    layer_output = apply_chunking_to_forward(
         | 
| 372 | 
            +
                        self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
         | 
| 373 | 
            +
                    )
         | 
| 374 | 
            +
                    outputs = (layer_output,) + outputs
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    outputs = outputs + (present_key_value,)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    return outputs
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                def feed_forward_chunk(self, attention_output):
         | 
| 381 | 
            +
                    intermediate_output = self.intermediate(attention_output)
         | 
| 382 | 
            +
                    layer_output = self.output(intermediate_output, attention_output)
         | 
| 383 | 
            +
                    return layer_output
         | 
| 384 | 
            +
             | 
| 385 | 
            +
             | 
| 386 | 
            +
            class BertEncoder(nn.Module):
         | 
| 387 | 
            +
                def __init__(self, config):
         | 
| 388 | 
            +
                    super().__init__()
         | 
| 389 | 
            +
                    self.config = config
         | 
| 390 | 
            +
                    self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
         | 
| 391 | 
            +
                    self.gradient_checkpointing = False
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                def forward(
         | 
| 394 | 
            +
                    self,
         | 
| 395 | 
            +
                    hidden_states,
         | 
| 396 | 
            +
                    attention_mask=None,
         | 
| 397 | 
            +
                    head_mask=None,
         | 
| 398 | 
            +
                    encoder_hidden_states=None,
         | 
| 399 | 
            +
                    encoder_attention_mask=None,
         | 
| 400 | 
            +
                    past_key_values=None,
         | 
| 401 | 
            +
                    use_cache=None,
         | 
| 402 | 
            +
                    output_attentions=False,
         | 
| 403 | 
            +
                    output_hidden_states=False,
         | 
| 404 | 
            +
                    return_dict=True,
         | 
| 405 | 
            +
                    mode='multimodal',
         | 
| 406 | 
            +
                ):
         | 
| 407 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 408 | 
            +
                    all_self_attentions = () if output_attentions else None
         | 
| 409 | 
            +
                    all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    next_decoder_cache = () if use_cache else None
         | 
| 412 | 
            +
                           
         | 
| 413 | 
            +
                    for i in range(self.config.num_hidden_layers):
         | 
| 414 | 
            +
                        layer_module = self.layer[i]
         | 
| 415 | 
            +
                        if output_hidden_states:
         | 
| 416 | 
            +
                            all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                        layer_head_mask = head_mask[i] if head_mask is not None else None
         | 
| 419 | 
            +
                        past_key_value = past_key_values[i] if past_key_values is not None else None
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                        if self.gradient_checkpointing and self.training:
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                            if use_cache:
         | 
| 424 | 
            +
                                logger.warn(
         | 
| 425 | 
            +
                                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
         | 
| 426 | 
            +
                                )
         | 
| 427 | 
            +
                                use_cache = False
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                            def create_custom_forward(module):
         | 
| 430 | 
            +
                                def custom_forward(*inputs):
         | 
| 431 | 
            +
                                    return module(*inputs, past_key_value, output_attentions)
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                                return custom_forward
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                            layer_outputs = torch.utils.checkpoint.checkpoint(
         | 
| 436 | 
            +
                                create_custom_forward(layer_module),
         | 
| 437 | 
            +
                                hidden_states,
         | 
| 438 | 
            +
                                attention_mask,
         | 
| 439 | 
            +
                                layer_head_mask,
         | 
| 440 | 
            +
                                encoder_hidden_states,
         | 
| 441 | 
            +
                                encoder_attention_mask,
         | 
| 442 | 
            +
                                mode=mode,
         | 
| 443 | 
            +
                            )
         | 
| 444 | 
            +
                        else:
         | 
| 445 | 
            +
                            layer_outputs = layer_module(
         | 
| 446 | 
            +
                                hidden_states,
         | 
| 447 | 
            +
                                attention_mask,
         | 
| 448 | 
            +
                                layer_head_mask,
         | 
| 449 | 
            +
                                encoder_hidden_states,
         | 
| 450 | 
            +
                                encoder_attention_mask,
         | 
| 451 | 
            +
                                past_key_value,
         | 
| 452 | 
            +
                                output_attentions,
         | 
| 453 | 
            +
                                mode=mode,
         | 
| 454 | 
            +
                            )
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                        hidden_states = layer_outputs[0]
         | 
| 457 | 
            +
                        if use_cache:
         | 
| 458 | 
            +
                            next_decoder_cache += (layer_outputs[-1],)
         | 
| 459 | 
            +
                        if output_attentions:
         | 
| 460 | 
            +
                            all_self_attentions = all_self_attentions + (layer_outputs[1],)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    if output_hidden_states:
         | 
| 463 | 
            +
                        all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    if not return_dict:
         | 
| 466 | 
            +
                        return tuple(
         | 
| 467 | 
            +
                            v
         | 
| 468 | 
            +
                            for v in [
         | 
| 469 | 
            +
                                hidden_states,
         | 
| 470 | 
            +
                                next_decoder_cache,
         | 
| 471 | 
            +
                                all_hidden_states,
         | 
| 472 | 
            +
                                all_self_attentions,
         | 
| 473 | 
            +
                                all_cross_attentions,
         | 
| 474 | 
            +
                            ]
         | 
| 475 | 
            +
                            if v is not None
         | 
| 476 | 
            +
                        )
         | 
| 477 | 
            +
                    return BaseModelOutputWithPastAndCrossAttentions(
         | 
| 478 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 479 | 
            +
                        past_key_values=next_decoder_cache,
         | 
| 480 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 481 | 
            +
                        attentions=all_self_attentions,
         | 
| 482 | 
            +
                        cross_attentions=all_cross_attentions,
         | 
| 483 | 
            +
                    )
         | 
| 484 | 
            +
             | 
| 485 | 
            +
             | 
| 486 | 
            +
            class BertPooler(nn.Module):
         | 
| 487 | 
            +
                def __init__(self, config):
         | 
| 488 | 
            +
                    super().__init__()
         | 
| 489 | 
            +
                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 490 | 
            +
                    self.activation = nn.Tanh()
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                def forward(self, hidden_states):
         | 
| 493 | 
            +
                    # We "pool" the model by simply taking the hidden state corresponding
         | 
| 494 | 
            +
                    # to the first token.
         | 
| 495 | 
            +
                    first_token_tensor = hidden_states[:, 0]
         | 
| 496 | 
            +
                    pooled_output = self.dense(first_token_tensor)
         | 
| 497 | 
            +
                    pooled_output = self.activation(pooled_output)
         | 
| 498 | 
            +
                    return pooled_output
         | 
| 499 | 
            +
             | 
| 500 | 
            +
             | 
| 501 | 
            +
            class BertPredictionHeadTransform(nn.Module):
         | 
| 502 | 
            +
                def __init__(self, config):
         | 
| 503 | 
            +
                    super().__init__()
         | 
| 504 | 
            +
                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 505 | 
            +
                    if isinstance(config.hidden_act, str):
         | 
| 506 | 
            +
                        self.transform_act_fn = ACT2FN[config.hidden_act]
         | 
| 507 | 
            +
                    else:
         | 
| 508 | 
            +
                        self.transform_act_fn = config.hidden_act
         | 
| 509 | 
            +
                    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                def forward(self, hidden_states):
         | 
| 512 | 
            +
                    hidden_states = self.dense(hidden_states)
         | 
| 513 | 
            +
                    hidden_states = self.transform_act_fn(hidden_states)
         | 
| 514 | 
            +
                    hidden_states = self.LayerNorm(hidden_states)
         | 
| 515 | 
            +
                    return hidden_states
         | 
| 516 | 
            +
             | 
| 517 | 
            +
             | 
| 518 | 
            +
            class BertLMPredictionHead(nn.Module):
         | 
| 519 | 
            +
                def __init__(self, config):
         | 
| 520 | 
            +
                    super().__init__()
         | 
| 521 | 
            +
                    self.transform = BertPredictionHeadTransform(config)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    # The output weights are the same as the input embeddings, but there is
         | 
| 524 | 
            +
                    # an output-only bias for each token.
         | 
| 525 | 
            +
                    self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    self.bias = nn.Parameter(torch.zeros(config.vocab_size))
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
         | 
| 530 | 
            +
                    self.decoder.bias = self.bias
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                def forward(self, hidden_states):
         | 
| 533 | 
            +
                    hidden_states = self.transform(hidden_states)
         | 
| 534 | 
            +
                    hidden_states = self.decoder(hidden_states)
         | 
| 535 | 
            +
                    return hidden_states
         | 
| 536 | 
            +
             | 
| 537 | 
            +
             | 
| 538 | 
            +
            class BertOnlyMLMHead(nn.Module):
         | 
| 539 | 
            +
                def __init__(self, config):
         | 
| 540 | 
            +
                    super().__init__()
         | 
| 541 | 
            +
                    self.predictions = BertLMPredictionHead(config)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                def forward(self, sequence_output):
         | 
| 544 | 
            +
                    prediction_scores = self.predictions(sequence_output)
         | 
| 545 | 
            +
                    return prediction_scores
         | 
| 546 | 
            +
             | 
| 547 | 
            +
             | 
| 548 | 
            +
            class BertPreTrainedModel(PreTrainedModel):
         | 
| 549 | 
            +
                """
         | 
| 550 | 
            +
                An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
         | 
| 551 | 
            +
                models.
         | 
| 552 | 
            +
                """
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                config_class = BertConfig
         | 
| 555 | 
            +
                base_model_prefix = "bert"
         | 
| 556 | 
            +
                _keys_to_ignore_on_load_missing = [r"position_ids"]
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                def _init_weights(self, module):
         | 
| 559 | 
            +
                    """ Initialize the weights """
         | 
| 560 | 
            +
                    if isinstance(module, (nn.Linear, nn.Embedding)):
         | 
| 561 | 
            +
                        # Slightly different from the TF version which uses truncated_normal for initialization
         | 
| 562 | 
            +
                        # cf https://github.com/pytorch/pytorch/pull/5617
         | 
| 563 | 
            +
                        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
         | 
| 564 | 
            +
                    elif isinstance(module, nn.LayerNorm):
         | 
| 565 | 
            +
                        module.bias.data.zero_()
         | 
| 566 | 
            +
                        module.weight.data.fill_(1.0)
         | 
| 567 | 
            +
                    if isinstance(module, nn.Linear) and module.bias is not None:
         | 
| 568 | 
            +
                        module.bias.data.zero_()
         | 
| 569 | 
            +
             | 
| 570 | 
            +
             | 
| 571 | 
            +
            class BertModel(BertPreTrainedModel):
         | 
| 572 | 
            +
                """
         | 
| 573 | 
            +
                The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
         | 
| 574 | 
            +
                cross-attention is added between the self-attention layers, following the architecture described in `Attention is
         | 
| 575 | 
            +
                all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
         | 
| 576 | 
            +
                Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
         | 
| 577 | 
            +
                argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
         | 
| 578 | 
            +
                input to the forward pass.
         | 
| 579 | 
            +
                """
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                def __init__(self, config, add_pooling_layer=True):
         | 
| 582 | 
            +
                    super().__init__(config)
         | 
| 583 | 
            +
                    self.config = config
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    self.embeddings = BertEmbeddings(config)
         | 
| 586 | 
            +
                    
         | 
| 587 | 
            +
                    self.encoder = BertEncoder(config)
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    self.pooler = BertPooler(config) if add_pooling_layer else None
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    self.init_weights()
         | 
| 592 | 
            +
             
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                def get_input_embeddings(self):
         | 
| 595 | 
            +
                    return self.embeddings.word_embeddings
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                def set_input_embeddings(self, value):
         | 
| 598 | 
            +
                    self.embeddings.word_embeddings = value
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                def _prune_heads(self, heads_to_prune):
         | 
| 601 | 
            +
                    """
         | 
| 602 | 
            +
                    Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
         | 
| 603 | 
            +
                    class PreTrainedModel
         | 
| 604 | 
            +
                    """
         | 
| 605 | 
            +
                    for layer, heads in heads_to_prune.items():
         | 
| 606 | 
            +
                        self.encoder.layer[layer].attention.prune_heads(heads)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                
         | 
| 609 | 
            +
                def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
         | 
| 610 | 
            +
                    """
         | 
| 611 | 
            +
                    Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                    Arguments:
         | 
| 614 | 
            +
                        attention_mask (:obj:`torch.Tensor`):
         | 
| 615 | 
            +
                            Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
         | 
| 616 | 
            +
                        input_shape (:obj:`Tuple[int]`):
         | 
| 617 | 
            +
                            The shape of the input to the model.
         | 
| 618 | 
            +
                        device: (:obj:`torch.device`):
         | 
| 619 | 
            +
                            The device of the input to the model.
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                    Returns:
         | 
| 622 | 
            +
                        :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
         | 
| 623 | 
            +
                    """
         | 
| 624 | 
            +
                    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
         | 
| 625 | 
            +
                    # ourselves in which case we just need to make it broadcastable to all heads.
         | 
| 626 | 
            +
                    if attention_mask.dim() == 3:
         | 
| 627 | 
            +
                        extended_attention_mask = attention_mask[:, None, :, :]
         | 
| 628 | 
            +
                    elif attention_mask.dim() == 2:
         | 
| 629 | 
            +
                        # Provided a padding mask of dimensions [batch_size, seq_length]
         | 
| 630 | 
            +
                        # - if the model is a decoder, apply a causal mask in addition to the padding mask
         | 
| 631 | 
            +
                        # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
         | 
| 632 | 
            +
                        if is_decoder:
         | 
| 633 | 
            +
                            batch_size, seq_length = input_shape
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                            seq_ids = torch.arange(seq_length, device=device)
         | 
| 636 | 
            +
                            causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
         | 
| 637 | 
            +
                            # in case past_key_values are used we need to add a prefix ones mask to the causal mask
         | 
| 638 | 
            +
                            # causal and attention masks must have same type with pytorch version < 1.3
         | 
| 639 | 
            +
                            causal_mask = causal_mask.to(attention_mask.dtype)
         | 
| 640 | 
            +
               
         | 
| 641 | 
            +
                            if causal_mask.shape[1] < attention_mask.shape[1]:
         | 
| 642 | 
            +
                                prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
         | 
| 643 | 
            +
                                causal_mask = torch.cat(
         | 
| 644 | 
            +
                                    [
         | 
| 645 | 
            +
                                        torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
         | 
| 646 | 
            +
                                        causal_mask,
         | 
| 647 | 
            +
                                    ],
         | 
| 648 | 
            +
                                    axis=-1,
         | 
| 649 | 
            +
                                )                     
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                            extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
         | 
| 652 | 
            +
                        else:
         | 
| 653 | 
            +
                            extended_attention_mask = attention_mask[:, None, None, :]
         | 
| 654 | 
            +
                    else:
         | 
| 655 | 
            +
                        raise ValueError(
         | 
| 656 | 
            +
                            "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
         | 
| 657 | 
            +
                                input_shape, attention_mask.shape
         | 
| 658 | 
            +
                            )
         | 
| 659 | 
            +
                        )
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                    # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
         | 
| 662 | 
            +
                    # masked positions, this operation will create a tensor which is 0.0 for
         | 
| 663 | 
            +
                    # positions we want to attend and -10000.0 for masked positions.
         | 
| 664 | 
            +
                    # Since we are adding it to the raw scores before the softmax, this is
         | 
| 665 | 
            +
                    # effectively the same as removing these entirely.
         | 
| 666 | 
            +
                    extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
         | 
| 667 | 
            +
                    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
         | 
| 668 | 
            +
                    return extended_attention_mask
         | 
| 669 | 
            +
                
         | 
| 670 | 
            +
                def forward(
         | 
| 671 | 
            +
                    self,
         | 
| 672 | 
            +
                    input_ids=None,
         | 
| 673 | 
            +
                    attention_mask=None,
         | 
| 674 | 
            +
                    position_ids=None,
         | 
| 675 | 
            +
                    head_mask=None,
         | 
| 676 | 
            +
                    inputs_embeds=None,
         | 
| 677 | 
            +
                    encoder_embeds=None,
         | 
| 678 | 
            +
                    encoder_hidden_states=None,
         | 
| 679 | 
            +
                    encoder_attention_mask=None,
         | 
| 680 | 
            +
                    past_key_values=None,
         | 
| 681 | 
            +
                    use_cache=None,
         | 
| 682 | 
            +
                    output_attentions=None,
         | 
| 683 | 
            +
                    output_hidden_states=None,
         | 
| 684 | 
            +
                    return_dict=None,
         | 
| 685 | 
            +
                    is_decoder=False,
         | 
| 686 | 
            +
                    mode='multimodal',
         | 
| 687 | 
            +
                ):
         | 
| 688 | 
            +
                    r"""
         | 
| 689 | 
            +
                    encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
         | 
| 690 | 
            +
                        Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
         | 
| 691 | 
            +
                        the model is configured as a decoder.
         | 
| 692 | 
            +
                    encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
         | 
| 693 | 
            +
                        Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
         | 
| 694 | 
            +
                        the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
         | 
| 695 | 
            +
                        - 1 for tokens that are **not masked**,
         | 
| 696 | 
            +
                        - 0 for tokens that are **masked**.
         | 
| 697 | 
            +
                    past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
         | 
| 698 | 
            +
                        Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
         | 
| 699 | 
            +
                        If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
         | 
| 700 | 
            +
                        (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
         | 
| 701 | 
            +
                        instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
         | 
| 702 | 
            +
                    use_cache (:obj:`bool`, `optional`):
         | 
| 703 | 
            +
                        If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
         | 
| 704 | 
            +
                        decoding (see :obj:`past_key_values`).
         | 
| 705 | 
            +
                    """
         | 
| 706 | 
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 707 | 
            +
                    output_hidden_states = (
         | 
| 708 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 709 | 
            +
                    )
         | 
| 710 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    if is_decoder:
         | 
| 713 | 
            +
                        use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 714 | 
            +
                    else:
         | 
| 715 | 
            +
                        use_cache = False
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                    if input_ids is not None and inputs_embeds is not None:
         | 
| 718 | 
            +
                        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
         | 
| 719 | 
            +
                    elif input_ids is not None:
         | 
| 720 | 
            +
                        input_shape = input_ids.size()
         | 
| 721 | 
            +
                        batch_size, seq_length = input_shape
         | 
| 722 | 
            +
                        device = input_ids.device
         | 
| 723 | 
            +
                    elif inputs_embeds is not None:
         | 
| 724 | 
            +
                        input_shape = inputs_embeds.size()[:-1]
         | 
| 725 | 
            +
                        batch_size, seq_length = input_shape
         | 
| 726 | 
            +
                        device = inputs_embeds.device
         | 
| 727 | 
            +
                    elif encoder_embeds is not None:    
         | 
| 728 | 
            +
                        input_shape = encoder_embeds.size()[:-1]
         | 
| 729 | 
            +
                        batch_size, seq_length = input_shape 
         | 
| 730 | 
            +
                        device = encoder_embeds.device
         | 
| 731 | 
            +
                    else:
         | 
| 732 | 
            +
                        raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                    # past_key_values_length
         | 
| 735 | 
            +
                    past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                    if attention_mask is None:
         | 
| 738 | 
            +
                        attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
         | 
| 739 | 
            +
                        
         | 
| 740 | 
            +
                    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
         | 
| 741 | 
            +
                    # ourselves in which case we just need to make it broadcastable to all heads.
         | 
| 742 | 
            +
                    extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, 
         | 
| 743 | 
            +
                                                                                             device, is_decoder)
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    # If a 2D or 3D attention mask is provided for the cross-attention
         | 
| 746 | 
            +
                    # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
         | 
| 747 | 
            +
                    if encoder_hidden_states is not None:
         | 
| 748 | 
            +
                        if type(encoder_hidden_states) == list:
         | 
| 749 | 
            +
                            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
         | 
| 750 | 
            +
                        else:
         | 
| 751 | 
            +
                            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
         | 
| 752 | 
            +
                        encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
         | 
| 753 | 
            +
                        
         | 
| 754 | 
            +
                        if type(encoder_attention_mask) == list:
         | 
| 755 | 
            +
                            encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
         | 
| 756 | 
            +
                        elif encoder_attention_mask is None:
         | 
| 757 | 
            +
                            encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
         | 
| 758 | 
            +
                            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
         | 
| 759 | 
            +
                        else:    
         | 
| 760 | 
            +
                            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
         | 
| 761 | 
            +
                    else:
         | 
| 762 | 
            +
                        encoder_extended_attention_mask = None
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                    # Prepare head mask if needed
         | 
| 765 | 
            +
                    # 1.0 in head_mask indicate we keep the head
         | 
| 766 | 
            +
                    # attention_probs has shape bsz x n_heads x N x N
         | 
| 767 | 
            +
                    # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
         | 
| 768 | 
            +
                    # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
         | 
| 769 | 
            +
                    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
         | 
| 770 | 
            +
                    
         | 
| 771 | 
            +
                    if encoder_embeds is None:
         | 
| 772 | 
            +
                        embedding_output = self.embeddings(
         | 
| 773 | 
            +
                            input_ids=input_ids,
         | 
| 774 | 
            +
                            position_ids=position_ids,
         | 
| 775 | 
            +
                            inputs_embeds=inputs_embeds,
         | 
| 776 | 
            +
                            past_key_values_length=past_key_values_length,
         | 
| 777 | 
            +
                        )
         | 
| 778 | 
            +
                    else:
         | 
| 779 | 
            +
                        embedding_output = encoder_embeds
         | 
| 780 | 
            +
                        
         | 
| 781 | 
            +
                    encoder_outputs = self.encoder(
         | 
| 782 | 
            +
                        embedding_output,
         | 
| 783 | 
            +
                        attention_mask=extended_attention_mask,
         | 
| 784 | 
            +
                        head_mask=head_mask,
         | 
| 785 | 
            +
                        encoder_hidden_states=encoder_hidden_states,
         | 
| 786 | 
            +
                        encoder_attention_mask=encoder_extended_attention_mask,
         | 
| 787 | 
            +
                        past_key_values=past_key_values,
         | 
| 788 | 
            +
                        use_cache=use_cache,
         | 
| 789 | 
            +
                        output_attentions=output_attentions,
         | 
| 790 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 791 | 
            +
                        return_dict=return_dict,
         | 
| 792 | 
            +
                        mode=mode,
         | 
| 793 | 
            +
                    )
         | 
| 794 | 
            +
                    sequence_output = encoder_outputs[0]
         | 
| 795 | 
            +
                    pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                    if not return_dict:
         | 
| 798 | 
            +
                        return (sequence_output, pooled_output) + encoder_outputs[1:]
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                    return BaseModelOutputWithPoolingAndCrossAttentions(
         | 
| 801 | 
            +
                        last_hidden_state=sequence_output,
         | 
| 802 | 
            +
                        pooler_output=pooled_output,
         | 
| 803 | 
            +
                        past_key_values=encoder_outputs.past_key_values,
         | 
| 804 | 
            +
                        hidden_states=encoder_outputs.hidden_states,
         | 
| 805 | 
            +
                        attentions=encoder_outputs.attentions,
         | 
| 806 | 
            +
                        cross_attentions=encoder_outputs.cross_attentions,
         | 
| 807 | 
            +
                    )
         | 
| 808 | 
            +
             | 
| 809 | 
            +
             | 
| 810 | 
            +
             | 
| 811 | 
            +
            class BertLMHeadModel(BertPreTrainedModel):
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                _keys_to_ignore_on_load_unexpected = [r"pooler"]
         | 
| 814 | 
            +
                _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
         | 
| 815 | 
            +
             | 
| 816 | 
            +
                def __init__(self, config):
         | 
| 817 | 
            +
                    super().__init__(config)
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                    self.bert = BertModel(config, add_pooling_layer=False)
         | 
| 820 | 
            +
                    self.cls = BertOnlyMLMHead(config)
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                    self.init_weights()
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                def get_output_embeddings(self):
         | 
| 825 | 
            +
                    return self.cls.predictions.decoder
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                def set_output_embeddings(self, new_embeddings):
         | 
| 828 | 
            +
                    self.cls.predictions.decoder = new_embeddings
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                def forward(
         | 
| 831 | 
            +
                    self,
         | 
| 832 | 
            +
                    input_ids=None,
         | 
| 833 | 
            +
                    attention_mask=None,
         | 
| 834 | 
            +
                    position_ids=None,
         | 
| 835 | 
            +
                    head_mask=None,
         | 
| 836 | 
            +
                    inputs_embeds=None,
         | 
| 837 | 
            +
                    encoder_hidden_states=None,
         | 
| 838 | 
            +
                    encoder_attention_mask=None,
         | 
| 839 | 
            +
                    labels=None,
         | 
| 840 | 
            +
                    past_key_values=None,
         | 
| 841 | 
            +
                    use_cache=None,
         | 
| 842 | 
            +
                    output_attentions=None,
         | 
| 843 | 
            +
                    output_hidden_states=None,
         | 
| 844 | 
            +
                    return_dict=None,
         | 
| 845 | 
            +
                    return_logits=False,            
         | 
| 846 | 
            +
                    is_decoder=True,
         | 
| 847 | 
            +
                    reduction='mean',
         | 
| 848 | 
            +
                    mode='multimodal', 
         | 
| 849 | 
            +
                ):
         | 
| 850 | 
            +
                    r"""
         | 
| 851 | 
            +
                    encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
         | 
| 852 | 
            +
                        Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
         | 
| 853 | 
            +
                        the model is configured as a decoder.
         | 
| 854 | 
            +
                    encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
         | 
| 855 | 
            +
                        Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
         | 
| 856 | 
            +
                        the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
         | 
| 857 | 
            +
                        - 1 for tokens that are **not masked**,
         | 
| 858 | 
            +
                        - 0 for tokens that are **masked**.
         | 
| 859 | 
            +
                    labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
         | 
| 860 | 
            +
                        Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
         | 
| 861 | 
            +
                        ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
         | 
| 862 | 
            +
                        ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
         | 
| 863 | 
            +
                    past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
         | 
| 864 | 
            +
                        Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
         | 
| 865 | 
            +
                        If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
         | 
| 866 | 
            +
                        (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
         | 
| 867 | 
            +
                        instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
         | 
| 868 | 
            +
                    use_cache (:obj:`bool`, `optional`):
         | 
| 869 | 
            +
                        If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
         | 
| 870 | 
            +
                        decoding (see :obj:`past_key_values`).
         | 
| 871 | 
            +
                    Returns:
         | 
| 872 | 
            +
                    Example::
         | 
| 873 | 
            +
                        >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
         | 
| 874 | 
            +
                        >>> import torch
         | 
| 875 | 
            +
                        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
         | 
| 876 | 
            +
                        >>> config = BertConfig.from_pretrained("bert-base-cased")
         | 
| 877 | 
            +
                        >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
         | 
| 878 | 
            +
                        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
         | 
| 879 | 
            +
                        >>> outputs = model(**inputs)
         | 
| 880 | 
            +
                        >>> prediction_logits = outputs.logits
         | 
| 881 | 
            +
                    """
         | 
| 882 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 883 | 
            +
                    if labels is not None:
         | 
| 884 | 
            +
                        use_cache = False
         | 
| 885 | 
            +
             | 
| 886 | 
            +
                    outputs = self.bert(
         | 
| 887 | 
            +
                        input_ids,
         | 
| 888 | 
            +
                        attention_mask=attention_mask,
         | 
| 889 | 
            +
                        position_ids=position_ids,
         | 
| 890 | 
            +
                        head_mask=head_mask,
         | 
| 891 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 892 | 
            +
                        encoder_hidden_states=encoder_hidden_states,
         | 
| 893 | 
            +
                        encoder_attention_mask=encoder_attention_mask,
         | 
| 894 | 
            +
                        past_key_values=past_key_values,
         | 
| 895 | 
            +
                        use_cache=use_cache,
         | 
| 896 | 
            +
                        output_attentions=output_attentions,
         | 
| 897 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 898 | 
            +
                        return_dict=return_dict,
         | 
| 899 | 
            +
                        is_decoder=is_decoder,
         | 
| 900 | 
            +
                        mode=mode,
         | 
| 901 | 
            +
                    )
         | 
| 902 | 
            +
                    
         | 
| 903 | 
            +
                    sequence_output = outputs[0]
         | 
| 904 | 
            +
                    prediction_scores = self.cls(sequence_output)
         | 
| 905 | 
            +
                    
         | 
| 906 | 
            +
                    if return_logits:
         | 
| 907 | 
            +
                        return prediction_scores[:, :-1, :].contiguous()  
         | 
| 908 | 
            +
             | 
| 909 | 
            +
                    lm_loss = None
         | 
| 910 | 
            +
                    if labels is not None:
         | 
| 911 | 
            +
                        # we are doing next-token prediction; shift prediction scores and input ids by one
         | 
| 912 | 
            +
                        shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
         | 
| 913 | 
            +
                        labels = labels[:, 1:].contiguous()
         | 
| 914 | 
            +
                        loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) 
         | 
| 915 | 
            +
                        lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
         | 
| 916 | 
            +
                        if reduction=='none':
         | 
| 917 | 
            +
                            lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)               
         | 
| 918 | 
            +
             | 
| 919 | 
            +
                    if not return_dict:
         | 
| 920 | 
            +
                        output = (prediction_scores,) + outputs[2:]
         | 
| 921 | 
            +
                        return ((lm_loss,) + output) if lm_loss is not None else output
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                    return CausalLMOutputWithCrossAttentions(
         | 
| 924 | 
            +
                        loss=lm_loss,
         | 
| 925 | 
            +
                        logits=prediction_scores,
         | 
| 926 | 
            +
                        past_key_values=outputs.past_key_values,
         | 
| 927 | 
            +
                        hidden_states=outputs.hidden_states,
         | 
| 928 | 
            +
                        attentions=outputs.attentions,
         | 
| 929 | 
            +
                        cross_attentions=outputs.cross_attentions,
         | 
| 930 | 
            +
                    )
         | 
| 931 | 
            +
             | 
| 932 | 
            +
                def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
         | 
| 933 | 
            +
                    input_shape = input_ids.shape
         | 
| 934 | 
            +
                    # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
         | 
| 935 | 
            +
                    if attention_mask is None:
         | 
| 936 | 
            +
                        attention_mask = input_ids.new_ones(input_shape)
         | 
| 937 | 
            +
             | 
| 938 | 
            +
                    # cut decoder_input_ids if past is used
         | 
| 939 | 
            +
                    if past is not None:
         | 
| 940 | 
            +
                        input_ids = input_ids[:, -1:]
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                    return {
         | 
| 943 | 
            +
                        "input_ids": input_ids, 
         | 
| 944 | 
            +
                        "attention_mask": attention_mask, 
         | 
| 945 | 
            +
                        "past_key_values": past,
         | 
| 946 | 
            +
                        "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
         | 
| 947 | 
            +
                        "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
         | 
| 948 | 
            +
                        "is_decoder": True,
         | 
| 949 | 
            +
                    }
         | 
| 950 | 
            +
             | 
| 951 | 
            +
                def _reorder_cache(self, past, beam_idx):
         | 
| 952 | 
            +
                    reordered_past = ()
         | 
| 953 | 
            +
                    for layer_past in past:
         | 
| 954 | 
            +
                        reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
         | 
| 955 | 
            +
                    return reordered_past
         | 
    	
        finetune/blip/med_config.json
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "architectures": [
         | 
| 3 | 
            +
                  "BertModel"
         | 
| 4 | 
            +
                ],
         | 
| 5 | 
            +
                "attention_probs_dropout_prob": 0.1,
         | 
| 6 | 
            +
                "hidden_act": "gelu",
         | 
| 7 | 
            +
                "hidden_dropout_prob": 0.1,
         | 
| 8 | 
            +
                "hidden_size": 768,
         | 
| 9 | 
            +
                "initializer_range": 0.02,
         | 
| 10 | 
            +
                "intermediate_size": 3072,
         | 
| 11 | 
            +
                "layer_norm_eps": 1e-12,
         | 
| 12 | 
            +
                "max_position_embeddings": 512,
         | 
| 13 | 
            +
                "model_type": "bert",
         | 
| 14 | 
            +
                "num_attention_heads": 12,
         | 
| 15 | 
            +
                "num_hidden_layers": 12,
         | 
| 16 | 
            +
                "pad_token_id": 0,
         | 
| 17 | 
            +
                "type_vocab_size": 2,
         | 
| 18 | 
            +
                "vocab_size": 30524,
         | 
| 19 | 
            +
                "encoder_width": 768,
         | 
| 20 | 
            +
                "add_cross_attention": true   
         | 
| 21 | 
            +
              }
         | 
| 22 | 
            +
              
         | 
    	
        finetune/blip/vit.py
    ADDED
    
    | @@ -0,0 +1,305 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            '''
         | 
| 2 | 
            +
             * Copyright (c) 2022, salesforce.com, inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             * SPDX-License-Identifier: BSD-3-Clause
         | 
| 5 | 
            +
             * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
         | 
| 6 | 
            +
             * By Junnan Li
         | 
| 7 | 
            +
             * Based on timm code base
         | 
| 8 | 
            +
             * https://github.com/rwightman/pytorch-image-models/tree/master/timm
         | 
| 9 | 
            +
            '''
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            import torch.nn.functional as F
         | 
| 14 | 
            +
            from functools import partial
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from timm.models.vision_transformer import _cfg, PatchEmbed
         | 
| 17 | 
            +
            from timm.models.registry import register_model
         | 
| 18 | 
            +
            from timm.models.layers import trunc_normal_, DropPath
         | 
| 19 | 
            +
            from timm.models.helpers import named_apply, adapt_input_conv
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            class Mlp(nn.Module):
         | 
| 24 | 
            +
                """ MLP as used in Vision Transformer, MLP-Mixer and related networks
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
         | 
| 27 | 
            +
                    super().__init__()
         | 
| 28 | 
            +
                    out_features = out_features or in_features
         | 
| 29 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 30 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         | 
| 31 | 
            +
                    self.act = act_layer()
         | 
| 32 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         | 
| 33 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def forward(self, x):
         | 
| 36 | 
            +
                    x = self.fc1(x)
         | 
| 37 | 
            +
                    x = self.act(x)
         | 
| 38 | 
            +
                    x = self.drop(x)
         | 
| 39 | 
            +
                    x = self.fc2(x)
         | 
| 40 | 
            +
                    x = self.drop(x)
         | 
| 41 | 
            +
                    return x
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class Attention(nn.Module):
         | 
| 45 | 
            +
                def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.num_heads = num_heads
         | 
| 48 | 
            +
                    head_dim = dim // num_heads
         | 
| 49 | 
            +
                    # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
         | 
| 50 | 
            +
                    self.scale = qk_scale or head_dim ** -0.5
         | 
| 51 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 52 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 53 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 54 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 55 | 
            +
                    self.attn_gradients = None
         | 
| 56 | 
            +
                    self.attention_map = None
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                def save_attn_gradients(self, attn_gradients):
         | 
| 59 | 
            +
                    self.attn_gradients = attn_gradients
         | 
| 60 | 
            +
                    
         | 
| 61 | 
            +
                def get_attn_gradients(self):
         | 
| 62 | 
            +
                    return self.attn_gradients
         | 
| 63 | 
            +
                
         | 
| 64 | 
            +
                def save_attention_map(self, attention_map):
         | 
| 65 | 
            +
                    self.attention_map = attention_map
         | 
| 66 | 
            +
                    
         | 
| 67 | 
            +
                def get_attention_map(self):
         | 
| 68 | 
            +
                    return self.attention_map
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                def forward(self, x, register_hook=False):
         | 
| 71 | 
            +
                    B, N, C = x.shape
         | 
| 72 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 73 | 
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    attn = (q @ k.transpose(-2, -1)) * self.scale
         | 
| 76 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 77 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 78 | 
            +
                            
         | 
| 79 | 
            +
                    if register_hook:
         | 
| 80 | 
            +
                        self.save_attention_map(attn)
         | 
| 81 | 
            +
                        attn.register_hook(self.save_attn_gradients)        
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 84 | 
            +
                    x = self.proj(x)
         | 
| 85 | 
            +
                    x = self.proj_drop(x)
         | 
| 86 | 
            +
                    return x
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            class Block(nn.Module):
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
         | 
| 92 | 
            +
                             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
         | 
| 93 | 
            +
                    super().__init__()
         | 
| 94 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 95 | 
            +
                    self.attn = Attention(
         | 
| 96 | 
            +
                        dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
         | 
| 97 | 
            +
                    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
         | 
| 98 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 99 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 100 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 101 | 
            +
                    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    if use_grad_checkpointing:
         | 
| 104 | 
            +
                        self.attn = checkpoint_wrapper(self.attn)
         | 
| 105 | 
            +
                        self.mlp = checkpoint_wrapper(self.mlp)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def forward(self, x, register_hook=False):
         | 
| 108 | 
            +
                    x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
         | 
| 109 | 
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         | 
| 110 | 
            +
                    return x
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                
         | 
| 113 | 
            +
            class VisionTransformer(nn.Module):
         | 
| 114 | 
            +
                """ Vision Transformer
         | 
| 115 | 
            +
                A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
         | 
| 116 | 
            +
                    https://arxiv.org/abs/2010.11929
         | 
| 117 | 
            +
                """
         | 
| 118 | 
            +
                def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
         | 
| 119 | 
            +
                             num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
         | 
| 120 | 
            +
                             drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 
         | 
| 121 | 
            +
                             use_grad_checkpointing=False, ckpt_layer=0):
         | 
| 122 | 
            +
                    """
         | 
| 123 | 
            +
                    Args:
         | 
| 124 | 
            +
                        img_size (int, tuple): input image size
         | 
| 125 | 
            +
                        patch_size (int, tuple): patch size
         | 
| 126 | 
            +
                        in_chans (int): number of input channels
         | 
| 127 | 
            +
                        num_classes (int): number of classes for classification head
         | 
| 128 | 
            +
                        embed_dim (int): embedding dimension
         | 
| 129 | 
            +
                        depth (int): depth of transformer
         | 
| 130 | 
            +
                        num_heads (int): number of attention heads
         | 
| 131 | 
            +
                        mlp_ratio (int): ratio of mlp hidden dim to embedding dim
         | 
| 132 | 
            +
                        qkv_bias (bool): enable bias for qkv if True
         | 
| 133 | 
            +
                        qk_scale (float): override default qk scale of head_dim ** -0.5 if set
         | 
| 134 | 
            +
                        representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
         | 
| 135 | 
            +
                        drop_rate (float): dropout rate
         | 
| 136 | 
            +
                        attn_drop_rate (float): attention dropout rate
         | 
| 137 | 
            +
                        drop_path_rate (float): stochastic depth rate
         | 
| 138 | 
            +
                        norm_layer: (nn.Module): normalization layer
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
                    super().__init__()
         | 
| 141 | 
            +
                    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
         | 
| 142 | 
            +
                    norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    self.patch_embed = PatchEmbed(
         | 
| 145 | 
            +
                        img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    num_patches = self.patch_embed.num_patches
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 150 | 
            +
                    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         | 
| 151 | 
            +
                    self.pos_drop = nn.Dropout(p=drop_rate)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
         | 
| 154 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 155 | 
            +
                        Block(
         | 
| 156 | 
            +
                            dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
         | 
| 157 | 
            +
                            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
         | 
| 158 | 
            +
                            use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
                        for i in range(depth)])
         | 
| 161 | 
            +
                    self.norm = norm_layer(embed_dim)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    trunc_normal_(self.pos_embed, std=.02)
         | 
| 164 | 
            +
                    trunc_normal_(self.cls_token, std=.02)
         | 
| 165 | 
            +
                    self.apply(self._init_weights)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def _init_weights(self, m):
         | 
| 168 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 169 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 170 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 171 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 172 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 173 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 174 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                @torch.jit.ignore
         | 
| 177 | 
            +
                def no_weight_decay(self):
         | 
| 178 | 
            +
                    return {'pos_embed', 'cls_token'}
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def forward(self, x, register_blk=-1):
         | 
| 181 | 
            +
                    B = x.shape[0]
         | 
| 182 | 
            +
                    x = self.patch_embed(x)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
         | 
| 185 | 
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         | 
| 186 | 
            +
              
         | 
| 187 | 
            +
                    x = x + self.pos_embed[:,:x.size(1),:]
         | 
| 188 | 
            +
                    x = self.pos_drop(x)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    for i,blk in enumerate(self.blocks):
         | 
| 191 | 
            +
                        x = blk(x, register_blk==i)
         | 
| 192 | 
            +
                    x = self.norm(x)
         | 
| 193 | 
            +
                    
         | 
| 194 | 
            +
                    return x
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                @torch.jit.ignore()
         | 
| 197 | 
            +
                def load_pretrained(self, checkpoint_path, prefix=''):
         | 
| 198 | 
            +
                    _load_weights(self, checkpoint_path, prefix)
         | 
| 199 | 
            +
                    
         | 
| 200 | 
            +
             | 
| 201 | 
            +
            @torch.no_grad()
         | 
| 202 | 
            +
            def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
         | 
| 203 | 
            +
                """ Load weights from .npz checkpoints for official Google Brain Flax implementation
         | 
| 204 | 
            +
                """
         | 
| 205 | 
            +
                import numpy as np
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def _n2p(w, t=True):
         | 
| 208 | 
            +
                    if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
         | 
| 209 | 
            +
                        w = w.flatten()
         | 
| 210 | 
            +
                    if t:
         | 
| 211 | 
            +
                        if w.ndim == 4:
         | 
| 212 | 
            +
                            w = w.transpose([3, 2, 0, 1])
         | 
| 213 | 
            +
                        elif w.ndim == 3:
         | 
| 214 | 
            +
                            w = w.transpose([2, 0, 1])
         | 
| 215 | 
            +
                        elif w.ndim == 2:
         | 
| 216 | 
            +
                            w = w.transpose([1, 0])
         | 
| 217 | 
            +
                    return torch.from_numpy(w)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                w = np.load(checkpoint_path)
         | 
| 220 | 
            +
                if not prefix and 'opt/target/embedding/kernel' in w:
         | 
| 221 | 
            +
                    prefix = 'opt/target/'
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                if hasattr(model.patch_embed, 'backbone'):
         | 
| 224 | 
            +
                    # hybrid
         | 
| 225 | 
            +
                    backbone = model.patch_embed.backbone
         | 
| 226 | 
            +
                    stem_only = not hasattr(backbone, 'stem')
         | 
| 227 | 
            +
                    stem = backbone if stem_only else backbone.stem
         | 
| 228 | 
            +
                    stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
         | 
| 229 | 
            +
                    stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
         | 
| 230 | 
            +
                    stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
         | 
| 231 | 
            +
                    if not stem_only:
         | 
| 232 | 
            +
                        for i, stage in enumerate(backbone.stages):
         | 
| 233 | 
            +
                            for j, block in enumerate(stage.blocks):
         | 
| 234 | 
            +
                                bp = f'{prefix}block{i + 1}/unit{j + 1}/'
         | 
| 235 | 
            +
                                for r in range(3):
         | 
| 236 | 
            +
                                    getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
         | 
| 237 | 
            +
                                    getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
         | 
| 238 | 
            +
                                    getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
         | 
| 239 | 
            +
                                if block.downsample is not None:
         | 
| 240 | 
            +
                                    block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
         | 
| 241 | 
            +
                                    block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
         | 
| 242 | 
            +
                                    block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
         | 
| 243 | 
            +
                    embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
         | 
| 244 | 
            +
                else:
         | 
| 245 | 
            +
                    embed_conv_w = adapt_input_conv(
         | 
| 246 | 
            +
                        model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
         | 
| 247 | 
            +
                model.patch_embed.proj.weight.copy_(embed_conv_w)
         | 
| 248 | 
            +
                model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
         | 
| 249 | 
            +
                model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
         | 
| 250 | 
            +
                pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
         | 
| 251 | 
            +
                if pos_embed_w.shape != model.pos_embed.shape:
         | 
| 252 | 
            +
                    pos_embed_w = resize_pos_embed(  # resize pos embedding when different size from pretrained weights
         | 
| 253 | 
            +
                        pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
         | 
| 254 | 
            +
                model.pos_embed.copy_(pos_embed_w)
         | 
| 255 | 
            +
                model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
         | 
| 256 | 
            +
                model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
         | 
| 257 | 
            +
            #     if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
         | 
| 258 | 
            +
            #         model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
         | 
| 259 | 
            +
            #         model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
         | 
| 260 | 
            +
            #     if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
         | 
| 261 | 
            +
            #         model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
         | 
| 262 | 
            +
            #         model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
         | 
| 263 | 
            +
                for i, block in enumerate(model.blocks.children()):
         | 
| 264 | 
            +
                    block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
         | 
| 265 | 
            +
                    mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
         | 
| 266 | 
            +
                    block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
         | 
| 267 | 
            +
                    block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
         | 
| 268 | 
            +
                    block.attn.qkv.weight.copy_(torch.cat([
         | 
| 269 | 
            +
                        _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
         | 
| 270 | 
            +
                    block.attn.qkv.bias.copy_(torch.cat([
         | 
| 271 | 
            +
                        _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
         | 
| 272 | 
            +
                    block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
         | 
| 273 | 
            +
                    block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
         | 
| 274 | 
            +
                    for r in range(2):
         | 
| 275 | 
            +
                        getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
         | 
| 276 | 
            +
                        getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
         | 
| 277 | 
            +
                    block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
         | 
| 278 | 
            +
                    block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                        
         | 
| 281 | 
            +
            def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):        
         | 
| 282 | 
            +
                # interpolate position embedding
         | 
| 283 | 
            +
                embedding_size = pos_embed_checkpoint.shape[-1]
         | 
| 284 | 
            +
                num_patches = visual_encoder.patch_embed.num_patches
         | 
| 285 | 
            +
                num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
         | 
| 286 | 
            +
                # height (== width) for the checkpoint position embedding
         | 
| 287 | 
            +
                orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
         | 
| 288 | 
            +
                # height (== width) for the new position embedding
         | 
| 289 | 
            +
                new_size = int(num_patches ** 0.5)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                if orig_size!=new_size:
         | 
| 292 | 
            +
                    # class_token and dist_token are kept unchanged
         | 
| 293 | 
            +
                    extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 294 | 
            +
                    # only the position tokens are interpolated
         | 
| 295 | 
            +
                    pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 296 | 
            +
                    pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
         | 
| 297 | 
            +
                    pos_tokens = torch.nn.functional.interpolate(
         | 
| 298 | 
            +
                        pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
         | 
| 299 | 
            +
                    pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
         | 
| 300 | 
            +
                    new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 301 | 
            +
                    print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
         | 
| 302 | 
            +
                    
         | 
| 303 | 
            +
                    return new_pos_embed    
         | 
| 304 | 
            +
                else:
         | 
| 305 | 
            +
                    return pos_embed_checkpoint
         | 
    	
        finetune/clean_captions_and_tags.py
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # このスクリプトのライセンスは、Apache License 2.0とします
         | 
| 2 | 
            +
            # (c) 2022 Kohya S. @kohya_ss
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            import glob
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            import re
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from tqdm import tqdm
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
         | 
| 13 | 
            +
            PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
         | 
| 14 | 
            +
            PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
         | 
| 15 | 
            +
            PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
         | 
| 18 | 
            +
            PATTERNS_REMOVE_IN_MULTI = [
         | 
| 19 | 
            +
                PATTERN_HAIR_LENGTH,
         | 
| 20 | 
            +
                PATTERN_HAIR_CUT,
         | 
| 21 | 
            +
                re.compile(r', [\w\-]+ eyes, '),
         | 
| 22 | 
            +
                re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
         | 
| 23 | 
            +
                # 複数の髪型定義がある場合は削除する
         | 
| 24 | 
            +
                re.compile(
         | 
| 25 | 
            +
                    r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
         | 
| 26 | 
            +
            ]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def clean_tags(image_key, tags):
         | 
| 30 | 
            +
              # replace '_' to ' '
         | 
| 31 | 
            +
              tags = tags.replace('^_^', '^@@@^')
         | 
| 32 | 
            +
              tags = tags.replace('_', ' ')
         | 
| 33 | 
            +
              tags = tags.replace('^@@@^', '^_^')
         | 
| 34 | 
            +
             | 
| 35 | 
            +
              # remove rating: deepdanbooruのみ
         | 
| 36 | 
            +
              tokens = tags.split(", rating")
         | 
| 37 | 
            +
              if len(tokens) == 1:
         | 
| 38 | 
            +
                # WD14 taggerのときはこちらになるのでメッセージは出さない
         | 
| 39 | 
            +
                # print("no rating:")
         | 
| 40 | 
            +
                # print(f"{image_key} {tags}")
         | 
| 41 | 
            +
                pass
         | 
| 42 | 
            +
              else:
         | 
| 43 | 
            +
                if len(tokens) > 2:
         | 
| 44 | 
            +
                  print("multiple ratings:")
         | 
| 45 | 
            +
                  print(f"{image_key} {tags}")
         | 
| 46 | 
            +
                tags = tokens[0]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
              tags = ", " + tags.replace(", ", ", , ") + ", "     # カンマ付きで検索をするための身も蓋もない対策
         | 
| 49 | 
            +
              
         | 
| 50 | 
            +
              # 複数の人物がいる場合は髪色等のタグを削除する
         | 
| 51 | 
            +
              if 'girls' in tags or 'boys' in tags:
         | 
| 52 | 
            +
                for pat in PATTERNS_REMOVE_IN_MULTI:
         | 
| 53 | 
            +
                  found = pat.findall(tags)
         | 
| 54 | 
            +
                  if len(found) > 1:                        # 二つ以上、タグがある
         | 
| 55 | 
            +
                    tags = pat.sub("", tags)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                # 髪の特殊対応
         | 
| 58 | 
            +
                srch_hair_len = PATTERN_HAIR_LENGTH.search(tags)   # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
         | 
| 59 | 
            +
                if srch_hair_len:
         | 
| 60 | 
            +
                  org = srch_hair_len.group()
         | 
| 61 | 
            +
                  tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                found = PATTERN_HAIR.findall(tags)
         | 
| 64 | 
            +
                if len(found) > 1:
         | 
| 65 | 
            +
                  tags = PATTERN_HAIR.sub("", tags)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                if srch_hair_len:
         | 
| 68 | 
            +
                  tags = tags.replace(", @@@, ", org)                   # 戻す
         | 
| 69 | 
            +
             | 
| 70 | 
            +
              # white shirtとshirtみたいな重複タグの削除
         | 
| 71 | 
            +
              found = PATTERN_WORD.findall(tags)
         | 
| 72 | 
            +
              for word in found:
         | 
| 73 | 
            +
                if re.search(f", ((\w+) )+{word}, ", tags):
         | 
| 74 | 
            +
                  tags = tags.replace(f", {word}, ", "")
         | 
| 75 | 
            +
             | 
| 76 | 
            +
              tags = tags.replace(", , ", ", ")
         | 
| 77 | 
            +
              assert tags.startswith(", ") and tags.endswith(", ")
         | 
| 78 | 
            +
              tags = tags[2:-2]
         | 
| 79 | 
            +
              return tags
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            # 上から順に検索、置換される
         | 
| 83 | 
            +
            # ('置換元文字列', '置換後文字列')
         | 
| 84 | 
            +
            CAPTION_REPLACEMENTS = [
         | 
| 85 | 
            +
                ('anime anime', 'anime'),
         | 
| 86 | 
            +
                ('young ', ''),
         | 
| 87 | 
            +
                ('anime girl', 'girl'),
         | 
| 88 | 
            +
                ('cartoon female', 'girl'),
         | 
| 89 | 
            +
                ('cartoon lady', 'girl'),
         | 
| 90 | 
            +
                ('cartoon character', 'girl'),      # a or ~s
         | 
| 91 | 
            +
                ('cartoon woman', 'girl'),
         | 
| 92 | 
            +
                ('cartoon women', 'girls'),
         | 
| 93 | 
            +
                ('cartoon girl', 'girl'),
         | 
| 94 | 
            +
                ('anime female', 'girl'),
         | 
| 95 | 
            +
                ('anime lady', 'girl'),
         | 
| 96 | 
            +
                ('anime character', 'girl'),      # a or ~s
         | 
| 97 | 
            +
                ('anime woman', 'girl'),
         | 
| 98 | 
            +
                ('anime women', 'girls'),
         | 
| 99 | 
            +
                ('lady', 'girl'),
         | 
| 100 | 
            +
                ('female', 'girl'),
         | 
| 101 | 
            +
                ('woman', 'girl'),
         | 
| 102 | 
            +
                ('women', 'girls'),
         | 
| 103 | 
            +
                ('people', 'girls'),
         | 
| 104 | 
            +
                ('person', 'girl'),
         | 
| 105 | 
            +
                ('a cartoon figure', 'a figure'),
         | 
| 106 | 
            +
                ('a cartoon image', 'an image'),
         | 
| 107 | 
            +
                ('a cartoon picture', 'a picture'),
         | 
| 108 | 
            +
                ('an anime cartoon image', 'an image'),
         | 
| 109 | 
            +
                ('a cartoon anime drawing', 'a drawing'),
         | 
| 110 | 
            +
                ('a cartoon drawing', 'a drawing'),
         | 
| 111 | 
            +
                ('girl girl', 'girl'),
         | 
| 112 | 
            +
            ]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def clean_caption(caption):
         | 
| 116 | 
            +
              for rf, rt in CAPTION_REPLACEMENTS:
         | 
| 117 | 
            +
                replaced = True
         | 
| 118 | 
            +
                while replaced:
         | 
| 119 | 
            +
                  bef = caption
         | 
| 120 | 
            +
                  caption = caption.replace(rf, rt)
         | 
| 121 | 
            +
                  replaced = bef != caption
         | 
| 122 | 
            +
              return caption
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def main(args):
         | 
| 126 | 
            +
              if os.path.exists(args.in_json):
         | 
| 127 | 
            +
                print(f"loading existing metadata: {args.in_json}")
         | 
| 128 | 
            +
                with open(args.in_json, "rt", encoding='utf-8') as f:
         | 
| 129 | 
            +
                  metadata = json.load(f)
         | 
| 130 | 
            +
              else:
         | 
| 131 | 
            +
                print("no metadata / メタデータファイルがありません")
         | 
| 132 | 
            +
                return
         | 
| 133 | 
            +
             | 
| 134 | 
            +
              print("cleaning captions and tags.")
         | 
| 135 | 
            +
              image_keys = list(metadata.keys())
         | 
| 136 | 
            +
              for image_key in tqdm(image_keys):
         | 
| 137 | 
            +
                tags = metadata[image_key].get('tags')
         | 
| 138 | 
            +
                if tags is None:
         | 
| 139 | 
            +
                  print(f"image does not have tags / メタデータにタグがありません: {image_key}")
         | 
| 140 | 
            +
                else:
         | 
| 141 | 
            +
                  org = tags
         | 
| 142 | 
            +
                  tags = clean_tags(image_key, tags)
         | 
| 143 | 
            +
                  metadata[image_key]['tags'] = tags
         | 
| 144 | 
            +
                  if args.debug and org != tags:
         | 
| 145 | 
            +
                    print("FROM: " + org)
         | 
| 146 | 
            +
                    print("TO:   " + tags)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                caption = metadata[image_key].get('caption')
         | 
| 149 | 
            +
                if caption is None:
         | 
| 150 | 
            +
                  print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
         | 
| 151 | 
            +
                else:
         | 
| 152 | 
            +
                  org = caption
         | 
| 153 | 
            +
                  caption = clean_caption(caption)
         | 
| 154 | 
            +
                  metadata[image_key]['caption'] = caption
         | 
| 155 | 
            +
                  if args.debug and org != caption:
         | 
| 156 | 
            +
                    print("FROM: " + org)
         | 
| 157 | 
            +
                    print("TO:   " + caption)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
              # metadataを書き出して終わり
         | 
| 160 | 
            +
              print(f"writing metadata: {args.out_json}")
         | 
| 161 | 
            +
              with open(args.out_json, "wt", encoding='utf-8') as f:
         | 
| 162 | 
            +
                json.dump(metadata, f, indent=2)
         | 
| 163 | 
            +
              print("done!")
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 167 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 168 | 
            +
              # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
         | 
| 169 | 
            +
              parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
         | 
| 170 | 
            +
              parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
         | 
| 171 | 
            +
              parser.add_argument("--debug", action="store_true", help="debug mode")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
              return parser
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            if __name__ == '__main__':
         | 
| 177 | 
            +
              parser = setup_parser()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
              args, unknown = parser.parse_known_args()
         | 
| 180 | 
            +
              if len(unknown) == 1:
         | 
| 181 | 
            +
                print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
         | 
| 182 | 
            +
                print("All captions and tags in the metadata are processed.")
         | 
| 183 | 
            +
                print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
         | 
| 184 | 
            +
                print("メタデータ内のすべてのキャプションとタグが処理されます。")
         | 
| 185 | 
            +
                args.in_json = args.out_json
         | 
| 186 | 
            +
                args.out_json = unknown[0]
         | 
| 187 | 
            +
              elif len(unknown) > 0:
         | 
| 188 | 
            +
                raise ValueError(f"error: unrecognized arguments: {unknown}")
         | 
| 189 | 
            +
             | 
| 190 | 
            +
              main(args)
         | 
    	
        finetune/hypernetwork_nai.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # NAI compatible
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class HypernetworkModule(torch.nn.Module):
         | 
| 7 | 
            +
              def __init__(self, dim, multiplier=1.0):
         | 
| 8 | 
            +
                super().__init__()
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                linear1 = torch.nn.Linear(dim, dim * 2)
         | 
| 11 | 
            +
                linear2 = torch.nn.Linear(dim * 2, dim)
         | 
| 12 | 
            +
                linear1.weight.data.normal_(mean=0.0, std=0.01)
         | 
| 13 | 
            +
                linear1.bias.data.zero_()
         | 
| 14 | 
            +
                linear2.weight.data.normal_(mean=0.0, std=0.01)
         | 
| 15 | 
            +
                linear2.bias.data.zero_()
         | 
| 16 | 
            +
                linears = [linear1, linear2]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                self.linear = torch.nn.Sequential(*linears)
         | 
| 19 | 
            +
                self.multiplier = multiplier
         | 
| 20 | 
            +
             | 
| 21 | 
            +
              def forward(self, x):
         | 
| 22 | 
            +
                return x + self.linear(x) * self.multiplier
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class Hypernetwork(torch.nn.Module):
         | 
| 26 | 
            +
              enable_sizes = [320, 640, 768, 1280]
         | 
| 27 | 
            +
              # return self.modules[Hypernetwork.enable_sizes.index(size)]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
              def __init__(self, multiplier=1.0) -> None:
         | 
| 30 | 
            +
                super().__init__()
         | 
| 31 | 
            +
                self.modules = []
         | 
| 32 | 
            +
                for size in Hypernetwork.enable_sizes:
         | 
| 33 | 
            +
                  self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
         | 
| 34 | 
            +
                  self.register_module(f"{size}_0", self.modules[-1][0])
         | 
| 35 | 
            +
                  self.register_module(f"{size}_1", self.modules[-1][1])
         | 
| 36 | 
            +
             | 
| 37 | 
            +
              def apply_to_stable_diffusion(self, text_encoder, vae, unet):
         | 
| 38 | 
            +
                blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
         | 
| 39 | 
            +
                for block in blocks:
         | 
| 40 | 
            +
                  for subblk in block:
         | 
| 41 | 
            +
                    if 'SpatialTransformer' in str(type(subblk)):
         | 
| 42 | 
            +
                      for tf_block in subblk.transformer_blocks:
         | 
| 43 | 
            +
                        for attn in [tf_block.attn1, tf_block.attn2]:
         | 
| 44 | 
            +
                          size = attn.context_dim
         | 
| 45 | 
            +
                          if size in Hypernetwork.enable_sizes:
         | 
| 46 | 
            +
                            attn.hypernetwork = self
         | 
| 47 | 
            +
                          else:
         | 
| 48 | 
            +
                            attn.hypernetwork = None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
              def apply_to_diffusers(self, text_encoder, vae, unet):
         | 
| 51 | 
            +
                blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
         | 
| 52 | 
            +
                for block in blocks:
         | 
| 53 | 
            +
                  if hasattr(block, 'attentions'):
         | 
| 54 | 
            +
                    for subblk in block.attentions:
         | 
| 55 | 
            +
                      if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)):      # 0.6.0 and 0.7~
         | 
| 56 | 
            +
                        for tf_block in subblk.transformer_blocks:
         | 
| 57 | 
            +
                          for attn in [tf_block.attn1, tf_block.attn2]:
         | 
| 58 | 
            +
                            size = attn.to_k.in_features
         | 
| 59 | 
            +
                            if size in Hypernetwork.enable_sizes:
         | 
| 60 | 
            +
                              attn.hypernetwork = self
         | 
| 61 | 
            +
                            else:
         | 
| 62 | 
            +
                              attn.hypernetwork = None
         | 
| 63 | 
            +
                return True       # TODO error checking
         | 
| 64 | 
            +
             | 
| 65 | 
            +
              def forward(self, x, context):
         | 
| 66 | 
            +
                size = context.shape[-1]
         | 
| 67 | 
            +
                assert size in Hypernetwork.enable_sizes
         | 
| 68 | 
            +
                module = self.modules[Hypernetwork.enable_sizes.index(size)]
         | 
| 69 | 
            +
                return module[0].forward(context), module[1].forward(context)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
              def load_from_state_dict(self, state_dict):
         | 
| 72 | 
            +
                # old ver to new ver
         | 
| 73 | 
            +
                changes = {
         | 
| 74 | 
            +
                    'linear1.bias': 'linear.0.bias',
         | 
| 75 | 
            +
                    'linear1.weight': 'linear.0.weight',
         | 
| 76 | 
            +
                    'linear2.bias': 'linear.1.bias',
         | 
| 77 | 
            +
                    'linear2.weight': 'linear.1.weight',
         | 
| 78 | 
            +
                }
         | 
| 79 | 
            +
                for key_from, key_to in changes.items():
         | 
| 80 | 
            +
                  if key_from in state_dict:
         | 
| 81 | 
            +
                    state_dict[key_to] = state_dict[key_from]
         | 
| 82 | 
            +
                    del state_dict[key_from]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                for size, sd in state_dict.items():
         | 
| 85 | 
            +
                  if type(size) == int:
         | 
| 86 | 
            +
                    self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
         | 
| 87 | 
            +
                    self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
         | 
| 88 | 
            +
                return True
         | 
| 89 | 
            +
             | 
| 90 | 
            +
              def get_state_dict(self):
         | 
| 91 | 
            +
                state_dict = {}
         | 
| 92 | 
            +
                for i, size in enumerate(Hypernetwork.enable_sizes):
         | 
| 93 | 
            +
                  sd0 = self.modules[i][0].state_dict()
         | 
| 94 | 
            +
                  sd1 = self.modules[i][1].state_dict()
         | 
| 95 | 
            +
                  state_dict[size] = [sd0, sd1]
         | 
| 96 | 
            +
                return state_dict
         | 
    	
        finetune/make_captions.py
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import glob
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torchvision import transforms
         | 
| 12 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 13 | 
            +
            from blip.blip import blip_decoder
         | 
| 14 | 
            +
            import library.train_util as train_util
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            IMAGE_SIZE = 384
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # 正方形でいいのか? という気がするがソースがそうなので
         | 
| 22 | 
            +
            IMAGE_TRANSFORM = transforms.Compose([
         | 
| 23 | 
            +
                transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
         | 
| 24 | 
            +
                transforms.ToTensor(),
         | 
| 25 | 
            +
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
         | 
| 26 | 
            +
            ])
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # 共通化したいが微妙に処理が異なる……
         | 
| 29 | 
            +
            class ImageLoadingTransformDataset(torch.utils.data.Dataset):
         | 
| 30 | 
            +
              def __init__(self, image_paths):
         | 
| 31 | 
            +
                self.images = image_paths
         | 
| 32 | 
            +
             | 
| 33 | 
            +
              def __len__(self):
         | 
| 34 | 
            +
                return len(self.images)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
              def __getitem__(self, idx):
         | 
| 37 | 
            +
                img_path = self.images[idx]
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                try:
         | 
| 40 | 
            +
                  image = Image.open(img_path).convert("RGB")
         | 
| 41 | 
            +
                  # convert to tensor temporarily so dataloader will accept it
         | 
| 42 | 
            +
                  tensor = IMAGE_TRANSFORM(image)
         | 
| 43 | 
            +
                except Exception as e:
         | 
| 44 | 
            +
                  print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
         | 
| 45 | 
            +
                  return None
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                return (tensor, img_path)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def collate_fn_remove_corrupted(batch):
         | 
| 51 | 
            +
              """Collate function that allows to remove corrupted examples in the
         | 
| 52 | 
            +
              dataloader. It expects that the dataloader returns 'None' when that occurs.
         | 
| 53 | 
            +
              The 'None's in the batch are removed.
         | 
| 54 | 
            +
              """
         | 
| 55 | 
            +
              # Filter out all the Nones (corrupted examples)
         | 
| 56 | 
            +
              batch = list(filter(lambda x: x is not None, batch))
         | 
| 57 | 
            +
              return batch
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def main(args):
         | 
| 61 | 
            +
              # fix the seed for reproducibility
         | 
| 62 | 
            +
              seed = args.seed  # + utils.get_rank()
         | 
| 63 | 
            +
              torch.manual_seed(seed)
         | 
| 64 | 
            +
              np.random.seed(seed)
         | 
| 65 | 
            +
              random.seed(seed)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
              if not os.path.exists("blip"):
         | 
| 68 | 
            +
                args.train_data_dir = os.path.abspath(args.train_data_dir)        # convert to absolute path
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                cwd = os.getcwd()
         | 
| 71 | 
            +
                print('Current Working Directory is: ', cwd)
         | 
| 72 | 
            +
                os.chdir('finetune')
         | 
| 73 | 
            +
             | 
| 74 | 
            +
              print(f"load images from {args.train_data_dir}")
         | 
| 75 | 
            +
              image_paths = train_util.glob_images(args.train_data_dir)
         | 
| 76 | 
            +
              print(f"found {len(image_paths)} images.")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
              print(f"loading BLIP caption: {args.caption_weights}")
         | 
| 79 | 
            +
              model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
         | 
| 80 | 
            +
              model.eval()
         | 
| 81 | 
            +
              model = model.to(DEVICE)
         | 
| 82 | 
            +
              print("BLIP loaded")
         | 
| 83 | 
            +
             | 
| 84 | 
            +
              # captioningする
         | 
| 85 | 
            +
              def run_batch(path_imgs):
         | 
| 86 | 
            +
                imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                with torch.no_grad():
         | 
| 89 | 
            +
                  if args.beam_search:
         | 
| 90 | 
            +
                    captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
         | 
| 91 | 
            +
                                              max_length=args.max_length, min_length=args.min_length)
         | 
| 92 | 
            +
                  else:
         | 
| 93 | 
            +
                    captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                for (image_path, _), caption in zip(path_imgs, captions):
         | 
| 96 | 
            +
                  with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
         | 
| 97 | 
            +
                    f.write(caption + "\n")
         | 
| 98 | 
            +
                    if args.debug:
         | 
| 99 | 
            +
                      print(image_path, caption)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
              # 読み込みの高速化のためにDataLoaderを使うオプション
         | 
| 102 | 
            +
              if args.max_data_loader_n_workers is not None:
         | 
| 103 | 
            +
                dataset = ImageLoadingTransformDataset(image_paths)
         | 
| 104 | 
            +
                data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
         | 
| 105 | 
            +
                                                  num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
         | 
| 106 | 
            +
              else:
         | 
| 107 | 
            +
                data = [[(None, ip)] for ip in image_paths]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
              b_imgs = []
         | 
| 110 | 
            +
              for data_entry in tqdm(data, smoothing=0.0):
         | 
| 111 | 
            +
                for data in data_entry:
         | 
| 112 | 
            +
                  if data is None:
         | 
| 113 | 
            +
                    continue
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                  img_tensor, image_path = data
         | 
| 116 | 
            +
                  if img_tensor is None:
         | 
| 117 | 
            +
                    try:
         | 
| 118 | 
            +
                      raw_image = Image.open(image_path)
         | 
| 119 | 
            +
                      if raw_image.mode != 'RGB':
         | 
| 120 | 
            +
                        raw_image = raw_image.convert("RGB")
         | 
| 121 | 
            +
                      img_tensor = IMAGE_TRANSFORM(raw_image)
         | 
| 122 | 
            +
                    except Exception as e:
         | 
| 123 | 
            +
                      print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
         | 
| 124 | 
            +
                      continue
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                  b_imgs.append((image_path, img_tensor))
         | 
| 127 | 
            +
                  if len(b_imgs) >= args.batch_size:
         | 
| 128 | 
            +
                    run_batch(b_imgs)
         | 
| 129 | 
            +
                    b_imgs.clear()
         | 
| 130 | 
            +
              if len(b_imgs) > 0:
         | 
| 131 | 
            +
                run_batch(b_imgs)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
              print("done!")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 137 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 138 | 
            +
              parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
         | 
| 139 | 
            +
              parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
         | 
| 140 | 
            +
                                  help="BLIP caption weights (model_large_caption.pth) / BLIP captionの��みファイル(model_large_caption.pth)")
         | 
| 141 | 
            +
              parser.add_argument("--caption_extention", type=str, default=None,
         | 
| 142 | 
            +
                                  help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
         | 
| 143 | 
            +
              parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
         | 
| 144 | 
            +
              parser.add_argument("--beam_search", action="store_true",
         | 
| 145 | 
            +
                                  help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
         | 
| 146 | 
            +
              parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
         | 
| 147 | 
            +
              parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
         | 
| 148 | 
            +
                                  help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
         | 
| 149 | 
            +
              parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
         | 
| 150 | 
            +
              parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
         | 
| 151 | 
            +
              parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
         | 
| 152 | 
            +
              parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
         | 
| 153 | 
            +
              parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
         | 
| 154 | 
            +
              parser.add_argument("--debug", action="store_true", help="debug mode")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
              return parser
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            if __name__ == '__main__':
         | 
| 160 | 
            +
              parser = setup_parser()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
              args = parser.parse_args()
         | 
| 163 | 
            +
             | 
| 164 | 
            +
              # スペルミスしていたオプションを復元する
         | 
| 165 | 
            +
              if args.caption_extention is not None:
         | 
| 166 | 
            +
                args.caption_extension = args.caption_extention
         | 
| 167 | 
            +
             | 
| 168 | 
            +
              main(args)
         | 
    	
        finetune/make_captions_by_git.py
    ADDED
    
    | @@ -0,0 +1,151 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from transformers import AutoProcessor, AutoModelForCausalLM
         | 
| 9 | 
            +
            from transformers.generation.utils import GenerationMixin
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import library.train_util as train_util
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            PATTERN_REPLACE = [
         | 
| 17 | 
            +
                re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
         | 
| 18 | 
            +
                re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
         | 
| 19 | 
            +
                re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
         | 
| 20 | 
            +
                re.compile(r'with the number \d+ on (it|\w+ \w+)'),
         | 
| 21 | 
            +
                re.compile(r'with the words "'),
         | 
| 22 | 
            +
                re.compile(r'word \w+ on it'),
         | 
| 23 | 
            +
                re.compile(r'that says the word \w+ on it'),
         | 
| 24 | 
            +
                re.compile('that says\'the word "( on it)?'),
         | 
| 25 | 
            +
            ]
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # 誤検知しまくりの with the word xxxx を消す
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def remove_words(captions, debug):
         | 
| 31 | 
            +
              removed_caps = []
         | 
| 32 | 
            +
              for caption in captions:
         | 
| 33 | 
            +
                cap = caption
         | 
| 34 | 
            +
                for pat in PATTERN_REPLACE:
         | 
| 35 | 
            +
                  cap = pat.sub("", cap)
         | 
| 36 | 
            +
                if debug and cap != caption:
         | 
| 37 | 
            +
                  print(caption)
         | 
| 38 | 
            +
                  print(cap)
         | 
| 39 | 
            +
                removed_caps.append(cap)
         | 
| 40 | 
            +
              return removed_caps
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def collate_fn_remove_corrupted(batch):
         | 
| 44 | 
            +
              """Collate function that allows to remove corrupted examples in the
         | 
| 45 | 
            +
              dataloader. It expects that the dataloader returns 'None' when that occurs.
         | 
| 46 | 
            +
              The 'None's in the batch are removed.
         | 
| 47 | 
            +
              """
         | 
| 48 | 
            +
              # Filter out all the Nones (corrupted examples)
         | 
| 49 | 
            +
              batch = list(filter(lambda x: x is not None, batch))
         | 
| 50 | 
            +
              return batch
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def main(args):
         | 
| 54 | 
            +
              # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
         | 
| 55 | 
            +
              org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
         | 
| 56 | 
            +
              curr_batch_size = [args.batch_size]         # ループの最後で件数がbatch_size未満になるので入れ替えられるように
         | 
| 57 | 
            +
             | 
| 58 | 
            +
              # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
         | 
| 59 | 
            +
              # ここより上で置き換えようとするとすごく大変
         | 
| 60 | 
            +
              def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
         | 
| 61 | 
            +
                input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
         | 
| 62 | 
            +
                if input_ids.size()[0] != curr_batch_size[0]:
         | 
| 63 | 
            +
                  input_ids = input_ids.repeat(curr_batch_size[0], 1)
         | 
| 64 | 
            +
                return input_ids
         | 
| 65 | 
            +
              GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
         | 
| 66 | 
            +
             | 
| 67 | 
            +
              print(f"load images from {args.train_data_dir}")
         | 
| 68 | 
            +
              image_paths = train_util.glob_images(args.train_data_dir)
         | 
| 69 | 
            +
              print(f"found {len(image_paths)} images.")
         | 
| 70 | 
            +
             | 
| 71 | 
            +
              # できればcacheに依存せず明示的にダウンロードしたい
         | 
| 72 | 
            +
              print(f"loading GIT: {args.model_id}")
         | 
| 73 | 
            +
              git_processor = AutoProcessor.from_pretrained(args.model_id)
         | 
| 74 | 
            +
              git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
         | 
| 75 | 
            +
              print("GIT loaded")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
              # captioningする
         | 
| 78 | 
            +
              def run_batch(path_imgs):
         | 
| 79 | 
            +
                imgs = [im for _, im in path_imgs]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                curr_batch_size[0] = len(path_imgs)
         | 
| 82 | 
            +
                inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE)           # 画像はpil形式
         | 
| 83 | 
            +
                generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
         | 
| 84 | 
            +
                captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                if args.remove_words:
         | 
| 87 | 
            +
                  captions = remove_words(captions, args.debug)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                for (image_path, _), caption in zip(path_imgs, captions):
         | 
| 90 | 
            +
                  with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
         | 
| 91 | 
            +
                    f.write(caption + "\n")
         | 
| 92 | 
            +
                    if args.debug:
         | 
| 93 | 
            +
                      print(image_path, caption)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
              # 読み込みの高速化のためにDataLoaderを使うオプション
         | 
| 96 | 
            +
              if args.max_data_loader_n_workers is not None:
         | 
| 97 | 
            +
                dataset = train_util.ImageLoadingDataset(image_paths)
         | 
| 98 | 
            +
                data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
         | 
| 99 | 
            +
                                                   num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
         | 
| 100 | 
            +
              else:
         | 
| 101 | 
            +
                data = [[(None, ip)] for ip in image_paths]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
              b_imgs = []
         | 
| 104 | 
            +
              for data_entry in tqdm(data, smoothing=0.0):
         | 
| 105 | 
            +
                for data in data_entry:
         | 
| 106 | 
            +
                  if data is None:
         | 
| 107 | 
            +
                    continue
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                  image, image_path = data
         | 
| 110 | 
            +
                  if image is None:
         | 
| 111 | 
            +
                    try:
         | 
| 112 | 
            +
                      image = Image.open(image_path)
         | 
| 113 | 
            +
                      if image.mode != 'RGB':
         | 
| 114 | 
            +
                        image = image.convert("RGB")
         | 
| 115 | 
            +
                    except Exception as e:
         | 
| 116 | 
            +
                      print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
         | 
| 117 | 
            +
                      continue
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                  b_imgs.append((image_path, image))
         | 
| 120 | 
            +
                  if len(b_imgs) >= args.batch_size:
         | 
| 121 | 
            +
                    run_batch(b_imgs)
         | 
| 122 | 
            +
                    b_imgs.clear()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
              if len(b_imgs) > 0:
         | 
| 125 | 
            +
                run_batch(b_imgs)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
              print("done!")
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 131 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 132 | 
            +
              parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
         | 
| 133 | 
            +
              parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
         | 
| 134 | 
            +
              parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
         | 
| 135 | 
            +
                                  help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
         | 
| 136 | 
            +
              parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
         | 
| 137 | 
            +
              parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
         | 
| 138 | 
            +
                                  help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
         | 
| 139 | 
            +
              parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
         | 
| 140 | 
            +
              parser.add_argument("--remove_words", action="store_true",
         | 
| 141 | 
            +
                                  help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
         | 
| 142 | 
            +
              parser.add_argument("--debug", action="store_true", help="debug mode")
         | 
| 143 | 
            +
             | 
| 144 | 
            +
              return parser
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            if __name__ == '__main__':
         | 
| 148 | 
            +
              parser = setup_parser()
         | 
| 149 | 
            +
             | 
| 150 | 
            +
              args = parser.parse_args()
         | 
| 151 | 
            +
              main(args)
         | 
    	
        finetune/merge_captions_to_metadata.py
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            import library.train_util as train_util
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def main(args):
         | 
| 10 | 
            +
              assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
         | 
| 11 | 
            +
             | 
| 12 | 
            +
              train_data_dir_path = Path(args.train_data_dir)
         | 
| 13 | 
            +
              image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
         | 
| 14 | 
            +
              print(f"found {len(image_paths)} images.")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
              if args.in_json is None and Path(args.out_json).is_file():
         | 
| 17 | 
            +
                args.in_json = args.out_json
         | 
| 18 | 
            +
             | 
| 19 | 
            +
              if args.in_json is not None:
         | 
| 20 | 
            +
                print(f"loading existing metadata: {args.in_json}")
         | 
| 21 | 
            +
                metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
         | 
| 22 | 
            +
                print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
         | 
| 23 | 
            +
              else:
         | 
| 24 | 
            +
                print("new metadata will be created / 新しいメタデータファイルが作成されます")
         | 
| 25 | 
            +
                metadata = {}
         | 
| 26 | 
            +
             | 
| 27 | 
            +
              print("merge caption texts to metadata json.")
         | 
| 28 | 
            +
              for image_path in tqdm(image_paths):
         | 
| 29 | 
            +
                caption_path = image_path.with_suffix(args.caption_extension)
         | 
| 30 | 
            +
                caption = caption_path.read_text(encoding='utf-8').strip()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if not os.path.exists(caption_path):
         | 
| 33 | 
            +
                  caption_path = os.path.join(image_path, args.caption_extension)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                image_key = str(image_path) if args.full_path else image_path.stem
         | 
| 36 | 
            +
                if image_key not in metadata:
         | 
| 37 | 
            +
                  metadata[image_key] = {}
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                metadata[image_key]['caption'] = caption
         | 
| 40 | 
            +
                if args.debug:
         | 
| 41 | 
            +
                  print(image_key, caption)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
              # metadataを書き出して終わり
         | 
| 44 | 
            +
              print(f"writing metadata: {args.out_json}")
         | 
| 45 | 
            +
              Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
         | 
| 46 | 
            +
              print("done!")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 50 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 51 | 
            +
              parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
         | 
| 52 | 
            +
              parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
         | 
| 53 | 
            +
              parser.add_argument("--in_json", type=str,
         | 
| 54 | 
            +
                                  help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
         | 
| 55 | 
            +
              parser.add_argument("--caption_extention", type=str, default=None,
         | 
| 56 | 
            +
                                  help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
         | 
| 57 | 
            +
              parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
         | 
| 58 | 
            +
              parser.add_argument("--full_path", action="store_true",
         | 
| 59 | 
            +
                                  help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
         | 
| 60 | 
            +
              parser.add_argument("--recursive", action="store_true",
         | 
| 61 | 
            +
                                  help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
         | 
| 62 | 
            +
              parser.add_argument("--debug", action="store_true", help="debug mode")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
              return parser
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            if __name__ == '__main__':
         | 
| 68 | 
            +
              parser = setup_parser()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
              args = parser.parse_args()
         | 
| 71 | 
            +
             | 
| 72 | 
            +
              # スペルミスしていたオプションを復元する
         | 
| 73 | 
            +
              if args.caption_extention is not None:
         | 
| 74 | 
            +
                args.caption_extension = args.caption_extention
         | 
| 75 | 
            +
             | 
| 76 | 
            +
              main(args)
         | 
    	
        finetune/merge_dd_tags_to_metadata.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            import library.train_util as train_util
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def main(args):
         | 
| 10 | 
            +
              assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
         | 
| 11 | 
            +
             | 
| 12 | 
            +
              train_data_dir_path = Path(args.train_data_dir)
         | 
| 13 | 
            +
              image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
         | 
| 14 | 
            +
              print(f"found {len(image_paths)} images.")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
              if args.in_json is None and Path(args.out_json).is_file():
         | 
| 17 | 
            +
                args.in_json = args.out_json
         | 
| 18 | 
            +
             | 
| 19 | 
            +
              if args.in_json is not None:
         | 
| 20 | 
            +
                print(f"loading existing metadata: {args.in_json}")
         | 
| 21 | 
            +
                metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
         | 
| 22 | 
            +
                print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
         | 
| 23 | 
            +
              else:
         | 
| 24 | 
            +
                print("new metadata will be created / 新しいメタデータファイルが作成されます")
         | 
| 25 | 
            +
                metadata = {}
         | 
| 26 | 
            +
             | 
| 27 | 
            +
              print("merge tags to metadata json.")
         | 
| 28 | 
            +
              for image_path in tqdm(image_paths):
         | 
| 29 | 
            +
                tags_path = image_path.with_suffix(args.caption_extension)
         | 
| 30 | 
            +
                tags = tags_path.read_text(encoding='utf-8').strip()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if not os.path.exists(tags_path):
         | 
| 33 | 
            +
                  tags_path = os.path.join(image_path, args.caption_extension)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                image_key = str(image_path) if args.full_path else image_path.stem
         | 
| 36 | 
            +
                if image_key not in metadata:
         | 
| 37 | 
            +
                  metadata[image_key] = {}
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                metadata[image_key]['tags'] = tags
         | 
| 40 | 
            +
                if args.debug:
         | 
| 41 | 
            +
                  print(image_key, tags)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
              # metadataを書き出して終わり
         | 
| 44 | 
            +
              print(f"writing metadata: {args.out_json}")
         | 
| 45 | 
            +
              Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
         | 
| 46 | 
            +
             | 
| 47 | 
            +
              print("done!")
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 51 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 52 | 
            +
              parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
         | 
| 53 | 
            +
              parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
         | 
| 54 | 
            +
              parser.add_argument("--in_json", type=str,
         | 
| 55 | 
            +
                                  help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
         | 
| 56 | 
            +
              parser.add_argument("--full_path", action="store_true",
         | 
| 57 | 
            +
                                  help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
         | 
| 58 | 
            +
              parser.add_argument("--recursive", action="store_true",
         | 
| 59 | 
            +
                                  help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
         | 
| 60 | 
            +
              parser.add_argument("--caption_extension", type=str, default=".txt",
         | 
| 61 | 
            +
                                  help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
         | 
| 62 | 
            +
              parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
              return parser
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            if __name__ == '__main__':
         | 
| 68 | 
            +
              parser = setup_parser()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
              args = parser.parse_args()
         | 
| 71 | 
            +
              main(args)
         | 
    	
        finetune/prepare_buckets_latents.py
    ADDED
    
    | @@ -0,0 +1,267 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            import cv2
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from torchvision import transforms
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import library.model_util as model_util
         | 
| 13 | 
            +
            import library.train_util as train_util
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            IMAGE_TRANSFORMS = transforms.Compose(
         | 
| 18 | 
            +
                [
         | 
| 19 | 
            +
                    transforms.ToTensor(),
         | 
| 20 | 
            +
                    transforms.Normalize([0.5], [0.5]),
         | 
| 21 | 
            +
                ]
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def collate_fn_remove_corrupted(batch):
         | 
| 26 | 
            +
              """Collate function that allows to remove corrupted examples in the
         | 
| 27 | 
            +
              dataloader. It expects that the dataloader returns 'None' when that occurs.
         | 
| 28 | 
            +
              The 'None's in the batch are removed.
         | 
| 29 | 
            +
              """
         | 
| 30 | 
            +
              # Filter out all the Nones (corrupted examples)
         | 
| 31 | 
            +
              batch = list(filter(lambda x: x is not None, batch))
         | 
| 32 | 
            +
              return batch
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def get_latents(vae, images, weight_dtype):
         | 
| 36 | 
            +
              img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
         | 
| 37 | 
            +
              img_tensors = torch.stack(img_tensors)
         | 
| 38 | 
            +
              img_tensors = img_tensors.to(DEVICE, weight_dtype)
         | 
| 39 | 
            +
              with torch.no_grad():
         | 
| 40 | 
            +
                latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
         | 
| 41 | 
            +
              return latents
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
         | 
| 45 | 
            +
              if is_full_path:
         | 
| 46 | 
            +
                base_name = os.path.splitext(os.path.basename(image_key))[0]
         | 
| 47 | 
            +
              else:
         | 
| 48 | 
            +
                base_name = image_key
         | 
| 49 | 
            +
              if flip:
         | 
| 50 | 
            +
                base_name += '_flip'
         | 
| 51 | 
            +
              return os.path.join(data_dir, base_name)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def main(args):
         | 
| 55 | 
            +
              # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
         | 
| 56 | 
            +
              if args.bucket_reso_steps % 8 > 0:
         | 
| 57 | 
            +
                print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
              image_paths = train_util.glob_images(args.train_data_dir)
         | 
| 60 | 
            +
              print(f"found {len(image_paths)} images.")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
              if os.path.exists(args.in_json):
         | 
| 63 | 
            +
                print(f"loading existing metadata: {args.in_json}")
         | 
| 64 | 
            +
                with open(args.in_json, "rt", encoding='utf-8') as f:
         | 
| 65 | 
            +
                  metadata = json.load(f)
         | 
| 66 | 
            +
              else:
         | 
| 67 | 
            +
                print(f"no metadata / メタデータファイルがありません: {args.in_json}")
         | 
| 68 | 
            +
                return
         | 
| 69 | 
            +
             | 
| 70 | 
            +
              weight_dtype = torch.float32
         | 
| 71 | 
            +
              if args.mixed_precision == "fp16":
         | 
| 72 | 
            +
                weight_dtype = torch.float16
         | 
| 73 | 
            +
              elif args.mixed_precision == "bf16":
         | 
| 74 | 
            +
                weight_dtype = torch.bfloat16
         | 
| 75 | 
            +
             | 
| 76 | 
            +
              vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
         | 
| 77 | 
            +
              vae.eval()
         | 
| 78 | 
            +
              vae.to(DEVICE, dtype=weight_dtype)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
              # bucketのサイズを計算する
         | 
| 81 | 
            +
              max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
         | 
| 82 | 
            +
              assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
         | 
| 83 | 
            +
             | 
| 84 | 
            +
              bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
         | 
| 85 | 
            +
                                                        args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
         | 
| 86 | 
            +
              if not args.bucket_no_upscale:
         | 
| 87 | 
            +
                bucket_manager.make_buckets()
         | 
| 88 | 
            +
              else:
         | 
| 89 | 
            +
                print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
         | 
| 90 | 
            +
             | 
| 91 | 
            +
              # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
         | 
| 92 | 
            +
              img_ar_errors = []
         | 
| 93 | 
            +
             | 
| 94 | 
            +
              def process_batch(is_last):
         | 
| 95 | 
            +
                for bucket in bucket_manager.buckets:
         | 
| 96 | 
            +
                  if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
         | 
| 97 | 
            +
                    latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
         | 
| 98 | 
            +
                    assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
         | 
| 99 | 
            +
                        f"latent shape {latents.shape}, {bucket[0][1].shape}"
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    for (image_key, _), latent in zip(bucket, latents):
         | 
| 102 | 
            +
                      npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
         | 
| 103 | 
            +
                      np.savez(npz_file_name, latent)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # flip
         | 
| 106 | 
            +
                    if args.flip_aug:
         | 
| 107 | 
            +
                      latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype)   # copyがないとTensor変換できない
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                      for (image_key, _), latent in zip(bucket, latents):
         | 
| 110 | 
            +
                        npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
         | 
| 111 | 
            +
                        np.savez(npz_file_name, latent)
         | 
| 112 | 
            +
                    else:
         | 
| 113 | 
            +
                      # remove existing flipped npz
         | 
| 114 | 
            +
                      for image_key, _ in bucket:
         | 
| 115 | 
            +
                        npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
         | 
| 116 | 
            +
                        if os.path.isfile(npz_file_name):
         | 
| 117 | 
            +
                          print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
         | 
| 118 | 
            +
                          os.remove(npz_file_name)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    bucket.clear()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
              # 読み込みの高速化のためにDataLoaderを使うオプション
         | 
| 123 | 
            +
              if args.max_data_loader_n_workers is not None:
         | 
| 124 | 
            +
                dataset = train_util.ImageLoadingDataset(image_paths)
         | 
| 125 | 
            +
                data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
         | 
| 126 | 
            +
                                                   num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
         | 
| 127 | 
            +
              else:
         | 
| 128 | 
            +
                data = [[(None, ip)] for ip in image_paths]
         | 
| 129 | 
            +
             | 
| 130 | 
            +
              bucket_counts = {}
         | 
| 131 | 
            +
              for data_entry in tqdm(data, smoothing=0.0):
         | 
| 132 | 
            +
                if data_entry[0] is None:
         | 
| 133 | 
            +
                  continue
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                img_tensor, image_path = data_entry[0]
         | 
| 136 | 
            +
                if img_tensor is not None:
         | 
| 137 | 
            +
                  image = transforms.functional.to_pil_image(img_tensor)
         | 
| 138 | 
            +
                else:
         | 
| 139 | 
            +
                  try:
         | 
| 140 | 
            +
                    image = Image.open(image_path)
         | 
| 141 | 
            +
                    if image.mode != 'RGB':
         | 
| 142 | 
            +
                      image = image.convert("RGB")
         | 
| 143 | 
            +
                  except Exception as e:
         | 
| 144 | 
            +
                    print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
         | 
| 145 | 
            +
                    continue
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
         | 
| 148 | 
            +
                if image_key not in metadata:
         | 
| 149 | 
            +
                  metadata[image_key] = {}
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
         | 
| 154 | 
            +
                img_ar_errors.append(abs(ar_error))
         | 
| 155 | 
            +
                bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
         | 
| 158 | 
            +
                metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                if not args.bucket_no_upscale:
         | 
| 161 | 
            +
                  # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
         | 
| 162 | 
            +
                  assert resized_size[0] == reso[0] or resized_size[1] == reso[
         | 
| 163 | 
            +
                      1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
         | 
| 164 | 
            +
                  assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
         | 
| 165 | 
            +
                      1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
         | 
| 168 | 
            +
                    1], f"internal error resized size is small: {resized_size}, {reso}"
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                # 既に存在するファイルがあればshapeを確認して同じならskipする
         | 
| 171 | 
            +
                if args.skip_existing:
         | 
| 172 | 
            +
                  npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
         | 
| 173 | 
            +
                  if args.flip_aug:
         | 
| 174 | 
            +
                    npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                  found = True
         | 
| 177 | 
            +
                  for npz_file in npz_files:
         | 
| 178 | 
            +
                    if not os.path.exists(npz_file):
         | 
| 179 | 
            +
                      found = False
         | 
| 180 | 
            +
                      break
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    dat = np.load(npz_file)['arr_0']
         | 
| 183 | 
            +
                    if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8:     # latentsのshapeを確認
         | 
| 184 | 
            +
                      found = False
         | 
| 185 | 
            +
                      break
         | 
| 186 | 
            +
                  if found:
         | 
| 187 | 
            +
                    continue
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                # 画像をリサイズしてトリミングする
         | 
| 190 | 
            +
                # PILにinter_areaがないのでcv2で……
         | 
| 191 | 
            +
                image = np.array(image)
         | 
| 192 | 
            +
                if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]:            # リサイズ処理が必要?
         | 
| 193 | 
            +
                  image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                if resized_size[0] > reso[0]:
         | 
| 196 | 
            +
                  trim_size = resized_size[0] - reso[0]
         | 
| 197 | 
            +
                  image = image[:, trim_size//2:trim_size//2 + reso[0]]
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                if resized_size[1] > reso[1]:
         | 
| 200 | 
            +
                  trim_size = resized_size[1] - reso[1]
         | 
| 201 | 
            +
                  image = image[trim_size//2:trim_size//2 + reso[1]]
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                # # debug
         | 
| 206 | 
            +
                # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                # バッチへ追加
         | 
| 209 | 
            +
                bucket_manager.add_image(reso, (image_key, image))
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                # バッチを推論するか判定して推論する
         | 
| 212 | 
            +
                process_batch(False)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
              # 残りを処理する
         | 
| 215 | 
            +
              process_batch(True)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
              bucket_manager.sort()
         | 
| 218 | 
            +
              for i, reso in enumerate(bucket_manager.resos):
         | 
| 219 | 
            +
                count = bucket_counts.get(reso, 0)
         | 
| 220 | 
            +
                if count > 0:
         | 
| 221 | 
            +
                  print(f"bucket {i} {reso}: {count}")
         | 
| 222 | 
            +
              img_ar_errors = np.array(img_ar_errors)
         | 
| 223 | 
            +
              print(f"mean ar error: {np.mean(img_ar_errors)}")
         | 
| 224 | 
            +
             | 
| 225 | 
            +
              # metadataを書き出して終わり
         | 
| 226 | 
            +
              print(f"writing metadata: {args.out_json}")
         | 
| 227 | 
            +
              with open(args.out_json, "wt", encoding='utf-8') as f:
         | 
| 228 | 
            +
                json.dump(metadata, f, indent=2)
         | 
| 229 | 
            +
              print("done!")
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 233 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 234 | 
            +
              parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
         | 
| 235 | 
            +
              parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
         | 
| 236 | 
            +
              parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
         | 
| 237 | 
            +
              parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
         | 
| 238 | 
            +
              parser.add_argument("--v2", action='store_true',
         | 
| 239 | 
            +
                                  help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
         | 
| 240 | 
            +
              parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
         | 
| 241 | 
            +
              parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
         | 
| 242 | 
            +
                                  help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
         | 
| 243 | 
            +
              parser.add_argument("--max_resolution", type=str, default="512,512",
         | 
| 244 | 
            +
                                  help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
         | 
| 245 | 
            +
              parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
         | 
| 246 | 
            +
              parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
         | 
| 247 | 
            +
              parser.add_argument("--bucket_reso_steps", type=int, default=64,
         | 
| 248 | 
            +
                                  help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
         | 
| 249 | 
            +
              parser.add_argument("--bucket_no_upscale", action="store_true",
         | 
| 250 | 
            +
                                  help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
         | 
| 251 | 
            +
              parser.add_argument("--mixed_precision", type=str, default="no",
         | 
| 252 | 
            +
                                  choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
         | 
| 253 | 
            +
              parser.add_argument("--full_path", action="store_true",
         | 
| 254 | 
            +
                                  help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
         | 
| 255 | 
            +
              parser.add_argument("--flip_aug", action="store_true",
         | 
| 256 | 
            +
                                  help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
         | 
| 257 | 
            +
              parser.add_argument("--skip_existing", action="store_true",
         | 
| 258 | 
            +
                                  help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
         | 
| 259 | 
            +
             | 
| 260 | 
            +
              return parser
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            if __name__ == '__main__':
         | 
| 264 | 
            +
              parser = setup_parser()
         | 
| 265 | 
            +
             | 
| 266 | 
            +
              args = parser.parse_args()
         | 
| 267 | 
            +
              main(args)
         | 
    	
        finetune/tag_images_by_wd14_tagger.py
    ADDED
    
    | @@ -0,0 +1,206 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import csv
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from tensorflow.keras.models import load_model
         | 
| 11 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import library.train_util as train_util
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # from wd14 tagger
         | 
| 17 | 
            +
            IMAGE_SIZE = 448
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
         | 
| 20 | 
            +
            DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
         | 
| 21 | 
            +
            FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
         | 
| 22 | 
            +
            SUB_DIR = "variables"
         | 
| 23 | 
            +
            SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
         | 
| 24 | 
            +
            CSV_FILE = FILES[-1]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def preprocess_image(image):
         | 
| 28 | 
            +
              image = np.array(image)
         | 
| 29 | 
            +
              image = image[:, :, ::-1]                         # RGB->BGR
         | 
| 30 | 
            +
             | 
| 31 | 
            +
              # pad to square
         | 
| 32 | 
            +
              size = max(image.shape[0:2])
         | 
| 33 | 
            +
              pad_x = size - image.shape[1]
         | 
| 34 | 
            +
              pad_y = size - image.shape[0]
         | 
| 35 | 
            +
              pad_l = pad_x // 2
         | 
| 36 | 
            +
              pad_t = pad_y // 2
         | 
| 37 | 
            +
              image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
              interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
         | 
| 40 | 
            +
              image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
              image = image.astype(np.float32)
         | 
| 43 | 
            +
              return image
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class ImageLoadingPrepDataset(torch.utils.data.Dataset):
         | 
| 47 | 
            +
              def __init__(self, image_paths):
         | 
| 48 | 
            +
                self.images = image_paths
         | 
| 49 | 
            +
             | 
| 50 | 
            +
              def __len__(self):
         | 
| 51 | 
            +
                return len(self.images)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
              def __getitem__(self, idx):
         | 
| 54 | 
            +
                img_path = self.images[idx]
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                try:
         | 
| 57 | 
            +
                  image = Image.open(img_path).convert("RGB")
         | 
| 58 | 
            +
                  image = preprocess_image(image)
         | 
| 59 | 
            +
                  tensor = torch.tensor(image)
         | 
| 60 | 
            +
                except Exception as e:
         | 
| 61 | 
            +
                  print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
         | 
| 62 | 
            +
                  return None
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                return (tensor, img_path)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def collate_fn_remove_corrupted(batch):
         | 
| 68 | 
            +
              """Collate function that allows to remove corrupted examples in the
         | 
| 69 | 
            +
              dataloader. It expects that the dataloader returns 'None' when that occurs.
         | 
| 70 | 
            +
              The 'None's in the batch are removed.
         | 
| 71 | 
            +
              """
         | 
| 72 | 
            +
              # Filter out all the Nones (corrupted examples)
         | 
| 73 | 
            +
              batch = list(filter(lambda x: x is not None, batch))
         | 
| 74 | 
            +
              return batch
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def main(args):
         | 
| 78 | 
            +
              # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
         | 
| 79 | 
            +
              # depreacatedの警告が出るけどなくなったらその時
         | 
| 80 | 
            +
              # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
         | 
| 81 | 
            +
              if not os.path.exists(args.model_dir) or args.force_download:
         | 
| 82 | 
            +
                print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
         | 
| 83 | 
            +
                for file in FILES:
         | 
| 84 | 
            +
                  hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
         | 
| 85 | 
            +
                for file in SUB_DIR_FILES:
         | 
| 86 | 
            +
                  hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
         | 
| 87 | 
            +
                      args.model_dir, SUB_DIR), force_download=True, force_filename=file)
         | 
| 88 | 
            +
              else:
         | 
| 89 | 
            +
                print("using existing wd14 tagger model")
         | 
| 90 | 
            +
             | 
| 91 | 
            +
              # 画像を読み込む
         | 
| 92 | 
            +
              image_paths = train_util.glob_images(args.train_data_dir)
         | 
| 93 | 
            +
              print(f"found {len(image_paths)} images.")
         | 
| 94 | 
            +
             | 
| 95 | 
            +
              print("loading model and labels")
         | 
| 96 | 
            +
              model = load_model(args.model_dir)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
              # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
         | 
| 99 | 
            +
              # 依存ライブラリを増やしたくないので自力で読むよ
         | 
| 100 | 
            +
              with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
         | 
| 101 | 
            +
                reader = csv.reader(f)
         | 
| 102 | 
            +
                l = [row for row in reader]
         | 
| 103 | 
            +
                header = l[0]             # tag_id,name,category,count
         | 
| 104 | 
            +
                rows = l[1:]
         | 
| 105 | 
            +
              assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
         | 
| 106 | 
            +
             | 
| 107 | 
            +
              tags = [row[1] for row in rows[1:] if row[2] == '0']      # categoryが0、つまり通常のタグのみ
         | 
| 108 | 
            +
             | 
| 109 | 
            +
              # 推論する
         | 
| 110 | 
            +
              def run_batch(path_imgs):
         | 
| 111 | 
            +
                imgs = np.array([im for _, im in path_imgs])
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                probs = model(imgs, training=False)
         | 
| 114 | 
            +
                probs = probs.numpy()
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                for (image_path, _), prob in zip(path_imgs, probs):
         | 
| 117 | 
            +
                  # 最初の4つはratingなので無視する
         | 
| 118 | 
            +
                  # # First 4 labels are actually ratings: pick one with argmax
         | 
| 119 | 
            +
                  # ratings_names = label_names[:4]
         | 
| 120 | 
            +
                  # rating_index = ratings_names["probs"].argmax()
         | 
| 121 | 
            +
                  # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  # それ以降はタグなのでconfidenceがthresholdより高いものを追加する
         | 
| 124 | 
            +
                  # Everything else is tags: pick any where prediction confidence > threshold
         | 
| 125 | 
            +
                  tag_text = ""
         | 
| 126 | 
            +
                  for i, p in enumerate(prob[4:]):                # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
         | 
| 127 | 
            +
                    if p >= args.thresh and i < len(tags):
         | 
| 128 | 
            +
                      tag_text += ", " + tags[i]
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                  if len(tag_text) > 0:
         | 
| 131 | 
            +
                    tag_text = tag_text[2:]                   # 最初の ", " を消す
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                  with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
         | 
| 134 | 
            +
                    f.write(tag_text + '\n')
         | 
| 135 | 
            +
                    if args.debug:
         | 
| 136 | 
            +
                      print(image_path, tag_text)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
              # 読み込みの高速化のためにDataLoaderを使うオプション
         | 
| 139 | 
            +
              if args.max_data_loader_n_workers is not None:
         | 
| 140 | 
            +
                dataset = ImageLoadingPrepDataset(image_paths)
         | 
| 141 | 
            +
                data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
         | 
| 142 | 
            +
                                                   num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
         | 
| 143 | 
            +
              else:
         | 
| 144 | 
            +
                data = [[(None, ip)] for ip in image_paths]
         | 
| 145 | 
            +
             | 
| 146 | 
            +
              b_imgs = []
         | 
| 147 | 
            +
              for data_entry in tqdm(data, smoothing=0.0):
         | 
| 148 | 
            +
                for data in data_entry:
         | 
| 149 | 
            +
                  if data is None:
         | 
| 150 | 
            +
                    continue
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                  image, image_path = data
         | 
| 153 | 
            +
                  if image is not None:
         | 
| 154 | 
            +
                    image = image.detach().numpy()
         | 
| 155 | 
            +
                  else:
         | 
| 156 | 
            +
                    try:
         | 
| 157 | 
            +
                      image = Image.open(image_path)
         | 
| 158 | 
            +
                      if image.mode != 'RGB':
         | 
| 159 | 
            +
                        image = image.convert("RGB")
         | 
| 160 | 
            +
                      image = preprocess_image(image)
         | 
| 161 | 
            +
                    except Exception as e:
         | 
| 162 | 
            +
                      print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
         | 
| 163 | 
            +
                      continue
         | 
| 164 | 
            +
                  b_imgs.append((image_path, image))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                  if len(b_imgs) >= args.batch_size:
         | 
| 167 | 
            +
                    run_batch(b_imgs)
         | 
| 168 | 
            +
                    b_imgs.clear()
         | 
| 169 | 
            +
             | 
| 170 | 
            +
              if len(b_imgs) > 0:
         | 
| 171 | 
            +
                run_batch(b_imgs)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
              print("done!")
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            def setup_parser() -> argparse.ArgumentParser:
         | 
| 177 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 178 | 
            +
              parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
         | 
| 179 | 
            +
              parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
         | 
| 180 | 
            +
                                  help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
         | 
| 181 | 
            +
              parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
         | 
| 182 | 
            +
                                  help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
         | 
| 183 | 
            +
              parser.add_argument("--force_download", action='store_true',
         | 
| 184 | 
            +
                                  help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
         | 
| 185 | 
            +
              parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
         | 
| 186 | 
            +
              parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
         | 
| 187 | 
            +
              parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
         | 
| 188 | 
            +
                                  help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
         | 
| 189 | 
            +
              parser.add_argument("--caption_extention", type=str, default=None,
         | 
| 190 | 
            +
                                  help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
         | 
| 191 | 
            +
              parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
         | 
| 192 | 
            +
              parser.add_argument("--debug", action="store_true", help="debug mode")
         | 
| 193 | 
            +
             | 
| 194 | 
            +
              return parser
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            if __name__ == '__main__':
         | 
| 198 | 
            +
              parser = setup_parser()
         | 
| 199 | 
            +
             | 
| 200 | 
            +
              args = parser.parse_args()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
              # スペルミスしていたオプションを復元する
         | 
| 203 | 
            +
              if args.caption_extention is not None:
         | 
| 204 | 
            +
                args.caption_extension = args.caption_extention
         | 
| 205 | 
            +
             | 
| 206 | 
            +
              main(args)
         | 
    	
        finetune_gui.py
    ADDED
    
    | @@ -0,0 +1,888 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import subprocess
         | 
| 6 | 
            +
            import pathlib
         | 
| 7 | 
            +
            import argparse
         | 
| 8 | 
            +
            from library.common_gui import (
         | 
| 9 | 
            +
                get_folder_path,
         | 
| 10 | 
            +
                get_file_path,
         | 
| 11 | 
            +
                get_saveasfile_path,
         | 
| 12 | 
            +
                save_inference_file,
         | 
| 13 | 
            +
                gradio_advanced_training,
         | 
| 14 | 
            +
                run_cmd_advanced_training,
         | 
| 15 | 
            +
                gradio_training,
         | 
| 16 | 
            +
                run_cmd_advanced_training,
         | 
| 17 | 
            +
                gradio_config,
         | 
| 18 | 
            +
                gradio_source_model,
         | 
| 19 | 
            +
                color_aug_changed,
         | 
| 20 | 
            +
                run_cmd_training,
         | 
| 21 | 
            +
                # set_legacy_8bitadam,
         | 
| 22 | 
            +
                update_my_data,
         | 
| 23 | 
            +
                check_if_model_exist,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
            from library.tensorboard_gui import (
         | 
| 26 | 
            +
                gradio_tensorboard,
         | 
| 27 | 
            +
                start_tensorboard,
         | 
| 28 | 
            +
                stop_tensorboard,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
            from library.utilities import utilities_tab
         | 
| 31 | 
            +
            from library.sampler_gui import sample_gradio_config, run_cmd_sample
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 34 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 35 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 36 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def save_configuration(
         | 
| 42 | 
            +
                save_as,
         | 
| 43 | 
            +
                file_path,
         | 
| 44 | 
            +
                pretrained_model_name_or_path,
         | 
| 45 | 
            +
                v2,
         | 
| 46 | 
            +
                v_parameterization,
         | 
| 47 | 
            +
                train_dir,
         | 
| 48 | 
            +
                image_folder,
         | 
| 49 | 
            +
                output_dir,
         | 
| 50 | 
            +
                logging_dir,
         | 
| 51 | 
            +
                max_resolution,
         | 
| 52 | 
            +
                min_bucket_reso,
         | 
| 53 | 
            +
                max_bucket_reso,
         | 
| 54 | 
            +
                batch_size,
         | 
| 55 | 
            +
                flip_aug,
         | 
| 56 | 
            +
                caption_metadata_filename,
         | 
| 57 | 
            +
                latent_metadata_filename,
         | 
| 58 | 
            +
                full_path,
         | 
| 59 | 
            +
                learning_rate,
         | 
| 60 | 
            +
                lr_scheduler,
         | 
| 61 | 
            +
                lr_warmup,
         | 
| 62 | 
            +
                dataset_repeats,
         | 
| 63 | 
            +
                train_batch_size,
         | 
| 64 | 
            +
                epoch,
         | 
| 65 | 
            +
                save_every_n_epochs,
         | 
| 66 | 
            +
                mixed_precision,
         | 
| 67 | 
            +
                save_precision,
         | 
| 68 | 
            +
                seed,
         | 
| 69 | 
            +
                num_cpu_threads_per_process,
         | 
| 70 | 
            +
                train_text_encoder,
         | 
| 71 | 
            +
                create_caption,
         | 
| 72 | 
            +
                create_buckets,
         | 
| 73 | 
            +
                save_model_as,
         | 
| 74 | 
            +
                caption_extension,
         | 
| 75 | 
            +
                # use_8bit_adam,
         | 
| 76 | 
            +
                xformers,
         | 
| 77 | 
            +
                clip_skip,
         | 
| 78 | 
            +
                save_state,
         | 
| 79 | 
            +
                resume,
         | 
| 80 | 
            +
                gradient_checkpointing,
         | 
| 81 | 
            +
                gradient_accumulation_steps,
         | 
| 82 | 
            +
                mem_eff_attn,
         | 
| 83 | 
            +
                shuffle_caption,
         | 
| 84 | 
            +
                output_name,
         | 
| 85 | 
            +
                max_token_length,
         | 
| 86 | 
            +
                max_train_epochs,
         | 
| 87 | 
            +
                max_data_loader_n_workers,
         | 
| 88 | 
            +
                full_fp16,
         | 
| 89 | 
            +
                color_aug,
         | 
| 90 | 
            +
                model_list,
         | 
| 91 | 
            +
                cache_latents,
         | 
| 92 | 
            +
                use_latent_files,
         | 
| 93 | 
            +
                keep_tokens,
         | 
| 94 | 
            +
                persistent_data_loader_workers,
         | 
| 95 | 
            +
                bucket_no_upscale,
         | 
| 96 | 
            +
                random_crop,
         | 
| 97 | 
            +
                bucket_reso_steps,
         | 
| 98 | 
            +
                caption_dropout_every_n_epochs,
         | 
| 99 | 
            +
                caption_dropout_rate,
         | 
| 100 | 
            +
                optimizer,
         | 
| 101 | 
            +
                optimizer_args,
         | 
| 102 | 
            +
                noise_offset,
         | 
| 103 | 
            +
                sample_every_n_steps,
         | 
| 104 | 
            +
                sample_every_n_epochs,
         | 
| 105 | 
            +
                sample_sampler,
         | 
| 106 | 
            +
                sample_prompts,
         | 
| 107 | 
            +
                additional_parameters,
         | 
| 108 | 
            +
                vae_batch_size,
         | 
| 109 | 
            +
                min_snr_gamma,
         | 
| 110 | 
            +
            ):
         | 
| 111 | 
            +
                # Get list of function parameters and values
         | 
| 112 | 
            +
                parameters = list(locals().items())
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                original_file_path = file_path
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                save_as_bool = True if save_as.get('label') == 'True' else False
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if save_as_bool:
         | 
| 119 | 
            +
                    print('Save as...')
         | 
| 120 | 
            +
                    file_path = get_saveasfile_path(file_path)
         | 
| 121 | 
            +
                else:
         | 
| 122 | 
            +
                    print('Save...')
         | 
| 123 | 
            +
                    if file_path == None or file_path == '':
         | 
| 124 | 
            +
                        file_path = get_saveasfile_path(file_path)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                # print(file_path)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                if file_path == None or file_path == '':
         | 
| 129 | 
            +
                    return original_file_path  # In case a file_path was provided and the user decide to cancel the open action
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # Return the values of the variables as a dictionary
         | 
| 132 | 
            +
                variables = {
         | 
| 133 | 
            +
                    name: value
         | 
| 134 | 
            +
                    for name, value in parameters  # locals().items()
         | 
| 135 | 
            +
                    if name
         | 
| 136 | 
            +
                    not in [
         | 
| 137 | 
            +
                        'file_path',
         | 
| 138 | 
            +
                        'save_as',
         | 
| 139 | 
            +
                    ]
         | 
| 140 | 
            +
                }
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # Extract the destination directory from the file path
         | 
| 143 | 
            +
                destination_directory = os.path.dirname(file_path)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                # Create the destination directory if it doesn't exist
         | 
| 146 | 
            +
                if not os.path.exists(destination_directory):
         | 
| 147 | 
            +
                    os.makedirs(destination_directory)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                # Save the data to the selected file
         | 
| 150 | 
            +
                with open(file_path, 'w') as file:
         | 
| 151 | 
            +
                    json.dump(variables, file, indent=2)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                return file_path
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            def open_configuration(
         | 
| 157 | 
            +
                ask_for_file,
         | 
| 158 | 
            +
                file_path,
         | 
| 159 | 
            +
                pretrained_model_name_or_path,
         | 
| 160 | 
            +
                v2,
         | 
| 161 | 
            +
                v_parameterization,
         | 
| 162 | 
            +
                train_dir,
         | 
| 163 | 
            +
                image_folder,
         | 
| 164 | 
            +
                output_dir,
         | 
| 165 | 
            +
                logging_dir,
         | 
| 166 | 
            +
                max_resolution,
         | 
| 167 | 
            +
                min_bucket_reso,
         | 
| 168 | 
            +
                max_bucket_reso,
         | 
| 169 | 
            +
                batch_size,
         | 
| 170 | 
            +
                flip_aug,
         | 
| 171 | 
            +
                caption_metadata_filename,
         | 
| 172 | 
            +
                latent_metadata_filename,
         | 
| 173 | 
            +
                full_path,
         | 
| 174 | 
            +
                learning_rate,
         | 
| 175 | 
            +
                lr_scheduler,
         | 
| 176 | 
            +
                lr_warmup,
         | 
| 177 | 
            +
                dataset_repeats,
         | 
| 178 | 
            +
                train_batch_size,
         | 
| 179 | 
            +
                epoch,
         | 
| 180 | 
            +
                save_every_n_epochs,
         | 
| 181 | 
            +
                mixed_precision,
         | 
| 182 | 
            +
                save_precision,
         | 
| 183 | 
            +
                seed,
         | 
| 184 | 
            +
                num_cpu_threads_per_process,
         | 
| 185 | 
            +
                train_text_encoder,
         | 
| 186 | 
            +
                create_caption,
         | 
| 187 | 
            +
                create_buckets,
         | 
| 188 | 
            +
                save_model_as,
         | 
| 189 | 
            +
                caption_extension,
         | 
| 190 | 
            +
                # use_8bit_adam,
         | 
| 191 | 
            +
                xformers,
         | 
| 192 | 
            +
                clip_skip,
         | 
| 193 | 
            +
                save_state,
         | 
| 194 | 
            +
                resume,
         | 
| 195 | 
            +
                gradient_checkpointing,
         | 
| 196 | 
            +
                gradient_accumulation_steps,
         | 
| 197 | 
            +
                mem_eff_attn,
         | 
| 198 | 
            +
                shuffle_caption,
         | 
| 199 | 
            +
                output_name,
         | 
| 200 | 
            +
                max_token_length,
         | 
| 201 | 
            +
                max_train_epochs,
         | 
| 202 | 
            +
                max_data_loader_n_workers,
         | 
| 203 | 
            +
                full_fp16,
         | 
| 204 | 
            +
                color_aug,
         | 
| 205 | 
            +
                model_list,
         | 
| 206 | 
            +
                cache_latents,
         | 
| 207 | 
            +
                use_latent_files,
         | 
| 208 | 
            +
                keep_tokens,
         | 
| 209 | 
            +
                persistent_data_loader_workers,
         | 
| 210 | 
            +
                bucket_no_upscale,
         | 
| 211 | 
            +
                random_crop,
         | 
| 212 | 
            +
                bucket_reso_steps,
         | 
| 213 | 
            +
                caption_dropout_every_n_epochs,
         | 
| 214 | 
            +
                caption_dropout_rate,
         | 
| 215 | 
            +
                optimizer,
         | 
| 216 | 
            +
                optimizer_args,
         | 
| 217 | 
            +
                noise_offset,
         | 
| 218 | 
            +
                sample_every_n_steps,
         | 
| 219 | 
            +
                sample_every_n_epochs,
         | 
| 220 | 
            +
                sample_sampler,
         | 
| 221 | 
            +
                sample_prompts,
         | 
| 222 | 
            +
                additional_parameters,
         | 
| 223 | 
            +
                vae_batch_size,
         | 
| 224 | 
            +
                min_snr_gamma,
         | 
| 225 | 
            +
            ):
         | 
| 226 | 
            +
                # Get list of function parameters and values
         | 
| 227 | 
            +
                parameters = list(locals().items())
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                ask_for_file = True if ask_for_file.get('label') == 'True' else False
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                original_file_path = file_path
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                if ask_for_file:
         | 
| 234 | 
            +
                    file_path = get_file_path(file_path)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                if not file_path == '' and not file_path == None:
         | 
| 237 | 
            +
                    # load variables from JSON file
         | 
| 238 | 
            +
                    with open(file_path, 'r') as f:
         | 
| 239 | 
            +
                        my_data = json.load(f)
         | 
| 240 | 
            +
                        print('Loading config...')
         | 
| 241 | 
            +
                        # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
         | 
| 242 | 
            +
                        my_data = update_my_data(my_data)
         | 
| 243 | 
            +
                else:
         | 
| 244 | 
            +
                    file_path = original_file_path  # In case a file_path was provided and the user decide to cancel the open action
         | 
| 245 | 
            +
                    my_data = {}
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                values = [file_path]
         | 
| 248 | 
            +
                for key, value in parameters:
         | 
| 249 | 
            +
                    # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
         | 
| 250 | 
            +
                    if not key in ['ask_for_file', 'file_path']:
         | 
| 251 | 
            +
                        values.append(my_data.get(key, value))
         | 
| 252 | 
            +
                return tuple(values)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
             | 
| 255 | 
            +
            def train_model(
         | 
| 256 | 
            +
                pretrained_model_name_or_path,
         | 
| 257 | 
            +
                v2,
         | 
| 258 | 
            +
                v_parameterization,
         | 
| 259 | 
            +
                train_dir,
         | 
| 260 | 
            +
                image_folder,
         | 
| 261 | 
            +
                output_dir,
         | 
| 262 | 
            +
                logging_dir,
         | 
| 263 | 
            +
                max_resolution,
         | 
| 264 | 
            +
                min_bucket_reso,
         | 
| 265 | 
            +
                max_bucket_reso,
         | 
| 266 | 
            +
                batch_size,
         | 
| 267 | 
            +
                flip_aug,
         | 
| 268 | 
            +
                caption_metadata_filename,
         | 
| 269 | 
            +
                latent_metadata_filename,
         | 
| 270 | 
            +
                full_path,
         | 
| 271 | 
            +
                learning_rate,
         | 
| 272 | 
            +
                lr_scheduler,
         | 
| 273 | 
            +
                lr_warmup,
         | 
| 274 | 
            +
                dataset_repeats,
         | 
| 275 | 
            +
                train_batch_size,
         | 
| 276 | 
            +
                epoch,
         | 
| 277 | 
            +
                save_every_n_epochs,
         | 
| 278 | 
            +
                mixed_precision,
         | 
| 279 | 
            +
                save_precision,
         | 
| 280 | 
            +
                seed,
         | 
| 281 | 
            +
                num_cpu_threads_per_process,
         | 
| 282 | 
            +
                train_text_encoder,
         | 
| 283 | 
            +
                generate_caption_database,
         | 
| 284 | 
            +
                generate_image_buckets,
         | 
| 285 | 
            +
                save_model_as,
         | 
| 286 | 
            +
                caption_extension,
         | 
| 287 | 
            +
                # use_8bit_adam,
         | 
| 288 | 
            +
                xformers,
         | 
| 289 | 
            +
                clip_skip,
         | 
| 290 | 
            +
                save_state,
         | 
| 291 | 
            +
                resume,
         | 
| 292 | 
            +
                gradient_checkpointing,
         | 
| 293 | 
            +
                gradient_accumulation_steps,
         | 
| 294 | 
            +
                mem_eff_attn,
         | 
| 295 | 
            +
                shuffle_caption,
         | 
| 296 | 
            +
                output_name,
         | 
| 297 | 
            +
                max_token_length,
         | 
| 298 | 
            +
                max_train_epochs,
         | 
| 299 | 
            +
                max_data_loader_n_workers,
         | 
| 300 | 
            +
                full_fp16,
         | 
| 301 | 
            +
                color_aug,
         | 
| 302 | 
            +
                model_list,  # Keep this. Yes, it is unused here but required given the common list used
         | 
| 303 | 
            +
                cache_latents,
         | 
| 304 | 
            +
                use_latent_files,
         | 
| 305 | 
            +
                keep_tokens,
         | 
| 306 | 
            +
                persistent_data_loader_workers,
         | 
| 307 | 
            +
                bucket_no_upscale,
         | 
| 308 | 
            +
                random_crop,
         | 
| 309 | 
            +
                bucket_reso_steps,
         | 
| 310 | 
            +
                caption_dropout_every_n_epochs,
         | 
| 311 | 
            +
                caption_dropout_rate,
         | 
| 312 | 
            +
                optimizer,
         | 
| 313 | 
            +
                optimizer_args,
         | 
| 314 | 
            +
                noise_offset,
         | 
| 315 | 
            +
                sample_every_n_steps,
         | 
| 316 | 
            +
                sample_every_n_epochs,
         | 
| 317 | 
            +
                sample_sampler,
         | 
| 318 | 
            +
                sample_prompts,
         | 
| 319 | 
            +
                additional_parameters,
         | 
| 320 | 
            +
                vae_batch_size,
         | 
| 321 | 
            +
                min_snr_gamma,
         | 
| 322 | 
            +
            ):
         | 
| 323 | 
            +
                if check_if_model_exist(output_name, output_dir, save_model_as):
         | 
| 324 | 
            +
                    return
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                # create caption json file
         | 
| 327 | 
            +
                if generate_caption_database:
         | 
| 328 | 
            +
                    if not os.path.exists(train_dir):
         | 
| 329 | 
            +
                        os.mkdir(train_dir)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    run_cmd = f'{PYTHON} finetune/merge_captions_to_metadata.py'
         | 
| 332 | 
            +
                    if caption_extension == '':
         | 
| 333 | 
            +
                        run_cmd += f' --caption_extension=".caption"'
         | 
| 334 | 
            +
                    else:
         | 
| 335 | 
            +
                        run_cmd += f' --caption_extension={caption_extension}'
         | 
| 336 | 
            +
                    run_cmd += f' "{image_folder}"'
         | 
| 337 | 
            +
                    run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
         | 
| 338 | 
            +
                    if full_path:
         | 
| 339 | 
            +
                        run_cmd += f' --full_path'
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    print(run_cmd)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    # Run the command
         | 
| 344 | 
            +
                    if os.name == 'posix':
         | 
| 345 | 
            +
                        os.system(run_cmd)
         | 
| 346 | 
            +
                    else:
         | 
| 347 | 
            +
                        subprocess.run(run_cmd)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                # create images buckets
         | 
| 350 | 
            +
                if generate_image_buckets:
         | 
| 351 | 
            +
                    run_cmd = f'{PYTHON} finetune/prepare_buckets_latents.py'
         | 
| 352 | 
            +
                    run_cmd += f' "{image_folder}"'
         | 
| 353 | 
            +
                    run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
         | 
| 354 | 
            +
                    run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
         | 
| 355 | 
            +
                    run_cmd += f' "{pretrained_model_name_or_path}"'
         | 
| 356 | 
            +
                    run_cmd += f' --batch_size={batch_size}'
         | 
| 357 | 
            +
                    run_cmd += f' --max_resolution={max_resolution}'
         | 
| 358 | 
            +
                    run_cmd += f' --min_bucket_reso={min_bucket_reso}'
         | 
| 359 | 
            +
                    run_cmd += f' --max_bucket_reso={max_bucket_reso}'
         | 
| 360 | 
            +
                    run_cmd += f' --mixed_precision={mixed_precision}'
         | 
| 361 | 
            +
                    # if flip_aug:
         | 
| 362 | 
            +
                    #     run_cmd += f' --flip_aug'
         | 
| 363 | 
            +
                    if full_path:
         | 
| 364 | 
            +
                        run_cmd += f' --full_path'
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    print(run_cmd)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    # Run the command
         | 
| 369 | 
            +
                    if os.name == 'posix':
         | 
| 370 | 
            +
                        os.system(run_cmd)
         | 
| 371 | 
            +
                    else:
         | 
| 372 | 
            +
                        subprocess.run(run_cmd)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                image_num = len(
         | 
| 375 | 
            +
                    [
         | 
| 376 | 
            +
                        f
         | 
| 377 | 
            +
                        for f, lower_f in (
         | 
| 378 | 
            +
                            (file, file.lower()) for file in os.listdir(image_folder)
         | 
| 379 | 
            +
                        )
         | 
| 380 | 
            +
                        if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
         | 
| 381 | 
            +
                    ]
         | 
| 382 | 
            +
                )
         | 
| 383 | 
            +
                print(f'image_num = {image_num}')
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                repeats = int(image_num) * int(dataset_repeats)
         | 
| 386 | 
            +
                print(f'repeats = {str(repeats)}')
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                # calculate max_train_steps
         | 
| 389 | 
            +
                max_train_steps = int(
         | 
| 390 | 
            +
                    math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
         | 
| 391 | 
            +
                )
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                # Divide by two because flip augmentation create two copied of the source images
         | 
| 394 | 
            +
                if flip_aug:
         | 
| 395 | 
            +
                    max_train_steps = int(math.ceil(float(max_train_steps) / 2))
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                print(f'max_train_steps = {max_train_steps}')
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
         | 
| 400 | 
            +
                print(f'lr_warmup_steps = {lr_warmup_steps}')
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
         | 
| 403 | 
            +
                if v2:
         | 
| 404 | 
            +
                    run_cmd += ' --v2'
         | 
| 405 | 
            +
                if v_parameterization:
         | 
| 406 | 
            +
                    run_cmd += ' --v_parameterization'
         | 
| 407 | 
            +
                if train_text_encoder:
         | 
| 408 | 
            +
                    run_cmd += ' --train_text_encoder'
         | 
| 409 | 
            +
                run_cmd += (
         | 
| 410 | 
            +
                    f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
         | 
| 411 | 
            +
                )
         | 
| 412 | 
            +
                if use_latent_files == 'Yes':
         | 
| 413 | 
            +
                    run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
         | 
| 414 | 
            +
                else:
         | 
| 415 | 
            +
                    run_cmd += f' --in_json="{train_dir}/{caption_metadata_filename}"'
         | 
| 416 | 
            +
                run_cmd += f' --train_data_dir="{image_folder}"'
         | 
| 417 | 
            +
                run_cmd += f' --output_dir="{output_dir}"'
         | 
| 418 | 
            +
                if not logging_dir == '':
         | 
| 419 | 
            +
                    run_cmd += f' --logging_dir="{logging_dir}"'
         | 
| 420 | 
            +
                run_cmd += f' --dataset_repeats={dataset_repeats}'
         | 
| 421 | 
            +
                run_cmd += f' --learning_rate={learning_rate}'
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                run_cmd += ' --enable_bucket'
         | 
| 424 | 
            +
                run_cmd += f' --resolution={max_resolution}'
         | 
| 425 | 
            +
                run_cmd += f' --min_bucket_reso={min_bucket_reso}'
         | 
| 426 | 
            +
                run_cmd += f' --max_bucket_reso={max_bucket_reso}'
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                if not save_model_as == 'same as source model':
         | 
| 429 | 
            +
                    run_cmd += f' --save_model_as={save_model_as}'
         | 
| 430 | 
            +
                if int(gradient_accumulation_steps) > 1:
         | 
| 431 | 
            +
                    run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
         | 
| 432 | 
            +
                # if save_state:
         | 
| 433 | 
            +
                #     run_cmd += ' --save_state'
         | 
| 434 | 
            +
                # if not resume == '':
         | 
| 435 | 
            +
                #     run_cmd += f' --resume={resume}'
         | 
| 436 | 
            +
                if not output_name == '':
         | 
| 437 | 
            +
                    run_cmd += f' --output_name="{output_name}"'
         | 
| 438 | 
            +
                if int(max_token_length) > 75:
         | 
| 439 | 
            +
                    run_cmd += f' --max_token_length={max_token_length}'
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                run_cmd += run_cmd_training(
         | 
| 442 | 
            +
                    learning_rate=learning_rate,
         | 
| 443 | 
            +
                    lr_scheduler=lr_scheduler,
         | 
| 444 | 
            +
                    lr_warmup_steps=lr_warmup_steps,
         | 
| 445 | 
            +
                    train_batch_size=train_batch_size,
         | 
| 446 | 
            +
                    max_train_steps=max_train_steps,
         | 
| 447 | 
            +
                    save_every_n_epochs=save_every_n_epochs,
         | 
| 448 | 
            +
                    mixed_precision=mixed_precision,
         | 
| 449 | 
            +
                    save_precision=save_precision,
         | 
| 450 | 
            +
                    seed=seed,
         | 
| 451 | 
            +
                    caption_extension=caption_extension,
         | 
| 452 | 
            +
                    cache_latents=cache_latents,
         | 
| 453 | 
            +
                    optimizer=optimizer,
         | 
| 454 | 
            +
                    optimizer_args=optimizer_args,
         | 
| 455 | 
            +
                )
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                run_cmd += run_cmd_advanced_training(
         | 
| 458 | 
            +
                    max_train_epochs=max_train_epochs,
         | 
| 459 | 
            +
                    max_data_loader_n_workers=max_data_loader_n_workers,
         | 
| 460 | 
            +
                    max_token_length=max_token_length,
         | 
| 461 | 
            +
                    resume=resume,
         | 
| 462 | 
            +
                    save_state=save_state,
         | 
| 463 | 
            +
                    mem_eff_attn=mem_eff_attn,
         | 
| 464 | 
            +
                    clip_skip=clip_skip,
         | 
| 465 | 
            +
                    flip_aug=flip_aug,
         | 
| 466 | 
            +
                    color_aug=color_aug,
         | 
| 467 | 
            +
                    shuffle_caption=shuffle_caption,
         | 
| 468 | 
            +
                    gradient_checkpointing=gradient_checkpointing,
         | 
| 469 | 
            +
                    full_fp16=full_fp16,
         | 
| 470 | 
            +
                    xformers=xformers,
         | 
| 471 | 
            +
                    # use_8bit_adam=use_8bit_adam,
         | 
| 472 | 
            +
                    keep_tokens=keep_tokens,
         | 
| 473 | 
            +
                    persistent_data_loader_workers=persistent_data_loader_workers,
         | 
| 474 | 
            +
                    bucket_no_upscale=bucket_no_upscale,
         | 
| 475 | 
            +
                    random_crop=random_crop,
         | 
| 476 | 
            +
                    bucket_reso_steps=bucket_reso_steps,
         | 
| 477 | 
            +
                    caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
         | 
| 478 | 
            +
                    caption_dropout_rate=caption_dropout_rate,
         | 
| 479 | 
            +
                    noise_offset=noise_offset,
         | 
| 480 | 
            +
                    additional_parameters=additional_parameters,
         | 
| 481 | 
            +
                    vae_batch_size=vae_batch_size,
         | 
| 482 | 
            +
                    min_snr_gamma=min_snr_gamma,
         | 
| 483 | 
            +
                )
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                run_cmd += run_cmd_sample(
         | 
| 486 | 
            +
                    sample_every_n_steps,
         | 
| 487 | 
            +
                    sample_every_n_epochs,
         | 
| 488 | 
            +
                    sample_sampler,
         | 
| 489 | 
            +
                    sample_prompts,
         | 
| 490 | 
            +
                    output_dir,
         | 
| 491 | 
            +
                )
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                print(run_cmd)
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                # Run the command
         | 
| 496 | 
            +
                if os.name == 'posix':
         | 
| 497 | 
            +
                    os.system(run_cmd)
         | 
| 498 | 
            +
                else:
         | 
| 499 | 
            +
                    subprocess.run(run_cmd)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                # check if output_dir/last is a folder... therefore it is a diffuser model
         | 
| 502 | 
            +
                last_dir = pathlib.Path(f'{output_dir}/{output_name}')
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                if not last_dir.is_dir():
         | 
| 505 | 
            +
                    # Copy inference model for v2 if required
         | 
| 506 | 
            +
                    save_inference_file(output_dir, v2, v_parameterization, output_name)
         | 
| 507 | 
            +
             | 
| 508 | 
            +
             | 
| 509 | 
            +
            def remove_doublequote(file_path):
         | 
| 510 | 
            +
                if file_path != None:
         | 
| 511 | 
            +
                    file_path = file_path.replace('"', '')
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                return file_path
         | 
| 514 | 
            +
             | 
| 515 | 
            +
             | 
| 516 | 
            +
            def finetune_tab():
         | 
| 517 | 
            +
                dummy_db_true = gr.Label(value=True, visible=False)
         | 
| 518 | 
            +
                dummy_db_false = gr.Label(value=False, visible=False)
         | 
| 519 | 
            +
                gr.Markdown('Train a custom model using kohya finetune python code...')
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                (
         | 
| 522 | 
            +
                    button_open_config,
         | 
| 523 | 
            +
                    button_save_config,
         | 
| 524 | 
            +
                    button_save_as_config,
         | 
| 525 | 
            +
                    config_file_name,
         | 
| 526 | 
            +
                    button_load_config,
         | 
| 527 | 
            +
                ) = gradio_config()
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                (
         | 
| 530 | 
            +
                    pretrained_model_name_or_path,
         | 
| 531 | 
            +
                    v2,
         | 
| 532 | 
            +
                    v_parameterization,
         | 
| 533 | 
            +
                    save_model_as,
         | 
| 534 | 
            +
                    model_list,
         | 
| 535 | 
            +
                ) = gradio_source_model()
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                with gr.Tab('Folders'):
         | 
| 538 | 
            +
                    with gr.Row():
         | 
| 539 | 
            +
                        train_dir = gr.Textbox(
         | 
| 540 | 
            +
                            label='Training config folder',
         | 
| 541 | 
            +
                            placeholder='folder where the training configuration files will be saved',
         | 
| 542 | 
            +
                        )
         | 
| 543 | 
            +
                        train_dir_folder = gr.Button(
         | 
| 544 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 545 | 
            +
                        )
         | 
| 546 | 
            +
                        train_dir_folder.click(
         | 
| 547 | 
            +
                            get_folder_path,
         | 
| 548 | 
            +
                            outputs=train_dir,
         | 
| 549 | 
            +
                            show_progress=False,
         | 
| 550 | 
            +
                        )
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                        image_folder = gr.Textbox(
         | 
| 553 | 
            +
                            label='Training Image folder',
         | 
| 554 | 
            +
                            placeholder='folder where the training images are located',
         | 
| 555 | 
            +
                        )
         | 
| 556 | 
            +
                        image_folder_input_folder = gr.Button(
         | 
| 557 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 558 | 
            +
                        )
         | 
| 559 | 
            +
                        image_folder_input_folder.click(
         | 
| 560 | 
            +
                            get_folder_path,
         | 
| 561 | 
            +
                            outputs=image_folder,
         | 
| 562 | 
            +
                            show_progress=False,
         | 
| 563 | 
            +
                        )
         | 
| 564 | 
            +
                    with gr.Row():
         | 
| 565 | 
            +
                        output_dir = gr.Textbox(
         | 
| 566 | 
            +
                            label='Model output folder',
         | 
| 567 | 
            +
                            placeholder='folder where the model will be saved',
         | 
| 568 | 
            +
                        )
         | 
| 569 | 
            +
                        output_dir_input_folder = gr.Button(
         | 
| 570 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 571 | 
            +
                        )
         | 
| 572 | 
            +
                        output_dir_input_folder.click(
         | 
| 573 | 
            +
                            get_folder_path,
         | 
| 574 | 
            +
                            outputs=output_dir,
         | 
| 575 | 
            +
                            show_progress=False,
         | 
| 576 | 
            +
                        )
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                        logging_dir = gr.Textbox(
         | 
| 579 | 
            +
                            label='Logging folder',
         | 
| 580 | 
            +
                            placeholder='Optional: enable logging and output TensorBoard log to this folder',
         | 
| 581 | 
            +
                        )
         | 
| 582 | 
            +
                        logging_dir_input_folder = gr.Button(
         | 
| 583 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 584 | 
            +
                        )
         | 
| 585 | 
            +
                        logging_dir_input_folder.click(
         | 
| 586 | 
            +
                            get_folder_path,
         | 
| 587 | 
            +
                            outputs=logging_dir,
         | 
| 588 | 
            +
                            show_progress=False,
         | 
| 589 | 
            +
                        )
         | 
| 590 | 
            +
                    with gr.Row():
         | 
| 591 | 
            +
                        output_name = gr.Textbox(
         | 
| 592 | 
            +
                            label='Model output name',
         | 
| 593 | 
            +
                            placeholder='Name of the model to output',
         | 
| 594 | 
            +
                            value='last',
         | 
| 595 | 
            +
                            interactive=True,
         | 
| 596 | 
            +
                        )
         | 
| 597 | 
            +
                    train_dir.change(
         | 
| 598 | 
            +
                        remove_doublequote,
         | 
| 599 | 
            +
                        inputs=[train_dir],
         | 
| 600 | 
            +
                        outputs=[train_dir],
         | 
| 601 | 
            +
                    )
         | 
| 602 | 
            +
                    image_folder.change(
         | 
| 603 | 
            +
                        remove_doublequote,
         | 
| 604 | 
            +
                        inputs=[image_folder],
         | 
| 605 | 
            +
                        outputs=[image_folder],
         | 
| 606 | 
            +
                    )
         | 
| 607 | 
            +
                    output_dir.change(
         | 
| 608 | 
            +
                        remove_doublequote,
         | 
| 609 | 
            +
                        inputs=[output_dir],
         | 
| 610 | 
            +
                        outputs=[output_dir],
         | 
| 611 | 
            +
                    )
         | 
| 612 | 
            +
                with gr.Tab('Dataset preparation'):
         | 
| 613 | 
            +
                    with gr.Row():
         | 
| 614 | 
            +
                        max_resolution = gr.Textbox(
         | 
| 615 | 
            +
                            label='Resolution (width,height)', value='512,512'
         | 
| 616 | 
            +
                        )
         | 
| 617 | 
            +
                        min_bucket_reso = gr.Textbox(
         | 
| 618 | 
            +
                            label='Min bucket resolution', value='256'
         | 
| 619 | 
            +
                        )
         | 
| 620 | 
            +
                        max_bucket_reso = gr.Textbox(
         | 
| 621 | 
            +
                            label='Max bucket resolution', value='1024'
         | 
| 622 | 
            +
                        )
         | 
| 623 | 
            +
                        batch_size = gr.Textbox(label='Batch size', value='1')
         | 
| 624 | 
            +
                    with gr.Row():
         | 
| 625 | 
            +
                        create_caption = gr.Checkbox(
         | 
| 626 | 
            +
                            label='Generate caption metadata', value=True
         | 
| 627 | 
            +
                        )
         | 
| 628 | 
            +
                        create_buckets = gr.Checkbox(
         | 
| 629 | 
            +
                            label='Generate image buckets metadata', value=True
         | 
| 630 | 
            +
                        )
         | 
| 631 | 
            +
                        use_latent_files = gr.Dropdown(
         | 
| 632 | 
            +
                            label='Use latent files',
         | 
| 633 | 
            +
                            choices=[
         | 
| 634 | 
            +
                                'No',
         | 
| 635 | 
            +
                                'Yes',
         | 
| 636 | 
            +
                            ],
         | 
| 637 | 
            +
                            value='Yes',
         | 
| 638 | 
            +
                        )
         | 
| 639 | 
            +
                    with gr.Accordion('Advanced parameters', open=False):
         | 
| 640 | 
            +
                        with gr.Row():
         | 
| 641 | 
            +
                            caption_metadata_filename = gr.Textbox(
         | 
| 642 | 
            +
                                label='Caption metadata filename', value='meta_cap.json'
         | 
| 643 | 
            +
                            )
         | 
| 644 | 
            +
                            latent_metadata_filename = gr.Textbox(
         | 
| 645 | 
            +
                                label='Latent metadata filename', value='meta_lat.json'
         | 
| 646 | 
            +
                            )
         | 
| 647 | 
            +
                            full_path = gr.Checkbox(label='Use full path', value=True)
         | 
| 648 | 
            +
                with gr.Tab('Training parameters'):
         | 
| 649 | 
            +
                    (
         | 
| 650 | 
            +
                        learning_rate,
         | 
| 651 | 
            +
                        lr_scheduler,
         | 
| 652 | 
            +
                        lr_warmup,
         | 
| 653 | 
            +
                        train_batch_size,
         | 
| 654 | 
            +
                        epoch,
         | 
| 655 | 
            +
                        save_every_n_epochs,
         | 
| 656 | 
            +
                        mixed_precision,
         | 
| 657 | 
            +
                        save_precision,
         | 
| 658 | 
            +
                        num_cpu_threads_per_process,
         | 
| 659 | 
            +
                        seed,
         | 
| 660 | 
            +
                        caption_extension,
         | 
| 661 | 
            +
                        cache_latents,
         | 
| 662 | 
            +
                        optimizer,
         | 
| 663 | 
            +
                        optimizer_args,
         | 
| 664 | 
            +
                    ) = gradio_training(learning_rate_value='1e-5')
         | 
| 665 | 
            +
                    with gr.Row():
         | 
| 666 | 
            +
                        dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
         | 
| 667 | 
            +
                        train_text_encoder = gr.Checkbox(
         | 
| 668 | 
            +
                            label='Train text encoder', value=True
         | 
| 669 | 
            +
                        )
         | 
| 670 | 
            +
                    with gr.Accordion('Advanced parameters', open=False):
         | 
| 671 | 
            +
                        with gr.Row():
         | 
| 672 | 
            +
                            gradient_accumulation_steps = gr.Number(
         | 
| 673 | 
            +
                                label='Gradient accumulate steps', value='1'
         | 
| 674 | 
            +
                            )
         | 
| 675 | 
            +
                        (
         | 
| 676 | 
            +
                            # use_8bit_adam,
         | 
| 677 | 
            +
                            xformers,
         | 
| 678 | 
            +
                            full_fp16,
         | 
| 679 | 
            +
                            gradient_checkpointing,
         | 
| 680 | 
            +
                            shuffle_caption,
         | 
| 681 | 
            +
                            color_aug,
         | 
| 682 | 
            +
                            flip_aug,
         | 
| 683 | 
            +
                            clip_skip,
         | 
| 684 | 
            +
                            mem_eff_attn,
         | 
| 685 | 
            +
                            save_state,
         | 
| 686 | 
            +
                            resume,
         | 
| 687 | 
            +
                            max_token_length,
         | 
| 688 | 
            +
                            max_train_epochs,
         | 
| 689 | 
            +
                            max_data_loader_n_workers,
         | 
| 690 | 
            +
                            keep_tokens,
         | 
| 691 | 
            +
                            persistent_data_loader_workers,
         | 
| 692 | 
            +
                            bucket_no_upscale,
         | 
| 693 | 
            +
                            random_crop,
         | 
| 694 | 
            +
                            bucket_reso_steps,
         | 
| 695 | 
            +
                            caption_dropout_every_n_epochs,
         | 
| 696 | 
            +
                            caption_dropout_rate,
         | 
| 697 | 
            +
                            noise_offset,
         | 
| 698 | 
            +
                            additional_parameters,
         | 
| 699 | 
            +
                            vae_batch_size,
         | 
| 700 | 
            +
                            min_snr_gamma,
         | 
| 701 | 
            +
                        ) = gradio_advanced_training()
         | 
| 702 | 
            +
                        color_aug.change(
         | 
| 703 | 
            +
                            color_aug_changed,
         | 
| 704 | 
            +
                            inputs=[color_aug],
         | 
| 705 | 
            +
                            outputs=[cache_latents],  # Not applicable to fine_tune.py
         | 
| 706 | 
            +
                        )
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                    (
         | 
| 709 | 
            +
                        sample_every_n_steps,
         | 
| 710 | 
            +
                        sample_every_n_epochs,
         | 
| 711 | 
            +
                        sample_sampler,
         | 
| 712 | 
            +
                        sample_prompts,
         | 
| 713 | 
            +
                    ) = sample_gradio_config()
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                button_run = gr.Button('Train model', variant='primary')
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                # Setup gradio tensorboard buttons
         | 
| 718 | 
            +
                button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                button_start_tensorboard.click(
         | 
| 721 | 
            +
                    start_tensorboard,
         | 
| 722 | 
            +
                    inputs=logging_dir,
         | 
| 723 | 
            +
                )
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                button_stop_tensorboard.click(
         | 
| 726 | 
            +
                    stop_tensorboard,
         | 
| 727 | 
            +
                    show_progress=False,
         | 
| 728 | 
            +
                )
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                settings_list = [
         | 
| 731 | 
            +
                    pretrained_model_name_or_path,
         | 
| 732 | 
            +
                    v2,
         | 
| 733 | 
            +
                    v_parameterization,
         | 
| 734 | 
            +
                    train_dir,
         | 
| 735 | 
            +
                    image_folder,
         | 
| 736 | 
            +
                    output_dir,
         | 
| 737 | 
            +
                    logging_dir,
         | 
| 738 | 
            +
                    max_resolution,
         | 
| 739 | 
            +
                    min_bucket_reso,
         | 
| 740 | 
            +
                    max_bucket_reso,
         | 
| 741 | 
            +
                    batch_size,
         | 
| 742 | 
            +
                    flip_aug,
         | 
| 743 | 
            +
                    caption_metadata_filename,
         | 
| 744 | 
            +
                    latent_metadata_filename,
         | 
| 745 | 
            +
                    full_path,
         | 
| 746 | 
            +
                    learning_rate,
         | 
| 747 | 
            +
                    lr_scheduler,
         | 
| 748 | 
            +
                    lr_warmup,
         | 
| 749 | 
            +
                    dataset_repeats,
         | 
| 750 | 
            +
                    train_batch_size,
         | 
| 751 | 
            +
                    epoch,
         | 
| 752 | 
            +
                    save_every_n_epochs,
         | 
| 753 | 
            +
                    mixed_precision,
         | 
| 754 | 
            +
                    save_precision,
         | 
| 755 | 
            +
                    seed,
         | 
| 756 | 
            +
                    num_cpu_threads_per_process,
         | 
| 757 | 
            +
                    train_text_encoder,
         | 
| 758 | 
            +
                    create_caption,
         | 
| 759 | 
            +
                    create_buckets,
         | 
| 760 | 
            +
                    save_model_as,
         | 
| 761 | 
            +
                    caption_extension,
         | 
| 762 | 
            +
                    # use_8bit_adam,
         | 
| 763 | 
            +
                    xformers,
         | 
| 764 | 
            +
                    clip_skip,
         | 
| 765 | 
            +
                    save_state,
         | 
| 766 | 
            +
                    resume,
         | 
| 767 | 
            +
                    gradient_checkpointing,
         | 
| 768 | 
            +
                    gradient_accumulation_steps,
         | 
| 769 | 
            +
                    mem_eff_attn,
         | 
| 770 | 
            +
                    shuffle_caption,
         | 
| 771 | 
            +
                    output_name,
         | 
| 772 | 
            +
                    max_token_length,
         | 
| 773 | 
            +
                    max_train_epochs,
         | 
| 774 | 
            +
                    max_data_loader_n_workers,
         | 
| 775 | 
            +
                    full_fp16,
         | 
| 776 | 
            +
                    color_aug,
         | 
| 777 | 
            +
                    model_list,
         | 
| 778 | 
            +
                    cache_latents,
         | 
| 779 | 
            +
                    use_latent_files,
         | 
| 780 | 
            +
                    keep_tokens,
         | 
| 781 | 
            +
                    persistent_data_loader_workers,
         | 
| 782 | 
            +
                    bucket_no_upscale,
         | 
| 783 | 
            +
                    random_crop,
         | 
| 784 | 
            +
                    bucket_reso_steps,
         | 
| 785 | 
            +
                    caption_dropout_every_n_epochs,
         | 
| 786 | 
            +
                    caption_dropout_rate,
         | 
| 787 | 
            +
                    optimizer,
         | 
| 788 | 
            +
                    optimizer_args,
         | 
| 789 | 
            +
                    noise_offset,
         | 
| 790 | 
            +
                    sample_every_n_steps,
         | 
| 791 | 
            +
                    sample_every_n_epochs,
         | 
| 792 | 
            +
                    sample_sampler,
         | 
| 793 | 
            +
                    sample_prompts,
         | 
| 794 | 
            +
                    additional_parameters,
         | 
| 795 | 
            +
                    vae_batch_size,
         | 
| 796 | 
            +
                    min_snr_gamma,
         | 
| 797 | 
            +
                ]
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                button_run.click(train_model, inputs=settings_list)
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                button_open_config.click(
         | 
| 802 | 
            +
                    open_configuration,
         | 
| 803 | 
            +
                    inputs=[dummy_db_true, config_file_name] + settings_list,
         | 
| 804 | 
            +
                    outputs=[config_file_name] + settings_list,
         | 
| 805 | 
            +
                    show_progress=False,
         | 
| 806 | 
            +
                )
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                button_load_config.click(
         | 
| 809 | 
            +
                    open_configuration,
         | 
| 810 | 
            +
                    inputs=[dummy_db_false, config_file_name] + settings_list,
         | 
| 811 | 
            +
                    outputs=[config_file_name] + settings_list,
         | 
| 812 | 
            +
                    show_progress=False,
         | 
| 813 | 
            +
                )
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                button_save_config.click(
         | 
| 816 | 
            +
                    save_configuration,
         | 
| 817 | 
            +
                    inputs=[dummy_db_false, config_file_name] + settings_list,
         | 
| 818 | 
            +
                    outputs=[config_file_name],
         | 
| 819 | 
            +
                    show_progress=False,
         | 
| 820 | 
            +
                )
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                button_save_as_config.click(
         | 
| 823 | 
            +
                    save_configuration,
         | 
| 824 | 
            +
                    inputs=[dummy_db_true, config_file_name] + settings_list,
         | 
| 825 | 
            +
                    outputs=[config_file_name],
         | 
| 826 | 
            +
                    show_progress=False,
         | 
| 827 | 
            +
                )
         | 
| 828 | 
            +
             | 
| 829 | 
            +
             | 
| 830 | 
            +
            def UI(**kwargs):
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                css = ''
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                if os.path.exists('./style.css'):
         | 
| 835 | 
            +
                    with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
         | 
| 836 | 
            +
                        print('Load CSS...')
         | 
| 837 | 
            +
                        css += file.read() + '\n'
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                interface = gr.Blocks(css=css)
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                with interface:
         | 
| 842 | 
            +
                    with gr.Tab('Finetune'):
         | 
| 843 | 
            +
                        finetune_tab()
         | 
| 844 | 
            +
                    with gr.Tab('Utilities'):
         | 
| 845 | 
            +
                        utilities_tab(enable_dreambooth_tab=False)
         | 
| 846 | 
            +
             | 
| 847 | 
            +
                # Show the interface
         | 
| 848 | 
            +
                launch_kwargs = {}
         | 
| 849 | 
            +
                if not kwargs.get('username', None) == '':
         | 
| 850 | 
            +
                    launch_kwargs['auth'] = (
         | 
| 851 | 
            +
                        kwargs.get('username', None),
         | 
| 852 | 
            +
                        kwargs.get('password', None),
         | 
| 853 | 
            +
                    )
         | 
| 854 | 
            +
                if kwargs.get('server_port', 0) > 0:
         | 
| 855 | 
            +
                    launch_kwargs['server_port'] = kwargs.get('server_port', 0)
         | 
| 856 | 
            +
                if kwargs.get('inbrowser', False):
         | 
| 857 | 
            +
                    launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
         | 
| 858 | 
            +
                print(launch_kwargs)
         | 
| 859 | 
            +
                interface.launch(**launch_kwargs)
         | 
| 860 | 
            +
             | 
| 861 | 
            +
             | 
| 862 | 
            +
            if __name__ == '__main__':
         | 
| 863 | 
            +
                # torch.cuda.set_per_process_memory_fraction(0.48)
         | 
| 864 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 865 | 
            +
                parser.add_argument(
         | 
| 866 | 
            +
                    '--username', type=str, default='', help='Username for authentication'
         | 
| 867 | 
            +
                )
         | 
| 868 | 
            +
                parser.add_argument(
         | 
| 869 | 
            +
                    '--password', type=str, default='', help='Password for authentication'
         | 
| 870 | 
            +
                )
         | 
| 871 | 
            +
                parser.add_argument(
         | 
| 872 | 
            +
                    '--server_port',
         | 
| 873 | 
            +
                    type=int,
         | 
| 874 | 
            +
                    default=0,
         | 
| 875 | 
            +
                    help='Port to run the server listener on',
         | 
| 876 | 
            +
                )
         | 
| 877 | 
            +
                parser.add_argument(
         | 
| 878 | 
            +
                    '--inbrowser', action='store_true', help='Open in browser'
         | 
| 879 | 
            +
                )
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                args = parser.parse_args()
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                UI(
         | 
| 884 | 
            +
                    username=args.username,
         | 
| 885 | 
            +
                    password=args.password,
         | 
| 886 | 
            +
                    inbrowser=args.inbrowser,
         | 
| 887 | 
            +
                    server_port=args.server_port,
         | 
| 888 | 
            +
                )
         | 
    	
        gen_img_diffusers.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        gui.sh
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Activate the virtual environment
         | 
| 4 | 
            +
            source ./venv/bin/activate
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # If the requirements are validated, run the kohya_gui.py script with the command-line arguments
         | 
| 7 | 
            +
            if python tools/validate_requirements.py; then
         | 
| 8 | 
            +
                python kohya_gui.py "$@"
         | 
| 9 | 
            +
            fi
         | 
    	
        kohya_gui.py
    ADDED
    
    | @@ -0,0 +1,110 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import argparse
         | 
| 4 | 
            +
            from dreambooth_gui import dreambooth_tab
         | 
| 5 | 
            +
            from finetune_gui import finetune_tab
         | 
| 6 | 
            +
            from textual_inversion_gui import ti_tab
         | 
| 7 | 
            +
            from library.utilities import utilities_tab
         | 
| 8 | 
            +
            from library.extract_lora_gui import gradio_extract_lora_tab
         | 
| 9 | 
            +
            from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
         | 
| 10 | 
            +
            from library.merge_lora_gui import gradio_merge_lora_tab
         | 
| 11 | 
            +
            from library.resize_lora_gui import gradio_resize_lora_tab
         | 
| 12 | 
            +
            from lora_gui import lora_tab
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def UI(**kwargs):
         | 
| 16 | 
            +
                css = ''
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                if os.path.exists('./style.css'):
         | 
| 19 | 
            +
                    with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
         | 
| 20 | 
            +
                        print('Load CSS...')
         | 
| 21 | 
            +
                        css += file.read() + '\n'
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                interface = gr.Blocks(css=css, title='Kohya_ss GUI')
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                with interface:
         | 
| 26 | 
            +
                    with gr.Tab('Dreambooth'):
         | 
| 27 | 
            +
                        (
         | 
| 28 | 
            +
                            train_data_dir_input,
         | 
| 29 | 
            +
                            reg_data_dir_input,
         | 
| 30 | 
            +
                            output_dir_input,
         | 
| 31 | 
            +
                            logging_dir_input,
         | 
| 32 | 
            +
                        ) = dreambooth_tab()
         | 
| 33 | 
            +
                    with gr.Tab('Dreambooth LoRA'):
         | 
| 34 | 
            +
                        lora_tab()
         | 
| 35 | 
            +
                    with gr.Tab('Dreambooth TI'):
         | 
| 36 | 
            +
                        ti_tab()
         | 
| 37 | 
            +
                    with gr.Tab('Finetune'):
         | 
| 38 | 
            +
                        finetune_tab()
         | 
| 39 | 
            +
                    with gr.Tab('Utilities'):
         | 
| 40 | 
            +
                        utilities_tab(
         | 
| 41 | 
            +
                            train_data_dir_input=train_data_dir_input,
         | 
| 42 | 
            +
                            reg_data_dir_input=reg_data_dir_input,
         | 
| 43 | 
            +
                            output_dir_input=output_dir_input,
         | 
| 44 | 
            +
                            logging_dir_input=logging_dir_input,
         | 
| 45 | 
            +
                            enable_copy_info_button=True,
         | 
| 46 | 
            +
                        )
         | 
| 47 | 
            +
                        gradio_extract_lora_tab()
         | 
| 48 | 
            +
                        gradio_extract_lycoris_locon_tab()
         | 
| 49 | 
            +
                        gradio_merge_lora_tab()
         | 
| 50 | 
            +
                        gradio_resize_lora_tab()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                # Show the interface
         | 
| 53 | 
            +
                launch_kwargs = {}
         | 
| 54 | 
            +
                username = kwargs.get('username')
         | 
| 55 | 
            +
                password = kwargs.get('password')
         | 
| 56 | 
            +
                server_port = kwargs.get('server_port', 0)
         | 
| 57 | 
            +
                inbrowser = kwargs.get('inbrowser', False)
         | 
| 58 | 
            +
                share = kwargs.get('share', False)
         | 
| 59 | 
            +
                server_name = kwargs.get('listen')
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                launch_kwargs['server_name'] = server_name
         | 
| 62 | 
            +
                if username and password:
         | 
| 63 | 
            +
                    launch_kwargs['auth'] = (username, password)
         | 
| 64 | 
            +
                if server_port > 0:
         | 
| 65 | 
            +
                    launch_kwargs['server_port'] = server_port
         | 
| 66 | 
            +
                if inbrowser:
         | 
| 67 | 
            +
                    launch_kwargs['inbrowser'] = inbrowser
         | 
| 68 | 
            +
                if share:
         | 
| 69 | 
            +
                    launch_kwargs['share'] = share
         | 
| 70 | 
            +
                interface.launch(**launch_kwargs)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            if __name__ == '__main__':
         | 
| 74 | 
            +
                # torch.cuda.set_per_process_memory_fraction(0.48)
         | 
| 75 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 76 | 
            +
                parser.add_argument(
         | 
| 77 | 
            +
                    '--listen',
         | 
| 78 | 
            +
                    type=str,
         | 
| 79 | 
            +
                    default='127.0.0.1',
         | 
| 80 | 
            +
                    help='IP to listen on for connections to Gradio',
         | 
| 81 | 
            +
                )
         | 
| 82 | 
            +
                parser.add_argument(
         | 
| 83 | 
            +
                    '--username', type=str, default='', help='Username for authentication'
         | 
| 84 | 
            +
                )
         | 
| 85 | 
            +
                parser.add_argument(
         | 
| 86 | 
            +
                    '--password', type=str, default='', help='Password for authentication'
         | 
| 87 | 
            +
                )
         | 
| 88 | 
            +
                parser.add_argument(
         | 
| 89 | 
            +
                    '--server_port',
         | 
| 90 | 
            +
                    type=int,
         | 
| 91 | 
            +
                    default=0,
         | 
| 92 | 
            +
                    help='Port to run the server listener on',
         | 
| 93 | 
            +
                )
         | 
| 94 | 
            +
                parser.add_argument(
         | 
| 95 | 
            +
                    '--inbrowser', action='store_true', help='Open in browser'
         | 
| 96 | 
            +
                )
         | 
| 97 | 
            +
                parser.add_argument(
         | 
| 98 | 
            +
                    '--share', action='store_true', help='Share the gradio UI'
         | 
| 99 | 
            +
                )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                args = parser.parse_args()
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                UI(
         | 
| 104 | 
            +
                    username=args.username,
         | 
| 105 | 
            +
                    password=args.password,
         | 
| 106 | 
            +
                    inbrowser=args.inbrowser,
         | 
| 107 | 
            +
                    server_port=args.server_port,
         | 
| 108 | 
            +
                    share=args.share,
         | 
| 109 | 
            +
                    listen=args.listen,
         | 
| 110 | 
            +
                )
         | 
    	
        kohya_ss_colab.ipynb
    ADDED
    
    | @@ -0,0 +1,448 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "cells": [
         | 
| 3 | 
            +
                {
         | 
| 4 | 
            +
                  "cell_type": "markdown",
         | 
| 5 | 
            +
                  "metadata": {
         | 
| 6 | 
            +
                    "colab_type": "text",
         | 
| 7 | 
            +
                    "id": "view-in-github"
         | 
| 8 | 
            +
                  },
         | 
| 9 | 
            +
                  "source": [
         | 
| 10 | 
            +
                    "<a href=\"https://colab.research.google.com/github/panguin6010/kohya_ss_google_colab/blob/master/kohya_ss_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
         | 
| 11 | 
            +
                  ]
         | 
| 12 | 
            +
                },
         | 
| 13 | 
            +
                {
         | 
| 14 | 
            +
                  "cell_type": "markdown",
         | 
| 15 | 
            +
                  "metadata": {
         | 
| 16 | 
            +
                    "id": "MvroZ9rJ1iqN"
         | 
| 17 | 
            +
                  },
         | 
| 18 | 
            +
                  "source": [
         | 
| 19 | 
            +
                    "# Kohya SS WebUI Colab Setup\n",
         | 
| 20 | 
            +
                    "\n",
         | 
| 21 | 
            +
                    "This Colab workbook sets up a Kohya SS instance on Colab and provides a link to access the Kohya WebUI on Gradio Live. Kohya SS is a Python library that provides Stable Diffusion-based models for image, text, and audio generation tasks. This Colab workbook provides a convenient way for users to run Kohya SS without needing to install anything on their local machine.\n",
         | 
| 22 | 
            +
                    "\n",
         | 
| 23 | 
            +
                    "This workbook was inspired by the work of [Spaceginner](https://github.com/Spaceginner)'s original Colab workbook and the [Kohya SS project](https://github.com/bmaltais/kohya_ss) by [bmaltais](https://github.com/bmaltais). The Colab workbook was coded by [panguin6010](https://github.com/panguin6010) \n",
         | 
| 24 | 
            +
                    "\n",
         | 
| 25 | 
            +
                    "\n",
         | 
| 26 | 
            +
                    "## Tutorials\n",
         | 
| 27 | 
            +
                    "\n",
         | 
| 28 | 
            +
                    "Before running this code, make sure you are familiar with using Colab workbooks and have a basic understanding of Kohya SS and its usage. You can find tutorials for these online. If you encounter any issues or have suggestions for improvement, feel free to contribute to the project.\n"
         | 
| 29 | 
            +
                  ]
         | 
| 30 | 
            +
                },
         | 
| 31 | 
            +
                {
         | 
| 32 | 
            +
                  "cell_type": "markdown",
         | 
| 33 | 
            +
                  "metadata": {
         | 
| 34 | 
            +
                    "id": "DrAnm1um5vjh"
         | 
| 35 | 
            +
                  },
         | 
| 36 | 
            +
                  "source": [
         | 
| 37 | 
            +
                    "\n",
         | 
| 38 | 
            +
                    "\n",
         | 
| 39 | 
            +
                    "\n",
         | 
| 40 | 
            +
                    "---\n",
         | 
| 41 | 
            +
                    "\n"
         | 
| 42 | 
            +
                  ]
         | 
| 43 | 
            +
                },
         | 
| 44 | 
            +
                {
         | 
| 45 | 
            +
                  "cell_type": "code",
         | 
| 46 | 
            +
                  "execution_count": null,
         | 
| 47 | 
            +
                  "metadata": {
         | 
| 48 | 
            +
                    "colab": {
         | 
| 49 | 
            +
                      "base_uri": "https://localhost:8080/"
         | 
| 50 | 
            +
                    },
         | 
| 51 | 
            +
                    "id": "vmoRnFQEqOeY",
         | 
| 52 | 
            +
                    "outputId": "09876c9a-d043-4881-d92f-6ed54313c390"
         | 
| 53 | 
            +
                  },
         | 
| 54 | 
            +
                  "outputs": [],
         | 
| 55 | 
            +
                  "source": [
         | 
| 56 | 
            +
                    "#@markdown #Step 1: Mounting Google Drive\n",
         | 
| 57 | 
            +
                    "\n",
         | 
| 58 | 
            +
                    "#@markdown The first step in setting up Kohya SS on Colab is to mount your Google Drive to the Colab notebook. This allows you to save and access files from your Google Drive in the Colab notebook.\n",
         | 
| 59 | 
            +
                    "\n",
         | 
| 60 | 
            +
                    "#@markdown To mount your Google Drive, run the following code block, which mounts your Google Drive to the /content/gdrive directory in the Colab notebook.\n",
         | 
| 61 | 
            +
                    "\n",
         | 
| 62 | 
            +
                    "\n",
         | 
| 63 | 
            +
                    "\n",
         | 
| 64 | 
            +
                    "from google.colab import drive\n",
         | 
| 65 | 
            +
                    "drive.mount('/content/gdrive')"
         | 
| 66 | 
            +
                  ]
         | 
| 67 | 
            +
                },
         | 
| 68 | 
            +
                {
         | 
| 69 | 
            +
                  "cell_type": "markdown",
         | 
| 70 | 
            +
                  "metadata": {
         | 
| 71 | 
            +
                    "id": "mvQwnr4354BM"
         | 
| 72 | 
            +
                  },
         | 
| 73 | 
            +
                  "source": [
         | 
| 74 | 
            +
                    "\n",
         | 
| 75 | 
            +
                    "\n",
         | 
| 76 | 
            +
                    "---\n",
         | 
| 77 | 
            +
                    "\n"
         | 
| 78 | 
            +
                  ]
         | 
| 79 | 
            +
                },
         | 
| 80 | 
            +
                {
         | 
| 81 | 
            +
                  "cell_type": "code",
         | 
| 82 | 
            +
                  "execution_count": null,
         | 
| 83 | 
            +
                  "metadata": {
         | 
| 84 | 
            +
                    "cellView": "form",
         | 
| 85 | 
            +
                    "colab": {
         | 
| 86 | 
            +
                      "base_uri": "https://localhost:8080/",
         | 
| 87 | 
            +
                      "height": 49,
         | 
| 88 | 
            +
                      "referenced_widgets": [
         | 
| 89 | 
            +
                        "7ca7f6f727da46ac9a1149e69c16c81f",
         | 
| 90 | 
            +
                        "77e5e07552b641cf9c368fb3939cb1d1",
         | 
| 91 | 
            +
                        "235e01b92646444387ebd31ab945358e"
         | 
| 92 | 
            +
                      ]
         | 
| 93 | 
            +
                    },
         | 
| 94 | 
            +
                    "id": "jnhm7ycMrLWb",
         | 
| 95 | 
            +
                    "outputId": "63ba39ed-90c6-4b2d-f03e-61775587b083"
         | 
| 96 | 
            +
                  },
         | 
| 97 | 
            +
                  "outputs": [],
         | 
| 98 | 
            +
                  "source": [
         | 
| 99 | 
            +
                    "#@markdown #Kohya SS WebUI Installation\n",
         | 
| 100 | 
            +
                    "\n",
         | 
| 101 | 
            +
                    "#@markdown Now that your Google Drive is linked, we need to install the Kohya SS WebUI.\n",
         | 
| 102 | 
            +
                    "\n",
         | 
| 103 | 
            +
                    "#@markdown The code clones the [Kohya SS Google Colab](\"https://github.com/panguin6010/kohya_ss_google_colab\") repository and creates the necessary directories for Kohya SS to run. It then resets the git repository and pulls the latest changes. Finally, it displays a success message.\n",
         | 
| 104 | 
            +
                    "\n",
         | 
| 105 | 
            +
                    "#@markdown Note: If Google Drive is not connected, the code will use Colab storage instead.\n",
         | 
| 106 | 
            +
                    "\n",
         | 
| 107 | 
            +
                    "#@title\n",
         | 
| 108 | 
            +
                    "# Import necessary libraries\n",
         | 
| 109 | 
            +
                    "from IPython.display import clear_output\n",
         | 
| 110 | 
            +
                    "from IPython.utils import capture\n",
         | 
| 111 | 
            +
                    "from subprocess import getoutput\n",
         | 
| 112 | 
            +
                    "import ipywidgets as widgets\n",
         | 
| 113 | 
            +
                    "import sys\n",
         | 
| 114 | 
            +
                    "import fileinput\n",
         | 
| 115 | 
            +
                    "import os\n",
         | 
| 116 | 
            +
                    "import time\n",
         | 
| 117 | 
            +
                    "\n",
         | 
| 118 | 
            +
                    "# WebUI Installation\n",
         | 
| 119 | 
            +
                    "\n",
         | 
| 120 | 
            +
                    "# Check if Google Drive is connected\n",
         | 
| 121 | 
            +
                    "if not os.path.exists(\"/content/gdrive/MyDrive/\"):\n",
         | 
| 122 | 
            +
                    "    print(\"Gdrive not connected, using colab storage ...\")\n",
         | 
| 123 | 
            +
                    "    time.sleep(4)\n",
         | 
| 124 | 
            +
                    "    !mkdir -p /content/gdrive/MyDrive/\n",
         | 
| 125 | 
            +
                    "\n",
         | 
| 126 | 
            +
                    "# Clone the repository and create necessary directories\n",
         | 
| 127 | 
            +
                    "with capture.capture_output() as cap:\n",
         | 
| 128 | 
            +
                    "    def inf(msg, style, wdth):\n",
         | 
| 129 | 
            +
                    "        inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth))\n",
         | 
| 130 | 
            +
                    "        display(inf)\n",
         | 
| 131 | 
            +
                    "\n",
         | 
| 132 | 
            +
                    "    %mkdir -p /content/gdrive/MyDrive/sd\n",
         | 
| 133 | 
            +
                    "    %cd /content/gdrive/MyDrive/sd\n",
         | 
| 134 | 
            +
                    "    !git clone https://github.com/panguin6010/kohya_ss_google_colab kohya_ss_colab\n",
         | 
| 135 | 
            +
                    "    !mkdir -p /content/gdrive/MyDrive/sd/kohya_ss_colab/cache/huggingface\n",
         | 
| 136 | 
            +
                    "    !ln -s /content/gdrive/MyDrive/sd/kohya_ss_colab/cache/huggingface /root/.cache/\n",
         | 
| 137 | 
            +
                    "\n",
         | 
| 138 | 
            +
                    "# Reset the git repository and pull the latest changes\n",
         | 
| 139 | 
            +
                    "with capture.capture_output() as cap:\n",
         | 
| 140 | 
            +
                    "    %cd /content/gdrive/MyDrive/sd/kohya_ss_colab/\n",
         | 
| 141 | 
            +
                    "    !git reset --hard\n",
         | 
| 142 | 
            +
                    "    time.sleep(1)\n",
         | 
| 143 | 
            +
                    "\n",
         | 
| 144 | 
            +
                    "print(\"Updating the repository...\")\n",
         | 
| 145 | 
            +
                    "!git pull\n",
         | 
| 146 | 
            +
                    "\n",
         | 
| 147 | 
            +
                    "# Clear the output and display the success message\n",
         | 
| 148 | 
            +
                    "clear_output()\n",
         | 
| 149 | 
            +
                    "inf(\"✓ Done\", \"success\", \"50px\")"
         | 
| 150 | 
            +
                  ]
         | 
| 151 | 
            +
                },
         | 
| 152 | 
            +
                {
         | 
| 153 | 
            +
                  "cell_type": "markdown",
         | 
| 154 | 
            +
                  "metadata": {
         | 
| 155 | 
            +
                    "id": "8SrMhmFz7Lt4"
         | 
| 156 | 
            +
                  },
         | 
| 157 | 
            +
                  "source": [
         | 
| 158 | 
            +
                    "---"
         | 
| 159 | 
            +
                  ]
         | 
| 160 | 
            +
                },
         | 
| 161 | 
            +
                {
         | 
| 162 | 
            +
                  "cell_type": "code",
         | 
| 163 | 
            +
                  "execution_count": null,
         | 
| 164 | 
            +
                  "metadata": {
         | 
| 165 | 
            +
                    "cellView": "form",
         | 
| 166 | 
            +
                    "colab": {
         | 
| 167 | 
            +
                      "base_uri": "https://localhost:8080/",
         | 
| 168 | 
            +
                      "height": 49,
         | 
| 169 | 
            +
                      "referenced_widgets": [
         | 
| 170 | 
            +
                        "54e929bcb37e4997a696d0becdecfd84",
         | 
| 171 | 
            +
                        "43fbca3abb04401296967f819680f94f",
         | 
| 172 | 
            +
                        "6d87b2c916394932b1a53382fe3cdb4e"
         | 
| 173 | 
            +
                      ]
         | 
| 174 | 
            +
                    },
         | 
| 175 | 
            +
                    "id": "yjvkHRlDtDmV",
         | 
| 176 | 
            +
                    "outputId": "06e1e873-b1ed-4403-c9a4-19ac1caa961b"
         | 
| 177 | 
            +
                  },
         | 
| 178 | 
            +
                  "outputs": [],
         | 
| 179 | 
            +
                  "source": [
         | 
| 180 | 
            +
                    "#@markdown #Requirements Installation\n",
         | 
| 181 | 
            +
                    "\n",
         | 
| 182 | 
            +
                    "#@markdown Now that we have downloaded the Kohya SS WebUI, we need to install the necessary requirements.\n",
         | 
| 183 | 
            +
                    "\n",
         | 
| 184 | 
            +
                    "# Print the status message\n",
         | 
| 185 | 
            +
                    "print(\"Installing requirements...\")\n",
         | 
| 186 | 
            +
                    "\n",
         | 
| 187 | 
            +
                    "# Change the working directory to the project folder\n",
         | 
| 188 | 
            +
                    "%cd /content/gdrive/MyDrive/sd/kohya_ss_colab/\n",
         | 
| 189 | 
            +
                    "\n",
         | 
| 190 | 
            +
                    "# Install the requirements\n",
         | 
| 191 | 
            +
                    "with capture.capture_output() as cap:\n",
         | 
| 192 | 
            +
                    "    # Uncomment the following line if you need to install specific versions of torch and torchvision\n",
         | 
| 193 | 
            +
                    "    # !pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116\n",
         | 
| 194 | 
            +
                    "    \n",
         | 
| 195 | 
            +
                    "    # Install the requirements from the requirements.txt file\n",
         | 
| 196 | 
            +
                    "    !pip install -r requirements.txt\n",
         | 
| 197 | 
            +
                    "\n",
         | 
| 198 | 
            +
                    "# Clear the output to keep the notebook clean\n",
         | 
| 199 | 
            +
                    "clear_output()\n",
         | 
| 200 | 
            +
                    "\n",
         | 
| 201 | 
            +
                    "# Print the success message\n",
         | 
| 202 | 
            +
                    "inf(\"✓ Done\", \"success\", \"50px\")"
         | 
| 203 | 
            +
                  ]
         | 
| 204 | 
            +
                },
         | 
| 205 | 
            +
                {
         | 
| 206 | 
            +
                  "cell_type": "markdown",
         | 
| 207 | 
            +
                  "metadata": {
         | 
| 208 | 
            +
                    "id": "FLDvlHm1tYud"
         | 
| 209 | 
            +
                  },
         | 
| 210 | 
            +
                  "source": [
         | 
| 211 | 
            +
                    "\n",
         | 
| 212 | 
            +
                    "---\n",
         | 
| 213 | 
            +
                    "\n"
         | 
| 214 | 
            +
                  ]
         | 
| 215 | 
            +
                },
         | 
| 216 | 
            +
                {
         | 
| 217 | 
            +
                  "cell_type": "code",
         | 
| 218 | 
            +
                  "execution_count": null,
         | 
| 219 | 
            +
                  "metadata": {
         | 
| 220 | 
            +
                    "colab": {
         | 
| 221 | 
            +
                      "base_uri": "https://localhost:8080/"
         | 
| 222 | 
            +
                    },
         | 
| 223 | 
            +
                    "id": "IzS3hvuTtTqW",
         | 
| 224 | 
            +
                    "outputId": "9e629e1f-c8eb-43a2-9639-2583937ba81a"
         | 
| 225 | 
            +
                  },
         | 
| 226 | 
            +
                  "outputs": [],
         | 
| 227 | 
            +
                  "source": [
         | 
| 228 | 
            +
                    "#@markdown # Start Kohya ss WebUI\n",
         | 
| 229 | 
            +
                    "\n",
         | 
| 230 | 
            +
                    "User = \"\" #@param {type:\"string\"}\n",
         | 
| 231 | 
            +
                    "Password = \"\" #@param {type:\"string\"}\n",
         | 
| 232 | 
            +
                    "\n",
         | 
| 233 | 
            +
                    "#@markdown - Adding a username and password is not necessary but it will improve the security of your Kohya instance.\n",
         | 
| 234 | 
            +
                    "#@markdown ______\n",
         | 
| 235 | 
            +
                    "#@markdown # Please click the link that concludes with ```gradio.live``` to access your instance\n",
         | 
| 236 | 
            +
                    "# Encourage users to contribute improvements\n",
         | 
| 237 | 
            +
                    "print(\"Please feel free to make any changes or improvements you think would enhance this setup. Your input and contributions are greatly appreciated!\")\n",
         | 
| 238 | 
            +
                    "# Check if the user has provided a username and password\n",
         | 
| 239 | 
            +
                    "if User and Password:\n",
         | 
| 240 | 
            +
                    "    # Run the Kohya GUI with the provided credentials\n",
         | 
| 241 | 
            +
                    "    !python /content/gdrive/MyDrive/sd/kohya_ss_colab/kohya_gui.py --username $User --password $Password --share \n",
         | 
| 242 | 
            +
                    "else:\n",
         | 
| 243 | 
            +
                    "    # Run the Kohya GUI without credentials\n",
         | 
| 244 | 
            +
                    "    !python /content/gdrive/MyDrive/sd/kohya_ss_colab/kohya_gui.py --share \n"
         | 
| 245 | 
            +
                  ]
         | 
| 246 | 
            +
                }
         | 
| 247 | 
            +
              ],
         | 
| 248 | 
            +
              "metadata": {
         | 
| 249 | 
            +
                "colab": {
         | 
| 250 | 
            +
                  "authorship_tag": "ABX9TyOZmOjfS55zOBmbTmRNOf3b",
         | 
| 251 | 
            +
                  "include_colab_link": true,
         | 
| 252 | 
            +
                  "provenance": []
         | 
| 253 | 
            +
                },
         | 
| 254 | 
            +
                "kernelspec": {
         | 
| 255 | 
            +
                  "display_name": "Python 3",
         | 
| 256 | 
            +
                  "name": "python3"
         | 
| 257 | 
            +
                },
         | 
| 258 | 
            +
                "language_info": {
         | 
| 259 | 
            +
                  "name": "python"
         | 
| 260 | 
            +
                },
         | 
| 261 | 
            +
                "widgets": {
         | 
| 262 | 
            +
                  "application/vnd.jupyter.widget-state+json": {
         | 
| 263 | 
            +
                    "235e01b92646444387ebd31ab945358e": {
         | 
| 264 | 
            +
                      "model_module": "@jupyter-widgets/controls",
         | 
| 265 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 266 | 
            +
                      "model_name": "ButtonStyleModel",
         | 
| 267 | 
            +
                      "state": {
         | 
| 268 | 
            +
                        "_model_module": "@jupyter-widgets/controls",
         | 
| 269 | 
            +
                        "_model_module_version": "1.5.0",
         | 
| 270 | 
            +
                        "_model_name": "ButtonStyleModel",
         | 
| 271 | 
            +
                        "_view_count": null,
         | 
| 272 | 
            +
                        "_view_module": "@jupyter-widgets/base",
         | 
| 273 | 
            +
                        "_view_module_version": "1.2.0",
         | 
| 274 | 
            +
                        "_view_name": "StyleView",
         | 
| 275 | 
            +
                        "button_color": null,
         | 
| 276 | 
            +
                        "font_weight": ""
         | 
| 277 | 
            +
                      }
         | 
| 278 | 
            +
                    },
         | 
| 279 | 
            +
                    "43fbca3abb04401296967f819680f94f": {
         | 
| 280 | 
            +
                      "model_module": "@jupyter-widgets/base",
         | 
| 281 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 282 | 
            +
                      "model_name": "LayoutModel",
         | 
| 283 | 
            +
                      "state": {
         | 
| 284 | 
            +
                        "_model_module": "@jupyter-widgets/base",
         | 
| 285 | 
            +
                        "_model_module_version": "1.2.0",
         | 
| 286 | 
            +
                        "_model_name": "LayoutModel",
         | 
| 287 | 
            +
                        "_view_count": null,
         | 
| 288 | 
            +
                        "_view_module": "@jupyter-widgets/base",
         | 
| 289 | 
            +
                        "_view_module_version": "1.2.0",
         | 
| 290 | 
            +
                        "_view_name": "LayoutView",
         | 
| 291 | 
            +
                        "align_content": null,
         | 
| 292 | 
            +
                        "align_items": null,
         | 
| 293 | 
            +
                        "align_self": null,
         | 
| 294 | 
            +
                        "border": null,
         | 
| 295 | 
            +
                        "bottom": null,
         | 
| 296 | 
            +
                        "display": null,
         | 
| 297 | 
            +
                        "flex": null,
         | 
| 298 | 
            +
                        "flex_flow": null,
         | 
| 299 | 
            +
                        "grid_area": null,
         | 
| 300 | 
            +
                        "grid_auto_columns": null,
         | 
| 301 | 
            +
                        "grid_auto_flow": null,
         | 
| 302 | 
            +
                        "grid_auto_rows": null,
         | 
| 303 | 
            +
                        "grid_column": null,
         | 
| 304 | 
            +
                        "grid_gap": null,
         | 
| 305 | 
            +
                        "grid_row": null,
         | 
| 306 | 
            +
                        "grid_template_areas": null,
         | 
| 307 | 
            +
                        "grid_template_columns": null,
         | 
| 308 | 
            +
                        "grid_template_rows": null,
         | 
| 309 | 
            +
                        "height": null,
         | 
| 310 | 
            +
                        "justify_content": null,
         | 
| 311 | 
            +
                        "justify_items": null,
         | 
| 312 | 
            +
                        "left": null,
         | 
| 313 | 
            +
                        "margin": null,
         | 
| 314 | 
            +
                        "max_height": null,
         | 
| 315 | 
            +
                        "max_width": null,
         | 
| 316 | 
            +
                        "min_height": null,
         | 
| 317 | 
            +
                        "min_width": "50px",
         | 
| 318 | 
            +
                        "object_fit": null,
         | 
| 319 | 
            +
                        "object_position": null,
         | 
| 320 | 
            +
                        "order": null,
         | 
| 321 | 
            +
                        "overflow": null,
         | 
| 322 | 
            +
                        "overflow_x": null,
         | 
| 323 | 
            +
                        "overflow_y": null,
         | 
| 324 | 
            +
                        "padding": null,
         | 
| 325 | 
            +
                        "right": null,
         | 
| 326 | 
            +
                        "top": null,
         | 
| 327 | 
            +
                        "visibility": null,
         | 
| 328 | 
            +
                        "width": null
         | 
| 329 | 
            +
                      }
         | 
| 330 | 
            +
                    },
         | 
| 331 | 
            +
                    "54e929bcb37e4997a696d0becdecfd84": {
         | 
| 332 | 
            +
                      "model_module": "@jupyter-widgets/controls",
         | 
| 333 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 334 | 
            +
                      "model_name": "ButtonModel",
         | 
| 335 | 
            +
                      "state": {
         | 
| 336 | 
            +
                        "_dom_classes": [],
         | 
| 337 | 
            +
                        "_model_module": "@jupyter-widgets/controls",
         | 
| 338 | 
            +
                        "_model_module_version": "1.5.0",
         | 
| 339 | 
            +
                        "_model_name": "ButtonModel",
         | 
| 340 | 
            +
                        "_view_count": null,
         | 
| 341 | 
            +
                        "_view_module": "@jupyter-widgets/controls",
         | 
| 342 | 
            +
                        "_view_module_version": "1.5.0",
         | 
| 343 | 
            +
                        "_view_name": "ButtonView",
         | 
| 344 | 
            +
                        "button_style": "success",
         | 
| 345 | 
            +
                        "description": "✓ Done",
         | 
| 346 | 
            +
                        "disabled": true,
         | 
| 347 | 
            +
                        "icon": "",
         | 
| 348 | 
            +
                        "layout": "IPY_MODEL_43fbca3abb04401296967f819680f94f",
         | 
| 349 | 
            +
                        "style": "IPY_MODEL_6d87b2c916394932b1a53382fe3cdb4e",
         | 
| 350 | 
            +
                        "tooltip": ""
         | 
| 351 | 
            +
                      }
         | 
| 352 | 
            +
                    },
         | 
| 353 | 
            +
                    "6d87b2c916394932b1a53382fe3cdb4e": {
         | 
| 354 | 
            +
                      "model_module": "@jupyter-widgets/controls",
         | 
| 355 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 356 | 
            +
                      "model_name": "ButtonStyleModel",
         | 
| 357 | 
            +
                      "state": {
         | 
| 358 | 
            +
                        "_model_module": "@jupyter-widgets/controls",
         | 
| 359 | 
            +
                        "_model_module_version": "1.5.0",
         | 
| 360 | 
            +
                        "_model_name": "ButtonStyleModel",
         | 
| 361 | 
            +
                        "_view_count": null,
         | 
| 362 | 
            +
                        "_view_module": "@jupyter-widgets/base",
         | 
| 363 | 
            +
                        "_view_module_version": "1.2.0",
         | 
| 364 | 
            +
                        "_view_name": "StyleView",
         | 
| 365 | 
            +
                        "button_color": null,
         | 
| 366 | 
            +
                        "font_weight": ""
         | 
| 367 | 
            +
                      }
         | 
| 368 | 
            +
                    },
         | 
| 369 | 
            +
                    "77e5e07552b641cf9c368fb3939cb1d1": {
         | 
| 370 | 
            +
                      "model_module": "@jupyter-widgets/base",
         | 
| 371 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 372 | 
            +
                      "model_name": "LayoutModel",
         | 
| 373 | 
            +
                      "state": {
         | 
| 374 | 
            +
                        "_model_module": "@jupyter-widgets/base",
         | 
| 375 | 
            +
                        "_model_module_version": "1.2.0",
         | 
| 376 | 
            +
                        "_model_name": "LayoutModel",
         | 
| 377 | 
            +
                        "_view_count": null,
         | 
| 378 | 
            +
                        "_view_module": "@jupyter-widgets/base",
         | 
| 379 | 
            +
                        "_view_module_version": "1.2.0",
         | 
| 380 | 
            +
                        "_view_name": "LayoutView",
         | 
| 381 | 
            +
                        "align_content": null,
         | 
| 382 | 
            +
                        "align_items": null,
         | 
| 383 | 
            +
                        "align_self": null,
         | 
| 384 | 
            +
                        "border": null,
         | 
| 385 | 
            +
                        "bottom": null,
         | 
| 386 | 
            +
                        "display": null,
         | 
| 387 | 
            +
                        "flex": null,
         | 
| 388 | 
            +
                        "flex_flow": null,
         | 
| 389 | 
            +
                        "grid_area": null,
         | 
| 390 | 
            +
                        "grid_auto_columns": null,
         | 
| 391 | 
            +
                        "grid_auto_flow": null,
         | 
| 392 | 
            +
                        "grid_auto_rows": null,
         | 
| 393 | 
            +
                        "grid_column": null,
         | 
| 394 | 
            +
                        "grid_gap": null,
         | 
| 395 | 
            +
                        "grid_row": null,
         | 
| 396 | 
            +
                        "grid_template_areas": null,
         | 
| 397 | 
            +
                        "grid_template_columns": null,
         | 
| 398 | 
            +
                        "grid_template_rows": null,
         | 
| 399 | 
            +
                        "height": null,
         | 
| 400 | 
            +
                        "justify_content": null,
         | 
| 401 | 
            +
                        "justify_items": null,
         | 
| 402 | 
            +
                        "left": null,
         | 
| 403 | 
            +
                        "margin": null,
         | 
| 404 | 
            +
                        "max_height": null,
         | 
| 405 | 
            +
                        "max_width": null,
         | 
| 406 | 
            +
                        "min_height": null,
         | 
| 407 | 
            +
                        "min_width": "50px",
         | 
| 408 | 
            +
                        "object_fit": null,
         | 
| 409 | 
            +
                        "object_position": null,
         | 
| 410 | 
            +
                        "order": null,
         | 
| 411 | 
            +
                        "overflow": null,
         | 
| 412 | 
            +
                        "overflow_x": null,
         | 
| 413 | 
            +
                        "overflow_y": null,
         | 
| 414 | 
            +
                        "padding": null,
         | 
| 415 | 
            +
                        "right": null,
         | 
| 416 | 
            +
                        "top": null,
         | 
| 417 | 
            +
                        "visibility": null,
         | 
| 418 | 
            +
                        "width": null
         | 
| 419 | 
            +
                      }
         | 
| 420 | 
            +
                    },
         | 
| 421 | 
            +
                    "7ca7f6f727da46ac9a1149e69c16c81f": {
         | 
| 422 | 
            +
                      "model_module": "@jupyter-widgets/controls",
         | 
| 423 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 424 | 
            +
                      "model_name": "ButtonModel",
         | 
| 425 | 
            +
                      "state": {
         | 
| 426 | 
            +
                        "_dom_classes": [],
         | 
| 427 | 
            +
                        "_model_module": "@jupyter-widgets/controls",
         | 
| 428 | 
            +
                        "_model_module_version": "1.5.0",
         | 
| 429 | 
            +
                        "_model_name": "ButtonModel",
         | 
| 430 | 
            +
                        "_view_count": null,
         | 
| 431 | 
            +
                        "_view_module": "@jupyter-widgets/controls",
         | 
| 432 | 
            +
                        "_view_module_version": "1.5.0",
         | 
| 433 | 
            +
                        "_view_name": "ButtonView",
         | 
| 434 | 
            +
                        "button_style": "success",
         | 
| 435 | 
            +
                        "description": "✓ Done",
         | 
| 436 | 
            +
                        "disabled": true,
         | 
| 437 | 
            +
                        "icon": "",
         | 
| 438 | 
            +
                        "layout": "IPY_MODEL_77e5e07552b641cf9c368fb3939cb1d1",
         | 
| 439 | 
            +
                        "style": "IPY_MODEL_235e01b92646444387ebd31ab945358e",
         | 
| 440 | 
            +
                        "tooltip": ""
         | 
| 441 | 
            +
                      }
         | 
| 442 | 
            +
                    }
         | 
| 443 | 
            +
                  }
         | 
| 444 | 
            +
                }
         | 
| 445 | 
            +
              },
         | 
| 446 | 
            +
              "nbformat": 4,
         | 
| 447 | 
            +
              "nbformat_minor": 0
         | 
| 448 | 
            +
            }
         | 
    	
        library/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        library/basic_caption_gui.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            from .common_gui import get_folder_path, add_pre_postfix, find_replace
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def caption_images(
         | 
| 9 | 
            +
                caption_text,
         | 
| 10 | 
            +
                images_dir,
         | 
| 11 | 
            +
                overwrite,
         | 
| 12 | 
            +
                caption_ext,
         | 
| 13 | 
            +
                prefix,
         | 
| 14 | 
            +
                postfix,
         | 
| 15 | 
            +
                find_text,
         | 
| 16 | 
            +
                replace_text,
         | 
| 17 | 
            +
            ):
         | 
| 18 | 
            +
                # Check for images_dir
         | 
| 19 | 
            +
                if not images_dir:
         | 
| 20 | 
            +
                    msgbox('Image folder is missing...')
         | 
| 21 | 
            +
                    return
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                if not caption_ext:
         | 
| 24 | 
            +
                    msgbox('Please provide an extension for the caption files.')
         | 
| 25 | 
            +
                    return
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                if caption_text:
         | 
| 28 | 
            +
                    print(f'Captioning files in {images_dir} with {caption_text}...')
         | 
| 29 | 
            +
                    run_cmd = f'python "tools/caption.py"'
         | 
| 30 | 
            +
                    run_cmd += f' --caption_text="{caption_text}"'
         | 
| 31 | 
            +
                    if overwrite:
         | 
| 32 | 
            +
                        run_cmd += f' --overwrite'
         | 
| 33 | 
            +
                    if caption_ext:
         | 
| 34 | 
            +
                        run_cmd += f' --caption_file_ext="{caption_ext}"'
         | 
| 35 | 
            +
                    run_cmd += f' "{images_dir}"'
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    print(run_cmd)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    # Run the command
         | 
| 40 | 
            +
                    if os.name == 'posix':
         | 
| 41 | 
            +
                        os.system(run_cmd)
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        subprocess.run(run_cmd)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if overwrite:
         | 
| 46 | 
            +
                    if prefix or postfix:
         | 
| 47 | 
            +
                        # Add prefix and postfix
         | 
| 48 | 
            +
                        add_pre_postfix(
         | 
| 49 | 
            +
                            folder=images_dir,
         | 
| 50 | 
            +
                            caption_file_ext=caption_ext,
         | 
| 51 | 
            +
                            prefix=prefix,
         | 
| 52 | 
            +
                            postfix=postfix,
         | 
| 53 | 
            +
                        )
         | 
| 54 | 
            +
                    if find_text:
         | 
| 55 | 
            +
                        find_replace(
         | 
| 56 | 
            +
                            folder_path=images_dir,
         | 
| 57 | 
            +
                            caption_file_ext=caption_ext,
         | 
| 58 | 
            +
                            search_text=find_text,
         | 
| 59 | 
            +
                            replace_text=replace_text,
         | 
| 60 | 
            +
                        )
         | 
| 61 | 
            +
                else:
         | 
| 62 | 
            +
                    if prefix or postfix:
         | 
| 63 | 
            +
                        msgbox(
         | 
| 64 | 
            +
                            'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
         | 
| 65 | 
            +
                        )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                print('...captioning done')
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            # Gradio UI
         | 
| 71 | 
            +
            def gradio_basic_caption_gui_tab():
         | 
| 72 | 
            +
                with gr.Tab('Basic Captioning'):
         | 
| 73 | 
            +
                    gr.Markdown(
         | 
| 74 | 
            +
                        'This utility will allow the creation of simple caption files for each image in a folder.'
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
                    with gr.Row():
         | 
| 77 | 
            +
                        images_dir = gr.Textbox(
         | 
| 78 | 
            +
                            label='Image folder to caption',
         | 
| 79 | 
            +
                            placeholder='Directory containing the images to caption',
         | 
| 80 | 
            +
                            interactive=True,
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
                        folder_button = gr.Button('📂', elem_id='open_folder_small')
         | 
| 83 | 
            +
                        folder_button.click(
         | 
| 84 | 
            +
                            get_folder_path,
         | 
| 85 | 
            +
                            outputs=images_dir,
         | 
| 86 | 
            +
                            show_progress=False,
         | 
| 87 | 
            +
                        )
         | 
| 88 | 
            +
                        caption_ext = gr.Textbox(
         | 
| 89 | 
            +
                            label='Caption file extension',
         | 
| 90 | 
            +
                            placeholder='Extension for caption file. eg: .caption, .txt',
         | 
| 91 | 
            +
                            value='.txt',
         | 
| 92 | 
            +
                            interactive=True,
         | 
| 93 | 
            +
                        )
         | 
| 94 | 
            +
                        overwrite = gr.Checkbox(
         | 
| 95 | 
            +
                            label='Overwrite existing captions in folder',
         | 
| 96 | 
            +
                            interactive=True,
         | 
| 97 | 
            +
                            value=False,
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
                    with gr.Row():
         | 
| 100 | 
            +
                        prefix = gr.Textbox(
         | 
| 101 | 
            +
                            label='Prefix to add to caption',
         | 
| 102 | 
            +
                            placeholder='(Optional)',
         | 
| 103 | 
            +
                            interactive=True,
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
                        caption_text = gr.Textbox(
         | 
| 106 | 
            +
                            label='Caption text',
         | 
| 107 | 
            +
                            placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
         | 
| 108 | 
            +
                            interactive=True,
         | 
| 109 | 
            +
                        )
         | 
| 110 | 
            +
                        postfix = gr.Textbox(
         | 
| 111 | 
            +
                            label='Postfix to add to caption',
         | 
| 112 | 
            +
                            placeholder='(Optional)',
         | 
| 113 | 
            +
                            interactive=True,
         | 
| 114 | 
            +
                        )
         | 
| 115 | 
            +
                    with gr.Row():
         | 
| 116 | 
            +
                        find_text = gr.Textbox(
         | 
| 117 | 
            +
                            label='Find text',
         | 
| 118 | 
            +
                            placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
         | 
| 119 | 
            +
                            interactive=True,
         | 
| 120 | 
            +
                        )
         | 
| 121 | 
            +
                        replace_text = gr.Textbox(
         | 
| 122 | 
            +
                            label='Replacement text',
         | 
| 123 | 
            +
                            placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing',
         | 
| 124 | 
            +
                            interactive=True,
         | 
| 125 | 
            +
                        )
         | 
| 126 | 
            +
                        caption_button = gr.Button('Caption images')
         | 
| 127 | 
            +
                        caption_button.click(
         | 
| 128 | 
            +
                            caption_images,
         | 
| 129 | 
            +
                            inputs=[
         | 
| 130 | 
            +
                                caption_text,
         | 
| 131 | 
            +
                                images_dir,
         | 
| 132 | 
            +
                                overwrite,
         | 
| 133 | 
            +
                                caption_ext,
         | 
| 134 | 
            +
                                prefix,
         | 
| 135 | 
            +
                                postfix,
         | 
| 136 | 
            +
                                find_text,
         | 
| 137 | 
            +
                                replace_text,
         | 
| 138 | 
            +
                            ],
         | 
| 139 | 
            +
                            show_progress=False,
         | 
| 140 | 
            +
                        )
         | 
    	
        library/blip_caption_gui.py
    ADDED
    
    | @@ -0,0 +1,149 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from .common_gui import get_folder_path, add_pre_postfix
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def caption_images(
         | 
| 11 | 
            +
                train_data_dir,
         | 
| 12 | 
            +
                caption_file_ext,
         | 
| 13 | 
            +
                batch_size,
         | 
| 14 | 
            +
                num_beams,
         | 
| 15 | 
            +
                top_p,
         | 
| 16 | 
            +
                max_length,
         | 
| 17 | 
            +
                min_length,
         | 
| 18 | 
            +
                beam_search,
         | 
| 19 | 
            +
                prefix,
         | 
| 20 | 
            +
                postfix,
         | 
| 21 | 
            +
            ):
         | 
| 22 | 
            +
                # Check for caption_text_input
         | 
| 23 | 
            +
                # if caption_text_input == "":
         | 
| 24 | 
            +
                #     msgbox("Caption text is missing...")
         | 
| 25 | 
            +
                #     return
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # Check for images_dir_input
         | 
| 28 | 
            +
                if train_data_dir == '':
         | 
| 29 | 
            +
                    msgbox('Image folder is missing...')
         | 
| 30 | 
            +
                    return
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if caption_file_ext == '':
         | 
| 33 | 
            +
                    msgbox('Please provide an extension for the caption files.')
         | 
| 34 | 
            +
                    return
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                print(f'Captioning files in {train_data_dir}...')
         | 
| 37 | 
            +
                run_cmd = f'{PYTHON} "finetune/make_captions.py"'
         | 
| 38 | 
            +
                run_cmd += f' --batch_size="{int(batch_size)}"'
         | 
| 39 | 
            +
                run_cmd += f' --num_beams="{int(num_beams)}"'
         | 
| 40 | 
            +
                run_cmd += f' --top_p="{top_p}"'
         | 
| 41 | 
            +
                run_cmd += f' --max_length="{int(max_length)}"'
         | 
| 42 | 
            +
                run_cmd += f' --min_length="{int(min_length)}"'
         | 
| 43 | 
            +
                if beam_search:
         | 
| 44 | 
            +
                    run_cmd += f' --beam_search'
         | 
| 45 | 
            +
                if caption_file_ext != '':
         | 
| 46 | 
            +
                    run_cmd += f' --caption_extension="{caption_file_ext}"'
         | 
| 47 | 
            +
                run_cmd += f' "{train_data_dir}"'
         | 
| 48 | 
            +
                run_cmd += f' --caption_weights="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"'
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                print(run_cmd)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                # Run the command
         | 
| 53 | 
            +
                if os.name == 'posix':
         | 
| 54 | 
            +
                    os.system(run_cmd)
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    subprocess.run(run_cmd)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # Add prefix and postfix
         | 
| 59 | 
            +
                add_pre_postfix(
         | 
| 60 | 
            +
                    folder=train_data_dir,
         | 
| 61 | 
            +
                    caption_file_ext=caption_file_ext,
         | 
| 62 | 
            +
                    prefix=prefix,
         | 
| 63 | 
            +
                    postfix=postfix,
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                print('...captioning done')
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            ###
         | 
| 70 | 
            +
            # Gradio UI
         | 
| 71 | 
            +
            ###
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def gradio_blip_caption_gui_tab():
         | 
| 75 | 
            +
                with gr.Tab('BLIP Captioning'):
         | 
| 76 | 
            +
                    gr.Markdown(
         | 
| 77 | 
            +
                        'This utility will use BLIP to caption files for each images in a folder.'
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
                    with gr.Row():
         | 
| 80 | 
            +
                        train_data_dir = gr.Textbox(
         | 
| 81 | 
            +
                            label='Image folder to caption',
         | 
| 82 | 
            +
                            placeholder='Directory containing the images to caption',
         | 
| 83 | 
            +
                            interactive=True,
         | 
| 84 | 
            +
                        )
         | 
| 85 | 
            +
                        button_train_data_dir_input = gr.Button(
         | 
| 86 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 87 | 
            +
                        )
         | 
| 88 | 
            +
                        button_train_data_dir_input.click(
         | 
| 89 | 
            +
                            get_folder_path,
         | 
| 90 | 
            +
                            outputs=train_data_dir,
         | 
| 91 | 
            +
                            show_progress=False,
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
                    with gr.Row():
         | 
| 94 | 
            +
                        caption_file_ext = gr.Textbox(
         | 
| 95 | 
            +
                            label='Caption file extension',
         | 
| 96 | 
            +
                            placeholder='Extention for caption file. eg: .caption, .txt',
         | 
| 97 | 
            +
                            value='.txt',
         | 
| 98 | 
            +
                            interactive=True,
         | 
| 99 | 
            +
                        )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        prefix = gr.Textbox(
         | 
| 102 | 
            +
                            label='Prefix to add to BLIP caption',
         | 
| 103 | 
            +
                            placeholder='(Optional)',
         | 
| 104 | 
            +
                            interactive=True,
         | 
| 105 | 
            +
                        )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                        postfix = gr.Textbox(
         | 
| 108 | 
            +
                            label='Postfix to add to BLIP caption',
         | 
| 109 | 
            +
                            placeholder='(Optional)',
         | 
| 110 | 
            +
                            interactive=True,
         | 
| 111 | 
            +
                        )
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                        batch_size = gr.Number(
         | 
| 114 | 
            +
                            value=1, label='Batch size', interactive=True
         | 
| 115 | 
            +
                        )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    with gr.Row():
         | 
| 118 | 
            +
                        beam_search = gr.Checkbox(
         | 
| 119 | 
            +
                            label='Use beam search', interactive=True, value=True
         | 
| 120 | 
            +
                        )
         | 
| 121 | 
            +
                        num_beams = gr.Number(
         | 
| 122 | 
            +
                            value=1, label='Number of beams', interactive=True
         | 
| 123 | 
            +
                        )
         | 
| 124 | 
            +
                        top_p = gr.Number(value=0.9, label='Top p', interactive=True)
         | 
| 125 | 
            +
                        max_length = gr.Number(
         | 
| 126 | 
            +
                            value=75, label='Max length', interactive=True
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
                        min_length = gr.Number(
         | 
| 129 | 
            +
                            value=5, label='Min length', interactive=True
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    caption_button = gr.Button('Caption images')
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    caption_button.click(
         | 
| 135 | 
            +
                        caption_images,
         | 
| 136 | 
            +
                        inputs=[
         | 
| 137 | 
            +
                            train_data_dir,
         | 
| 138 | 
            +
                            caption_file_ext,
         | 
| 139 | 
            +
                            batch_size,
         | 
| 140 | 
            +
                            num_beams,
         | 
| 141 | 
            +
                            top_p,
         | 
| 142 | 
            +
                            max_length,
         | 
| 143 | 
            +
                            min_length,
         | 
| 144 | 
            +
                            beam_search,
         | 
| 145 | 
            +
                            prefix,
         | 
| 146 | 
            +
                            postfix,
         | 
| 147 | 
            +
                        ],
         | 
| 148 | 
            +
                        show_progress=False,
         | 
| 149 | 
            +
                    )
         | 
    	
        library/common_gui.py
    ADDED
    
    | @@ -0,0 +1,978 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from tkinter import filedialog, Tk
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            import easygui
         | 
| 6 | 
            +
            import shutil
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 9 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 10 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 11 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # define a list of substrings to search for v2 base models
         | 
| 14 | 
            +
            V2_BASE_MODELS = [
         | 
| 15 | 
            +
                'stabilityai/stable-diffusion-2-1-base',
         | 
| 16 | 
            +
                'stabilityai/stable-diffusion-2-base',
         | 
| 17 | 
            +
            ]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # define a list of substrings to search for v_parameterization models
         | 
| 20 | 
            +
            V_PARAMETERIZATION_MODELS = [
         | 
| 21 | 
            +
                'stabilityai/stable-diffusion-2-1',
         | 
| 22 | 
            +
                'stabilityai/stable-diffusion-2',
         | 
| 23 | 
            +
            ]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # define a list of substrings to v1.x models
         | 
| 26 | 
            +
            V1_MODELS = [
         | 
| 27 | 
            +
                'CompVis/stable-diffusion-v1-4',
         | 
| 28 | 
            +
                'runwayml/stable-diffusion-v1-5',
         | 
| 29 | 
            +
            ]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            # define a list of substrings to search for
         | 
| 32 | 
            +
            ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def check_if_model_exist(output_name, output_dir, save_model_as):
         | 
| 38 | 
            +
                if save_model_as in ['diffusers', 'diffusers_safetendors']:
         | 
| 39 | 
            +
                    ckpt_folder = os.path.join(output_dir, output_name)
         | 
| 40 | 
            +
                    if os.path.isdir(ckpt_folder):
         | 
| 41 | 
            +
                        msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?'
         | 
| 42 | 
            +
                        if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
         | 
| 43 | 
            +
                            print(
         | 
| 44 | 
            +
                                'Aborting training due to existing model with same name...'
         | 
| 45 | 
            +
                            )
         | 
| 46 | 
            +
                            return True
         | 
| 47 | 
            +
                elif save_model_as in ['ckpt', 'safetensors']:
         | 
| 48 | 
            +
                    ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as)
         | 
| 49 | 
            +
                    if os.path.isfile(ckpt_file):
         | 
| 50 | 
            +
                        msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?'
         | 
| 51 | 
            +
                        if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
         | 
| 52 | 
            +
                            print(
         | 
| 53 | 
            +
                                'Aborting training due to existing model with same name...'
         | 
| 54 | 
            +
                            )
         | 
| 55 | 
            +
                            return True
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    print(
         | 
| 58 | 
            +
                        'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...'
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
                    return False
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                return False
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def update_my_data(my_data):
         | 
| 66 | 
            +
                # Update the optimizer based on the use_8bit_adam flag
         | 
| 67 | 
            +
                use_8bit_adam = my_data.get('use_8bit_adam', False)
         | 
| 68 | 
            +
                my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
         | 
| 71 | 
            +
                model_list = my_data.get('model_list', [])
         | 
| 72 | 
            +
                pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
         | 
| 73 | 
            +
                if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
         | 
| 74 | 
            +
                    my_data['model_list'] = 'custom'
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                # Convert epoch and save_every_n_epochs values to int if they are strings
         | 
| 77 | 
            +
                for key in ['epoch', 'save_every_n_epochs']:
         | 
| 78 | 
            +
                    value = my_data.get(key, -1)
         | 
| 79 | 
            +
                    if isinstance(value, str) and value.isdigit():
         | 
| 80 | 
            +
                        my_data[key] = int(value)
         | 
| 81 | 
            +
                    elif not value:
         | 
| 82 | 
            +
                        my_data[key] = -1
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                # Update LoRA_type if it is set to LoCon
         | 
| 85 | 
            +
                if my_data.get('LoRA_type', 'Standard') == 'LoCon':
         | 
| 86 | 
            +
                    my_data['LoRA_type'] = 'LyCORIS/LoCon'
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                # Update model save choices due to changes for LoRA and TI training
         | 
| 89 | 
            +
                if (
         | 
| 90 | 
            +
                    (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
         | 
| 91 | 
            +
                    and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
         | 
| 92 | 
            +
                ):
         | 
| 93 | 
            +
                    message = (
         | 
| 94 | 
            +
                        'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
                    if my_data.get('LoRA_type'):
         | 
| 97 | 
            +
                        print(message.format('LoRA'))
         | 
| 98 | 
            +
                    if my_data.get('num_vectors_per_token'):
         | 
| 99 | 
            +
                        print(message.format('TI'))
         | 
| 100 | 
            +
                    my_data['save_model_as'] = 'safetensors'
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                return my_data
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def get_dir_and_file(file_path):
         | 
| 106 | 
            +
                dir_path, file_name = os.path.split(file_path)
         | 
| 107 | 
            +
                return (dir_path, file_name)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            # def has_ext_files(directory, extension):
         | 
| 111 | 
            +
            #     # Iterate through all the files in the directory
         | 
| 112 | 
            +
            #     for file in os.listdir(directory):
         | 
| 113 | 
            +
            #         # If the file name ends with extension, return True
         | 
| 114 | 
            +
            #         if file.endswith(extension):
         | 
| 115 | 
            +
            #             return True
         | 
| 116 | 
            +
            #     # If no extension files were found, return False
         | 
| 117 | 
            +
            #     return False
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def get_file_path(
         | 
| 121 | 
            +
                file_path='', default_extension='.json', extension_name='Config files'
         | 
| 122 | 
            +
            ):
         | 
| 123 | 
            +
                if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
         | 
| 124 | 
            +
                    current_file_path = file_path
         | 
| 125 | 
            +
                    # print(f'current file path: {current_file_path}')
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    initial_dir, initial_file = get_dir_and_file(file_path)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Create a hidden Tkinter root window
         | 
| 130 | 
            +
                    root = Tk()
         | 
| 131 | 
            +
                    root.wm_attributes('-topmost', 1)
         | 
| 132 | 
            +
                    root.withdraw()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # Show the open file dialog and get the selected file path
         | 
| 135 | 
            +
                    file_path = filedialog.askopenfilename(
         | 
| 136 | 
            +
                        filetypes=(
         | 
| 137 | 
            +
                            (extension_name, f'*{default_extension}'),
         | 
| 138 | 
            +
                            ('All files', '*.*'),
         | 
| 139 | 
            +
                        ),
         | 
| 140 | 
            +
                        defaultextension=default_extension,
         | 
| 141 | 
            +
                        initialfile=initial_file,
         | 
| 142 | 
            +
                        initialdir=initial_dir,
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # Destroy the hidden root window
         | 
| 146 | 
            +
                    root.destroy()
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    # If no file is selected, use the current file path
         | 
| 149 | 
            +
                    if not file_path:
         | 
| 150 | 
            +
                        file_path = current_file_path
         | 
| 151 | 
            +
                    current_file_path = file_path
         | 
| 152 | 
            +
                    # print(f'current file path: {current_file_path}')
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                return file_path
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            def get_any_file_path(file_path=''):
         | 
| 158 | 
            +
                if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
         | 
| 159 | 
            +
                    current_file_path = file_path
         | 
| 160 | 
            +
                    # print(f'current file path: {current_file_path}')
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    initial_dir, initial_file = get_dir_and_file(file_path)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    root = Tk()
         | 
| 165 | 
            +
                    root.wm_attributes('-topmost', 1)
         | 
| 166 | 
            +
                    root.withdraw()
         | 
| 167 | 
            +
                    file_path = filedialog.askopenfilename(
         | 
| 168 | 
            +
                        initialdir=initial_dir,
         | 
| 169 | 
            +
                        initialfile=initial_file,
         | 
| 170 | 
            +
                    )
         | 
| 171 | 
            +
                    root.destroy()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    if file_path == '':
         | 
| 174 | 
            +
                        file_path = current_file_path
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                return file_path
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            def remove_doublequote(file_path):
         | 
| 180 | 
            +
                if file_path != None:
         | 
| 181 | 
            +
                    file_path = file_path.replace('"', '')
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                return file_path
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            # def set_legacy_8bitadam(optimizer, use_8bit_adam):
         | 
| 187 | 
            +
            #     if optimizer == 'AdamW8bit':
         | 
| 188 | 
            +
            #         # use_8bit_adam = True
         | 
| 189 | 
            +
            #         return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
         | 
| 190 | 
            +
            #             value=True, interactive=False, visible=True
         | 
| 191 | 
            +
            #         )
         | 
| 192 | 
            +
            #     else:
         | 
| 193 | 
            +
            #         # use_8bit_adam = False
         | 
| 194 | 
            +
            #         return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
         | 
| 195 | 
            +
            #             value=False, interactive=False, visible=True
         | 
| 196 | 
            +
            #         )
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            def get_folder_path(folder_path=''):
         | 
| 200 | 
            +
                if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
         | 
| 201 | 
            +
                    current_folder_path = folder_path
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    initial_dir, initial_file = get_dir_and_file(folder_path)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    root = Tk()
         | 
| 206 | 
            +
                    root.wm_attributes('-topmost', 1)
         | 
| 207 | 
            +
                    root.withdraw()
         | 
| 208 | 
            +
                    folder_path = filedialog.askdirectory(initialdir=initial_dir)
         | 
| 209 | 
            +
                    root.destroy()
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    if folder_path == '':
         | 
| 212 | 
            +
                        folder_path = current_folder_path
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                return folder_path
         | 
| 215 | 
            +
             | 
| 216 | 
            +
             | 
| 217 | 
            +
            def get_saveasfile_path(
         | 
| 218 | 
            +
                file_path='', defaultextension='.json', extension_name='Config files'
         | 
| 219 | 
            +
            ):
         | 
| 220 | 
            +
                if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
         | 
| 221 | 
            +
                    current_file_path = file_path
         | 
| 222 | 
            +
                    # print(f'current file path: {current_file_path}')
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    initial_dir, initial_file = get_dir_and_file(file_path)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    root = Tk()
         | 
| 227 | 
            +
                    root.wm_attributes('-topmost', 1)
         | 
| 228 | 
            +
                    root.withdraw()
         | 
| 229 | 
            +
                    save_file_path = filedialog.asksaveasfile(
         | 
| 230 | 
            +
                        filetypes=(
         | 
| 231 | 
            +
                            (f'{extension_name}', f'{defaultextension}'),
         | 
| 232 | 
            +
                            ('All files', '*'),
         | 
| 233 | 
            +
                        ),
         | 
| 234 | 
            +
                        defaultextension=defaultextension,
         | 
| 235 | 
            +
                        initialdir=initial_dir,
         | 
| 236 | 
            +
                        initialfile=initial_file,
         | 
| 237 | 
            +
                    )
         | 
| 238 | 
            +
                    root.destroy()
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    # print(save_file_path)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    if save_file_path == None:
         | 
| 243 | 
            +
                        file_path = current_file_path
         | 
| 244 | 
            +
                    else:
         | 
| 245 | 
            +
                        print(save_file_path.name)
         | 
| 246 | 
            +
                        file_path = save_file_path.name
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # print(file_path)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                return file_path
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
            def get_saveasfilename_path(
         | 
| 254 | 
            +
                file_path='', extensions='*', extension_name='Config files'
         | 
| 255 | 
            +
            ):
         | 
| 256 | 
            +
                if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
         | 
| 257 | 
            +
                    current_file_path = file_path
         | 
| 258 | 
            +
                    # print(f'current file path: {current_file_path}')
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    initial_dir, initial_file = get_dir_and_file(file_path)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    root = Tk()
         | 
| 263 | 
            +
                    root.wm_attributes('-topmost', 1)
         | 
| 264 | 
            +
                    root.withdraw()
         | 
| 265 | 
            +
                    save_file_path = filedialog.asksaveasfilename(
         | 
| 266 | 
            +
                        filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
         | 
| 267 | 
            +
                        defaultextension=extensions,
         | 
| 268 | 
            +
                        initialdir=initial_dir,
         | 
| 269 | 
            +
                        initialfile=initial_file,
         | 
| 270 | 
            +
                    )
         | 
| 271 | 
            +
                    root.destroy()
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    if save_file_path == '':
         | 
| 274 | 
            +
                        file_path = current_file_path
         | 
| 275 | 
            +
                    else:
         | 
| 276 | 
            +
                        # print(save_file_path)
         | 
| 277 | 
            +
                        file_path = save_file_path
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                return file_path
         | 
| 280 | 
            +
             | 
| 281 | 
            +
             | 
| 282 | 
            +
            def add_pre_postfix(
         | 
| 283 | 
            +
                folder: str = '',
         | 
| 284 | 
            +
                prefix: str = '',
         | 
| 285 | 
            +
                postfix: str = '',
         | 
| 286 | 
            +
                caption_file_ext: str = '.caption',
         | 
| 287 | 
            +
            ) -> None:
         | 
| 288 | 
            +
                """
         | 
| 289 | 
            +
                Add prefix and/or postfix to the content of caption files within a folder.
         | 
| 290 | 
            +
                If no caption files are found, create one with the requested prefix and/or postfix.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                Args:
         | 
| 293 | 
            +
                    folder (str): Path to the folder containing caption files.
         | 
| 294 | 
            +
                    prefix (str, optional): Prefix to add to the content of the caption files.
         | 
| 295 | 
            +
                    postfix (str, optional): Postfix to add to the content of the caption files.
         | 
| 296 | 
            +
                    caption_file_ext (str, optional): Extension of the caption files.
         | 
| 297 | 
            +
                """
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                if prefix == '' and postfix == '':
         | 
| 300 | 
            +
                    return
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
         | 
| 303 | 
            +
                image_files = [
         | 
| 304 | 
            +
                    f for f in os.listdir(folder) if f.lower().endswith(image_extensions)
         | 
| 305 | 
            +
                ]
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                for image_file in image_files:
         | 
| 308 | 
            +
                    caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
         | 
| 309 | 
            +
                    caption_file_path = os.path.join(folder, caption_file_name)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if not os.path.exists(caption_file_path):
         | 
| 312 | 
            +
                        with open(caption_file_path, 'w') as f:
         | 
| 313 | 
            +
                            separator = ' ' if prefix and postfix else ''
         | 
| 314 | 
            +
                            f.write(f'{prefix}{separator}{postfix}')
         | 
| 315 | 
            +
                    else:
         | 
| 316 | 
            +
                        with open(caption_file_path, 'r+') as f:
         | 
| 317 | 
            +
                            content = f.read()
         | 
| 318 | 
            +
                            content = content.rstrip()
         | 
| 319 | 
            +
                            f.seek(0, 0)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                            prefix_separator = ' ' if prefix else ''
         | 
| 322 | 
            +
                            postfix_separator = ' ' if postfix else ''
         | 
| 323 | 
            +
                            f.write(
         | 
| 324 | 
            +
                                f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}'
         | 
| 325 | 
            +
                            )
         | 
| 326 | 
            +
             | 
| 327 | 
            +
             | 
| 328 | 
            +
            def has_ext_files(folder_path: str, file_extension: str) -> bool:
         | 
| 329 | 
            +
                """
         | 
| 330 | 
            +
                Check if there are any files with the specified extension in the given folder.
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                Args:
         | 
| 333 | 
            +
                    folder_path (str): Path to the folder containing files.
         | 
| 334 | 
            +
                    file_extension (str): Extension of the files to look for.
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                Returns:
         | 
| 337 | 
            +
                    bool: True if files with the specified extension are found, False otherwise.
         | 
| 338 | 
            +
                """
         | 
| 339 | 
            +
                for file in os.listdir(folder_path):
         | 
| 340 | 
            +
                    if file.endswith(file_extension):
         | 
| 341 | 
            +
                        return True
         | 
| 342 | 
            +
                return False
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            def find_replace(
         | 
| 346 | 
            +
                folder_path: str = '',
         | 
| 347 | 
            +
                caption_file_ext: str = '.caption',
         | 
| 348 | 
            +
                search_text: str = '',
         | 
| 349 | 
            +
                replace_text: str = '',
         | 
| 350 | 
            +
            ) -> None:
         | 
| 351 | 
            +
                """
         | 
| 352 | 
            +
                Find and replace text in caption files within a folder.
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                Args:
         | 
| 355 | 
            +
                    folder_path (str, optional): Path to the folder containing caption files.
         | 
| 356 | 
            +
                    caption_file_ext (str, optional): Extension of the caption files.
         | 
| 357 | 
            +
                    search_text (str, optional): Text to search for in the caption files.
         | 
| 358 | 
            +
                    replace_text (str, optional): Text to replace the search text with.
         | 
| 359 | 
            +
                """
         | 
| 360 | 
            +
                print('Running caption find/replace')
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                if not has_ext_files(folder_path, caption_file_ext):
         | 
| 363 | 
            +
                    msgbox(
         | 
| 364 | 
            +
                        f'No files with extension {caption_file_ext} were found in {folder_path}...'
         | 
| 365 | 
            +
                    )
         | 
| 366 | 
            +
                    return
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                if search_text == '':
         | 
| 369 | 
            +
                    return
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                caption_files = [
         | 
| 372 | 
            +
                    f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)
         | 
| 373 | 
            +
                ]
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                for caption_file in caption_files:
         | 
| 376 | 
            +
                    with open(
         | 
| 377 | 
            +
                        os.path.join(folder_path, caption_file), 'r', errors='ignore'
         | 
| 378 | 
            +
                    ) as f:
         | 
| 379 | 
            +
                        content = f.read()
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    content = content.replace(search_text, replace_text)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    with open(os.path.join(folder_path, caption_file), 'w') as f:
         | 
| 384 | 
            +
                        f.write(content)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
             | 
| 387 | 
            +
            def color_aug_changed(color_aug):
         | 
| 388 | 
            +
                if color_aug:
         | 
| 389 | 
            +
                    msgbox(
         | 
| 390 | 
            +
                        'Disabling "Cache latent" because "Color augmentation" has been selected...'
         | 
| 391 | 
            +
                    )
         | 
| 392 | 
            +
                    return gr.Checkbox.update(value=False, interactive=False)
         | 
| 393 | 
            +
                else:
         | 
| 394 | 
            +
                    return gr.Checkbox.update(value=True, interactive=True)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
             | 
| 397 | 
            +
            def save_inference_file(output_dir, v2, v_parameterization, output_name):
         | 
| 398 | 
            +
                # List all files in the directory
         | 
| 399 | 
            +
                files = os.listdir(output_dir)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                # Iterate over the list of files
         | 
| 402 | 
            +
                for file in files:
         | 
| 403 | 
            +
                    # Check if the file starts with the value of output_name
         | 
| 404 | 
            +
                    if file.startswith(output_name):
         | 
| 405 | 
            +
                        # Check if it is a file or a directory
         | 
| 406 | 
            +
                        if os.path.isfile(os.path.join(output_dir, file)):
         | 
| 407 | 
            +
                            # Split the file name and extension
         | 
| 408 | 
            +
                            file_name, ext = os.path.splitext(file)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                            # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
         | 
| 411 | 
            +
                            if v2 and v_parameterization:
         | 
| 412 | 
            +
                                print(
         | 
| 413 | 
            +
                                    f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml'
         | 
| 414 | 
            +
                                )
         | 
| 415 | 
            +
                                shutil.copy(
         | 
| 416 | 
            +
                                    f'./v2_inference/v2-inference-v.yaml',
         | 
| 417 | 
            +
                                    f'{output_dir}/{file_name}.yaml',
         | 
| 418 | 
            +
                                )
         | 
| 419 | 
            +
                            elif v2:
         | 
| 420 | 
            +
                                print(
         | 
| 421 | 
            +
                                    f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml'
         | 
| 422 | 
            +
                                )
         | 
| 423 | 
            +
                                shutil.copy(
         | 
| 424 | 
            +
                                    f'./v2_inference/v2-inference.yaml',
         | 
| 425 | 
            +
                                    f'{output_dir}/{file_name}.yaml',
         | 
| 426 | 
            +
                                )
         | 
| 427 | 
            +
             | 
| 428 | 
            +
             | 
| 429 | 
            +
            def set_pretrained_model_name_or_path_input(
         | 
| 430 | 
            +
                model_list, pretrained_model_name_or_path, v2, v_parameterization
         | 
| 431 | 
            +
            ):
         | 
| 432 | 
            +
                # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
         | 
| 433 | 
            +
                if str(model_list) in V2_BASE_MODELS:
         | 
| 434 | 
            +
                    print('SD v2 model detected. Setting --v2 parameter')
         | 
| 435 | 
            +
                    v2 = True
         | 
| 436 | 
            +
                    v_parameterization = False
         | 
| 437 | 
            +
                    pretrained_model_name_or_path = str(model_list)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
         | 
| 440 | 
            +
                if str(model_list) in V_PARAMETERIZATION_MODELS:
         | 
| 441 | 
            +
                    print(
         | 
| 442 | 
            +
                        'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
         | 
| 443 | 
            +
                    )
         | 
| 444 | 
            +
                    v2 = True
         | 
| 445 | 
            +
                    v_parameterization = True
         | 
| 446 | 
            +
                    pretrained_model_name_or_path = str(model_list)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                if str(model_list) in V1_MODELS:
         | 
| 449 | 
            +
                    v2 = False
         | 
| 450 | 
            +
                    v_parameterization = False
         | 
| 451 | 
            +
                    pretrained_model_name_or_path = str(model_list)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                if model_list == 'custom':
         | 
| 454 | 
            +
                    if (
         | 
| 455 | 
            +
                        str(pretrained_model_name_or_path) in V1_MODELS
         | 
| 456 | 
            +
                        or str(pretrained_model_name_or_path) in V2_BASE_MODELS
         | 
| 457 | 
            +
                        or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS
         | 
| 458 | 
            +
                    ):
         | 
| 459 | 
            +
                        pretrained_model_name_or_path = ''
         | 
| 460 | 
            +
                        v2 = False
         | 
| 461 | 
            +
                        v_parameterization = False
         | 
| 462 | 
            +
                return model_list, pretrained_model_name_or_path, v2, v_parameterization
         | 
| 463 | 
            +
             | 
| 464 | 
            +
             | 
| 465 | 
            +
            def set_v2_checkbox(model_list, v2, v_parameterization):
         | 
| 466 | 
            +
                # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
         | 
| 467 | 
            +
                if str(model_list) in V2_BASE_MODELS:
         | 
| 468 | 
            +
                    v2 = True
         | 
| 469 | 
            +
                    v_parameterization = False
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
         | 
| 472 | 
            +
                if str(model_list) in V_PARAMETERIZATION_MODELS:
         | 
| 473 | 
            +
                    v2 = True
         | 
| 474 | 
            +
                    v_parameterization = True
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                if str(model_list) in V1_MODELS:
         | 
| 477 | 
            +
                    v2 = False
         | 
| 478 | 
            +
                    v_parameterization = False
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                return v2, v_parameterization
         | 
| 481 | 
            +
             | 
| 482 | 
            +
             | 
| 483 | 
            +
            def set_model_list(
         | 
| 484 | 
            +
                model_list,
         | 
| 485 | 
            +
                pretrained_model_name_or_path,
         | 
| 486 | 
            +
                v2,
         | 
| 487 | 
            +
                v_parameterization,
         | 
| 488 | 
            +
            ):
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
         | 
| 491 | 
            +
                    model_list = 'custom'
         | 
| 492 | 
            +
                else:
         | 
| 493 | 
            +
                    model_list = pretrained_model_name_or_path
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                return model_list, v2, v_parameterization
         | 
| 496 | 
            +
             | 
| 497 | 
            +
             | 
| 498 | 
            +
            ###
         | 
| 499 | 
            +
            ### Gradio common GUI section
         | 
| 500 | 
            +
            ###
         | 
| 501 | 
            +
             | 
| 502 | 
            +
             | 
| 503 | 
            +
            def gradio_config():
         | 
| 504 | 
            +
                with gr.Accordion('Configuration file', open=False):
         | 
| 505 | 
            +
                    with gr.Row():
         | 
| 506 | 
            +
                        button_open_config = gr.Button('Open 📂', elem_id='open_folder')
         | 
| 507 | 
            +
                        button_save_config = gr.Button('Save 💾', elem_id='open_folder')
         | 
| 508 | 
            +
                        button_save_as_config = gr.Button(
         | 
| 509 | 
            +
                            'Save as... 💾', elem_id='open_folder'
         | 
| 510 | 
            +
                        )
         | 
| 511 | 
            +
                        config_file_name = gr.Textbox(
         | 
| 512 | 
            +
                            label='',
         | 
| 513 | 
            +
                            placeholder="type the configuration file path or use the 'Open' button above to select it...",
         | 
| 514 | 
            +
                            interactive=True,
         | 
| 515 | 
            +
                        )
         | 
| 516 | 
            +
                        button_load_config = gr.Button('Load 💾', elem_id='open_folder')
         | 
| 517 | 
            +
                        config_file_name.change(
         | 
| 518 | 
            +
                            remove_doublequote,
         | 
| 519 | 
            +
                            inputs=[config_file_name],
         | 
| 520 | 
            +
                            outputs=[config_file_name],
         | 
| 521 | 
            +
                        )
         | 
| 522 | 
            +
                return (
         | 
| 523 | 
            +
                    button_open_config,
         | 
| 524 | 
            +
                    button_save_config,
         | 
| 525 | 
            +
                    button_save_as_config,
         | 
| 526 | 
            +
                    config_file_name,
         | 
| 527 | 
            +
                    button_load_config,
         | 
| 528 | 
            +
                )
         | 
| 529 | 
            +
             | 
| 530 | 
            +
             | 
| 531 | 
            +
            def get_pretrained_model_name_or_path_file(
         | 
| 532 | 
            +
                model_list, pretrained_model_name_or_path
         | 
| 533 | 
            +
            ):
         | 
| 534 | 
            +
                pretrained_model_name_or_path = get_any_file_path(
         | 
| 535 | 
            +
                    pretrained_model_name_or_path
         | 
| 536 | 
            +
                )
         | 
| 537 | 
            +
                set_model_list(model_list, pretrained_model_name_or_path)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
             | 
| 540 | 
            +
            def gradio_source_model(save_model_as_choices = [
         | 
| 541 | 
            +
                                'same as source model',
         | 
| 542 | 
            +
                                'ckpt',
         | 
| 543 | 
            +
                                'diffusers',
         | 
| 544 | 
            +
                                'diffusers_safetensors',
         | 
| 545 | 
            +
                                'safetensors',
         | 
| 546 | 
            +
                            ]):
         | 
| 547 | 
            +
                with gr.Tab('Source model'):
         | 
| 548 | 
            +
                    # Define the input elements
         | 
| 549 | 
            +
                    with gr.Row():
         | 
| 550 | 
            +
                        pretrained_model_name_or_path = gr.Textbox(
         | 
| 551 | 
            +
                            label='Pretrained model name or path',
         | 
| 552 | 
            +
                            placeholder='enter the path to custom model or name of pretrained model',
         | 
| 553 | 
            +
                            value='runwayml/stable-diffusion-v1-5',
         | 
| 554 | 
            +
                        )
         | 
| 555 | 
            +
                        pretrained_model_name_or_path_file = gr.Button(
         | 
| 556 | 
            +
                            document_symbol, elem_id='open_folder_small'
         | 
| 557 | 
            +
                        )
         | 
| 558 | 
            +
                        pretrained_model_name_or_path_file.click(
         | 
| 559 | 
            +
                            get_any_file_path,
         | 
| 560 | 
            +
                            inputs=pretrained_model_name_or_path,
         | 
| 561 | 
            +
                            outputs=pretrained_model_name_or_path,
         | 
| 562 | 
            +
                            show_progress=False,
         | 
| 563 | 
            +
                        )
         | 
| 564 | 
            +
                        pretrained_model_name_or_path_folder = gr.Button(
         | 
| 565 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 566 | 
            +
                        )
         | 
| 567 | 
            +
                        pretrained_model_name_or_path_folder.click(
         | 
| 568 | 
            +
                            get_folder_path,
         | 
| 569 | 
            +
                            inputs=pretrained_model_name_or_path,
         | 
| 570 | 
            +
                            outputs=pretrained_model_name_or_path,
         | 
| 571 | 
            +
                            show_progress=False,
         | 
| 572 | 
            +
                        )
         | 
| 573 | 
            +
                        model_list = gr.Dropdown(
         | 
| 574 | 
            +
                            label='Model Quick Pick',
         | 
| 575 | 
            +
                            choices=[
         | 
| 576 | 
            +
                                'custom',
         | 
| 577 | 
            +
                                'stabilityai/stable-diffusion-2-1-base',
         | 
| 578 | 
            +
                                'stabilityai/stable-diffusion-2-base',
         | 
| 579 | 
            +
                                'stabilityai/stable-diffusion-2-1',
         | 
| 580 | 
            +
                                'stabilityai/stable-diffusion-2',
         | 
| 581 | 
            +
                                'runwayml/stable-diffusion-v1-5',
         | 
| 582 | 
            +
                                'CompVis/stable-diffusion-v1-4',
         | 
| 583 | 
            +
                            ],
         | 
| 584 | 
            +
                            value='runwayml/stable-diffusion-v1-5',
         | 
| 585 | 
            +
                        )
         | 
| 586 | 
            +
                        save_model_as = gr.Dropdown(
         | 
| 587 | 
            +
                            label='Save trained model as',
         | 
| 588 | 
            +
                            choices=save_model_as_choices,
         | 
| 589 | 
            +
                            value='safetensors',
         | 
| 590 | 
            +
                        )
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    with gr.Row():
         | 
| 593 | 
            +
                        v2 = gr.Checkbox(label='v2', value=False)
         | 
| 594 | 
            +
                        v_parameterization = gr.Checkbox(
         | 
| 595 | 
            +
                            label='v_parameterization', value=False
         | 
| 596 | 
            +
                        )
         | 
| 597 | 
            +
                        v2.change(
         | 
| 598 | 
            +
                            set_v2_checkbox,
         | 
| 599 | 
            +
                            inputs=[model_list, v2, v_parameterization],
         | 
| 600 | 
            +
                            outputs=[v2, v_parameterization],
         | 
| 601 | 
            +
                            show_progress=False,
         | 
| 602 | 
            +
                        )
         | 
| 603 | 
            +
                        v_parameterization.change(
         | 
| 604 | 
            +
                            set_v2_checkbox,
         | 
| 605 | 
            +
                            inputs=[model_list, v2, v_parameterization],
         | 
| 606 | 
            +
                            outputs=[v2, v_parameterization],
         | 
| 607 | 
            +
                            show_progress=False,
         | 
| 608 | 
            +
                        )
         | 
| 609 | 
            +
                    model_list.change(
         | 
| 610 | 
            +
                        set_pretrained_model_name_or_path_input,
         | 
| 611 | 
            +
                        inputs=[
         | 
| 612 | 
            +
                            model_list,
         | 
| 613 | 
            +
                            pretrained_model_name_or_path,
         | 
| 614 | 
            +
                            v2,
         | 
| 615 | 
            +
                            v_parameterization,
         | 
| 616 | 
            +
                        ],
         | 
| 617 | 
            +
                        outputs=[
         | 
| 618 | 
            +
                            model_list,
         | 
| 619 | 
            +
                            pretrained_model_name_or_path,
         | 
| 620 | 
            +
                            v2,
         | 
| 621 | 
            +
                            v_parameterization,
         | 
| 622 | 
            +
                        ],
         | 
| 623 | 
            +
                        show_progress=False,
         | 
| 624 | 
            +
                    )
         | 
| 625 | 
            +
                    # Update the model list and parameters when user click outside the button or field
         | 
| 626 | 
            +
                    pretrained_model_name_or_path.change(
         | 
| 627 | 
            +
                        set_model_list,
         | 
| 628 | 
            +
                        inputs=[
         | 
| 629 | 
            +
                            model_list,
         | 
| 630 | 
            +
                            pretrained_model_name_or_path,
         | 
| 631 | 
            +
                            v2,
         | 
| 632 | 
            +
                            v_parameterization,
         | 
| 633 | 
            +
                        ],
         | 
| 634 | 
            +
                        outputs=[
         | 
| 635 | 
            +
                            model_list,
         | 
| 636 | 
            +
                            v2,
         | 
| 637 | 
            +
                            v_parameterization,
         | 
| 638 | 
            +
                        ],
         | 
| 639 | 
            +
                        show_progress=False,
         | 
| 640 | 
            +
                    )
         | 
| 641 | 
            +
                return (
         | 
| 642 | 
            +
                    pretrained_model_name_or_path,
         | 
| 643 | 
            +
                    v2,
         | 
| 644 | 
            +
                    v_parameterization,
         | 
| 645 | 
            +
                    save_model_as,
         | 
| 646 | 
            +
                    model_list,
         | 
| 647 | 
            +
                )
         | 
| 648 | 
            +
             | 
| 649 | 
            +
             | 
| 650 | 
            +
            def gradio_training(
         | 
| 651 | 
            +
                learning_rate_value='1e-6',
         | 
| 652 | 
            +
                lr_scheduler_value='constant',
         | 
| 653 | 
            +
                lr_warmup_value='0',
         | 
| 654 | 
            +
            ):
         | 
| 655 | 
            +
                with gr.Row():
         | 
| 656 | 
            +
                    train_batch_size = gr.Slider(
         | 
| 657 | 
            +
                        minimum=1,
         | 
| 658 | 
            +
                        maximum=64,
         | 
| 659 | 
            +
                        label='Train batch size',
         | 
| 660 | 
            +
                        value=1,
         | 
| 661 | 
            +
                        step=1,
         | 
| 662 | 
            +
                    )
         | 
| 663 | 
            +
                    epoch = gr.Number(label='Epoch', value=1, precision=0)
         | 
| 664 | 
            +
                    save_every_n_epochs = gr.Number(
         | 
| 665 | 
            +
                        label='Save every N epochs', value=1, precision=0
         | 
| 666 | 
            +
                    )
         | 
| 667 | 
            +
                    caption_extension = gr.Textbox(
         | 
| 668 | 
            +
                        label='Caption Extension',
         | 
| 669 | 
            +
                        placeholder='(Optional) Extension for caption files. default: .caption',
         | 
| 670 | 
            +
                    )
         | 
| 671 | 
            +
                with gr.Row():
         | 
| 672 | 
            +
                    mixed_precision = gr.Dropdown(
         | 
| 673 | 
            +
                        label='Mixed precision',
         | 
| 674 | 
            +
                        choices=[
         | 
| 675 | 
            +
                            'no',
         | 
| 676 | 
            +
                            'fp16',
         | 
| 677 | 
            +
                            'bf16',
         | 
| 678 | 
            +
                        ],
         | 
| 679 | 
            +
                        value='fp16',
         | 
| 680 | 
            +
                    )
         | 
| 681 | 
            +
                    save_precision = gr.Dropdown(
         | 
| 682 | 
            +
                        label='Save precision',
         | 
| 683 | 
            +
                        choices=[
         | 
| 684 | 
            +
                            'float',
         | 
| 685 | 
            +
                            'fp16',
         | 
| 686 | 
            +
                            'bf16',
         | 
| 687 | 
            +
                        ],
         | 
| 688 | 
            +
                        value='fp16',
         | 
| 689 | 
            +
                    )
         | 
| 690 | 
            +
                    num_cpu_threads_per_process = gr.Slider(
         | 
| 691 | 
            +
                        minimum=1,
         | 
| 692 | 
            +
                        maximum=os.cpu_count(),
         | 
| 693 | 
            +
                        step=1,
         | 
| 694 | 
            +
                        label='Number of CPU threads per core',
         | 
| 695 | 
            +
                        value=2,
         | 
| 696 | 
            +
                    )
         | 
| 697 | 
            +
                    seed = gr.Textbox(label='Seed', placeholder='(Optional) eg:1234')
         | 
| 698 | 
            +
                    cache_latents = gr.Checkbox(label='Cache latent', value=True)
         | 
| 699 | 
            +
                with gr.Row():
         | 
| 700 | 
            +
                    learning_rate = gr.Textbox(
         | 
| 701 | 
            +
                        label='Learning rate', value=learning_rate_value
         | 
| 702 | 
            +
                    )
         | 
| 703 | 
            +
                    lr_scheduler = gr.Dropdown(
         | 
| 704 | 
            +
                        label='LR Scheduler',
         | 
| 705 | 
            +
                        choices=[
         | 
| 706 | 
            +
                            'adafactor',
         | 
| 707 | 
            +
                            'constant',
         | 
| 708 | 
            +
                            'constant_with_warmup',
         | 
| 709 | 
            +
                            'cosine',
         | 
| 710 | 
            +
                            'cosine_with_restarts',
         | 
| 711 | 
            +
                            'linear',
         | 
| 712 | 
            +
                            'polynomial',
         | 
| 713 | 
            +
                        ],
         | 
| 714 | 
            +
                        value=lr_scheduler_value,
         | 
| 715 | 
            +
                    )
         | 
| 716 | 
            +
                    lr_warmup = gr.Textbox(
         | 
| 717 | 
            +
                        label='LR warmup (% of steps)', value=lr_warmup_value
         | 
| 718 | 
            +
                    )
         | 
| 719 | 
            +
                    optimizer = gr.Dropdown(
         | 
| 720 | 
            +
                        label='Optimizer',
         | 
| 721 | 
            +
                        choices=[
         | 
| 722 | 
            +
                            'AdamW',
         | 
| 723 | 
            +
                            'AdamW8bit',
         | 
| 724 | 
            +
                            'Adafactor',
         | 
| 725 | 
            +
                            'DAdaptation',
         | 
| 726 | 
            +
                            'Lion',
         | 
| 727 | 
            +
                            'SGDNesterov',
         | 
| 728 | 
            +
                            'SGDNesterov8bit',
         | 
| 729 | 
            +
                        ],
         | 
| 730 | 
            +
                        value='AdamW8bit',
         | 
| 731 | 
            +
                        interactive=True,
         | 
| 732 | 
            +
                    )
         | 
| 733 | 
            +
                with gr.Row():
         | 
| 734 | 
            +
                    optimizer_args = gr.Textbox(
         | 
| 735 | 
            +
                        label='Optimizer extra arguments',
         | 
| 736 | 
            +
                        placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True',
         | 
| 737 | 
            +
                    )
         | 
| 738 | 
            +
                return (
         | 
| 739 | 
            +
                    learning_rate,
         | 
| 740 | 
            +
                    lr_scheduler,
         | 
| 741 | 
            +
                    lr_warmup,
         | 
| 742 | 
            +
                    train_batch_size,
         | 
| 743 | 
            +
                    epoch,
         | 
| 744 | 
            +
                    save_every_n_epochs,
         | 
| 745 | 
            +
                    mixed_precision,
         | 
| 746 | 
            +
                    save_precision,
         | 
| 747 | 
            +
                    num_cpu_threads_per_process,
         | 
| 748 | 
            +
                    seed,
         | 
| 749 | 
            +
                    caption_extension,
         | 
| 750 | 
            +
                    cache_latents,
         | 
| 751 | 
            +
                    optimizer,
         | 
| 752 | 
            +
                    optimizer_args,
         | 
| 753 | 
            +
                )
         | 
| 754 | 
            +
             | 
| 755 | 
            +
             | 
| 756 | 
            +
            def run_cmd_training(**kwargs):
         | 
| 757 | 
            +
                options = [
         | 
| 758 | 
            +
                    f' --learning_rate="{kwargs.get("learning_rate", "")}"'
         | 
| 759 | 
            +
                    if kwargs.get('learning_rate')
         | 
| 760 | 
            +
                    else '',
         | 
| 761 | 
            +
                    f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
         | 
| 762 | 
            +
                    if kwargs.get('lr_scheduler')
         | 
| 763 | 
            +
                    else '',
         | 
| 764 | 
            +
                    f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
         | 
| 765 | 
            +
                    if kwargs.get('lr_warmup_steps')
         | 
| 766 | 
            +
                    else '',
         | 
| 767 | 
            +
                    f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
         | 
| 768 | 
            +
                    if kwargs.get('train_batch_size')
         | 
| 769 | 
            +
                    else '',
         | 
| 770 | 
            +
                    f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
         | 
| 771 | 
            +
                    if kwargs.get('max_train_steps')
         | 
| 772 | 
            +
                    else '',
         | 
| 773 | 
            +
                    f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"'
         | 
| 774 | 
            +
                    if int(kwargs.get('save_every_n_epochs'))
         | 
| 775 | 
            +
                    else '',
         | 
| 776 | 
            +
                    f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
         | 
| 777 | 
            +
                    if kwargs.get('mixed_precision')
         | 
| 778 | 
            +
                    else '',
         | 
| 779 | 
            +
                    f' --save_precision="{kwargs.get("save_precision", "")}"'
         | 
| 780 | 
            +
                    if kwargs.get('save_precision')
         | 
| 781 | 
            +
                    else '',
         | 
| 782 | 
            +
                    f' --seed="{kwargs.get("seed", "")}"'
         | 
| 783 | 
            +
                    if kwargs.get('seed') != ''
         | 
| 784 | 
            +
                    else '',
         | 
| 785 | 
            +
                    f' --caption_extension="{kwargs.get("caption_extension", "")}"'
         | 
| 786 | 
            +
                    if kwargs.get('caption_extension')
         | 
| 787 | 
            +
                    else '',
         | 
| 788 | 
            +
                    ' --cache_latents' if kwargs.get('cache_latents') else '',
         | 
| 789 | 
            +
                    # ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
         | 
| 790 | 
            +
                    f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"',
         | 
| 791 | 
            +
                    f' --optimizer_args {kwargs.get("optimizer_args", "")}'
         | 
| 792 | 
            +
                    if not kwargs.get('optimizer_args') == ''
         | 
| 793 | 
            +
                    else '',
         | 
| 794 | 
            +
                ]
         | 
| 795 | 
            +
                run_cmd = ''.join(options)
         | 
| 796 | 
            +
                return run_cmd
         | 
| 797 | 
            +
             | 
| 798 | 
            +
             | 
| 799 | 
            +
            def gradio_advanced_training():
         | 
| 800 | 
            +
                with gr.Row():
         | 
| 801 | 
            +
                    additional_parameters = gr.Textbox(
         | 
| 802 | 
            +
                        label='Additional parameters',
         | 
| 803 | 
            +
                        placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"',
         | 
| 804 | 
            +
                    )
         | 
| 805 | 
            +
                with gr.Row():
         | 
| 806 | 
            +
                    keep_tokens = gr.Slider(
         | 
| 807 | 
            +
                        label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
         | 
| 808 | 
            +
                    )
         | 
| 809 | 
            +
                    clip_skip = gr.Slider(
         | 
| 810 | 
            +
                        label='Clip skip', value='1', minimum=1, maximum=12, step=1
         | 
| 811 | 
            +
                    )
         | 
| 812 | 
            +
                    max_token_length = gr.Dropdown(
         | 
| 813 | 
            +
                        label='Max Token Length',
         | 
| 814 | 
            +
                        choices=[
         | 
| 815 | 
            +
                            '75',
         | 
| 816 | 
            +
                            '150',
         | 
| 817 | 
            +
                            '225',
         | 
| 818 | 
            +
                        ],
         | 
| 819 | 
            +
                        value='75',
         | 
| 820 | 
            +
                    )
         | 
| 821 | 
            +
                    full_fp16 = gr.Checkbox(
         | 
| 822 | 
            +
                        label='Full fp16 training (experimental)', value=False
         | 
| 823 | 
            +
                    )
         | 
| 824 | 
            +
                with gr.Row():
         | 
| 825 | 
            +
                    gradient_checkpointing = gr.Checkbox(
         | 
| 826 | 
            +
                        label='Gradient checkpointing', value=False
         | 
| 827 | 
            +
                    )
         | 
| 828 | 
            +
                    shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False)
         | 
| 829 | 
            +
                    persistent_data_loader_workers = gr.Checkbox(
         | 
| 830 | 
            +
                        label='Persistent data loader', value=False
         | 
| 831 | 
            +
                    )
         | 
| 832 | 
            +
                    mem_eff_attn = gr.Checkbox(
         | 
| 833 | 
            +
                        label='Memory efficient attention', value=False
         | 
| 834 | 
            +
                    )
         | 
| 835 | 
            +
                with gr.Row():
         | 
| 836 | 
            +
                    # This use_8bit_adam element should be removed in a future release as it is no longer used
         | 
| 837 | 
            +
                    # use_8bit_adam = gr.Checkbox(
         | 
| 838 | 
            +
                    #     label='Use 8bit adam', value=False, visible=False
         | 
| 839 | 
            +
                    # )
         | 
| 840 | 
            +
                    xformers = gr.Checkbox(label='Use xformers', value=True)
         | 
| 841 | 
            +
                    color_aug = gr.Checkbox(label='Color augmentation', value=False)
         | 
| 842 | 
            +
                    flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
         | 
| 843 | 
            +
                    min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1)
         | 
| 844 | 
            +
                with gr.Row():
         | 
| 845 | 
            +
                    bucket_no_upscale = gr.Checkbox(
         | 
| 846 | 
            +
                        label="Don't upscale bucket resolution", value=True
         | 
| 847 | 
            +
                    )
         | 
| 848 | 
            +
                    bucket_reso_steps = gr.Number(
         | 
| 849 | 
            +
                        label='Bucket resolution steps', value=64
         | 
| 850 | 
            +
                    )
         | 
| 851 | 
            +
                    random_crop = gr.Checkbox(
         | 
| 852 | 
            +
                        label='Random crop instead of center crop', value=False
         | 
| 853 | 
            +
                    )
         | 
| 854 | 
            +
                    noise_offset = gr.Textbox(
         | 
| 855 | 
            +
                        label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1'
         | 
| 856 | 
            +
                    )
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                with gr.Row():
         | 
| 859 | 
            +
                    caption_dropout_every_n_epochs = gr.Number(
         | 
| 860 | 
            +
                        label='Dropout caption every n epochs', value=0
         | 
| 861 | 
            +
                    )
         | 
| 862 | 
            +
                    caption_dropout_rate = gr.Slider(
         | 
| 863 | 
            +
                        label='Rate of caption dropout', value=0, minimum=0, maximum=1
         | 
| 864 | 
            +
                    )
         | 
| 865 | 
            +
                    vae_batch_size = gr.Slider(
         | 
| 866 | 
            +
                        label='VAE batch size',
         | 
| 867 | 
            +
                        minimum=0,
         | 
| 868 | 
            +
                        maximum=32,
         | 
| 869 | 
            +
                        value=0,
         | 
| 870 | 
            +
                        step=1
         | 
| 871 | 
            +
                    )
         | 
| 872 | 
            +
                with gr.Row():
         | 
| 873 | 
            +
                    save_state = gr.Checkbox(label='Save training state', value=False)
         | 
| 874 | 
            +
                    resume = gr.Textbox(
         | 
| 875 | 
            +
                        label='Resume from saved training state',
         | 
| 876 | 
            +
                        placeholder='path to "last-state" state folder to resume from',
         | 
| 877 | 
            +
                    )
         | 
| 878 | 
            +
                    resume_button = gr.Button('📂', elem_id='open_folder_small')
         | 
| 879 | 
            +
                    resume_button.click(
         | 
| 880 | 
            +
                        get_folder_path,
         | 
| 881 | 
            +
                        outputs=resume,
         | 
| 882 | 
            +
                        show_progress=False,
         | 
| 883 | 
            +
                    )
         | 
| 884 | 
            +
                    max_train_epochs = gr.Textbox(
         | 
| 885 | 
            +
                        label='Max train epoch',
         | 
| 886 | 
            +
                        placeholder='(Optional) Override number of epoch',
         | 
| 887 | 
            +
                    )
         | 
| 888 | 
            +
                    max_data_loader_n_workers = gr.Textbox(
         | 
| 889 | 
            +
                        label='Max num workers for DataLoader',
         | 
| 890 | 
            +
                        placeholder='(Optional) Override number of epoch. Default: 8',
         | 
| 891 | 
            +
                        value="0",
         | 
| 892 | 
            +
                    )
         | 
| 893 | 
            +
                return (
         | 
| 894 | 
            +
                    # use_8bit_adam,
         | 
| 895 | 
            +
                    xformers,
         | 
| 896 | 
            +
                    full_fp16,
         | 
| 897 | 
            +
                    gradient_checkpointing,
         | 
| 898 | 
            +
                    shuffle_caption,
         | 
| 899 | 
            +
                    color_aug,
         | 
| 900 | 
            +
                    flip_aug,
         | 
| 901 | 
            +
                    clip_skip,
         | 
| 902 | 
            +
                    mem_eff_attn,
         | 
| 903 | 
            +
                    save_state,
         | 
| 904 | 
            +
                    resume,
         | 
| 905 | 
            +
                    max_token_length,
         | 
| 906 | 
            +
                    max_train_epochs,
         | 
| 907 | 
            +
                    max_data_loader_n_workers,
         | 
| 908 | 
            +
                    keep_tokens,
         | 
| 909 | 
            +
                    persistent_data_loader_workers,
         | 
| 910 | 
            +
                    bucket_no_upscale,
         | 
| 911 | 
            +
                    random_crop,
         | 
| 912 | 
            +
                    bucket_reso_steps,
         | 
| 913 | 
            +
                    caption_dropout_every_n_epochs,
         | 
| 914 | 
            +
                    caption_dropout_rate,
         | 
| 915 | 
            +
                    noise_offset,
         | 
| 916 | 
            +
                    additional_parameters,
         | 
| 917 | 
            +
                    vae_batch_size,
         | 
| 918 | 
            +
                    min_snr_gamma,
         | 
| 919 | 
            +
                )
         | 
| 920 | 
            +
             | 
| 921 | 
            +
             | 
| 922 | 
            +
            def run_cmd_advanced_training(**kwargs):
         | 
| 923 | 
            +
                options = [
         | 
| 924 | 
            +
                    f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
         | 
| 925 | 
            +
                    if kwargs.get('max_train_epochs')
         | 
| 926 | 
            +
                    else '',
         | 
| 927 | 
            +
                    f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
         | 
| 928 | 
            +
                    if kwargs.get('max_data_loader_n_workers')
         | 
| 929 | 
            +
                    else '',
         | 
| 930 | 
            +
                    f' --max_token_length={kwargs.get("max_token_length", "")}'
         | 
| 931 | 
            +
                    if int(kwargs.get('max_token_length', 75)) > 75
         | 
| 932 | 
            +
                    else '',
         | 
| 933 | 
            +
                    f' --clip_skip={kwargs.get("clip_skip", "")}'
         | 
| 934 | 
            +
                    if int(kwargs.get('clip_skip', 1)) > 1
         | 
| 935 | 
            +
                    else '',
         | 
| 936 | 
            +
                    f' --resume="{kwargs.get("resume", "")}"'
         | 
| 937 | 
            +
                    if kwargs.get('resume')
         | 
| 938 | 
            +
                    else '',
         | 
| 939 | 
            +
                    f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
         | 
| 940 | 
            +
                    if int(kwargs.get('keep_tokens', 0)) > 0
         | 
| 941 | 
            +
                    else '',
         | 
| 942 | 
            +
                    f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
         | 
| 943 | 
            +
                    if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
         | 
| 944 | 
            +
                    else '',
         | 
| 945 | 
            +
                    f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
         | 
| 946 | 
            +
                    if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
         | 
| 947 | 
            +
                    else '',
         | 
| 948 | 
            +
                    f' --vae_batch_size="{kwargs.get("vae_batch_size", 0)}"'
         | 
| 949 | 
            +
                    if int(kwargs.get('vae_batch_size', 0)) > 0
         | 
| 950 | 
            +
                    else '',
         | 
| 951 | 
            +
                    f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
         | 
| 952 | 
            +
                    if int(kwargs.get('bucket_reso_steps', 64)) >= 1
         | 
| 953 | 
            +
                    else '',
         | 
| 954 | 
            +
                    f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}'
         | 
| 955 | 
            +
                    if int(kwargs.get('min_snr_gamma', 0)) >= 1
         | 
| 956 | 
            +
                    else '',
         | 
| 957 | 
            +
                    ' --save_state' if kwargs.get('save_state') else '',
         | 
| 958 | 
            +
                    ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
         | 
| 959 | 
            +
                    ' --color_aug' if kwargs.get('color_aug') else '',
         | 
| 960 | 
            +
                    ' --flip_aug' if kwargs.get('flip_aug') else '',
         | 
| 961 | 
            +
                    ' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
         | 
| 962 | 
            +
                    ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing')
         | 
| 963 | 
            +
                    else '',
         | 
| 964 | 
            +
                    ' --full_fp16' if kwargs.get('full_fp16') else '',
         | 
| 965 | 
            +
                    ' --xformers' if kwargs.get('xformers') else '',
         | 
| 966 | 
            +
                    # ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
         | 
| 967 | 
            +
                    ' --persistent_data_loader_workers'
         | 
| 968 | 
            +
                    if kwargs.get('persistent_data_loader_workers')
         | 
| 969 | 
            +
                    else '',
         | 
| 970 | 
            +
                    ' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '',
         | 
| 971 | 
            +
                    ' --random_crop' if kwargs.get('random_crop') else '',
         | 
| 972 | 
            +
                    f' --noise_offset={float(kwargs.get("noise_offset", 0))}'
         | 
| 973 | 
            +
                    if not kwargs.get('noise_offset', '') == ''
         | 
| 974 | 
            +
                    else '',
         | 
| 975 | 
            +
                    f' {kwargs.get("additional_parameters", "")}',
         | 
| 976 | 
            +
                ]
         | 
| 977 | 
            +
                run_cmd = ''.join(options)
         | 
| 978 | 
            +
                return run_cmd
         | 
    	
        library/config_util.py
    ADDED
    
    | @@ -0,0 +1,536 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            from dataclasses import (
         | 
| 3 | 
            +
              asdict,
         | 
| 4 | 
            +
              dataclass,
         | 
| 5 | 
            +
            )
         | 
| 6 | 
            +
            import functools
         | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            from textwrap import dedent, indent
         | 
| 9 | 
            +
            import json
         | 
| 10 | 
            +
            from pathlib import Path
         | 
| 11 | 
            +
            # from toolz import curry
         | 
| 12 | 
            +
            from typing import (
         | 
| 13 | 
            +
              List,
         | 
| 14 | 
            +
              Optional,
         | 
| 15 | 
            +
              Sequence,
         | 
| 16 | 
            +
              Tuple,
         | 
| 17 | 
            +
              Union,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import toml
         | 
| 21 | 
            +
            import voluptuous
         | 
| 22 | 
            +
            from voluptuous import (
         | 
| 23 | 
            +
              Any,
         | 
| 24 | 
            +
              ExactSequence,
         | 
| 25 | 
            +
              MultipleInvalid,
         | 
| 26 | 
            +
              Object,
         | 
| 27 | 
            +
              Required,
         | 
| 28 | 
            +
              Schema,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
            from transformers import CLIPTokenizer
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from . import train_util
         | 
| 33 | 
            +
            from .train_util import (
         | 
| 34 | 
            +
              DreamBoothSubset,
         | 
| 35 | 
            +
              FineTuningSubset,
         | 
| 36 | 
            +
              DreamBoothDataset,
         | 
| 37 | 
            +
              FineTuningDataset,
         | 
| 38 | 
            +
              DatasetGroup,
         | 
| 39 | 
            +
            )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def add_config_arguments(parser: argparse.ArgumentParser):
         | 
| 43 | 
            +
              parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            # TODO: inherit Params class in Subset, Dataset
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            @dataclass
         | 
| 48 | 
            +
            class BaseSubsetParams:
         | 
| 49 | 
            +
              image_dir: Optional[str] = None
         | 
| 50 | 
            +
              num_repeats: int = 1
         | 
| 51 | 
            +
              shuffle_caption: bool = False
         | 
| 52 | 
            +
              keep_tokens: int = 0
         | 
| 53 | 
            +
              color_aug: bool = False
         | 
| 54 | 
            +
              flip_aug: bool = False
         | 
| 55 | 
            +
              face_crop_aug_range: Optional[Tuple[float, float]] = None
         | 
| 56 | 
            +
              random_crop: bool = False
         | 
| 57 | 
            +
              caption_dropout_rate: float = 0.0
         | 
| 58 | 
            +
              caption_dropout_every_n_epochs: int = 0
         | 
| 59 | 
            +
              caption_tag_dropout_rate: float = 0.0
         | 
| 60 | 
            +
              token_warmup_min: int = 1
         | 
| 61 | 
            +
              token_warmup_step: float = 0
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            @dataclass
         | 
| 64 | 
            +
            class DreamBoothSubsetParams(BaseSubsetParams):
         | 
| 65 | 
            +
              is_reg: bool = False
         | 
| 66 | 
            +
              class_tokens: Optional[str] = None
         | 
| 67 | 
            +
              caption_extension: str = ".caption"
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            @dataclass
         | 
| 70 | 
            +
            class FineTuningSubsetParams(BaseSubsetParams):
         | 
| 71 | 
            +
              metadata_file: Optional[str] = None
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            @dataclass
         | 
| 74 | 
            +
            class BaseDatasetParams:
         | 
| 75 | 
            +
              tokenizer: CLIPTokenizer = None
         | 
| 76 | 
            +
              max_token_length: int = None
         | 
| 77 | 
            +
              resolution: Optional[Tuple[int, int]] = None
         | 
| 78 | 
            +
              debug_dataset: bool = False
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            @dataclass
         | 
| 81 | 
            +
            class DreamBoothDatasetParams(BaseDatasetParams):
         | 
| 82 | 
            +
              batch_size: int = 1
         | 
| 83 | 
            +
              enable_bucket: bool = False
         | 
| 84 | 
            +
              min_bucket_reso: int = 256
         | 
| 85 | 
            +
              max_bucket_reso: int = 1024
         | 
| 86 | 
            +
              bucket_reso_steps: int = 64
         | 
| 87 | 
            +
              bucket_no_upscale: bool = False
         | 
| 88 | 
            +
              prior_loss_weight: float = 1.0
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            @dataclass
         | 
| 91 | 
            +
            class FineTuningDatasetParams(BaseDatasetParams):
         | 
| 92 | 
            +
              batch_size: int = 1
         | 
| 93 | 
            +
              enable_bucket: bool = False
         | 
| 94 | 
            +
              min_bucket_reso: int = 256
         | 
| 95 | 
            +
              max_bucket_reso: int = 1024
         | 
| 96 | 
            +
              bucket_reso_steps: int = 64
         | 
| 97 | 
            +
              bucket_no_upscale: bool = False
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            @dataclass
         | 
| 100 | 
            +
            class SubsetBlueprint:
         | 
| 101 | 
            +
              params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            @dataclass
         | 
| 104 | 
            +
            class DatasetBlueprint:
         | 
| 105 | 
            +
              is_dreambooth: bool
         | 
| 106 | 
            +
              params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
         | 
| 107 | 
            +
              subsets: Sequence[SubsetBlueprint]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            @dataclass
         | 
| 110 | 
            +
            class DatasetGroupBlueprint:
         | 
| 111 | 
            +
              datasets: Sequence[DatasetBlueprint]
         | 
| 112 | 
            +
            @dataclass
         | 
| 113 | 
            +
            class Blueprint:
         | 
| 114 | 
            +
              dataset_group: DatasetGroupBlueprint
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            class ConfigSanitizer:
         | 
| 118 | 
            +
              # @curry
         | 
| 119 | 
            +
              @staticmethod
         | 
| 120 | 
            +
              def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
         | 
| 121 | 
            +
                Schema(ExactSequence([klass, klass]))(value)
         | 
| 122 | 
            +
                return tuple(value)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
              # @curry
         | 
| 125 | 
            +
              @staticmethod
         | 
| 126 | 
            +
              def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
         | 
| 127 | 
            +
                Schema(Any(klass, ExactSequence([klass, klass])))(value)
         | 
| 128 | 
            +
                try:
         | 
| 129 | 
            +
                  Schema(klass)(value)
         | 
| 130 | 
            +
                  return (value, value)
         | 
| 131 | 
            +
                except:
         | 
| 132 | 
            +
                  return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
              # subset schema
         | 
| 135 | 
            +
              SUBSET_ASCENDABLE_SCHEMA = {
         | 
| 136 | 
            +
                "color_aug": bool,
         | 
| 137 | 
            +
                "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
         | 
| 138 | 
            +
                "flip_aug": bool,
         | 
| 139 | 
            +
                "num_repeats": int,
         | 
| 140 | 
            +
                "random_crop": bool,
         | 
| 141 | 
            +
                "shuffle_caption": bool,
         | 
| 142 | 
            +
                "keep_tokens": int,
         | 
| 143 | 
            +
                "token_warmup_min": int,
         | 
| 144 | 
            +
                "token_warmup_step": Any(float,int),
         | 
| 145 | 
            +
              }
         | 
| 146 | 
            +
              # DO means DropOut
         | 
| 147 | 
            +
              DO_SUBSET_ASCENDABLE_SCHEMA = {
         | 
| 148 | 
            +
                "caption_dropout_every_n_epochs": int,
         | 
| 149 | 
            +
                "caption_dropout_rate": Any(float, int),
         | 
| 150 | 
            +
                "caption_tag_dropout_rate": Any(float, int),
         | 
| 151 | 
            +
              }
         | 
| 152 | 
            +
              # DB means DreamBooth
         | 
| 153 | 
            +
              DB_SUBSET_ASCENDABLE_SCHEMA = {
         | 
| 154 | 
            +
                "caption_extension": str,
         | 
| 155 | 
            +
                "class_tokens": str,
         | 
| 156 | 
            +
              }
         | 
| 157 | 
            +
              DB_SUBSET_DISTINCT_SCHEMA = {
         | 
| 158 | 
            +
                Required("image_dir"): str,
         | 
| 159 | 
            +
                "is_reg": bool,
         | 
| 160 | 
            +
              }
         | 
| 161 | 
            +
              # FT means FineTuning
         | 
| 162 | 
            +
              FT_SUBSET_DISTINCT_SCHEMA = {
         | 
| 163 | 
            +
                Required("metadata_file"): str,
         | 
| 164 | 
            +
                "image_dir": str,
         | 
| 165 | 
            +
              }
         | 
| 166 | 
            +
             | 
| 167 | 
            +
              # datasets schema
         | 
| 168 | 
            +
              DATASET_ASCENDABLE_SCHEMA = {
         | 
| 169 | 
            +
                "batch_size": int,
         | 
| 170 | 
            +
                "bucket_no_upscale": bool,
         | 
| 171 | 
            +
                "bucket_reso_steps": int,
         | 
| 172 | 
            +
                "enable_bucket": bool,
         | 
| 173 | 
            +
                "max_bucket_reso": int,
         | 
| 174 | 
            +
                "min_bucket_reso": int,
         | 
| 175 | 
            +
                "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
         | 
| 176 | 
            +
              }
         | 
| 177 | 
            +
             | 
| 178 | 
            +
              # options handled by argparse but not handled by user config
         | 
| 179 | 
            +
              ARGPARSE_SPECIFIC_SCHEMA = {
         | 
| 180 | 
            +
                "debug_dataset": bool,
         | 
| 181 | 
            +
                "max_token_length": Any(None, int),
         | 
| 182 | 
            +
                "prior_loss_weight": Any(float, int),
         | 
| 183 | 
            +
              }
         | 
| 184 | 
            +
              # for handling default None value of argparse
         | 
| 185 | 
            +
              ARGPARSE_NULLABLE_OPTNAMES = [
         | 
| 186 | 
            +
                "face_crop_aug_range",
         | 
| 187 | 
            +
                "resolution",
         | 
| 188 | 
            +
              ]
         | 
| 189 | 
            +
              # prepare map because option name may differ among argparse and user config
         | 
| 190 | 
            +
              ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
         | 
| 191 | 
            +
                "train_batch_size": "batch_size",
         | 
| 192 | 
            +
                "dataset_repeats": "num_repeats",
         | 
| 193 | 
            +
              }
         | 
| 194 | 
            +
             | 
| 195 | 
            +
              def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
         | 
| 196 | 
            +
                assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                self.db_subset_schema = self.__merge_dict(
         | 
| 199 | 
            +
                  self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 200 | 
            +
                  self.DB_SUBSET_DISTINCT_SCHEMA,
         | 
| 201 | 
            +
                  self.DB_SUBSET_ASCENDABLE_SCHEMA,
         | 
| 202 | 
            +
                  self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 203 | 
            +
                )
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                self.ft_subset_schema = self.__merge_dict(
         | 
| 206 | 
            +
                  self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 207 | 
            +
                  self.FT_SUBSET_DISTINCT_SCHEMA,
         | 
| 208 | 
            +
                  self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 209 | 
            +
                )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                self.db_dataset_schema = self.__merge_dict(
         | 
| 212 | 
            +
                  self.DATASET_ASCENDABLE_SCHEMA,
         | 
| 213 | 
            +
                  self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 214 | 
            +
                  self.DB_SUBSET_ASCENDABLE_SCHEMA,
         | 
| 215 | 
            +
                  self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 216 | 
            +
                  {"subsets": [self.db_subset_schema]},
         | 
| 217 | 
            +
                )
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                self.ft_dataset_schema = self.__merge_dict(
         | 
| 220 | 
            +
                  self.DATASET_ASCENDABLE_SCHEMA,
         | 
| 221 | 
            +
                  self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 222 | 
            +
                  self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 223 | 
            +
                  {"subsets": [self.ft_subset_schema]},
         | 
| 224 | 
            +
                )
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                if support_dreambooth and support_finetuning:
         | 
| 227 | 
            +
                  def validate_flex_dataset(dataset_config: dict):
         | 
| 228 | 
            +
                    subsets_config = dataset_config.get("subsets", [])
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # check dataset meets FT style
         | 
| 231 | 
            +
                    # NOTE: all FT subsets should have "metadata_file"
         | 
| 232 | 
            +
                    if all(["metadata_file" in subset for subset in subsets_config]):
         | 
| 233 | 
            +
                      return Schema(self.ft_dataset_schema)(dataset_config)
         | 
| 234 | 
            +
                    # check dataset meets DB style
         | 
| 235 | 
            +
                    # NOTE: all DB subsets should have no "metadata_file"
         | 
| 236 | 
            +
                    elif all(["metadata_file" not in subset for subset in subsets_config]):
         | 
| 237 | 
            +
                      return Schema(self.db_dataset_schema)(dataset_config)
         | 
| 238 | 
            +
                    else:
         | 
| 239 | 
            +
                      raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                  self.dataset_schema = validate_flex_dataset
         | 
| 242 | 
            +
                elif support_dreambooth:
         | 
| 243 | 
            +
                  self.dataset_schema = self.db_dataset_schema
         | 
| 244 | 
            +
                else:
         | 
| 245 | 
            +
                  self.dataset_schema = self.ft_dataset_schema
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                self.general_schema = self.__merge_dict(
         | 
| 248 | 
            +
                  self.DATASET_ASCENDABLE_SCHEMA,
         | 
| 249 | 
            +
                  self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 250 | 
            +
                  self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
         | 
| 251 | 
            +
                  self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 252 | 
            +
                )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                self.user_config_validator = Schema({
         | 
| 255 | 
            +
                  "general": self.general_schema,
         | 
| 256 | 
            +
                  "datasets": [self.dataset_schema],
         | 
| 257 | 
            +
                })
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                self.argparse_schema = self.__merge_dict(
         | 
| 260 | 
            +
                  self.general_schema,
         | 
| 261 | 
            +
                  self.ARGPARSE_SPECIFIC_SCHEMA,
         | 
| 262 | 
            +
                  {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
         | 
| 263 | 
            +
                  {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
         | 
| 264 | 
            +
                )
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
              def sanitize_user_config(self, user_config: dict) -> dict:
         | 
| 269 | 
            +
                try:
         | 
| 270 | 
            +
                  return self.user_config_validator(user_config)
         | 
| 271 | 
            +
                except MultipleInvalid:
         | 
| 272 | 
            +
                  # TODO: エラー発生時のメッセージをわかりやすくする
         | 
| 273 | 
            +
                  print("Invalid user config / ユーザ設定の形式が正しくないようです")
         | 
| 274 | 
            +
                  raise
         | 
| 275 | 
            +
             | 
| 276 | 
            +
              # NOTE: In nature, argument parser result is not needed to be sanitize
         | 
| 277 | 
            +
              #   However this will help us to detect program bug
         | 
| 278 | 
            +
              def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
         | 
| 279 | 
            +
                try:
         | 
| 280 | 
            +
                  return self.argparse_config_validator(argparse_namespace)
         | 
| 281 | 
            +
                except MultipleInvalid:
         | 
| 282 | 
            +
                  # XXX: this should be a bug
         | 
| 283 | 
            +
                  print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
         | 
| 284 | 
            +
                  raise
         | 
| 285 | 
            +
             | 
| 286 | 
            +
              # NOTE: value would be overwritten by latter dict if there is already the same key
         | 
| 287 | 
            +
              @staticmethod
         | 
| 288 | 
            +
              def __merge_dict(*dict_list: dict) -> dict:
         | 
| 289 | 
            +
                merged = {}
         | 
| 290 | 
            +
                for schema in dict_list:
         | 
| 291 | 
            +
                  # merged |= schema
         | 
| 292 | 
            +
                  for k, v in schema.items():
         | 
| 293 | 
            +
                    merged[k] = v
         | 
| 294 | 
            +
                return merged
         | 
| 295 | 
            +
             | 
| 296 | 
            +
             | 
| 297 | 
            +
            class BlueprintGenerator:
         | 
| 298 | 
            +
              BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
         | 
| 299 | 
            +
              }
         | 
| 300 | 
            +
             | 
| 301 | 
            +
              def __init__(self, sanitizer: ConfigSanitizer):
         | 
| 302 | 
            +
                self.sanitizer = sanitizer
         | 
| 303 | 
            +
             | 
| 304 | 
            +
              # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
         | 
| 305 | 
            +
              def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
         | 
| 306 | 
            +
                sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
         | 
| 307 | 
            +
                sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                # convert argparse namespace to dict like config
         | 
| 310 | 
            +
                # NOTE: it is ok to have extra entries in dict
         | 
| 311 | 
            +
                optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
         | 
| 312 | 
            +
                argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                general_config = sanitized_user_config.get("general", {})
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                dataset_blueprints = []
         | 
| 317 | 
            +
                for dataset_config in sanitized_user_config.get("datasets", []):
         | 
| 318 | 
            +
                  # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
         | 
| 319 | 
            +
                  subsets = dataset_config.get("subsets", [])
         | 
| 320 | 
            +
                  is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
         | 
| 321 | 
            +
                  if is_dreambooth:
         | 
| 322 | 
            +
                    subset_params_klass = DreamBoothSubsetParams
         | 
| 323 | 
            +
                    dataset_params_klass = DreamBoothDatasetParams
         | 
| 324 | 
            +
                  else:
         | 
| 325 | 
            +
                    subset_params_klass = FineTuningSubsetParams
         | 
| 326 | 
            +
                    dataset_params_klass = FineTuningDatasetParams
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                  subset_blueprints = []
         | 
| 329 | 
            +
                  for subset_config in subsets:
         | 
| 330 | 
            +
                    params = self.generate_params_by_fallbacks(subset_params_klass,
         | 
| 331 | 
            +
                                                               [subset_config, dataset_config, general_config, argparse_config, runtime_params])
         | 
| 332 | 
            +
                    subset_blueprints.append(SubsetBlueprint(params))
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                  params = self.generate_params_by_fallbacks(dataset_params_klass,
         | 
| 335 | 
            +
                                                             [dataset_config, general_config, argparse_config, runtime_params])
         | 
| 336 | 
            +
                  dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                return Blueprint(dataset_group_blueprint)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
              @staticmethod
         | 
| 343 | 
            +
              def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
         | 
| 344 | 
            +
                name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
         | 
| 345 | 
            +
                search_value = BlueprintGenerator.search_value
         | 
| 346 | 
            +
                default_params = asdict(param_klass())
         | 
| 347 | 
            +
                param_names = default_params.keys()
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                return param_klass(**params)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
              @staticmethod
         | 
| 354 | 
            +
              def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
         | 
| 355 | 
            +
                for cand in fallbacks:
         | 
| 356 | 
            +
                  value = cand.get(key)
         | 
| 357 | 
            +
                  if value is not None:
         | 
| 358 | 
            +
                    return value
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                return default_value
         | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
            def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
         | 
| 364 | 
            +
              datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
         | 
| 365 | 
            +
             | 
| 366 | 
            +
              for dataset_blueprint in dataset_group_blueprint.datasets:
         | 
| 367 | 
            +
                if dataset_blueprint.is_dreambooth:
         | 
| 368 | 
            +
                  subset_klass = DreamBoothSubset
         | 
| 369 | 
            +
                  dataset_klass = DreamBoothDataset
         | 
| 370 | 
            +
                else:
         | 
| 371 | 
            +
                  subset_klass = FineTuningSubset
         | 
| 372 | 
            +
                  dataset_klass = FineTuningDataset
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
         | 
| 375 | 
            +
                dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
         | 
| 376 | 
            +
                datasets.append(dataset)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
              # print info
         | 
| 379 | 
            +
              info = ""
         | 
| 380 | 
            +
              for i, dataset in enumerate(datasets):
         | 
| 381 | 
            +
                is_dreambooth = isinstance(dataset, DreamBoothDataset)
         | 
| 382 | 
            +
                info += dedent(f"""\
         | 
| 383 | 
            +
                  [Dataset {i}]
         | 
| 384 | 
            +
                    batch_size: {dataset.batch_size}
         | 
| 385 | 
            +
                    resolution: {(dataset.width, dataset.height)}
         | 
| 386 | 
            +
                    enable_bucket: {dataset.enable_bucket}
         | 
| 387 | 
            +
                """)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                if dataset.enable_bucket:
         | 
| 390 | 
            +
                  info += indent(dedent(f"""\
         | 
| 391 | 
            +
                    min_bucket_reso: {dataset.min_bucket_reso}
         | 
| 392 | 
            +
                    max_bucket_reso: {dataset.max_bucket_reso}
         | 
| 393 | 
            +
                    bucket_reso_steps: {dataset.bucket_reso_steps}
         | 
| 394 | 
            +
                    bucket_no_upscale: {dataset.bucket_no_upscale}
         | 
| 395 | 
            +
                  \n"""), "  ")
         | 
| 396 | 
            +
                else:
         | 
| 397 | 
            +
                  info += "\n"
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                for j, subset in enumerate(dataset.subsets):
         | 
| 400 | 
            +
                  info += indent(dedent(f"""\
         | 
| 401 | 
            +
                    [Subset {j} of Dataset {i}]
         | 
| 402 | 
            +
                      image_dir: "{subset.image_dir}"
         | 
| 403 | 
            +
                      image_count: {subset.img_count}
         | 
| 404 | 
            +
                      num_repeats: {subset.num_repeats}
         | 
| 405 | 
            +
                      shuffle_caption: {subset.shuffle_caption}
         | 
| 406 | 
            +
                      keep_tokens: {subset.keep_tokens}
         | 
| 407 | 
            +
                      caption_dropout_rate: {subset.caption_dropout_rate}
         | 
| 408 | 
            +
                      caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
         | 
| 409 | 
            +
                      caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
         | 
| 410 | 
            +
                      color_aug: {subset.color_aug}
         | 
| 411 | 
            +
                      flip_aug: {subset.flip_aug}
         | 
| 412 | 
            +
                      face_crop_aug_range: {subset.face_crop_aug_range}
         | 
| 413 | 
            +
                      random_crop: {subset.random_crop}
         | 
| 414 | 
            +
                      token_warmup_min: {subset.token_warmup_min},
         | 
| 415 | 
            +
                      token_warmup_step: {subset.token_warmup_step},
         | 
| 416 | 
            +
                  """), "  ")
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                  if is_dreambooth:
         | 
| 419 | 
            +
                    info += indent(dedent(f"""\
         | 
| 420 | 
            +
                      is_reg: {subset.is_reg}
         | 
| 421 | 
            +
                      class_tokens: {subset.class_tokens}
         | 
| 422 | 
            +
                      caption_extension: {subset.caption_extension}
         | 
| 423 | 
            +
                    \n"""), "    ")
         | 
| 424 | 
            +
                  else:
         | 
| 425 | 
            +
                    info += indent(dedent(f"""\
         | 
| 426 | 
            +
                      metadata_file: {subset.metadata_file}
         | 
| 427 | 
            +
                    \n"""), "    ")
         | 
| 428 | 
            +
             | 
| 429 | 
            +
              print(info)
         | 
| 430 | 
            +
             | 
| 431 | 
            +
              # make buckets first because it determines the length of dataset
         | 
| 432 | 
            +
              # and set the same seed for all datasets
         | 
| 433 | 
            +
              seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
         | 
| 434 | 
            +
              for i, dataset in enumerate(datasets):
         | 
| 435 | 
            +
                print(f"[Dataset {i}]")
         | 
| 436 | 
            +
                dataset.make_buckets()
         | 
| 437 | 
            +
                dataset.set_seed(seed)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
              return DatasetGroup(datasets)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
             | 
| 442 | 
            +
            def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
         | 
| 443 | 
            +
              def extract_dreambooth_params(name: str) -> Tuple[int, str]:
         | 
| 444 | 
            +
                tokens = name.split('_')
         | 
| 445 | 
            +
                try:
         | 
| 446 | 
            +
                  n_repeats = int(tokens[0])
         | 
| 447 | 
            +
                except ValueError as e:
         | 
| 448 | 
            +
                  print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
         | 
| 449 | 
            +
                  return 0, ""
         | 
| 450 | 
            +
                caption_by_folder = '_'.join(tokens[1:])
         | 
| 451 | 
            +
                return n_repeats, caption_by_folder
         | 
| 452 | 
            +
             | 
| 453 | 
            +
              def generate(base_dir: Optional[str], is_reg: bool):
         | 
| 454 | 
            +
                if base_dir is None:
         | 
| 455 | 
            +
                  return []
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                base_dir: Path = Path(base_dir)
         | 
| 458 | 
            +
                if not base_dir.is_dir():
         | 
| 459 | 
            +
                  return []
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                subsets_config = []
         | 
| 462 | 
            +
                for subdir in base_dir.iterdir():
         | 
| 463 | 
            +
                  if not subdir.is_dir():
         | 
| 464 | 
            +
                    continue
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                  num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
         | 
| 467 | 
            +
                  if num_repeats < 1:
         | 
| 468 | 
            +
                    continue
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                  subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
         | 
| 471 | 
            +
                  subsets_config.append(subset_config)
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                return subsets_config
         | 
| 474 | 
            +
             | 
| 475 | 
            +
              subsets_config = []
         | 
| 476 | 
            +
              subsets_config += generate(train_data_dir, False)
         | 
| 477 | 
            +
              subsets_config += generate(reg_data_dir, True)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
              return subsets_config
         | 
| 480 | 
            +
             | 
| 481 | 
            +
             | 
| 482 | 
            +
            def load_user_config(file: str) -> dict:
         | 
| 483 | 
            +
              file: Path = Path(file)
         | 
| 484 | 
            +
              if not file.is_file():
         | 
| 485 | 
            +
                raise ValueError(f"file not found / ファイルが見つかりません: {file}")
         | 
| 486 | 
            +
             | 
| 487 | 
            +
              if file.name.lower().endswith('.json'):
         | 
| 488 | 
            +
                try:
         | 
| 489 | 
            +
                  config = json.load(file)
         | 
| 490 | 
            +
                except Exception:
         | 
| 491 | 
            +
                  print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
         | 
| 492 | 
            +
                  raise
         | 
| 493 | 
            +
              elif file.name.lower().endswith('.toml'):
         | 
| 494 | 
            +
                try:
         | 
| 495 | 
            +
                  config = toml.load(file)
         | 
| 496 | 
            +
                except Exception:
         | 
| 497 | 
            +
                  print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
         | 
| 498 | 
            +
                  raise
         | 
| 499 | 
            +
              else:
         | 
| 500 | 
            +
                raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
         | 
| 501 | 
            +
             | 
| 502 | 
            +
              return config
         | 
| 503 | 
            +
             | 
| 504 | 
            +
            # for config test
         | 
| 505 | 
            +
            if __name__ == "__main__":
         | 
| 506 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 507 | 
            +
              parser.add_argument("--support_dreambooth", action="store_true")
         | 
| 508 | 
            +
              parser.add_argument("--support_finetuning", action="store_true")
         | 
| 509 | 
            +
              parser.add_argument("--support_dropout", action="store_true")
         | 
| 510 | 
            +
              parser.add_argument("dataset_config")
         | 
| 511 | 
            +
              config_args, remain = parser.parse_known_args()
         | 
| 512 | 
            +
             | 
| 513 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 514 | 
            +
              train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
         | 
| 515 | 
            +
              train_util.add_training_arguments(parser, config_args.support_dreambooth)
         | 
| 516 | 
            +
              argparse_namespace = parser.parse_args(remain)
         | 
| 517 | 
            +
              train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
         | 
| 518 | 
            +
             | 
| 519 | 
            +
              print("[argparse_namespace]")
         | 
| 520 | 
            +
              print(vars(argparse_namespace))
         | 
| 521 | 
            +
             | 
| 522 | 
            +
              user_config = load_user_config(config_args.dataset_config)
         | 
| 523 | 
            +
             | 
| 524 | 
            +
              print("\n[user_config]")
         | 
| 525 | 
            +
              print(user_config)
         | 
| 526 | 
            +
             | 
| 527 | 
            +
              sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
         | 
| 528 | 
            +
              sanitized_user_config = sanitizer.sanitize_user_config(user_config)
         | 
| 529 | 
            +
             | 
| 530 | 
            +
              print("\n[sanitized_user_config]")
         | 
| 531 | 
            +
              print(sanitized_user_config)
         | 
| 532 | 
            +
             | 
| 533 | 
            +
              blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
         | 
| 534 | 
            +
             | 
| 535 | 
            +
              print("\n[blueprint]")
         | 
| 536 | 
            +
              print(blueprint)
         | 
    	
        library/convert_model_gui.py
    ADDED
    
    | @@ -0,0 +1,247 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import shutil
         | 
| 6 | 
            +
            from .common_gui import get_folder_path, get_file_path
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 9 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 10 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 11 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 12 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def convert_model(
         | 
| 16 | 
            +
                source_model_input,
         | 
| 17 | 
            +
                source_model_type,
         | 
| 18 | 
            +
                target_model_folder_input,
         | 
| 19 | 
            +
                target_model_name_input,
         | 
| 20 | 
            +
                target_model_type,
         | 
| 21 | 
            +
                target_save_precision_type,
         | 
| 22 | 
            +
            ):
         | 
| 23 | 
            +
                # Check for caption_text_input
         | 
| 24 | 
            +
                if source_model_type == '':
         | 
| 25 | 
            +
                    msgbox('Invalid source model type')
         | 
| 26 | 
            +
                    return
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                # Check if source model exist
         | 
| 29 | 
            +
                if os.path.isfile(source_model_input):
         | 
| 30 | 
            +
                    print('The provided source model is a file')
         | 
| 31 | 
            +
                elif os.path.isdir(source_model_input):
         | 
| 32 | 
            +
                    print('The provided model is a folder')
         | 
| 33 | 
            +
                else:
         | 
| 34 | 
            +
                    msgbox('The provided source model is neither a file nor a folder')
         | 
| 35 | 
            +
                    return
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                # Check if source model exist
         | 
| 38 | 
            +
                if os.path.isdir(target_model_folder_input):
         | 
| 39 | 
            +
                    print('The provided model folder exist')
         | 
| 40 | 
            +
                else:
         | 
| 41 | 
            +
                    msgbox('The provided target folder does not exist')
         | 
| 42 | 
            +
                    return
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"'
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                v1_models = [
         | 
| 47 | 
            +
                    'runwayml/stable-diffusion-v1-5',
         | 
| 48 | 
            +
                    'CompVis/stable-diffusion-v1-4',
         | 
| 49 | 
            +
                ]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # check if v1 models
         | 
| 52 | 
            +
                if str(source_model_type) in v1_models:
         | 
| 53 | 
            +
                    print('SD v1 model specified. Setting --v1 parameter')
         | 
| 54 | 
            +
                    run_cmd += ' --v1'
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    print('SD v2 model specified. Setting --v2 parameter')
         | 
| 57 | 
            +
                    run_cmd += ' --v2'
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                if not target_save_precision_type == 'unspecified':
         | 
| 60 | 
            +
                    run_cmd += f' --{target_save_precision_type}'
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                if (
         | 
| 63 | 
            +
                    target_model_type == 'diffuser'
         | 
| 64 | 
            +
                    or target_model_type == 'diffuser_safetensors'
         | 
| 65 | 
            +
                ):
         | 
| 66 | 
            +
                    run_cmd += f' --reference_model="{source_model_type}"'
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                if target_model_type == 'diffuser_safetensors':
         | 
| 69 | 
            +
                    run_cmd += ' --use_safetensors'
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                run_cmd += f' "{source_model_input}"'
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                if (
         | 
| 74 | 
            +
                    target_model_type == 'diffuser'
         | 
| 75 | 
            +
                    or target_model_type == 'diffuser_safetensors'
         | 
| 76 | 
            +
                ):
         | 
| 77 | 
            +
                    target_model_path = os.path.join(
         | 
| 78 | 
            +
                        target_model_folder_input, target_model_name_input
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
                    run_cmd += f' "{target_model_path}"'
         | 
| 81 | 
            +
                else:
         | 
| 82 | 
            +
                    target_model_path = os.path.join(
         | 
| 83 | 
            +
                        target_model_folder_input,
         | 
| 84 | 
            +
                        f'{target_model_name_input}.{target_model_type}',
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    run_cmd += f' "{target_model_path}"'
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                print(run_cmd)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                # Run the command
         | 
| 91 | 
            +
                if os.name == 'posix':
         | 
| 92 | 
            +
                    os.system(run_cmd)
         | 
| 93 | 
            +
                else:
         | 
| 94 | 
            +
                    subprocess.run(run_cmd)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                if (
         | 
| 97 | 
            +
                    not target_model_type == 'diffuser'
         | 
| 98 | 
            +
                    or target_model_type == 'diffuser_safetensors'
         | 
| 99 | 
            +
                ):
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    v2_models = [
         | 
| 102 | 
            +
                        'stabilityai/stable-diffusion-2-1-base',
         | 
| 103 | 
            +
                        'stabilityai/stable-diffusion-2-base',
         | 
| 104 | 
            +
                    ]
         | 
| 105 | 
            +
                    v_parameterization = [
         | 
| 106 | 
            +
                        'stabilityai/stable-diffusion-2-1',
         | 
| 107 | 
            +
                        'stabilityai/stable-diffusion-2',
         | 
| 108 | 
            +
                    ]
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if str(source_model_type) in v2_models:
         | 
| 111 | 
            +
                        inference_file = os.path.join(
         | 
| 112 | 
            +
                            target_model_folder_input, f'{target_model_name_input}.yaml'
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
                        print(f'Saving v2-inference.yaml as {inference_file}')
         | 
| 115 | 
            +
                        shutil.copy(
         | 
| 116 | 
            +
                            f'./v2_inference/v2-inference.yaml',
         | 
| 117 | 
            +
                            f'{inference_file}',
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    if str(source_model_type) in v_parameterization:
         | 
| 121 | 
            +
                        inference_file = os.path.join(
         | 
| 122 | 
            +
                            target_model_folder_input, f'{target_model_name_input}.yaml'
         | 
| 123 | 
            +
                        )
         | 
| 124 | 
            +
                        print(f'Saving v2-inference-v.yaml as {inference_file}')
         | 
| 125 | 
            +
                        shutil.copy(
         | 
| 126 | 
            +
                            f'./v2_inference/v2-inference-v.yaml',
         | 
| 127 | 
            +
                            f'{inference_file}',
         | 
| 128 | 
            +
                        )
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            #   parser = argparse.ArgumentParser()
         | 
| 132 | 
            +
            #   parser.add_argument("--v1", action='store_true',
         | 
| 133 | 
            +
            #                       help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
         | 
| 134 | 
            +
            #   parser.add_argument("--v2", action='store_true',
         | 
| 135 | 
            +
            #                       help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
         | 
| 136 | 
            +
            #   parser.add_argument("--fp16", action='store_true',
         | 
| 137 | 
            +
            #                       help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
         | 
| 138 | 
            +
            #   parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
         | 
| 139 | 
            +
            #   parser.add_argument("--float", action='store_true',
         | 
| 140 | 
            +
            #                       help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
         | 
| 141 | 
            +
            #   parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記���するepoch数の値')
         | 
| 142 | 
            +
            #   parser.add_argument("--global_step", type=int, default=0,
         | 
| 143 | 
            +
            #                       help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
         | 
| 144 | 
            +
            #   parser.add_argument("--reference_model", type=str, default=None,
         | 
| 145 | 
            +
            #                       help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            #   parser.add_argument("model_to_load", type=str, default=None,
         | 
| 148 | 
            +
            #                       help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
         | 
| 149 | 
            +
            #   parser.add_argument("model_to_save", type=str, default=None,
         | 
| 150 | 
            +
            #                       help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            ###
         | 
| 154 | 
            +
            # Gradio UI
         | 
| 155 | 
            +
            ###
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def gradio_convert_model_tab():
         | 
| 159 | 
            +
                with gr.Tab('Convert model'):
         | 
| 160 | 
            +
                    gr.Markdown(
         | 
| 161 | 
            +
                        'This utility can be used to convert from one stable diffusion model format to another.'
         | 
| 162 | 
            +
                    )
         | 
| 163 | 
            +
                    with gr.Row():
         | 
| 164 | 
            +
                        source_model_input = gr.Textbox(
         | 
| 165 | 
            +
                            label='Source model',
         | 
| 166 | 
            +
                            placeholder='path to source model folder of file to convert...',
         | 
| 167 | 
            +
                            interactive=True,
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                        button_source_model_dir = gr.Button(
         | 
| 170 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 171 | 
            +
                        )
         | 
| 172 | 
            +
                        button_source_model_dir.click(
         | 
| 173 | 
            +
                            get_folder_path,
         | 
| 174 | 
            +
                            outputs=source_model_input,
         | 
| 175 | 
            +
                            show_progress=False,
         | 
| 176 | 
            +
                        )
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                        button_source_model_file = gr.Button(
         | 
| 179 | 
            +
                            document_symbol, elem_id='open_folder_small'
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                        button_source_model_file.click(
         | 
| 182 | 
            +
                            get_file_path,
         | 
| 183 | 
            +
                            inputs=[source_model_input],
         | 
| 184 | 
            +
                            outputs=source_model_input,
         | 
| 185 | 
            +
                            show_progress=False,
         | 
| 186 | 
            +
                        )
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                        source_model_type = gr.Dropdown(
         | 
| 189 | 
            +
                            label='Source model type',
         | 
| 190 | 
            +
                            choices=[
         | 
| 191 | 
            +
                                'stabilityai/stable-diffusion-2-1-base',
         | 
| 192 | 
            +
                                'stabilityai/stable-diffusion-2-base',
         | 
| 193 | 
            +
                                'stabilityai/stable-diffusion-2-1',
         | 
| 194 | 
            +
                                'stabilityai/stable-diffusion-2',
         | 
| 195 | 
            +
                                'runwayml/stable-diffusion-v1-5',
         | 
| 196 | 
            +
                                'CompVis/stable-diffusion-v1-4',
         | 
| 197 | 
            +
                            ],
         | 
| 198 | 
            +
                        )
         | 
| 199 | 
            +
                    with gr.Row():
         | 
| 200 | 
            +
                        target_model_folder_input = gr.Textbox(
         | 
| 201 | 
            +
                            label='Target model folder',
         | 
| 202 | 
            +
                            placeholder='path to target model folder of file name to create...',
         | 
| 203 | 
            +
                            interactive=True,
         | 
| 204 | 
            +
                        )
         | 
| 205 | 
            +
                        button_target_model_folder = gr.Button(
         | 
| 206 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 207 | 
            +
                        )
         | 
| 208 | 
            +
                        button_target_model_folder.click(
         | 
| 209 | 
            +
                            get_folder_path,
         | 
| 210 | 
            +
                            outputs=target_model_folder_input,
         | 
| 211 | 
            +
                            show_progress=False,
         | 
| 212 | 
            +
                        )
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        target_model_name_input = gr.Textbox(
         | 
| 215 | 
            +
                            label='Target model name',
         | 
| 216 | 
            +
                            placeholder='target model name...',
         | 
| 217 | 
            +
                            interactive=True,
         | 
| 218 | 
            +
                        )
         | 
| 219 | 
            +
                        target_model_type = gr.Dropdown(
         | 
| 220 | 
            +
                            label='Target model type',
         | 
| 221 | 
            +
                            choices=[
         | 
| 222 | 
            +
                                'diffuser',
         | 
| 223 | 
            +
                                'diffuser_safetensors',
         | 
| 224 | 
            +
                                'ckpt',
         | 
| 225 | 
            +
                                'safetensors',
         | 
| 226 | 
            +
                            ],
         | 
| 227 | 
            +
                        )
         | 
| 228 | 
            +
                        target_save_precision_type = gr.Dropdown(
         | 
| 229 | 
            +
                            label='Target model precision',
         | 
| 230 | 
            +
                            choices=['unspecified', 'fp16', 'bf16', 'float'],
         | 
| 231 | 
            +
                            value='unspecified',
         | 
| 232 | 
            +
                        )
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    convert_button = gr.Button('Convert model')
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    convert_button.click(
         | 
| 237 | 
            +
                        convert_model,
         | 
| 238 | 
            +
                        inputs=[
         | 
| 239 | 
            +
                            source_model_input,
         | 
| 240 | 
            +
                            source_model_type,
         | 
| 241 | 
            +
                            target_model_folder_input,
         | 
| 242 | 
            +
                            target_model_name_input,
         | 
| 243 | 
            +
                            target_model_type,
         | 
| 244 | 
            +
                            target_save_precision_type,
         | 
| 245 | 
            +
                        ],
         | 
| 246 | 
            +
                        show_progress=False,
         | 
| 247 | 
            +
                    )
         | 
    	
        library/custom_train_functions.py
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): 
         | 
| 5 | 
            +
              alphas_cumprod = noise_scheduler.alphas_cumprod
         | 
| 6 | 
            +
              sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
         | 
| 7 | 
            +
              sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
         | 
| 8 | 
            +
              alpha = sqrt_alphas_cumprod
         | 
| 9 | 
            +
              sigma = sqrt_one_minus_alphas_cumprod
         | 
| 10 | 
            +
              all_snr = (alpha / sigma) ** 2
         | 
| 11 | 
            +
              snr = torch.stack([all_snr[t] for t in timesteps])
         | 
| 12 | 
            +
              gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
         | 
| 13 | 
            +
              snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
         | 
| 14 | 
            +
              loss = loss * snr_weight
         | 
| 15 | 
            +
              return loss
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def add_custom_train_arguments(parser: argparse.ArgumentParser):
         | 
| 18 | 
            +
              parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")
         | 
    	
        library/dataset_balancing_gui.py
    ADDED
    
    | @@ -0,0 +1,146 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            from easygui import msgbox, boolbox
         | 
| 5 | 
            +
            from .common_gui import get_folder_path
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # def select_folder():
         | 
| 8 | 
            +
            #     # Open a file dialog to select a directory
         | 
| 9 | 
            +
            #     folder = filedialog.askdirectory()
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            #     # Update the GUI to display the selected folder
         | 
| 12 | 
            +
            #     selected_folder_label.config(text=folder)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def dataset_balancing(concept_repeats, folder, insecure):
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                if not concept_repeats > 0:
         | 
| 18 | 
            +
                    # Display an error message if the total number of repeats is not a valid integer
         | 
| 19 | 
            +
                    msgbox('Please enter a valid integer for the total number of repeats.')
         | 
| 20 | 
            +
                    return
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                concept_repeats = int(concept_repeats)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                # Check if folder exist
         | 
| 25 | 
            +
                if folder == '' or not os.path.isdir(folder):
         | 
| 26 | 
            +
                    msgbox('Please enter a valid folder for balancing.')
         | 
| 27 | 
            +
                    return
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                pattern = re.compile(r'^\d+_.+$')
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                # Iterate over the subdirectories in the selected folder
         | 
| 32 | 
            +
                for subdir in os.listdir(folder):
         | 
| 33 | 
            +
                    if pattern.match(subdir) or insecure:
         | 
| 34 | 
            +
                        # Calculate the number of repeats for the current subdirectory
         | 
| 35 | 
            +
                        # Get a list of all the files in the folder
         | 
| 36 | 
            +
                        files = os.listdir(os.path.join(folder, subdir))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                        # Filter the list to include only image files
         | 
| 39 | 
            +
                        image_files = [
         | 
| 40 | 
            +
                            f
         | 
| 41 | 
            +
                            for f in files
         | 
| 42 | 
            +
                            if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))
         | 
| 43 | 
            +
                        ]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        # Count the number of image files
         | 
| 46 | 
            +
                        images = len(image_files)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                        # Check if the subdirectory name starts with a number inside braces,
         | 
| 49 | 
            +
                        # indicating that the repeats value should be multiplied
         | 
| 50 | 
            +
                        match = re.match(r'^\{(\d+\.?\d*)\}', subdir)
         | 
| 51 | 
            +
                        if match:
         | 
| 52 | 
            +
                            # Multiply the repeats value by the number inside the braces
         | 
| 53 | 
            +
                            if not images == 0:
         | 
| 54 | 
            +
                                repeats = max(
         | 
| 55 | 
            +
                                    1,
         | 
| 56 | 
            +
                                    round(
         | 
| 57 | 
            +
                                        concept_repeats / images * float(match.group(1))
         | 
| 58 | 
            +
                                    ),
         | 
| 59 | 
            +
                                )
         | 
| 60 | 
            +
                            else:
         | 
| 61 | 
            +
                                repeats = 0
         | 
| 62 | 
            +
                            subdir = subdir[match.end() :]
         | 
| 63 | 
            +
                        else:
         | 
| 64 | 
            +
                            if not images == 0:
         | 
| 65 | 
            +
                                repeats = max(1, round(concept_repeats / images))
         | 
| 66 | 
            +
                            else:
         | 
| 67 | 
            +
                                repeats = 0
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                        # Check if the subdirectory name already has a number at the beginning
         | 
| 70 | 
            +
                        match = re.match(r'^\d+_', subdir)
         | 
| 71 | 
            +
                        if match:
         | 
| 72 | 
            +
                            # Replace the existing number with the new number
         | 
| 73 | 
            +
                            old_name = os.path.join(folder, subdir)
         | 
| 74 | 
            +
                            new_name = os.path.join(
         | 
| 75 | 
            +
                                folder, f'{repeats}_{subdir[match.end():]}'
         | 
| 76 | 
            +
                            )
         | 
| 77 | 
            +
                        else:
         | 
| 78 | 
            +
                            # Add the new number at the beginning of the name
         | 
| 79 | 
            +
                            old_name = os.path.join(folder, subdir)
         | 
| 80 | 
            +
                            new_name = os.path.join(folder, f'{repeats}_{subdir}')
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                        os.rename(old_name, new_name)
         | 
| 83 | 
            +
                    else:
         | 
| 84 | 
            +
                        print(
         | 
| 85 | 
            +
                            f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                msgbox('Dataset balancing completed...')
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def warning(insecure):
         | 
| 92 | 
            +
                if insecure:
         | 
| 93 | 
            +
                    if boolbox(
         | 
| 94 | 
            +
                        f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?',
         | 
| 95 | 
            +
                        choices=('Yes, I like danger', 'No, get me out of here'),
         | 
| 96 | 
            +
                    ):
         | 
| 97 | 
            +
                        return True
         | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        return False
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def gradio_dataset_balancing_tab():
         | 
| 103 | 
            +
                with gr.Tab('Dreambooth/LoRA Dataset balancing'):
         | 
| 104 | 
            +
                    gr.Markdown(
         | 
| 105 | 
            +
                        'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.'
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
                    gr.Markdown(
         | 
| 108 | 
            +
                        'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!'
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
                    with gr.Row():
         | 
| 111 | 
            +
                        select_dataset_folder_input = gr.Textbox(
         | 
| 112 | 
            +
                            label='Dataset folder',
         | 
| 113 | 
            +
                            placeholder='Folder containing the concepts folders to balance...',
         | 
| 114 | 
            +
                            interactive=True,
         | 
| 115 | 
            +
                        )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                        select_dataset_folder_button = gr.Button(
         | 
| 118 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
                        select_dataset_folder_button.click(
         | 
| 121 | 
            +
                            get_folder_path,
         | 
| 122 | 
            +
                            outputs=select_dataset_folder_input,
         | 
| 123 | 
            +
                            show_progress=False,
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        total_repeats_number = gr.Number(
         | 
| 127 | 
            +
                            value=1000,
         | 
| 128 | 
            +
                            interactive=True,
         | 
| 129 | 
            +
                            label='Training steps per concept per epoch',
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
                    with gr.Accordion('Advanced options', open=False):
         | 
| 132 | 
            +
                        insecure = gr.Checkbox(
         | 
| 133 | 
            +
                            value=False,
         | 
| 134 | 
            +
                            label='DANGER!!! -- Insecure folder renaming -- DANGER!!!',
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
                        insecure.change(warning, inputs=insecure, outputs=insecure)
         | 
| 137 | 
            +
                    balance_button = gr.Button('Balance dataset')
         | 
| 138 | 
            +
                    balance_button.click(
         | 
| 139 | 
            +
                        dataset_balancing,
         | 
| 140 | 
            +
                        inputs=[
         | 
| 141 | 
            +
                            total_repeats_number,
         | 
| 142 | 
            +
                            select_dataset_folder_input,
         | 
| 143 | 
            +
                            insecure,
         | 
| 144 | 
            +
                        ],
         | 
| 145 | 
            +
                        show_progress=False,
         | 
| 146 | 
            +
                    )
         | 
    	
        library/dreambooth_folder_creation_gui.py
    ADDED
    
    | @@ -0,0 +1,210 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import diropenbox, msgbox
         | 
| 3 | 
            +
            from .common_gui import get_folder_path
         | 
| 4 | 
            +
            import shutil
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def copy_info_to_Folders_tab(training_folder):
         | 
| 9 | 
            +
                img_folder = os.path.join(training_folder, 'img')
         | 
| 10 | 
            +
                if os.path.exists(os.path.join(training_folder, 'reg')):
         | 
| 11 | 
            +
                    reg_folder = os.path.join(training_folder, 'reg')
         | 
| 12 | 
            +
                else:
         | 
| 13 | 
            +
                    reg_folder = ''
         | 
| 14 | 
            +
                model_folder = os.path.join(training_folder, 'model')
         | 
| 15 | 
            +
                log_folder = os.path.join(training_folder, 'log')
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                return img_folder, reg_folder, model_folder, log_folder
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def dreambooth_folder_preparation(
         | 
| 21 | 
            +
                util_training_images_dir_input,
         | 
| 22 | 
            +
                util_training_images_repeat_input,
         | 
| 23 | 
            +
                util_instance_prompt_input,
         | 
| 24 | 
            +
                util_regularization_images_dir_input,
         | 
| 25 | 
            +
                util_regularization_images_repeat_input,
         | 
| 26 | 
            +
                util_class_prompt_input,
         | 
| 27 | 
            +
                util_training_dir_output,
         | 
| 28 | 
            +
            ):
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # Check if the input variables are empty
         | 
| 31 | 
            +
                if not len(util_training_dir_output):
         | 
| 32 | 
            +
                    print(
         | 
| 33 | 
            +
                        "Destination training directory is missing... can't perform the required task..."
         | 
| 34 | 
            +
                    )
         | 
| 35 | 
            +
                    return
         | 
| 36 | 
            +
                else:
         | 
| 37 | 
            +
                    # Create the util_training_dir_output directory if it doesn't exist
         | 
| 38 | 
            +
                    os.makedirs(util_training_dir_output, exist_ok=True)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # Check for instance prompt
         | 
| 41 | 
            +
                if util_instance_prompt_input == '':
         | 
| 42 | 
            +
                    msgbox('Instance prompt missing...')
         | 
| 43 | 
            +
                    return
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # Check for class prompt
         | 
| 46 | 
            +
                if util_class_prompt_input == '':
         | 
| 47 | 
            +
                    msgbox('Class prompt missing...')
         | 
| 48 | 
            +
                    return
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # Create the training_dir path
         | 
| 51 | 
            +
                if util_training_images_dir_input == '':
         | 
| 52 | 
            +
                    print(
         | 
| 53 | 
            +
                        "Training images directory is missing... can't perform the required task..."
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    return
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    training_dir = os.path.join(
         | 
| 58 | 
            +
                        util_training_dir_output,
         | 
| 59 | 
            +
                        f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}',
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # Remove folders if they exist
         | 
| 63 | 
            +
                    if os.path.exists(training_dir):
         | 
| 64 | 
            +
                        print(f'Removing existing directory {training_dir}...')
         | 
| 65 | 
            +
                        shutil.rmtree(training_dir)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    # Copy the training images to their respective directories
         | 
| 68 | 
            +
                    print(f'Copy {util_training_images_dir_input} to {training_dir}...')
         | 
| 69 | 
            +
                    shutil.copytree(util_training_images_dir_input, training_dir)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                if not util_regularization_images_dir_input == '':
         | 
| 72 | 
            +
                    # Create the regularization_dir path
         | 
| 73 | 
            +
                    if not util_regularization_images_repeat_input > 0:
         | 
| 74 | 
            +
                        print('Repeats is missing... not copying regularisation images...')
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        regularization_dir = os.path.join(
         | 
| 77 | 
            +
                            util_training_dir_output,
         | 
| 78 | 
            +
                            f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}',
         | 
| 79 | 
            +
                        )
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                        # Remove folders if they exist
         | 
| 82 | 
            +
                        if os.path.exists(regularization_dir):
         | 
| 83 | 
            +
                            print(f'Removing existing directory {regularization_dir}...')
         | 
| 84 | 
            +
                            shutil.rmtree(regularization_dir)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                        # Copy the regularisation images to their respective directories
         | 
| 87 | 
            +
                        print(
         | 
| 88 | 
            +
                            f'Copy {util_regularization_images_dir_input} to {regularization_dir}...'
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
                        shutil.copytree(
         | 
| 91 | 
            +
                            util_regularization_images_dir_input, regularization_dir
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
                else:
         | 
| 94 | 
            +
                    print(
         | 
| 95 | 
            +
                        'Regularization images directory is missing... not copying regularisation images...'
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                # create log and model folder
         | 
| 99 | 
            +
                # Check if the log folder exists and create it if it doesn't
         | 
| 100 | 
            +
                if not os.path.exists(os.path.join(util_training_dir_output, 'log')):
         | 
| 101 | 
            +
                    os.makedirs(os.path.join(util_training_dir_output, 'log'))
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # Check if the model folder exists and create it if it doesn't
         | 
| 104 | 
            +
                if not os.path.exists(os.path.join(util_training_dir_output, 'model')):
         | 
| 105 | 
            +
                    os.makedirs(os.path.join(util_training_dir_output, 'model'))
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                print(
         | 
| 108 | 
            +
                    f'Done creating kohya_ss training folder structure at {util_training_dir_output}...'
         | 
| 109 | 
            +
                )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            def gradio_dreambooth_folder_creation_tab(
         | 
| 113 | 
            +
                train_data_dir_input=gr.Textbox(),
         | 
| 114 | 
            +
                reg_data_dir_input=gr.Textbox(),
         | 
| 115 | 
            +
                output_dir_input=gr.Textbox(),
         | 
| 116 | 
            +
                logging_dir_input=gr.Textbox(),
         | 
| 117 | 
            +
            ):
         | 
| 118 | 
            +
                with gr.Tab('Dreambooth/LoRA Folder preparation'):
         | 
| 119 | 
            +
                    gr.Markdown(
         | 
| 120 | 
            +
                        'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.'
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
                    with gr.Row():
         | 
| 123 | 
            +
                        util_instance_prompt_input = gr.Textbox(
         | 
| 124 | 
            +
                            label='Instance prompt',
         | 
| 125 | 
            +
                            placeholder='Eg: asd',
         | 
| 126 | 
            +
                            interactive=True,
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
                        util_class_prompt_input = gr.Textbox(
         | 
| 129 | 
            +
                            label='Class prompt',
         | 
| 130 | 
            +
                            placeholder='Eg: person',
         | 
| 131 | 
            +
                            interactive=True,
         | 
| 132 | 
            +
                        )
         | 
| 133 | 
            +
                    with gr.Row():
         | 
| 134 | 
            +
                        util_training_images_dir_input = gr.Textbox(
         | 
| 135 | 
            +
                            label='Training images',
         | 
| 136 | 
            +
                            placeholder='Directory containing the training images',
         | 
| 137 | 
            +
                            interactive=True,
         | 
| 138 | 
            +
                        )
         | 
| 139 | 
            +
                        button_util_training_images_dir_input = gr.Button(
         | 
| 140 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 141 | 
            +
                        )
         | 
| 142 | 
            +
                        button_util_training_images_dir_input.click(
         | 
| 143 | 
            +
                            get_folder_path,
         | 
| 144 | 
            +
                            outputs=util_training_images_dir_input,
         | 
| 145 | 
            +
                            show_progress=False,
         | 
| 146 | 
            +
                        )
         | 
| 147 | 
            +
                        util_training_images_repeat_input = gr.Number(
         | 
| 148 | 
            +
                            label='Repeats',
         | 
| 149 | 
            +
                            value=40,
         | 
| 150 | 
            +
                            interactive=True,
         | 
| 151 | 
            +
                            elem_id='number_input',
         | 
| 152 | 
            +
                        )
         | 
| 153 | 
            +
                    with gr.Row():
         | 
| 154 | 
            +
                        util_regularization_images_dir_input = gr.Textbox(
         | 
| 155 | 
            +
                            label='Regularisation images',
         | 
| 156 | 
            +
                            placeholder='(Optional) Directory containing the regularisation images',
         | 
| 157 | 
            +
                            interactive=True,
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                        button_util_regularization_images_dir_input = gr.Button(
         | 
| 160 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 161 | 
            +
                        )
         | 
| 162 | 
            +
                        button_util_regularization_images_dir_input.click(
         | 
| 163 | 
            +
                            get_folder_path,
         | 
| 164 | 
            +
                            outputs=util_regularization_images_dir_input,
         | 
| 165 | 
            +
                            show_progress=False,
         | 
| 166 | 
            +
                        )
         | 
| 167 | 
            +
                        util_regularization_images_repeat_input = gr.Number(
         | 
| 168 | 
            +
                            label='Repeats',
         | 
| 169 | 
            +
                            value=1,
         | 
| 170 | 
            +
                            interactive=True,
         | 
| 171 | 
            +
                            elem_id='number_input',
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
                    with gr.Row():
         | 
| 174 | 
            +
                        util_training_dir_output = gr.Textbox(
         | 
| 175 | 
            +
                            label='Destination training directory',
         | 
| 176 | 
            +
                            placeholder='Directory where formatted training and regularisation folders will be placed',
         | 
| 177 | 
            +
                            interactive=True,
         | 
| 178 | 
            +
                        )
         | 
| 179 | 
            +
                        button_util_training_dir_output = gr.Button(
         | 
| 180 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 181 | 
            +
                        )
         | 
| 182 | 
            +
                        button_util_training_dir_output.click(
         | 
| 183 | 
            +
                            get_folder_path, outputs=util_training_dir_output
         | 
| 184 | 
            +
                        )
         | 
| 185 | 
            +
                    button_prepare_training_data = gr.Button('Prepare training data')
         | 
| 186 | 
            +
                    button_prepare_training_data.click(
         | 
| 187 | 
            +
                        dreambooth_folder_preparation,
         | 
| 188 | 
            +
                        inputs=[
         | 
| 189 | 
            +
                            util_training_images_dir_input,
         | 
| 190 | 
            +
                            util_training_images_repeat_input,
         | 
| 191 | 
            +
                            util_instance_prompt_input,
         | 
| 192 | 
            +
                            util_regularization_images_dir_input,
         | 
| 193 | 
            +
                            util_regularization_images_repeat_input,
         | 
| 194 | 
            +
                            util_class_prompt_input,
         | 
| 195 | 
            +
                            util_training_dir_output,
         | 
| 196 | 
            +
                        ],
         | 
| 197 | 
            +
                        show_progress=False,
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
                    button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
         | 
| 200 | 
            +
                    button_copy_info_to_Folders_tab.click(
         | 
| 201 | 
            +
                        copy_info_to_Folders_tab,
         | 
| 202 | 
            +
                        inputs=[util_training_dir_output],
         | 
| 203 | 
            +
                        outputs=[
         | 
| 204 | 
            +
                            train_data_dir_input,
         | 
| 205 | 
            +
                            reg_data_dir_input,
         | 
| 206 | 
            +
                            output_dir_input,
         | 
| 207 | 
            +
                            logging_dir_input,
         | 
| 208 | 
            +
                        ],
         | 
| 209 | 
            +
                        show_progress=False,
         | 
| 210 | 
            +
                    )
         | 
    	
        library/extract_lora_gui.py
    ADDED
    
    | @@ -0,0 +1,178 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from .common_gui import (
         | 
| 6 | 
            +
                get_saveasfilename_path,
         | 
| 7 | 
            +
                get_any_file_path,
         | 
| 8 | 
            +
                get_file_path,
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 12 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 13 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 14 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 15 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def extract_lora(
         | 
| 19 | 
            +
                model_tuned,
         | 
| 20 | 
            +
                model_org,
         | 
| 21 | 
            +
                save_to,
         | 
| 22 | 
            +
                save_precision,
         | 
| 23 | 
            +
                dim,
         | 
| 24 | 
            +
                v2,
         | 
| 25 | 
            +
                conv_dim,
         | 
| 26 | 
            +
                device,
         | 
| 27 | 
            +
            ):
         | 
| 28 | 
            +
                # Check for caption_text_input
         | 
| 29 | 
            +
                if model_tuned == '':
         | 
| 30 | 
            +
                    msgbox('Invalid finetuned model file')
         | 
| 31 | 
            +
                    return
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                if model_org == '':
         | 
| 34 | 
            +
                    msgbox('Invalid base model file')
         | 
| 35 | 
            +
                    return
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                # Check if source model exist
         | 
| 38 | 
            +
                if not os.path.isfile(model_tuned):
         | 
| 39 | 
            +
                    msgbox('The provided finetuned model is not a file')
         | 
| 40 | 
            +
                    return
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                if not os.path.isfile(model_org):
         | 
| 43 | 
            +
                    msgbox('The provided base model is not a file')
         | 
| 44 | 
            +
                    return
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                run_cmd = (
         | 
| 47 | 
            +
                    f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"'
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
                run_cmd += f' --save_precision {save_precision}'
         | 
| 50 | 
            +
                run_cmd += f' --save_to "{save_to}"'
         | 
| 51 | 
            +
                run_cmd += f' --model_org "{model_org}"'
         | 
| 52 | 
            +
                run_cmd += f' --model_tuned "{model_tuned}"'
         | 
| 53 | 
            +
                run_cmd += f' --dim {dim}'
         | 
| 54 | 
            +
                run_cmd += f' --device {device}'
         | 
| 55 | 
            +
                if conv_dim > 0:
         | 
| 56 | 
            +
                    run_cmd += f' --conv_dim {conv_dim}'
         | 
| 57 | 
            +
                if v2:
         | 
| 58 | 
            +
                    run_cmd += f' --v2'
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                print(run_cmd)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # Run the command
         | 
| 63 | 
            +
                if os.name == 'posix':
         | 
| 64 | 
            +
                    os.system(run_cmd)
         | 
| 65 | 
            +
                else:
         | 
| 66 | 
            +
                    subprocess.run(run_cmd)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            ###
         | 
| 70 | 
            +
            # Gradio UI
         | 
| 71 | 
            +
            ###
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def gradio_extract_lora_tab():
         | 
| 75 | 
            +
                with gr.Tab('Extract LoRA'):
         | 
| 76 | 
            +
                    gr.Markdown(
         | 
| 77 | 
            +
                        'This utility can extract a LoRA network from a finetuned model.'
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
                    lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
         | 
| 80 | 
            +
                    lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
         | 
| 81 | 
            +
                    model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
         | 
| 82 | 
            +
                    model_ext_name = gr.Textbox(value='Model types', visible=False)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    with gr.Row():
         | 
| 85 | 
            +
                        model_tuned = gr.Textbox(
         | 
| 86 | 
            +
                            label='Finetuned model',
         | 
| 87 | 
            +
                            placeholder='Path to the finetuned model to extract',
         | 
| 88 | 
            +
                            interactive=True,
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
                        button_model_tuned_file = gr.Button(
         | 
| 91 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
                        button_model_tuned_file.click(
         | 
| 94 | 
            +
                            get_file_path,
         | 
| 95 | 
            +
                            inputs=[model_tuned, model_ext, model_ext_name],
         | 
| 96 | 
            +
                            outputs=model_tuned,
         | 
| 97 | 
            +
                            show_progress=False,
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                        model_org = gr.Textbox(
         | 
| 101 | 
            +
                            label='Stable Diffusion base model',
         | 
| 102 | 
            +
                            placeholder='Stable Diffusion original model: ckpt or safetensors file',
         | 
| 103 | 
            +
                            interactive=True,
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
                        button_model_org_file = gr.Button(
         | 
| 106 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 107 | 
            +
                        )
         | 
| 108 | 
            +
                        button_model_org_file.click(
         | 
| 109 | 
            +
                            get_file_path,
         | 
| 110 | 
            +
                            inputs=[model_org, model_ext, model_ext_name],
         | 
| 111 | 
            +
                            outputs=model_org,
         | 
| 112 | 
            +
                            show_progress=False,
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
                    with gr.Row():
         | 
| 115 | 
            +
                        save_to = gr.Textbox(
         | 
| 116 | 
            +
                            label='Save to',
         | 
| 117 | 
            +
                            placeholder='path where to save the extracted LoRA model...',
         | 
| 118 | 
            +
                            interactive=True,
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
                        button_save_to = gr.Button(
         | 
| 121 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
                        button_save_to.click(
         | 
| 124 | 
            +
                            get_saveasfilename_path,
         | 
| 125 | 
            +
                            inputs=[save_to, lora_ext, lora_ext_name],
         | 
| 126 | 
            +
                            outputs=save_to,
         | 
| 127 | 
            +
                            show_progress=False,
         | 
| 128 | 
            +
                        )
         | 
| 129 | 
            +
                        save_precision = gr.Dropdown(
         | 
| 130 | 
            +
                            label='Save precision',
         | 
| 131 | 
            +
                            choices=['fp16', 'bf16', 'float'],
         | 
| 132 | 
            +
                            value='float',
         | 
| 133 | 
            +
                            interactive=True,
         | 
| 134 | 
            +
                        )
         | 
| 135 | 
            +
                    with gr.Row():
         | 
| 136 | 
            +
                        dim = gr.Slider(
         | 
| 137 | 
            +
                            minimum=4,
         | 
| 138 | 
            +
                            maximum=1024,
         | 
| 139 | 
            +
                            label='Network Dimension (Rank)',
         | 
| 140 | 
            +
                            value=128,
         | 
| 141 | 
            +
                            step=1,
         | 
| 142 | 
            +
                            interactive=True,
         | 
| 143 | 
            +
                        )
         | 
| 144 | 
            +
                        conv_dim = gr.Slider(
         | 
| 145 | 
            +
                            minimum=0,
         | 
| 146 | 
            +
                            maximum=1024,
         | 
| 147 | 
            +
                            label='Conv Dimension (Rank)',
         | 
| 148 | 
            +
                            value=128,
         | 
| 149 | 
            +
                            step=1,
         | 
| 150 | 
            +
                            interactive=True,
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                        v2 = gr.Checkbox(label='v2', value=False, interactive=True)
         | 
| 153 | 
            +
                        device = gr.Dropdown(
         | 
| 154 | 
            +
                            label='Device',
         | 
| 155 | 
            +
                            choices=[
         | 
| 156 | 
            +
                                'cpu',
         | 
| 157 | 
            +
                                'cuda',
         | 
| 158 | 
            +
                            ],
         | 
| 159 | 
            +
                            value='cuda',
         | 
| 160 | 
            +
                            interactive=True,
         | 
| 161 | 
            +
                        )
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    extract_button = gr.Button('Extract LoRA model')
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    extract_button.click(
         | 
| 166 | 
            +
                        extract_lora,
         | 
| 167 | 
            +
                        inputs=[
         | 
| 168 | 
            +
                            model_tuned,
         | 
| 169 | 
            +
                            model_org,
         | 
| 170 | 
            +
                            save_to,
         | 
| 171 | 
            +
                            save_precision,
         | 
| 172 | 
            +
                            dim,
         | 
| 173 | 
            +
                            v2,
         | 
| 174 | 
            +
                            conv_dim,
         | 
| 175 | 
            +
                            device
         | 
| 176 | 
            +
                        ],
         | 
| 177 | 
            +
                        show_progress=False,
         | 
| 178 | 
            +
                    )
         | 
    	
        library/extract_lycoris_locon_gui.py
    ADDED
    
    | @@ -0,0 +1,309 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from .common_gui import (
         | 
| 6 | 
            +
                get_saveasfilename_path,
         | 
| 7 | 
            +
                get_any_file_path,
         | 
| 8 | 
            +
                get_file_path,
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 12 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 13 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 14 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 15 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def extract_lycoris_locon(
         | 
| 19 | 
            +
                db_model,
         | 
| 20 | 
            +
                base_model,
         | 
| 21 | 
            +
                output_name,
         | 
| 22 | 
            +
                device,
         | 
| 23 | 
            +
                is_v2,
         | 
| 24 | 
            +
                mode,
         | 
| 25 | 
            +
                linear_dim,
         | 
| 26 | 
            +
                conv_dim,
         | 
| 27 | 
            +
                linear_threshold,
         | 
| 28 | 
            +
                conv_threshold,
         | 
| 29 | 
            +
                linear_ratio,
         | 
| 30 | 
            +
                conv_ratio,
         | 
| 31 | 
            +
                linear_quantile,
         | 
| 32 | 
            +
                conv_quantile,
         | 
| 33 | 
            +
                use_sparse_bias,
         | 
| 34 | 
            +
                sparsity,
         | 
| 35 | 
            +
                disable_cp,
         | 
| 36 | 
            +
            ):
         | 
| 37 | 
            +
                # Check for caption_text_input
         | 
| 38 | 
            +
                if db_model == '':
         | 
| 39 | 
            +
                    msgbox('Invalid finetuned model file')
         | 
| 40 | 
            +
                    return
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                if base_model == '':
         | 
| 43 | 
            +
                    msgbox('Invalid base model file')
         | 
| 44 | 
            +
                    return
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                # Check if source model exist
         | 
| 47 | 
            +
                if not os.path.isfile(db_model):
         | 
| 48 | 
            +
                    msgbox('The provided finetuned model is not a file')
         | 
| 49 | 
            +
                    return
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if not os.path.isfile(base_model):
         | 
| 52 | 
            +
                    msgbox('The provided base model is not a file')
         | 
| 53 | 
            +
                    return
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
         | 
| 56 | 
            +
                if is_v2:
         | 
| 57 | 
            +
                    run_cmd += f' --is_v2'
         | 
| 58 | 
            +
                run_cmd += f' --device {device}'
         | 
| 59 | 
            +
                run_cmd += f' --mode {mode}'
         | 
| 60 | 
            +
                run_cmd += f' --safetensors'
         | 
| 61 | 
            +
                run_cmd += f' --linear_dim {linear_dim}'
         | 
| 62 | 
            +
                run_cmd += f' --conv_dim {conv_dim}'
         | 
| 63 | 
            +
                run_cmd += f' --linear_threshold {linear_threshold}'
         | 
| 64 | 
            +
                run_cmd += f' --conv_threshold {conv_threshold}'
         | 
| 65 | 
            +
                run_cmd += f' --linear_ratio {linear_ratio}'
         | 
| 66 | 
            +
                run_cmd += f' --conv_ratio {conv_ratio}'
         | 
| 67 | 
            +
                run_cmd += f' --linear_quantile {linear_quantile}'
         | 
| 68 | 
            +
                run_cmd += f' --conv_quantile {conv_quantile}'
         | 
| 69 | 
            +
                if use_sparse_bias:
         | 
| 70 | 
            +
                    run_cmd += f' --use_sparse_bias'
         | 
| 71 | 
            +
                run_cmd += f' --sparsity {sparsity}'
         | 
| 72 | 
            +
                if disable_cp:
         | 
| 73 | 
            +
                    run_cmd += f' --disable_cp'
         | 
| 74 | 
            +
                run_cmd += f' "{base_model}"'
         | 
| 75 | 
            +
                run_cmd += f' "{db_model}"'
         | 
| 76 | 
            +
                run_cmd += f' "{output_name}"'
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                print(run_cmd)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                # Run the command
         | 
| 81 | 
            +
                if os.name == 'posix':
         | 
| 82 | 
            +
                    os.system(run_cmd)
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    subprocess.run(run_cmd)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            ###
         | 
| 88 | 
            +
            # Gradio UI
         | 
| 89 | 
            +
            ###
         | 
| 90 | 
            +
            # def update_mode(mode):
         | 
| 91 | 
            +
            #     # 'fixed', 'threshold','ratio','quantile'
         | 
| 92 | 
            +
            #     if mode == 'fixed':
         | 
| 93 | 
            +
            #         return gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False)
         | 
| 94 | 
            +
            #     if mode == 'threshold':
         | 
| 95 | 
            +
            #         return gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False)
         | 
| 96 | 
            +
            #     if mode == 'ratio':
         | 
| 97 | 
            +
            #         return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False)
         | 
| 98 | 
            +
            #     if mode == 'threshold':
         | 
| 99 | 
            +
            #         return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def update_mode(mode):
         | 
| 103 | 
            +
                # Create a list of possible mode values
         | 
| 104 | 
            +
                modes = ['fixed', 'threshold', 'ratio', 'quantile']
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # Initialize an empty list to store visibility updates
         | 
| 107 | 
            +
                updates = []
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                # Iterate through the possible modes
         | 
| 110 | 
            +
                for m in modes:
         | 
| 111 | 
            +
                    # Add a visibility update for each mode, setting it to True if the input mode matches the current mode in the loop
         | 
| 112 | 
            +
                    updates.append(gr.Row.update(visible=(mode == m)))
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                # Return the visibility updates as a tuple
         | 
| 115 | 
            +
                return tuple(updates)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def gradio_extract_lycoris_locon_tab():
         | 
| 119 | 
            +
                with gr.Tab('Extract LyCORIS LoCON'):
         | 
| 120 | 
            +
                    gr.Markdown(
         | 
| 121 | 
            +
                        'This utility can extract a LyCORIS LoCon network from a finetuned model.'
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
                    lora_ext = gr.Textbox(
         | 
| 124 | 
            +
                        value='*.safetensors', visible=False
         | 
| 125 | 
            +
                    )   # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
         | 
| 126 | 
            +
                    lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
         | 
| 127 | 
            +
                    model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
         | 
| 128 | 
            +
                    model_ext_name = gr.Textbox(value='Model types', visible=False)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    with gr.Row():
         | 
| 131 | 
            +
                        db_model = gr.Textbox(
         | 
| 132 | 
            +
                            label='Finetuned model',
         | 
| 133 | 
            +
                            placeholder='Path to the finetuned model to extract',
         | 
| 134 | 
            +
                            interactive=True,
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
                        button_db_model_file = gr.Button(
         | 
| 137 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 138 | 
            +
                        )
         | 
| 139 | 
            +
                        button_db_model_file.click(
         | 
| 140 | 
            +
                            get_file_path,
         | 
| 141 | 
            +
                            inputs=[db_model, model_ext, model_ext_name],
         | 
| 142 | 
            +
                            outputs=db_model,
         | 
| 143 | 
            +
                            show_progress=False,
         | 
| 144 | 
            +
                        )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                        base_model = gr.Textbox(
         | 
| 147 | 
            +
                            label='Stable Diffusion base model',
         | 
| 148 | 
            +
                            placeholder='Stable Diffusion original model: ckpt or safetensors file',
         | 
| 149 | 
            +
                            interactive=True,
         | 
| 150 | 
            +
                        )
         | 
| 151 | 
            +
                        button_base_model_file = gr.Button(
         | 
| 152 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 153 | 
            +
                        )
         | 
| 154 | 
            +
                        button_base_model_file.click(
         | 
| 155 | 
            +
                            get_file_path,
         | 
| 156 | 
            +
                            inputs=[base_model, model_ext, model_ext_name],
         | 
| 157 | 
            +
                            outputs=base_model,
         | 
| 158 | 
            +
                            show_progress=False,
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
                    with gr.Row():
         | 
| 161 | 
            +
                        output_name = gr.Textbox(
         | 
| 162 | 
            +
                            label='Save to',
         | 
| 163 | 
            +
                            placeholder='path where to save the extracted LoRA model...',
         | 
| 164 | 
            +
                            interactive=True,
         | 
| 165 | 
            +
                        )
         | 
| 166 | 
            +
                        button_output_name = gr.Button(
         | 
| 167 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                        button_output_name.click(
         | 
| 170 | 
            +
                            get_saveasfilename_path,
         | 
| 171 | 
            +
                            inputs=[output_name, lora_ext, lora_ext_name],
         | 
| 172 | 
            +
                            outputs=output_name,
         | 
| 173 | 
            +
                            show_progress=False,
         | 
| 174 | 
            +
                        )
         | 
| 175 | 
            +
                        device = gr.Dropdown(
         | 
| 176 | 
            +
                            label='Device',
         | 
| 177 | 
            +
                            choices=[
         | 
| 178 | 
            +
                                'cpu',
         | 
| 179 | 
            +
                                'cuda',
         | 
| 180 | 
            +
                            ],
         | 
| 181 | 
            +
                            value='cuda',
         | 
| 182 | 
            +
                            interactive=True,
         | 
| 183 | 
            +
                        )
         | 
| 184 | 
            +
                        is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True)
         | 
| 185 | 
            +
                    mode = gr.Dropdown(
         | 
| 186 | 
            +
                        label='Mode',
         | 
| 187 | 
            +
                        choices=['fixed', 'threshold', 'ratio', 'quantile'],
         | 
| 188 | 
            +
                        value='fixed',
         | 
| 189 | 
            +
                        interactive=True,
         | 
| 190 | 
            +
                    )
         | 
| 191 | 
            +
                    with gr.Row(visible=True) as fixed:
         | 
| 192 | 
            +
                        linear_dim = gr.Slider(
         | 
| 193 | 
            +
                            minimum=1,
         | 
| 194 | 
            +
                            maximum=1024,
         | 
| 195 | 
            +
                            label='Network Dimension',
         | 
| 196 | 
            +
                            value=1,
         | 
| 197 | 
            +
                            step=1,
         | 
| 198 | 
            +
                            interactive=True,
         | 
| 199 | 
            +
                        )
         | 
| 200 | 
            +
                        conv_dim = gr.Slider(
         | 
| 201 | 
            +
                            minimum=1,
         | 
| 202 | 
            +
                            maximum=1024,
         | 
| 203 | 
            +
                            label='Conv Dimension',
         | 
| 204 | 
            +
                            value=1,
         | 
| 205 | 
            +
                            step=1,
         | 
| 206 | 
            +
                            interactive=True,
         | 
| 207 | 
            +
                        )
         | 
| 208 | 
            +
                    with gr.Row(visible=False) as threshold:
         | 
| 209 | 
            +
                        linear_threshold = gr.Slider(
         | 
| 210 | 
            +
                            minimum=0,
         | 
| 211 | 
            +
                            maximum=1,
         | 
| 212 | 
            +
                            label='Linear threshold',
         | 
| 213 | 
            +
                            value=0,
         | 
| 214 | 
            +
                            step=0.01,
         | 
| 215 | 
            +
                            interactive=True,
         | 
| 216 | 
            +
                        )
         | 
| 217 | 
            +
                        conv_threshold = gr.Slider(
         | 
| 218 | 
            +
                            minimum=0,
         | 
| 219 | 
            +
                            maximum=1,
         | 
| 220 | 
            +
                            label='Conv threshold',
         | 
| 221 | 
            +
                            value=0,
         | 
| 222 | 
            +
                            step=0.01,
         | 
| 223 | 
            +
                            interactive=True,
         | 
| 224 | 
            +
                        )
         | 
| 225 | 
            +
                    with gr.Row(visible=False) as ratio:
         | 
| 226 | 
            +
                        linear_ratio = gr.Slider(
         | 
| 227 | 
            +
                            minimum=0,
         | 
| 228 | 
            +
                            maximum=1,
         | 
| 229 | 
            +
                            label='Linear ratio',
         | 
| 230 | 
            +
                            value=0,
         | 
| 231 | 
            +
                            step=0.01,
         | 
| 232 | 
            +
                            interactive=True,
         | 
| 233 | 
            +
                        )
         | 
| 234 | 
            +
                        conv_ratio = gr.Slider(
         | 
| 235 | 
            +
                            minimum=0,
         | 
| 236 | 
            +
                            maximum=1,
         | 
| 237 | 
            +
                            label='Conv ratio',
         | 
| 238 | 
            +
                            value=0,
         | 
| 239 | 
            +
                            step=0.01,
         | 
| 240 | 
            +
                            interactive=True,
         | 
| 241 | 
            +
                        )
         | 
| 242 | 
            +
                    with gr.Row(visible=False) as quantile:
         | 
| 243 | 
            +
                        linear_quantile = gr.Slider(
         | 
| 244 | 
            +
                            minimum=0,
         | 
| 245 | 
            +
                            maximum=1,
         | 
| 246 | 
            +
                            label='Linear quantile',
         | 
| 247 | 
            +
                            value=0.75,
         | 
| 248 | 
            +
                            step=0.01,
         | 
| 249 | 
            +
                            interactive=True,
         | 
| 250 | 
            +
                        )
         | 
| 251 | 
            +
                        conv_quantile = gr.Slider(
         | 
| 252 | 
            +
                            minimum=0,
         | 
| 253 | 
            +
                            maximum=1,
         | 
| 254 | 
            +
                            label='Conv quantile',
         | 
| 255 | 
            +
                            value=0.75,
         | 
| 256 | 
            +
                            step=0.01,
         | 
| 257 | 
            +
                            interactive=True,
         | 
| 258 | 
            +
                        )
         | 
| 259 | 
            +
                    with gr.Row():
         | 
| 260 | 
            +
                        use_sparse_bias = gr.Checkbox(
         | 
| 261 | 
            +
                            label='Use sparse biais', value=False, interactive=True
         | 
| 262 | 
            +
                        )
         | 
| 263 | 
            +
                        sparsity = gr.Slider(
         | 
| 264 | 
            +
                            minimum=0,
         | 
| 265 | 
            +
                            maximum=1,
         | 
| 266 | 
            +
                            label='Sparsity',
         | 
| 267 | 
            +
                            value=0.98,
         | 
| 268 | 
            +
                            step=0.01,
         | 
| 269 | 
            +
                            interactive=True,
         | 
| 270 | 
            +
                        )
         | 
| 271 | 
            +
                        disable_cp = gr.Checkbox(
         | 
| 272 | 
            +
                            label='Disable CP decomposition', value=False, interactive=True
         | 
| 273 | 
            +
                        )
         | 
| 274 | 
            +
                    mode.change(
         | 
| 275 | 
            +
                        update_mode,
         | 
| 276 | 
            +
                        inputs=[mode],
         | 
| 277 | 
            +
                        outputs=[
         | 
| 278 | 
            +
                            fixed,
         | 
| 279 | 
            +
                            threshold,
         | 
| 280 | 
            +
                            ratio,
         | 
| 281 | 
            +
                            quantile,
         | 
| 282 | 
            +
                        ],
         | 
| 283 | 
            +
                    )
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    extract_button = gr.Button('Extract LyCORIS LoCon')
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    extract_button.click(
         | 
| 288 | 
            +
                        extract_lycoris_locon,
         | 
| 289 | 
            +
                        inputs=[
         | 
| 290 | 
            +
                            db_model,
         | 
| 291 | 
            +
                            base_model,
         | 
| 292 | 
            +
                            output_name,
         | 
| 293 | 
            +
                            device,
         | 
| 294 | 
            +
                            is_v2,
         | 
| 295 | 
            +
                            mode,
         | 
| 296 | 
            +
                            linear_dim,
         | 
| 297 | 
            +
                            conv_dim,
         | 
| 298 | 
            +
                            linear_threshold,
         | 
| 299 | 
            +
                            conv_threshold,
         | 
| 300 | 
            +
                            linear_ratio,
         | 
| 301 | 
            +
                            conv_ratio,
         | 
| 302 | 
            +
                            linear_quantile,
         | 
| 303 | 
            +
                            conv_quantile,
         | 
| 304 | 
            +
                            use_sparse_bias,
         | 
| 305 | 
            +
                            sparsity,
         | 
| 306 | 
            +
                            disable_cp,
         | 
| 307 | 
            +
                        ],
         | 
| 308 | 
            +
                        show_progress=False,
         | 
| 309 | 
            +
                    )
         | 
    	
        library/git_caption_gui.py
    ADDED
    
    | @@ -0,0 +1,136 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from .common_gui import get_folder_path, add_pre_postfix
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def caption_images(
         | 
| 11 | 
            +
                train_data_dir,
         | 
| 12 | 
            +
                caption_ext,
         | 
| 13 | 
            +
                batch_size,
         | 
| 14 | 
            +
                max_data_loader_n_workers,
         | 
| 15 | 
            +
                max_length,
         | 
| 16 | 
            +
                model_id,
         | 
| 17 | 
            +
                prefix,
         | 
| 18 | 
            +
                postfix,
         | 
| 19 | 
            +
            ):
         | 
| 20 | 
            +
                # Check for images_dir_input
         | 
| 21 | 
            +
                if train_data_dir == '':
         | 
| 22 | 
            +
                    msgbox('Image folder is missing...')
         | 
| 23 | 
            +
                    return
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                if caption_ext == '':
         | 
| 26 | 
            +
                    msgbox('Please provide an extension for the caption files.')
         | 
| 27 | 
            +
                    return
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                print(f'GIT captioning files in {train_data_dir}...')
         | 
| 30 | 
            +
                run_cmd = (
         | 
| 31 | 
            +
                    f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"'
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
                if not model_id == '':
         | 
| 34 | 
            +
                    run_cmd += f' --model_id="{model_id}"'
         | 
| 35 | 
            +
                run_cmd += f' --batch_size="{int(batch_size)}"'
         | 
| 36 | 
            +
                run_cmd += (
         | 
| 37 | 
            +
                    f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
                run_cmd += f' --max_length="{int(max_length)}"'
         | 
| 40 | 
            +
                if caption_ext != '':
         | 
| 41 | 
            +
                    run_cmd += f' --caption_extension="{caption_ext}"'
         | 
| 42 | 
            +
                run_cmd += f' "{train_data_dir}"'
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                print(run_cmd)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                # Run the command
         | 
| 47 | 
            +
                subprocess.run(run_cmd)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                # Add prefix and postfix
         | 
| 50 | 
            +
                add_pre_postfix(
         | 
| 51 | 
            +
                    folder=train_data_dir,
         | 
| 52 | 
            +
                    caption_file_ext=caption_ext,
         | 
| 53 | 
            +
                    prefix=prefix,
         | 
| 54 | 
            +
                    postfix=postfix,
         | 
| 55 | 
            +
                )
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                print('...captioning done')
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            ###
         | 
| 61 | 
            +
            # Gradio UI
         | 
| 62 | 
            +
            ###
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def gradio_git_caption_gui_tab():
         | 
| 66 | 
            +
                with gr.Tab('GIT Captioning'):
         | 
| 67 | 
            +
                    gr.Markdown(
         | 
| 68 | 
            +
                        'This utility will use GIT to caption files for each images in a folder.'
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
                    with gr.Row():
         | 
| 71 | 
            +
                        train_data_dir = gr.Textbox(
         | 
| 72 | 
            +
                            label='Image folder to caption',
         | 
| 73 | 
            +
                            placeholder='Directory containing the images to caption',
         | 
| 74 | 
            +
                            interactive=True,
         | 
| 75 | 
            +
                        )
         | 
| 76 | 
            +
                        button_train_data_dir_input = gr.Button(
         | 
| 77 | 
            +
                            '📂', elem_id='open_folder_small'
         | 
| 78 | 
            +
                        )
         | 
| 79 | 
            +
                        button_train_data_dir_input.click(
         | 
| 80 | 
            +
                            get_folder_path,
         | 
| 81 | 
            +
                            outputs=train_data_dir,
         | 
| 82 | 
            +
                            show_progress=False,
         | 
| 83 | 
            +
                        )
         | 
| 84 | 
            +
                    with gr.Row():
         | 
| 85 | 
            +
                        caption_ext = gr.Textbox(
         | 
| 86 | 
            +
                            label='Caption file extension',
         | 
| 87 | 
            +
                            placeholder='Extention for caption file. eg: .caption, .txt',
         | 
| 88 | 
            +
                            value='.txt',
         | 
| 89 | 
            +
                            interactive=True,
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                        prefix = gr.Textbox(
         | 
| 93 | 
            +
                            label='Prefix to add to BLIP caption',
         | 
| 94 | 
            +
                            placeholder='(Optional)',
         | 
| 95 | 
            +
                            interactive=True,
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                        postfix = gr.Textbox(
         | 
| 99 | 
            +
                            label='Postfix to add to BLIP caption',
         | 
| 100 | 
            +
                            placeholder='(Optional)',
         | 
| 101 | 
            +
                            interactive=True,
         | 
| 102 | 
            +
                        )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        batch_size = gr.Number(
         | 
| 105 | 
            +
                            value=1, label='Batch size', interactive=True
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    with gr.Row():
         | 
| 109 | 
            +
                        max_data_loader_n_workers = gr.Number(
         | 
| 110 | 
            +
                            value=2, label='Number of workers', interactive=True
         | 
| 111 | 
            +
                        )
         | 
| 112 | 
            +
                        max_length = gr.Number(
         | 
| 113 | 
            +
                            value=75, label='Max length', interactive=True
         | 
| 114 | 
            +
                        )
         | 
| 115 | 
            +
                        model_id = gr.Textbox(
         | 
| 116 | 
            +
                            label='Model',
         | 
| 117 | 
            +
                            placeholder='(Optional) model id for GIT in Hugging Face',
         | 
| 118 | 
            +
                            interactive=True,
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    caption_button = gr.Button('Caption images')
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    caption_button.click(
         | 
| 124 | 
            +
                        caption_images,
         | 
| 125 | 
            +
                        inputs=[
         | 
| 126 | 
            +
                            train_data_dir,
         | 
| 127 | 
            +
                            caption_ext,
         | 
| 128 | 
            +
                            batch_size,
         | 
| 129 | 
            +
                            max_data_loader_n_workers,
         | 
| 130 | 
            +
                            max_length,
         | 
| 131 | 
            +
                            model_id,
         | 
| 132 | 
            +
                            prefix,
         | 
| 133 | 
            +
                            postfix,
         | 
| 134 | 
            +
                        ],
         | 
| 135 | 
            +
                        show_progress=False,
         | 
| 136 | 
            +
                    )
         | 
    	
        library/lpw_stable_diffusion.py
    ADDED
    
    | @@ -0,0 +1,1179 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
         | 
| 2 | 
            +
            # and modify to support SD2.x
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import inspect
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
            from typing import Callable, List, Optional, Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import PIL
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from packaging import version
         | 
| 12 | 
            +
            from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import diffusers
         | 
| 15 | 
            +
            from diffusers import SchedulerMixin, StableDiffusionPipeline
         | 
| 16 | 
            +
            from diffusers.models import AutoencoderKL, UNet2DConditionModel
         | 
| 17 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
         | 
| 18 | 
            +
            from diffusers.utils import logging
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            try:
         | 
| 22 | 
            +
                from diffusers.utils import PIL_INTERPOLATION
         | 
| 23 | 
            +
            except ImportError:
         | 
| 24 | 
            +
                if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
         | 
| 25 | 
            +
                    PIL_INTERPOLATION = {
         | 
| 26 | 
            +
                        "linear": PIL.Image.Resampling.BILINEAR,
         | 
| 27 | 
            +
                        "bilinear": PIL.Image.Resampling.BILINEAR,
         | 
| 28 | 
            +
                        "bicubic": PIL.Image.Resampling.BICUBIC,
         | 
| 29 | 
            +
                        "lanczos": PIL.Image.Resampling.LANCZOS,
         | 
| 30 | 
            +
                        "nearest": PIL.Image.Resampling.NEAREST,
         | 
| 31 | 
            +
                    }
         | 
| 32 | 
            +
                else:
         | 
| 33 | 
            +
                    PIL_INTERPOLATION = {
         | 
| 34 | 
            +
                        "linear": PIL.Image.LINEAR,
         | 
| 35 | 
            +
                        "bilinear": PIL.Image.BILINEAR,
         | 
| 36 | 
            +
                        "bicubic": PIL.Image.BICUBIC,
         | 
| 37 | 
            +
                        "lanczos": PIL.Image.LANCZOS,
         | 
| 38 | 
            +
                        "nearest": PIL.Image.NEAREST,
         | 
| 39 | 
            +
                    }
         | 
| 40 | 
            +
            # ------------------------------------------------------------------------------
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            re_attention = re.compile(
         | 
| 45 | 
            +
                r"""
         | 
| 46 | 
            +
            \\\(|
         | 
| 47 | 
            +
            \\\)|
         | 
| 48 | 
            +
            \\\[|
         | 
| 49 | 
            +
            \\]|
         | 
| 50 | 
            +
            \\\\|
         | 
| 51 | 
            +
            \\|
         | 
| 52 | 
            +
            \(|
         | 
| 53 | 
            +
            \[|
         | 
| 54 | 
            +
            :([+-]?[.\d]+)\)|
         | 
| 55 | 
            +
            \)|
         | 
| 56 | 
            +
            ]|
         | 
| 57 | 
            +
            [^\\()\[\]:]+|
         | 
| 58 | 
            +
            :
         | 
| 59 | 
            +
            """,
         | 
| 60 | 
            +
                re.X,
         | 
| 61 | 
            +
            )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def parse_prompt_attention(text):
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
         | 
| 67 | 
            +
                Accepted tokens are:
         | 
| 68 | 
            +
                  (abc) - increases attention to abc by a multiplier of 1.1
         | 
| 69 | 
            +
                  (abc:3.12) - increases attention to abc by a multiplier of 3.12
         | 
| 70 | 
            +
                  [abc] - decreases attention to abc by a multiplier of 1.1
         | 
| 71 | 
            +
                  \( - literal character '('
         | 
| 72 | 
            +
                  \[ - literal character '['
         | 
| 73 | 
            +
                  \) - literal character ')'
         | 
| 74 | 
            +
                  \] - literal character ']'
         | 
| 75 | 
            +
                  \\ - literal character '\'
         | 
| 76 | 
            +
                  anything else - just text
         | 
| 77 | 
            +
                >>> parse_prompt_attention('normal text')
         | 
| 78 | 
            +
                [['normal text', 1.0]]
         | 
| 79 | 
            +
                >>> parse_prompt_attention('an (important) word')
         | 
| 80 | 
            +
                [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
         | 
| 81 | 
            +
                >>> parse_prompt_attention('(unbalanced')
         | 
| 82 | 
            +
                [['unbalanced', 1.1]]
         | 
| 83 | 
            +
                >>> parse_prompt_attention('\(literal\]')
         | 
| 84 | 
            +
                [['(literal]', 1.0]]
         | 
| 85 | 
            +
                >>> parse_prompt_attention('(unnecessary)(parens)')
         | 
| 86 | 
            +
                [['unnecessaryparens', 1.1]]
         | 
| 87 | 
            +
                >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
         | 
| 88 | 
            +
                [['a ', 1.0],
         | 
| 89 | 
            +
                 ['house', 1.5730000000000004],
         | 
| 90 | 
            +
                 [' ', 1.1],
         | 
| 91 | 
            +
                 ['on', 1.0],
         | 
| 92 | 
            +
                 [' a ', 1.1],
         | 
| 93 | 
            +
                 ['hill', 0.55],
         | 
| 94 | 
            +
                 [', sun, ', 1.1],
         | 
| 95 | 
            +
                 ['sky', 1.4641000000000006],
         | 
| 96 | 
            +
                 ['.', 1.1]]
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                res = []
         | 
| 100 | 
            +
                round_brackets = []
         | 
| 101 | 
            +
                square_brackets = []
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                round_bracket_multiplier = 1.1
         | 
| 104 | 
            +
                square_bracket_multiplier = 1 / 1.1
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def multiply_range(start_position, multiplier):
         | 
| 107 | 
            +
                    for p in range(start_position, len(res)):
         | 
| 108 | 
            +
                        res[p][1] *= multiplier
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                for m in re_attention.finditer(text):
         | 
| 111 | 
            +
                    text = m.group(0)
         | 
| 112 | 
            +
                    weight = m.group(1)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    if text.startswith("\\"):
         | 
| 115 | 
            +
                        res.append([text[1:], 1.0])
         | 
| 116 | 
            +
                    elif text == "(":
         | 
| 117 | 
            +
                        round_brackets.append(len(res))
         | 
| 118 | 
            +
                    elif text == "[":
         | 
| 119 | 
            +
                        square_brackets.append(len(res))
         | 
| 120 | 
            +
                    elif weight is not None and len(round_brackets) > 0:
         | 
| 121 | 
            +
                        multiply_range(round_brackets.pop(), float(weight))
         | 
| 122 | 
            +
                    elif text == ")" and len(round_brackets) > 0:
         | 
| 123 | 
            +
                        multiply_range(round_brackets.pop(), round_bracket_multiplier)
         | 
| 124 | 
            +
                    elif text == "]" and len(square_brackets) > 0:
         | 
| 125 | 
            +
                        multiply_range(square_brackets.pop(), square_bracket_multiplier)
         | 
| 126 | 
            +
                    else:
         | 
| 127 | 
            +
                        res.append([text, 1.0])
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                for pos in round_brackets:
         | 
| 130 | 
            +
                    multiply_range(pos, round_bracket_multiplier)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                for pos in square_brackets:
         | 
| 133 | 
            +
                    multiply_range(pos, square_bracket_multiplier)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                if len(res) == 0:
         | 
| 136 | 
            +
                    res = [["", 1.0]]
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                # merge runs of identical weights
         | 
| 139 | 
            +
                i = 0
         | 
| 140 | 
            +
                while i + 1 < len(res):
         | 
| 141 | 
            +
                    if res[i][1] == res[i + 1][1]:
         | 
| 142 | 
            +
                        res[i][0] += res[i + 1][0]
         | 
| 143 | 
            +
                        res.pop(i + 1)
         | 
| 144 | 
            +
                    else:
         | 
| 145 | 
            +
                        i += 1
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                return res
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
         | 
| 151 | 
            +
                r"""
         | 
| 152 | 
            +
                Tokenize a list of prompts and return its tokens with weights of each token.
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                No padding, starting or ending token is included.
         | 
| 155 | 
            +
                """
         | 
| 156 | 
            +
                tokens = []
         | 
| 157 | 
            +
                weights = []
         | 
| 158 | 
            +
                truncated = False
         | 
| 159 | 
            +
                for text in prompt:
         | 
| 160 | 
            +
                    texts_and_weights = parse_prompt_attention(text)
         | 
| 161 | 
            +
                    text_token = []
         | 
| 162 | 
            +
                    text_weight = []
         | 
| 163 | 
            +
                    for word, weight in texts_and_weights:
         | 
| 164 | 
            +
                        # tokenize and discard the starting and the ending token
         | 
| 165 | 
            +
                        token = pipe.tokenizer(word).input_ids[1:-1]
         | 
| 166 | 
            +
                        text_token += token
         | 
| 167 | 
            +
                        # copy the weight by length of token
         | 
| 168 | 
            +
                        text_weight += [weight] * len(token)
         | 
| 169 | 
            +
                        # stop if the text is too long (longer than truncation limit)
         | 
| 170 | 
            +
                        if len(text_token) > max_length:
         | 
| 171 | 
            +
                            truncated = True
         | 
| 172 | 
            +
                            break
         | 
| 173 | 
            +
                    # truncate
         | 
| 174 | 
            +
                    if len(text_token) > max_length:
         | 
| 175 | 
            +
                        truncated = True
         | 
| 176 | 
            +
                        text_token = text_token[:max_length]
         | 
| 177 | 
            +
                        text_weight = text_weight[:max_length]
         | 
| 178 | 
            +
                    tokens.append(text_token)
         | 
| 179 | 
            +
                    weights.append(text_weight)
         | 
| 180 | 
            +
                if truncated:
         | 
| 181 | 
            +
                    logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
         | 
| 182 | 
            +
                return tokens, weights
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
         | 
| 186 | 
            +
                r"""
         | 
| 187 | 
            +
                Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
         | 
| 188 | 
            +
                """
         | 
| 189 | 
            +
                max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
         | 
| 190 | 
            +
                weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
         | 
| 191 | 
            +
                for i in range(len(tokens)):
         | 
| 192 | 
            +
                    tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
         | 
| 193 | 
            +
                    if no_boseos_middle:
         | 
| 194 | 
            +
                        weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
         | 
| 195 | 
            +
                    else:
         | 
| 196 | 
            +
                        w = []
         | 
| 197 | 
            +
                        if len(weights[i]) == 0:
         | 
| 198 | 
            +
                            w = [1.0] * weights_length
         | 
| 199 | 
            +
                        else:
         | 
| 200 | 
            +
                            for j in range(max_embeddings_multiples):
         | 
| 201 | 
            +
                                w.append(1.0)  # weight for starting token in this chunk
         | 
| 202 | 
            +
                                w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
         | 
| 203 | 
            +
                                w.append(1.0)  # weight for ending token in this chunk
         | 
| 204 | 
            +
                            w += [1.0] * (weights_length - len(w))
         | 
| 205 | 
            +
                        weights[i] = w[:]
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return tokens, weights
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def get_unweighted_text_embeddings(
         | 
| 211 | 
            +
                pipe: StableDiffusionPipeline,
         | 
| 212 | 
            +
                text_input: torch.Tensor,
         | 
| 213 | 
            +
                chunk_length: int,
         | 
| 214 | 
            +
                clip_skip: int,
         | 
| 215 | 
            +
                eos: int,
         | 
| 216 | 
            +
                pad: int,
         | 
| 217 | 
            +
                no_boseos_middle: Optional[bool] = True,
         | 
| 218 | 
            +
            ):
         | 
| 219 | 
            +
                """
         | 
| 220 | 
            +
                When the length of tokens is a multiple of the capacity of the text encoder,
         | 
| 221 | 
            +
                it should be split into chunks and sent to the text encoder individually.
         | 
| 222 | 
            +
                """
         | 
| 223 | 
            +
                max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
         | 
| 224 | 
            +
                if max_embeddings_multiples > 1:
         | 
| 225 | 
            +
                    text_embeddings = []
         | 
| 226 | 
            +
                    for i in range(max_embeddings_multiples):
         | 
| 227 | 
            +
                        # extract the i-th chunk
         | 
| 228 | 
            +
                        text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                        # cover the head and the tail by the starting and the ending tokens
         | 
| 231 | 
            +
                        text_input_chunk[:, 0] = text_input[0, 0]
         | 
| 232 | 
            +
                        if pad == eos:  # v1
         | 
| 233 | 
            +
                            text_input_chunk[:, -1] = text_input[0, -1]
         | 
| 234 | 
            +
                        else:  # v2
         | 
| 235 | 
            +
                            for j in range(len(text_input_chunk)):
         | 
| 236 | 
            +
                                if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad:  # 最後に普通の文字がある
         | 
| 237 | 
            +
                                    text_input_chunk[j, -1] = eos
         | 
| 238 | 
            +
                                if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
         | 
| 239 | 
            +
                                    text_input_chunk[j, 1] = eos
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        if clip_skip is None or clip_skip == 1:
         | 
| 242 | 
            +
                            text_embedding = pipe.text_encoder(text_input_chunk)[0]
         | 
| 243 | 
            +
                        else:
         | 
| 244 | 
            +
                            enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
         | 
| 245 | 
            +
                            text_embedding = enc_out["hidden_states"][-clip_skip]
         | 
| 246 | 
            +
                            text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                        # cover the head and the tail by the starting and the ending tokens
         | 
| 249 | 
            +
                        text_input_chunk[:, 0] = text_input[0, 0]
         | 
| 250 | 
            +
                        text_input_chunk[:, -1] = text_input[0, -1]
         | 
| 251 | 
            +
                        text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        if no_boseos_middle:
         | 
| 254 | 
            +
                            if i == 0:
         | 
| 255 | 
            +
                                # discard the ending token
         | 
| 256 | 
            +
                                text_embedding = text_embedding[:, :-1]
         | 
| 257 | 
            +
                            elif i == max_embeddings_multiples - 1:
         | 
| 258 | 
            +
                                # discard the starting token
         | 
| 259 | 
            +
                                text_embedding = text_embedding[:, 1:]
         | 
| 260 | 
            +
                            else:
         | 
| 261 | 
            +
                                # discard both starting and ending tokens
         | 
| 262 | 
            +
                                text_embedding = text_embedding[:, 1:-1]
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        text_embeddings.append(text_embedding)
         | 
| 265 | 
            +
                    text_embeddings = torch.concat(text_embeddings, axis=1)
         | 
| 266 | 
            +
                else:
         | 
| 267 | 
            +
                    text_embeddings = pipe.text_encoder(text_input)[0]
         | 
| 268 | 
            +
                return text_embeddings
         | 
| 269 | 
            +
             | 
| 270 | 
            +
             | 
| 271 | 
            +
            def get_weighted_text_embeddings(
         | 
| 272 | 
            +
                pipe: StableDiffusionPipeline,
         | 
| 273 | 
            +
                prompt: Union[str, List[str]],
         | 
| 274 | 
            +
                uncond_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 275 | 
            +
                max_embeddings_multiples: Optional[int] = 3,
         | 
| 276 | 
            +
                no_boseos_middle: Optional[bool] = False,
         | 
| 277 | 
            +
                skip_parsing: Optional[bool] = False,
         | 
| 278 | 
            +
                skip_weighting: Optional[bool] = False,
         | 
| 279 | 
            +
                clip_skip=None,
         | 
| 280 | 
            +
            ):
         | 
| 281 | 
            +
                r"""
         | 
| 282 | 
            +
                Prompts can be assigned with local weights using brackets. For example,
         | 
| 283 | 
            +
                prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
         | 
| 284 | 
            +
                and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                Args:
         | 
| 289 | 
            +
                    pipe (`StableDiffusionPipeline`):
         | 
| 290 | 
            +
                        Pipe to provide access to the tokenizer and the text encoder.
         | 
| 291 | 
            +
                    prompt (`str` or `List[str]`):
         | 
| 292 | 
            +
                        The prompt or prompts to guide the image generation.
         | 
| 293 | 
            +
                    uncond_prompt (`str` or `List[str]`):
         | 
| 294 | 
            +
                        The unconditional prompt or prompts for guide the image generation. If unconditional prompt
         | 
| 295 | 
            +
                        is provided, the embeddings of prompt and uncond_prompt are concatenated.
         | 
| 296 | 
            +
                    max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 297 | 
            +
                        The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 298 | 
            +
                    no_boseos_middle (`bool`, *optional*, defaults to `False`):
         | 
| 299 | 
            +
                        If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
         | 
| 300 | 
            +
                        ending token in each of the chunk in the middle.
         | 
| 301 | 
            +
                    skip_parsing (`bool`, *optional*, defaults to `False`):
         | 
| 302 | 
            +
                        Skip the parsing of brackets.
         | 
| 303 | 
            +
                    skip_weighting (`bool`, *optional*, defaults to `False`):
         | 
| 304 | 
            +
                        Skip the weighting. When the parsing is skipped, it is forced True.
         | 
| 305 | 
            +
                """
         | 
| 306 | 
            +
                max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 307 | 
            +
                if isinstance(prompt, str):
         | 
| 308 | 
            +
                    prompt = [prompt]
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                if not skip_parsing:
         | 
| 311 | 
            +
                    prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
         | 
| 312 | 
            +
                    if uncond_prompt is not None:
         | 
| 313 | 
            +
                        if isinstance(uncond_prompt, str):
         | 
| 314 | 
            +
                            uncond_prompt = [uncond_prompt]
         | 
| 315 | 
            +
                        uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
         | 
| 316 | 
            +
                else:
         | 
| 317 | 
            +
                    prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
         | 
| 318 | 
            +
                    prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
         | 
| 319 | 
            +
                    if uncond_prompt is not None:
         | 
| 320 | 
            +
                        if isinstance(uncond_prompt, str):
         | 
| 321 | 
            +
                            uncond_prompt = [uncond_prompt]
         | 
| 322 | 
            +
                        uncond_tokens = [
         | 
| 323 | 
            +
                            token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
         | 
| 324 | 
            +
                        ]
         | 
| 325 | 
            +
                        uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                # round up the longest length of tokens to a multiple of (model_max_length - 2)
         | 
| 328 | 
            +
                max_length = max([len(token) for token in prompt_tokens])
         | 
| 329 | 
            +
                if uncond_prompt is not None:
         | 
| 330 | 
            +
                    max_length = max(max_length, max([len(token) for token in uncond_tokens]))
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                max_embeddings_multiples = min(
         | 
| 333 | 
            +
                    max_embeddings_multiples,
         | 
| 334 | 
            +
                    (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
         | 
| 335 | 
            +
                )
         | 
| 336 | 
            +
                max_embeddings_multiples = max(1, max_embeddings_multiples)
         | 
| 337 | 
            +
                max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                # pad the length of tokens and weights
         | 
| 340 | 
            +
                bos = pipe.tokenizer.bos_token_id
         | 
| 341 | 
            +
                eos = pipe.tokenizer.eos_token_id
         | 
| 342 | 
            +
                pad = pipe.tokenizer.pad_token_id
         | 
| 343 | 
            +
                prompt_tokens, prompt_weights = pad_tokens_and_weights(
         | 
| 344 | 
            +
                    prompt_tokens,
         | 
| 345 | 
            +
                    prompt_weights,
         | 
| 346 | 
            +
                    max_length,
         | 
| 347 | 
            +
                    bos,
         | 
| 348 | 
            +
                    eos,
         | 
| 349 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 350 | 
            +
                    chunk_length=pipe.tokenizer.model_max_length,
         | 
| 351 | 
            +
                )
         | 
| 352 | 
            +
                prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
         | 
| 353 | 
            +
                if uncond_prompt is not None:
         | 
| 354 | 
            +
                    uncond_tokens, uncond_weights = pad_tokens_and_weights(
         | 
| 355 | 
            +
                        uncond_tokens,
         | 
| 356 | 
            +
                        uncond_weights,
         | 
| 357 | 
            +
                        max_length,
         | 
| 358 | 
            +
                        bos,
         | 
| 359 | 
            +
                        eos,
         | 
| 360 | 
            +
                        no_boseos_middle=no_boseos_middle,
         | 
| 361 | 
            +
                        chunk_length=pipe.tokenizer.model_max_length,
         | 
| 362 | 
            +
                    )
         | 
| 363 | 
            +
                    uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                # get the embeddings
         | 
| 366 | 
            +
                text_embeddings = get_unweighted_text_embeddings(
         | 
| 367 | 
            +
                    pipe,
         | 
| 368 | 
            +
                    prompt_tokens,
         | 
| 369 | 
            +
                    pipe.tokenizer.model_max_length,
         | 
| 370 | 
            +
                    clip_skip,
         | 
| 371 | 
            +
                    eos,
         | 
| 372 | 
            +
                    pad,
         | 
| 373 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 374 | 
            +
                )
         | 
| 375 | 
            +
                prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
         | 
| 376 | 
            +
                if uncond_prompt is not None:
         | 
| 377 | 
            +
                    uncond_embeddings = get_unweighted_text_embeddings(
         | 
| 378 | 
            +
                        pipe,
         | 
| 379 | 
            +
                        uncond_tokens,
         | 
| 380 | 
            +
                        pipe.tokenizer.model_max_length,
         | 
| 381 | 
            +
                        clip_skip,
         | 
| 382 | 
            +
                        eos,
         | 
| 383 | 
            +
                        pad,
         | 
| 384 | 
            +
                        no_boseos_middle=no_boseos_middle,
         | 
| 385 | 
            +
                    )
         | 
| 386 | 
            +
                    uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                # assign weights to the prompts and normalize in the sense of mean
         | 
| 389 | 
            +
                # TODO: should we normalize by chunk or in a whole (current implementation)?
         | 
| 390 | 
            +
                if (not skip_parsing) and (not skip_weighting):
         | 
| 391 | 
            +
                    previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 392 | 
            +
                    text_embeddings *= prompt_weights.unsqueeze(-1)
         | 
| 393 | 
            +
                    current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 394 | 
            +
                    text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
         | 
| 395 | 
            +
                    if uncond_prompt is not None:
         | 
| 396 | 
            +
                        previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
         | 
| 397 | 
            +
                        uncond_embeddings *= uncond_weights.unsqueeze(-1)
         | 
| 398 | 
            +
                        current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
         | 
| 399 | 
            +
                        uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                if uncond_prompt is not None:
         | 
| 402 | 
            +
                    return text_embeddings, uncond_embeddings
         | 
| 403 | 
            +
                return text_embeddings, None
         | 
| 404 | 
            +
             | 
| 405 | 
            +
             | 
| 406 | 
            +
            def preprocess_image(image):
         | 
| 407 | 
            +
                w, h = image.size
         | 
| 408 | 
            +
                w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
         | 
| 409 | 
            +
                image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
         | 
| 410 | 
            +
                image = np.array(image).astype(np.float32) / 255.0
         | 
| 411 | 
            +
                image = image[None].transpose(0, 3, 1, 2)
         | 
| 412 | 
            +
                image = torch.from_numpy(image)
         | 
| 413 | 
            +
                return 2.0 * image - 1.0
         | 
| 414 | 
            +
             | 
| 415 | 
            +
             | 
| 416 | 
            +
            def preprocess_mask(mask, scale_factor=8):
         | 
| 417 | 
            +
                mask = mask.convert("L")
         | 
| 418 | 
            +
                w, h = mask.size
         | 
| 419 | 
            +
                w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
         | 
| 420 | 
            +
                mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
         | 
| 421 | 
            +
                mask = np.array(mask).astype(np.float32) / 255.0
         | 
| 422 | 
            +
                mask = np.tile(mask, (4, 1, 1))
         | 
| 423 | 
            +
                mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?
         | 
| 424 | 
            +
                mask = 1 - mask  # repaint white, keep black
         | 
| 425 | 
            +
                mask = torch.from_numpy(mask)
         | 
| 426 | 
            +
                return mask
         | 
| 427 | 
            +
             | 
| 428 | 
            +
             | 
| 429 | 
            +
            class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
         | 
| 430 | 
            +
                r"""
         | 
| 431 | 
            +
                Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
         | 
| 432 | 
            +
                weighting in prompt.
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 435 | 
            +
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                Args:
         | 
| 438 | 
            +
                    vae ([`AutoencoderKL`]):
         | 
| 439 | 
            +
                        Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
         | 
| 440 | 
            +
                    text_encoder ([`CLIPTextModel`]):
         | 
| 441 | 
            +
                        Frozen text-encoder. Stable Diffusion uses the text portion of
         | 
| 442 | 
            +
                        [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
         | 
| 443 | 
            +
                        the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
         | 
| 444 | 
            +
                    tokenizer (`CLIPTokenizer`):
         | 
| 445 | 
            +
                        Tokenizer of class
         | 
| 446 | 
            +
                        [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
         | 
| 447 | 
            +
                    unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
         | 
| 448 | 
            +
                    scheduler ([`SchedulerMixin`]):
         | 
| 449 | 
            +
                        A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
         | 
| 450 | 
            +
                        [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
         | 
| 451 | 
            +
                    safety_checker ([`StableDiffusionSafetyChecker`]):
         | 
| 452 | 
            +
                        Classification module that estimates whether generated images could be considered offensive or harmful.
         | 
| 453 | 
            +
                        Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
         | 
| 454 | 
            +
                    feature_extractor ([`CLIPFeatureExtractor`]):
         | 
| 455 | 
            +
                        Model that extracts features from generated images to be used as inputs for the `safety_checker`.
         | 
| 456 | 
            +
                """
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                def __init__(
         | 
| 461 | 
            +
                    self,
         | 
| 462 | 
            +
                    vae: AutoencoderKL,
         | 
| 463 | 
            +
                    text_encoder: CLIPTextModel,
         | 
| 464 | 
            +
                    tokenizer: CLIPTokenizer,
         | 
| 465 | 
            +
                    unet: UNet2DConditionModel,
         | 
| 466 | 
            +
                    scheduler: SchedulerMixin,
         | 
| 467 | 
            +
                    clip_skip: int,
         | 
| 468 | 
            +
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 469 | 
            +
                    feature_extractor: CLIPFeatureExtractor,
         | 
| 470 | 
            +
                    requires_safety_checker: bool = True,
         | 
| 471 | 
            +
                ):
         | 
| 472 | 
            +
                    super().__init__(
         | 
| 473 | 
            +
                        vae=vae,
         | 
| 474 | 
            +
                        text_encoder=text_encoder,
         | 
| 475 | 
            +
                        tokenizer=tokenizer,
         | 
| 476 | 
            +
                        unet=unet,
         | 
| 477 | 
            +
                        scheduler=scheduler,
         | 
| 478 | 
            +
                        safety_checker=safety_checker,
         | 
| 479 | 
            +
                        feature_extractor=feature_extractor,
         | 
| 480 | 
            +
                        requires_safety_checker=requires_safety_checker,
         | 
| 481 | 
            +
                    )
         | 
| 482 | 
            +
                    self.clip_skip = clip_skip
         | 
| 483 | 
            +
                    self.__init__additional__()
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                # else:
         | 
| 486 | 
            +
                #     def __init__(
         | 
| 487 | 
            +
                #         self,
         | 
| 488 | 
            +
                #         vae: AutoencoderKL,
         | 
| 489 | 
            +
                #         text_encoder: CLIPTextModel,
         | 
| 490 | 
            +
                #         tokenizer: CLIPTokenizer,
         | 
| 491 | 
            +
                #         unet: UNet2DConditionModel,
         | 
| 492 | 
            +
                #         scheduler: SchedulerMixin,
         | 
| 493 | 
            +
                #         safety_checker: StableDiffusionSafetyChecker,
         | 
| 494 | 
            +
                #         feature_extractor: CLIPFeatureExtractor,
         | 
| 495 | 
            +
                #     ):
         | 
| 496 | 
            +
                #         super().__init__(
         | 
| 497 | 
            +
                #             vae=vae,
         | 
| 498 | 
            +
                #             text_encoder=text_encoder,
         | 
| 499 | 
            +
                #             tokenizer=tokenizer,
         | 
| 500 | 
            +
                #             unet=unet,
         | 
| 501 | 
            +
                #             scheduler=scheduler,
         | 
| 502 | 
            +
                #             safety_checker=safety_checker,
         | 
| 503 | 
            +
                #             feature_extractor=feature_extractor,
         | 
| 504 | 
            +
                #         )
         | 
| 505 | 
            +
                #         self.__init__additional__()
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                def __init__additional__(self):
         | 
| 508 | 
            +
                    if not hasattr(self, "vae_scale_factor"):
         | 
| 509 | 
            +
                        setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                @property
         | 
| 512 | 
            +
                def _execution_device(self):
         | 
| 513 | 
            +
                    r"""
         | 
| 514 | 
            +
                    Returns the device on which the pipeline's models will be executed. After calling
         | 
| 515 | 
            +
                    `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
         | 
| 516 | 
            +
                    hooks.
         | 
| 517 | 
            +
                    """
         | 
| 518 | 
            +
                    if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
         | 
| 519 | 
            +
                        return self.device
         | 
| 520 | 
            +
                    for module in self.unet.modules():
         | 
| 521 | 
            +
                        if (
         | 
| 522 | 
            +
                            hasattr(module, "_hf_hook")
         | 
| 523 | 
            +
                            and hasattr(module._hf_hook, "execution_device")
         | 
| 524 | 
            +
                            and module._hf_hook.execution_device is not None
         | 
| 525 | 
            +
                        ):
         | 
| 526 | 
            +
                            return torch.device(module._hf_hook.execution_device)
         | 
| 527 | 
            +
                    return self.device
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                def _encode_prompt(
         | 
| 530 | 
            +
                    self,
         | 
| 531 | 
            +
                    prompt,
         | 
| 532 | 
            +
                    device,
         | 
| 533 | 
            +
                    num_images_per_prompt,
         | 
| 534 | 
            +
                    do_classifier_free_guidance,
         | 
| 535 | 
            +
                    negative_prompt,
         | 
| 536 | 
            +
                    max_embeddings_multiples,
         | 
| 537 | 
            +
                ):
         | 
| 538 | 
            +
                    r"""
         | 
| 539 | 
            +
                    Encodes the prompt into text encoder hidden states.
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    Args:
         | 
| 542 | 
            +
                        prompt (`str` or `list(int)`):
         | 
| 543 | 
            +
                            prompt to be encoded
         | 
| 544 | 
            +
                        device: (`torch.device`):
         | 
| 545 | 
            +
                            torch device
         | 
| 546 | 
            +
                        num_images_per_prompt (`int`):
         | 
| 547 | 
            +
                            number of images that should be generated per prompt
         | 
| 548 | 
            +
                        do_classifier_free_guidance (`bool`):
         | 
| 549 | 
            +
                            whether to use classifier free guidance or not
         | 
| 550 | 
            +
                        negative_prompt (`str` or `List[str]`):
         | 
| 551 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 552 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 553 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 554 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 555 | 
            +
                    """
         | 
| 556 | 
            +
                    batch_size = len(prompt) if isinstance(prompt, list) else 1
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    if negative_prompt is None:
         | 
| 559 | 
            +
                        negative_prompt = [""] * batch_size
         | 
| 560 | 
            +
                    elif isinstance(negative_prompt, str):
         | 
| 561 | 
            +
                        negative_prompt = [negative_prompt] * batch_size
         | 
| 562 | 
            +
                    if batch_size != len(negative_prompt):
         | 
| 563 | 
            +
                        raise ValueError(
         | 
| 564 | 
            +
                            f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 565 | 
            +
                            f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 566 | 
            +
                            " the batch size of `prompt`."
         | 
| 567 | 
            +
                        )
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
         | 
| 570 | 
            +
                        pipe=self,
         | 
| 571 | 
            +
                        prompt=prompt,
         | 
| 572 | 
            +
                        uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
         | 
| 573 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 574 | 
            +
                        clip_skip=self.clip_skip,
         | 
| 575 | 
            +
                    )
         | 
| 576 | 
            +
                    bs_embed, seq_len, _ = text_embeddings.shape
         | 
| 577 | 
            +
                    text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
         | 
| 578 | 
            +
                    text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    if do_classifier_free_guidance:
         | 
| 581 | 
            +
                        bs_embed, seq_len, _ = uncond_embeddings.shape
         | 
| 582 | 
            +
                        uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
         | 
| 583 | 
            +
                        uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 584 | 
            +
                        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    return text_embeddings
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                def check_inputs(self, prompt, height, width, strength, callback_steps):
         | 
| 589 | 
            +
                    if not isinstance(prompt, str) and not isinstance(prompt, list):
         | 
| 590 | 
            +
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    if strength < 0 or strength > 1:
         | 
| 593 | 
            +
                        raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 596 | 
            +
                        print(height, width)
         | 
| 597 | 
            +
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    if (callback_steps is None) or (
         | 
| 600 | 
            +
                        callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
         | 
| 601 | 
            +
                    ):
         | 
| 602 | 
            +
                        raise ValueError(
         | 
| 603 | 
            +
                            f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
         | 
| 604 | 
            +
                        )
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
         | 
| 607 | 
            +
                    if is_text2img:
         | 
| 608 | 
            +
                        return self.scheduler.timesteps.to(device), num_inference_steps
         | 
| 609 | 
            +
                    else:
         | 
| 610 | 
            +
                        # get the original timestep using init_timestep
         | 
| 611 | 
            +
                        offset = self.scheduler.config.get("steps_offset", 0)
         | 
| 612 | 
            +
                        init_timestep = int(num_inference_steps * strength) + offset
         | 
| 613 | 
            +
                        init_timestep = min(init_timestep, num_inference_steps)
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                        t_start = max(num_inference_steps - init_timestep + offset, 0)
         | 
| 616 | 
            +
                        timesteps = self.scheduler.timesteps[t_start:].to(device)
         | 
| 617 | 
            +
                        return timesteps, num_inference_steps - t_start
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                def run_safety_checker(self, image, device, dtype):
         | 
| 620 | 
            +
                    if self.safety_checker is not None:
         | 
| 621 | 
            +
                        safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
         | 
| 622 | 
            +
                        image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
         | 
| 623 | 
            +
                    else:
         | 
| 624 | 
            +
                        has_nsfw_concept = None
         | 
| 625 | 
            +
                    return image, has_nsfw_concept
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                def decode_latents(self, latents):
         | 
| 628 | 
            +
                    latents = 1 / 0.18215 * latents
         | 
| 629 | 
            +
                    image = self.vae.decode(latents).sample
         | 
| 630 | 
            +
                    image = (image / 2 + 0.5).clamp(0, 1)
         | 
| 631 | 
            +
                    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
         | 
| 632 | 
            +
                    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
         | 
| 633 | 
            +
                    return image
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 636 | 
            +
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 637 | 
            +
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 638 | 
            +
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 639 | 
            +
                    # and should be between [0, 1]
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 642 | 
            +
                    extra_step_kwargs = {}
         | 
| 643 | 
            +
                    if accepts_eta:
         | 
| 644 | 
            +
                        extra_step_kwargs["eta"] = eta
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                    # check if the scheduler accepts generator
         | 
| 647 | 
            +
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 648 | 
            +
                    if accepts_generator:
         | 
| 649 | 
            +
                        extra_step_kwargs["generator"] = generator
         | 
| 650 | 
            +
                    return extra_step_kwargs
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
         | 
| 653 | 
            +
                    if image is None:
         | 
| 654 | 
            +
                        shape = (
         | 
| 655 | 
            +
                            batch_size,
         | 
| 656 | 
            +
                            self.unet.in_channels,
         | 
| 657 | 
            +
                            height // self.vae_scale_factor,
         | 
| 658 | 
            +
                            width // self.vae_scale_factor,
         | 
| 659 | 
            +
                        )
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                        if latents is None:
         | 
| 662 | 
            +
                            if device.type == "mps":
         | 
| 663 | 
            +
                                # randn does not work reproducibly on mps
         | 
| 664 | 
            +
                                latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
         | 
| 665 | 
            +
                            else:
         | 
| 666 | 
            +
                                latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
         | 
| 667 | 
            +
                        else:
         | 
| 668 | 
            +
                            if latents.shape != shape:
         | 
| 669 | 
            +
                                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
         | 
| 670 | 
            +
                            latents = latents.to(device)
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                        # scale the initial noise by the standard deviation required by the scheduler
         | 
| 673 | 
            +
                        latents = latents * self.scheduler.init_noise_sigma
         | 
| 674 | 
            +
                        return latents, None, None
         | 
| 675 | 
            +
                    else:
         | 
| 676 | 
            +
                        init_latent_dist = self.vae.encode(image).latent_dist
         | 
| 677 | 
            +
                        init_latents = init_latent_dist.sample(generator=generator)
         | 
| 678 | 
            +
                        init_latents = 0.18215 * init_latents
         | 
| 679 | 
            +
                        init_latents = torch.cat([init_latents] * batch_size, dim=0)
         | 
| 680 | 
            +
                        init_latents_orig = init_latents
         | 
| 681 | 
            +
                        shape = init_latents.shape
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                        # add noise to latents using the timesteps
         | 
| 684 | 
            +
                        if device.type == "mps":
         | 
| 685 | 
            +
                            noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
         | 
| 686 | 
            +
                        else:
         | 
| 687 | 
            +
                            noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
         | 
| 688 | 
            +
                        latents = self.scheduler.add_noise(init_latents, noise, timestep)
         | 
| 689 | 
            +
                        return latents, init_latents_orig, noise
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                @torch.no_grad()
         | 
| 692 | 
            +
                def __call__(
         | 
| 693 | 
            +
                    self,
         | 
| 694 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 695 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 696 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 697 | 
            +
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 698 | 
            +
                    height: int = 512,
         | 
| 699 | 
            +
                    width: int = 512,
         | 
| 700 | 
            +
                    num_inference_steps: int = 50,
         | 
| 701 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 702 | 
            +
                    strength: float = 0.8,
         | 
| 703 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 704 | 
            +
                    eta: float = 0.0,
         | 
| 705 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 706 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 707 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 708 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 709 | 
            +
                    return_dict: bool = True,
         | 
| 710 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 711 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 712 | 
            +
                    callback_steps: int = 1,
         | 
| 713 | 
            +
                ):
         | 
| 714 | 
            +
                    r"""
         | 
| 715 | 
            +
                    Function invoked when calling the pipeline for generation.
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                    Args:
         | 
| 718 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 719 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 720 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 721 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 722 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 723 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 724 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 725 | 
            +
                            process.
         | 
| 726 | 
            +
                        mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 727 | 
            +
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 728 | 
            +
                            replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
         | 
| 729 | 
            +
                            PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
         | 
| 730 | 
            +
                            contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
         | 
| 731 | 
            +
                        height (`int`, *optional*, defaults to 512):
         | 
| 732 | 
            +
                            The height in pixels of the generated image.
         | 
| 733 | 
            +
                        width (`int`, *optional*, defaults to 512):
         | 
| 734 | 
            +
                            The width in pixels of the generated image.
         | 
| 735 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 736 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 737 | 
            +
                            expense of slower inference.
         | 
| 738 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 739 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 740 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 741 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 742 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 743 | 
            +
                            usually at the expense of lower image quality.
         | 
| 744 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 745 | 
            +
                            Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
         | 
| 746 | 
            +
                            `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
         | 
| 747 | 
            +
                            number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
         | 
| 748 | 
            +
                            noise will be maximum and the denoising process will run for the full number of iterations specified in
         | 
| 749 | 
            +
                            `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
         | 
| 750 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 751 | 
            +
                            The number of images to generate per prompt.
         | 
| 752 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 753 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 754 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 755 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 756 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 757 | 
            +
                            deterministic.
         | 
| 758 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 759 | 
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 760 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 761 | 
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 762 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 763 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 764 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 765 | 
            +
                            The output format of the generate image. Choose between
         | 
| 766 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 767 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 768 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 769 | 
            +
                            plain tuple.
         | 
| 770 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 771 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 772 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 773 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 774 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 775 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 776 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 777 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 778 | 
            +
                            called at every step.
         | 
| 779 | 
            +
             | 
| 780 | 
            +
                    Returns:
         | 
| 781 | 
            +
                        `None` if cancelled by `is_cancelled_callback`,
         | 
| 782 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 783 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 784 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 785 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 786 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 787 | 
            +
                    """
         | 
| 788 | 
            +
                    # 0. Default height and width to unet
         | 
| 789 | 
            +
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 790 | 
            +
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                    # 1. Check inputs. Raise error if not correct
         | 
| 793 | 
            +
                    self.check_inputs(prompt, height, width, strength, callback_steps)
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                    # 2. Define call parameters
         | 
| 796 | 
            +
                    batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 797 | 
            +
                    device = self._execution_device
         | 
| 798 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 799 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 800 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 801 | 
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                    # 3. Encode input prompt
         | 
| 804 | 
            +
                    text_embeddings = self._encode_prompt(
         | 
| 805 | 
            +
                        prompt,
         | 
| 806 | 
            +
                        device,
         | 
| 807 | 
            +
                        num_images_per_prompt,
         | 
| 808 | 
            +
                        do_classifier_free_guidance,
         | 
| 809 | 
            +
                        negative_prompt,
         | 
| 810 | 
            +
                        max_embeddings_multiples,
         | 
| 811 | 
            +
                    )
         | 
| 812 | 
            +
                    dtype = text_embeddings.dtype
         | 
| 813 | 
            +
             | 
| 814 | 
            +
                    # 4. Preprocess image and mask
         | 
| 815 | 
            +
                    if isinstance(image, PIL.Image.Image):
         | 
| 816 | 
            +
                        image = preprocess_image(image)
         | 
| 817 | 
            +
                    if image is not None:
         | 
| 818 | 
            +
                        image = image.to(device=self.device, dtype=dtype)
         | 
| 819 | 
            +
                    if isinstance(mask_image, PIL.Image.Image):
         | 
| 820 | 
            +
                        mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
         | 
| 821 | 
            +
                    if mask_image is not None:
         | 
| 822 | 
            +
                        mask = mask_image.to(device=self.device, dtype=dtype)
         | 
| 823 | 
            +
                        mask = torch.cat([mask] * batch_size * num_images_per_prompt)
         | 
| 824 | 
            +
                    else:
         | 
| 825 | 
            +
                        mask = None
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                    # 5. set timesteps
         | 
| 828 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 829 | 
            +
                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
         | 
| 830 | 
            +
                    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                    # 6. Prepare latent variables
         | 
| 833 | 
            +
                    latents, init_latents_orig, noise = self.prepare_latents(
         | 
| 834 | 
            +
                        image,
         | 
| 835 | 
            +
                        latent_timestep,
         | 
| 836 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 837 | 
            +
                        height,
         | 
| 838 | 
            +
                        width,
         | 
| 839 | 
            +
                        dtype,
         | 
| 840 | 
            +
                        device,
         | 
| 841 | 
            +
                        generator,
         | 
| 842 | 
            +
                        latents,
         | 
| 843 | 
            +
                    )
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 846 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                    # 8. Denoising loop
         | 
| 849 | 
            +
                    for i, t in enumerate(self.progress_bar(timesteps)):
         | 
| 850 | 
            +
                        # expand the latents if we are doing classifier free guidance
         | 
| 851 | 
            +
                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         | 
| 852 | 
            +
                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                        # predict the noise residual
         | 
| 855 | 
            +
                        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
         | 
| 856 | 
            +
             | 
| 857 | 
            +
                        # perform guidance
         | 
| 858 | 
            +
                        if do_classifier_free_guidance:
         | 
| 859 | 
            +
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 860 | 
            +
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 861 | 
            +
             | 
| 862 | 
            +
                        # compute the previous noisy sample x_t -> x_t-1
         | 
| 863 | 
            +
                        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                        if mask is not None:
         | 
| 866 | 
            +
                            # masking
         | 
| 867 | 
            +
                            init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
         | 
| 868 | 
            +
                            latents = (init_latents_proper * mask) + (latents * (1 - mask))
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                        # call the callback, if provided
         | 
| 871 | 
            +
                        if i % callback_steps == 0:
         | 
| 872 | 
            +
                            if callback is not None:
         | 
| 873 | 
            +
                                callback(i, t, latents)
         | 
| 874 | 
            +
                            if is_cancelled_callback is not None and is_cancelled_callback():
         | 
| 875 | 
            +
                                return None
         | 
| 876 | 
            +
             | 
| 877 | 
            +
                    # 9. Post-processing
         | 
| 878 | 
            +
                    image = self.decode_latents(latents)
         | 
| 879 | 
            +
             | 
| 880 | 
            +
                    # 10. Run safety checker
         | 
| 881 | 
            +
                    image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                    # 11. Convert to PIL
         | 
| 884 | 
            +
                    if output_type == "pil":
         | 
| 885 | 
            +
                        image = self.numpy_to_pil(image)
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                    if not return_dict:
         | 
| 888 | 
            +
                        return image, has_nsfw_concept
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                def text2img(
         | 
| 893 | 
            +
                    self,
         | 
| 894 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 895 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 896 | 
            +
                    height: int = 512,
         | 
| 897 | 
            +
                    width: int = 512,
         | 
| 898 | 
            +
                    num_inference_steps: int = 50,
         | 
| 899 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 900 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 901 | 
            +
                    eta: float = 0.0,
         | 
| 902 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 903 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 904 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 905 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 906 | 
            +
                    return_dict: bool = True,
         | 
| 907 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 908 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 909 | 
            +
                    callback_steps: int = 1,
         | 
| 910 | 
            +
                ):
         | 
| 911 | 
            +
                    r"""
         | 
| 912 | 
            +
                    Function for text-to-image generation.
         | 
| 913 | 
            +
                    Args:
         | 
| 914 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 915 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 916 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 917 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 918 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 919 | 
            +
                        height (`int`, *optional*, defaults to 512):
         | 
| 920 | 
            +
                            The height in pixels of the generated image.
         | 
| 921 | 
            +
                        width (`int`, *optional*, defaults to 512):
         | 
| 922 | 
            +
                            The width in pixels of the generated image.
         | 
| 923 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 924 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 925 | 
            +
                            expense of slower inference.
         | 
| 926 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 927 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 928 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 929 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 930 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 931 | 
            +
                            usually at the expense of lower image quality.
         | 
| 932 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 933 | 
            +
                            The number of images to generate per prompt.
         | 
| 934 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 935 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 936 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 937 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 938 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 939 | 
            +
                            deterministic.
         | 
| 940 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 941 | 
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 942 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 943 | 
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 944 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 945 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 946 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 947 | 
            +
                            The output format of the generate image. Choose between
         | 
| 948 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 949 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 950 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 951 | 
            +
                            plain tuple.
         | 
| 952 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 953 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 954 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 955 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 956 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 957 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 958 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 959 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 960 | 
            +
                            called at every step.
         | 
| 961 | 
            +
                    Returns:
         | 
| 962 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 963 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 964 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 965 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 966 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 967 | 
            +
                    """
         | 
| 968 | 
            +
                    return self.__call__(
         | 
| 969 | 
            +
                        prompt=prompt,
         | 
| 970 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 971 | 
            +
                        height=height,
         | 
| 972 | 
            +
                        width=width,
         | 
| 973 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 974 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 975 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 976 | 
            +
                        eta=eta,
         | 
| 977 | 
            +
                        generator=generator,
         | 
| 978 | 
            +
                        latents=latents,
         | 
| 979 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 980 | 
            +
                        output_type=output_type,
         | 
| 981 | 
            +
                        return_dict=return_dict,
         | 
| 982 | 
            +
                        callback=callback,
         | 
| 983 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 984 | 
            +
                        callback_steps=callback_steps,
         | 
| 985 | 
            +
                    )
         | 
| 986 | 
            +
             | 
| 987 | 
            +
                def img2img(
         | 
| 988 | 
            +
                    self,
         | 
| 989 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 990 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 991 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 992 | 
            +
                    strength: float = 0.8,
         | 
| 993 | 
            +
                    num_inference_steps: Optional[int] = 50,
         | 
| 994 | 
            +
                    guidance_scale: Optional[float] = 7.5,
         | 
| 995 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 996 | 
            +
                    eta: Optional[float] = 0.0,
         | 
| 997 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 998 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 999 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 1000 | 
            +
                    return_dict: bool = True,
         | 
| 1001 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 1002 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 1003 | 
            +
                    callback_steps: int = 1,
         | 
| 1004 | 
            +
                ):
         | 
| 1005 | 
            +
                    r"""
         | 
| 1006 | 
            +
                    Function for image-to-image generation.
         | 
| 1007 | 
            +
                    Args:
         | 
| 1008 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1009 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 1010 | 
            +
                            process.
         | 
| 1011 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 1012 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 1013 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 1014 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 1015 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 1016 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 1017 | 
            +
                            Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
         | 
| 1018 | 
            +
                            `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
         | 
| 1019 | 
            +
                            number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
         | 
| 1020 | 
            +
                            noise will be maximum and the denoising process will run for the full number of iterations specified in
         | 
| 1021 | 
            +
                            `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
         | 
| 1022 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 1023 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 1024 | 
            +
                            expense of slower inference. This parameter will be modulated by `strength`.
         | 
| 1025 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 1026 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 1027 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 1028 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 1029 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 1030 | 
            +
                            usually at the expense of lower image quality.
         | 
| 1031 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 1032 | 
            +
                            The number of images to generate per prompt.
         | 
| 1033 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 1034 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 1035 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 1036 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 1037 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 1038 | 
            +
                            deterministic.
         | 
| 1039 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 1040 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1041 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1042 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1043 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1044 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1045 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1046 | 
            +
                            plain tuple.
         | 
| 1047 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1048 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1049 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1050 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1051 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1052 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1053 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1054 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1055 | 
            +
                            called at every step.
         | 
| 1056 | 
            +
                    Returns:
         | 
| 1057 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1058 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1059 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1060 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1061 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1062 | 
            +
                    """
         | 
| 1063 | 
            +
                    return self.__call__(
         | 
| 1064 | 
            +
                        prompt=prompt,
         | 
| 1065 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1066 | 
            +
                        image=image,
         | 
| 1067 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1068 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1069 | 
            +
                        strength=strength,
         | 
| 1070 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1071 | 
            +
                        eta=eta,
         | 
| 1072 | 
            +
                        generator=generator,
         | 
| 1073 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1074 | 
            +
                        output_type=output_type,
         | 
| 1075 | 
            +
                        return_dict=return_dict,
         | 
| 1076 | 
            +
                        callback=callback,
         | 
| 1077 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1078 | 
            +
                        callback_steps=callback_steps,
         | 
| 1079 | 
            +
                    )
         | 
| 1080 | 
            +
             | 
| 1081 | 
            +
                def inpaint(
         | 
| 1082 | 
            +
                    self,
         | 
| 1083 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1084 | 
            +
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1085 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 1086 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 1087 | 
            +
                    strength: float = 0.8,
         | 
| 1088 | 
            +
                    num_inference_steps: Optional[int] = 50,
         | 
| 1089 | 
            +
                    guidance_scale: Optional[float] = 7.5,
         | 
| 1090 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 1091 | 
            +
                    eta: Optional[float] = 0.0,
         | 
| 1092 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 1093 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 1094 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 1095 | 
            +
                    return_dict: bool = True,
         | 
| 1096 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 1097 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 1098 | 
            +
                    callback_steps: int = 1,
         | 
| 1099 | 
            +
                ):
         | 
| 1100 | 
            +
                    r"""
         | 
| 1101 | 
            +
                    Function for inpaint.
         | 
| 1102 | 
            +
                    Args:
         | 
| 1103 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1104 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 1105 | 
            +
                            process. This is the image whose masked region will be inpainted.
         | 
| 1106 | 
            +
                        mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1107 | 
            +
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 1108 | 
            +
                            replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
         | 
| 1109 | 
            +
                            PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
         | 
| 1110 | 
            +
                            contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
         | 
| 1111 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 1112 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 1113 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 1114 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 1115 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 1116 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 1117 | 
            +
                            Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
         | 
| 1118 | 
            +
                            is 1, the denoising process will be run on the masked area for the full number of iterations specified
         | 
| 1119 | 
            +
                            in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
         | 
| 1120 | 
            +
                            noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
         | 
| 1121 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 1122 | 
            +
                            The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
         | 
| 1123 | 
            +
                            the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
         | 
| 1124 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 1125 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 1126 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 1127 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 1128 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 1129 | 
            +
                            usually at the expense of lower image quality.
         | 
| 1130 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 1131 | 
            +
                            The number of images to generate per prompt.
         | 
| 1132 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 1133 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 1134 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 1135 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 1136 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 1137 | 
            +
                            deterministic.
         | 
| 1138 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 1139 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1140 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1141 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1142 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1143 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1144 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1145 | 
            +
                            plain tuple.
         | 
| 1146 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1147 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1148 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1149 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1150 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1151 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1152 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1153 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1154 | 
            +
                            called at every step.
         | 
| 1155 | 
            +
                    Returns:
         | 
| 1156 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1157 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1158 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1159 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1160 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1161 | 
            +
                    """
         | 
| 1162 | 
            +
                    return self.__call__(
         | 
| 1163 | 
            +
                        prompt=prompt,
         | 
| 1164 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1165 | 
            +
                        image=image,
         | 
| 1166 | 
            +
                        mask_image=mask_image,
         | 
| 1167 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1168 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1169 | 
            +
                        strength=strength,
         | 
| 1170 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1171 | 
            +
                        eta=eta,
         | 
| 1172 | 
            +
                        generator=generator,
         | 
| 1173 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1174 | 
            +
                        output_type=output_type,
         | 
| 1175 | 
            +
                        return_dict=return_dict,
         | 
| 1176 | 
            +
                        callback=callback,
         | 
| 1177 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1178 | 
            +
                        callback_steps=callback_steps,
         | 
| 1179 | 
            +
                    )
         | 
    	
        library/merge_lora_gui.py
    ADDED
    
    | @@ -0,0 +1,156 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from .common_gui import (
         | 
| 6 | 
            +
                get_saveasfilename_path,
         | 
| 7 | 
            +
                get_any_file_path,
         | 
| 8 | 
            +
                get_file_path,
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 12 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 13 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 14 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 15 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def merge_lora(
         | 
| 19 | 
            +
                lora_a_model,
         | 
| 20 | 
            +
                lora_b_model,
         | 
| 21 | 
            +
                ratio,
         | 
| 22 | 
            +
                save_to,
         | 
| 23 | 
            +
                precision,
         | 
| 24 | 
            +
                save_precision,
         | 
| 25 | 
            +
            ):
         | 
| 26 | 
            +
                # Check for caption_text_input
         | 
| 27 | 
            +
                if lora_a_model == '':
         | 
| 28 | 
            +
                    msgbox('Invalid model A file')
         | 
| 29 | 
            +
                    return
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                if lora_b_model == '':
         | 
| 32 | 
            +
                    msgbox('Invalid model B file')
         | 
| 33 | 
            +
                    return
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                # Check if source model exist
         | 
| 36 | 
            +
                if not os.path.isfile(lora_a_model):
         | 
| 37 | 
            +
                    msgbox('The provided model A is not a file')
         | 
| 38 | 
            +
                    return
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                if not os.path.isfile(lora_b_model):
         | 
| 41 | 
            +
                    msgbox('The provided model B is not a file')
         | 
| 42 | 
            +
                    return
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                ratio_a = ratio
         | 
| 45 | 
            +
                ratio_b = 1 - ratio
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                run_cmd = f'{PYTHON} "{os.path.join("networks","merge_lora.py")}"'
         | 
| 48 | 
            +
                run_cmd += f' --save_precision {save_precision}'
         | 
| 49 | 
            +
                run_cmd += f' --precision {precision}'
         | 
| 50 | 
            +
                run_cmd += f' --save_to "{save_to}"'
         | 
| 51 | 
            +
                run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"'
         | 
| 52 | 
            +
                run_cmd += f' --ratios {ratio_a} {ratio_b}'
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                print(run_cmd)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # Run the command
         | 
| 57 | 
            +
                if os.name == 'posix':
         | 
| 58 | 
            +
                    os.system(run_cmd)
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    subprocess.run(run_cmd)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            ###
         | 
| 64 | 
            +
            # Gradio UI
         | 
| 65 | 
            +
            ###
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def gradio_merge_lora_tab():
         | 
| 69 | 
            +
                with gr.Tab('Merge LoRA'):
         | 
| 70 | 
            +
                    gr.Markdown('This utility can merge two LoRA networks together.')
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
         | 
| 73 | 
            +
                    lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    with gr.Row():
         | 
| 76 | 
            +
                        lora_a_model = gr.Textbox(
         | 
| 77 | 
            +
                            label='LoRA model "A"',
         | 
| 78 | 
            +
                            placeholder='Path to the LoRA A model',
         | 
| 79 | 
            +
                            interactive=True,
         | 
| 80 | 
            +
                        )
         | 
| 81 | 
            +
                        button_lora_a_model_file = gr.Button(
         | 
| 82 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 83 | 
            +
                        )
         | 
| 84 | 
            +
                        button_lora_a_model_file.click(
         | 
| 85 | 
            +
                            get_file_path,
         | 
| 86 | 
            +
                            inputs=[lora_a_model, lora_ext, lora_ext_name],
         | 
| 87 | 
            +
                            outputs=lora_a_model,
         | 
| 88 | 
            +
                            show_progress=False,
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        lora_b_model = gr.Textbox(
         | 
| 92 | 
            +
                            label='LoRA model "B"',
         | 
| 93 | 
            +
                            placeholder='Path to the LoRA B model',
         | 
| 94 | 
            +
                            interactive=True,
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
                        button_lora_b_model_file = gr.Button(
         | 
| 97 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
                        button_lora_b_model_file.click(
         | 
| 100 | 
            +
                            get_file_path,
         | 
| 101 | 
            +
                            inputs=[lora_b_model, lora_ext, lora_ext_name],
         | 
| 102 | 
            +
                            outputs=lora_b_model,
         | 
| 103 | 
            +
                            show_progress=False,
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
                    with gr.Row():
         | 
| 106 | 
            +
                        ratio = gr.Slider(
         | 
| 107 | 
            +
                            label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
         | 
| 108 | 
            +
                            minimum=0,
         | 
| 109 | 
            +
                            maximum=1,
         | 
| 110 | 
            +
                            step=0.01,
         | 
| 111 | 
            +
                            value=0.5,
         | 
| 112 | 
            +
                            interactive=True,
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    with gr.Row():
         | 
| 116 | 
            +
                        save_to = gr.Textbox(
         | 
| 117 | 
            +
                            label='Save to',
         | 
| 118 | 
            +
                            placeholder='path for the file to save...',
         | 
| 119 | 
            +
                            interactive=True,
         | 
| 120 | 
            +
                        )
         | 
| 121 | 
            +
                        button_save_to = gr.Button(
         | 
| 122 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 123 | 
            +
                        )
         | 
| 124 | 
            +
                        button_save_to.click(
         | 
| 125 | 
            +
                            get_saveasfilename_path,
         | 
| 126 | 
            +
                            inputs=[save_to, lora_ext, lora_ext_name],
         | 
| 127 | 
            +
                            outputs=save_to,
         | 
| 128 | 
            +
                            show_progress=False,
         | 
| 129 | 
            +
                        )
         | 
| 130 | 
            +
                        precision = gr.Dropdown(
         | 
| 131 | 
            +
                            label='Merge precision',
         | 
| 132 | 
            +
                            choices=['fp16', 'bf16', 'float'],
         | 
| 133 | 
            +
                            value='float',
         | 
| 134 | 
            +
                            interactive=True,
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
                        save_precision = gr.Dropdown(
         | 
| 137 | 
            +
                            label='Save precision',
         | 
| 138 | 
            +
                            choices=['fp16', 'bf16', 'float'],
         | 
| 139 | 
            +
                            value='float',
         | 
| 140 | 
            +
                            interactive=True,
         | 
| 141 | 
            +
                        )
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    convert_button = gr.Button('Merge model')
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    convert_button.click(
         | 
| 146 | 
            +
                        merge_lora,
         | 
| 147 | 
            +
                        inputs=[
         | 
| 148 | 
            +
                            lora_a_model,
         | 
| 149 | 
            +
                            lora_b_model,
         | 
| 150 | 
            +
                            ratio,
         | 
| 151 | 
            +
                            save_to,
         | 
| 152 | 
            +
                            precision,
         | 
| 153 | 
            +
                            save_precision,
         | 
| 154 | 
            +
                        ],
         | 
| 155 | 
            +
                        show_progress=False,
         | 
| 156 | 
            +
                    )
         | 
    	
        library/model_util.py
    ADDED
    
    | @@ -0,0 +1,1165 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # v1: split from train_db_fixed.py.
         | 
| 2 | 
            +
            # v2: support safetensors
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
         | 
| 8 | 
            +
            from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
         | 
| 9 | 
            +
            from safetensors.torch import load_file, save_file
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # DiffUsers版StableDiffusionのモデルパラメータ
         | 
| 12 | 
            +
            NUM_TRAIN_TIMESTEPS = 1000
         | 
| 13 | 
            +
            BETA_START = 0.00085
         | 
| 14 | 
            +
            BETA_END = 0.0120
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            UNET_PARAMS_MODEL_CHANNELS = 320
         | 
| 17 | 
            +
            UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
         | 
| 18 | 
            +
            UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
         | 
| 19 | 
            +
            UNET_PARAMS_IMAGE_SIZE = 64  # fixed from old invalid value `32`
         | 
| 20 | 
            +
            UNET_PARAMS_IN_CHANNELS = 4
         | 
| 21 | 
            +
            UNET_PARAMS_OUT_CHANNELS = 4
         | 
| 22 | 
            +
            UNET_PARAMS_NUM_RES_BLOCKS = 2
         | 
| 23 | 
            +
            UNET_PARAMS_CONTEXT_DIM = 768
         | 
| 24 | 
            +
            UNET_PARAMS_NUM_HEADS = 8
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            VAE_PARAMS_Z_CHANNELS = 4
         | 
| 27 | 
            +
            VAE_PARAMS_RESOLUTION = 256
         | 
| 28 | 
            +
            VAE_PARAMS_IN_CHANNELS = 3
         | 
| 29 | 
            +
            VAE_PARAMS_OUT_CH = 3
         | 
| 30 | 
            +
            VAE_PARAMS_CH = 128
         | 
| 31 | 
            +
            VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
         | 
| 32 | 
            +
            VAE_PARAMS_NUM_RES_BLOCKS = 2
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # V2
         | 
| 35 | 
            +
            V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
         | 
| 36 | 
            +
            V2_UNET_PARAMS_CONTEXT_DIM = 1024
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            # Diffusersの設定を読み込むための参照モデル
         | 
| 39 | 
            +
            DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
         | 
| 40 | 
            +
            DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            # region StableDiffusion->Diffusersの変換コード
         | 
| 44 | 
            +
            # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def shave_segments(path, n_shave_prefix_segments=1):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                Removes segments. Positive values shave the first segments, negative shave the last segments.
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                if n_shave_prefix_segments >= 0:
         | 
| 52 | 
            +
                    return ".".join(path.split(".")[n_shave_prefix_segments:])
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                    return ".".join(path.split(".")[:n_shave_prefix_segments])
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                Updates paths inside resnets to the new naming scheme (local renaming)
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                mapping = []
         | 
| 62 | 
            +
                for old_item in old_list:
         | 
| 63 | 
            +
                    new_item = old_item.replace("in_layers.0", "norm1")
         | 
| 64 | 
            +
                    new_item = new_item.replace("in_layers.2", "conv1")
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    new_item = new_item.replace("out_layers.0", "norm2")
         | 
| 67 | 
            +
                    new_item = new_item.replace("out_layers.3", "conv2")
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    new_item = new_item.replace("emb_layers.1", "time_emb_proj")
         | 
| 70 | 
            +
                    new_item = new_item.replace("skip_connection", "conv_shortcut")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                return mapping
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
                Updates paths inside resnets to the new naming scheme (local renaming)
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                mapping = []
         | 
| 84 | 
            +
                for old_item in old_list:
         | 
| 85 | 
            +
                    new_item = old_item
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    new_item = new_item.replace("nin_shortcut", "conv_shortcut")
         | 
| 88 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                return mapping
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def renew_attention_paths(old_list, n_shave_prefix_segments=0):
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                Updates paths inside attentions to the new naming scheme (local renaming)
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                mapping = []
         | 
| 100 | 
            +
                for old_item in old_list:
         | 
| 101 | 
            +
                    new_item = old_item
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    #         new_item = new_item.replace('norm.weight', 'group_norm.weight')
         | 
| 104 | 
            +
                    #         new_item = new_item.replace('norm.bias', 'group_norm.bias')
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    #         new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
         | 
| 107 | 
            +
                    #         new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    #         new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                return mapping
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
         | 
| 117 | 
            +
                """
         | 
| 118 | 
            +
                Updates paths inside attentions to the new naming scheme (local renaming)
         | 
| 119 | 
            +
                """
         | 
| 120 | 
            +
                mapping = []
         | 
| 121 | 
            +
                for old_item in old_list:
         | 
| 122 | 
            +
                    new_item = old_item
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    new_item = new_item.replace("norm.weight", "group_norm.weight")
         | 
| 125 | 
            +
                    new_item = new_item.replace("norm.bias", "group_norm.bias")
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    new_item = new_item.replace("q.weight", "query.weight")
         | 
| 128 | 
            +
                    new_item = new_item.replace("q.bias", "query.bias")
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    new_item = new_item.replace("k.weight", "key.weight")
         | 
| 131 | 
            +
                    new_item = new_item.replace("k.bias", "key.bias")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    new_item = new_item.replace("v.weight", "value.weight")
         | 
| 134 | 
            +
                    new_item = new_item.replace("v.bias", "value.bias")
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
         | 
| 137 | 
            +
                    new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                return mapping
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            def assign_to_checkpoint(
         | 
| 147 | 
            +
                paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
         | 
| 148 | 
            +
            ):
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                This does the final conversion step: take locally converted weights and apply a global renaming
         | 
| 151 | 
            +
                to them. It splits attention layers, and takes into account additional replacements
         | 
| 152 | 
            +
                that may arise.
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                Assigns the weights to the new checkpoint.
         | 
| 155 | 
            +
                """
         | 
| 156 | 
            +
                assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                # Splits the attention layers into three variables.
         | 
| 159 | 
            +
                if attention_paths_to_split is not None:
         | 
| 160 | 
            +
                    for path, path_map in attention_paths_to_split.items():
         | 
| 161 | 
            +
                        old_tensor = old_checkpoint[path]
         | 
| 162 | 
            +
                        channels = old_tensor.shape[0] // 3
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                        target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                        old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
         | 
| 169 | 
            +
                        query, key, value = old_tensor.split(channels // num_heads, dim=1)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                        checkpoint[path_map["query"]] = query.reshape(target_shape)
         | 
| 172 | 
            +
                        checkpoint[path_map["key"]] = key.reshape(target_shape)
         | 
| 173 | 
            +
                        checkpoint[path_map["value"]] = value.reshape(target_shape)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                for path in paths:
         | 
| 176 | 
            +
                    new_path = path["new"]
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    # These have already been assigned
         | 
| 179 | 
            +
                    if attention_paths_to_split is not None and new_path in attention_paths_to_split:
         | 
| 180 | 
            +
                        continue
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # Global renaming happens here
         | 
| 183 | 
            +
                    new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
         | 
| 184 | 
            +
                    new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
         | 
| 185 | 
            +
                    new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if additional_replacements is not None:
         | 
| 188 | 
            +
                        for replacement in additional_replacements:
         | 
| 189 | 
            +
                            new_path = new_path.replace(replacement["old"], replacement["new"])
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    # proj_attn.weight has to be converted from conv 1D to linear
         | 
| 192 | 
            +
                    if "proj_attn.weight" in new_path:
         | 
| 193 | 
            +
                        checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
         | 
| 194 | 
            +
                    else:
         | 
| 195 | 
            +
                        checkpoint[new_path] = old_checkpoint[path["old"]]
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            def conv_attn_to_linear(checkpoint):
         | 
| 199 | 
            +
                keys = list(checkpoint.keys())
         | 
| 200 | 
            +
                attn_keys = ["query.weight", "key.weight", "value.weight"]
         | 
| 201 | 
            +
                for key in keys:
         | 
| 202 | 
            +
                    if ".".join(key.split(".")[-2:]) in attn_keys:
         | 
| 203 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 204 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0, 0]
         | 
| 205 | 
            +
                    elif "proj_attn.weight" in key:
         | 
| 206 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 207 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0]
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def linear_transformer_to_conv(checkpoint):
         | 
| 211 | 
            +
                keys = list(checkpoint.keys())
         | 
| 212 | 
            +
                tf_keys = ["proj_in.weight", "proj_out.weight"]
         | 
| 213 | 
            +
                for key in keys:
         | 
| 214 | 
            +
                    if ".".join(key.split(".")[-2:]) in tf_keys:
         | 
| 215 | 
            +
                        if checkpoint[key].ndim == 2:
         | 
| 216 | 
            +
                            checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            def convert_ldm_unet_checkpoint(v2, checkpoint, config):
         | 
| 220 | 
            +
                """
         | 
| 221 | 
            +
                Takes a state dict and a config, and returns a converted checkpoint.
         | 
| 222 | 
            +
                """
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                # extract state_dict for UNet
         | 
| 225 | 
            +
                unet_state_dict = {}
         | 
| 226 | 
            +
                unet_key = "model.diffusion_model."
         | 
| 227 | 
            +
                keys = list(checkpoint.keys())
         | 
| 228 | 
            +
                for key in keys:
         | 
| 229 | 
            +
                    if key.startswith(unet_key):
         | 
| 230 | 
            +
                        unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                new_checkpoint = {}
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
         | 
| 235 | 
            +
                new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
         | 
| 236 | 
            +
                new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
         | 
| 237 | 
            +
                new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
         | 
| 240 | 
            +
                new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
         | 
| 243 | 
            +
                new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
         | 
| 244 | 
            +
                new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
         | 
| 245 | 
            +
                new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                # Retrieves the keys for the input blocks only
         | 
| 248 | 
            +
                num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
         | 
| 249 | 
            +
                input_blocks = {
         | 
| 250 | 
            +
                    layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
         | 
| 251 | 
            +
                }
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                # Retrieves the keys for the middle blocks only
         | 
| 254 | 
            +
                num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
         | 
| 255 | 
            +
                middle_blocks = {
         | 
| 256 | 
            +
                    layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
         | 
| 257 | 
            +
                }
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                # Retrieves the keys for the output blocks only
         | 
| 260 | 
            +
                num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
         | 
| 261 | 
            +
                output_blocks = {
         | 
| 262 | 
            +
                    layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
         | 
| 263 | 
            +
                }
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                for i in range(1, num_input_blocks):
         | 
| 266 | 
            +
                    block_id = (i - 1) // (config["layers_per_block"] + 1)
         | 
| 267 | 
            +
                    layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
         | 
| 270 | 
            +
                    attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
         | 
| 273 | 
            +
                        new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
         | 
| 274 | 
            +
                            f"input_blocks.{i}.0.op.weight"
         | 
| 275 | 
            +
                        )
         | 
| 276 | 
            +
                        new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    paths = renew_resnet_paths(resnets)
         | 
| 279 | 
            +
                    meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
         | 
| 280 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    if len(attentions):
         | 
| 283 | 
            +
                        paths = renew_attention_paths(attentions)
         | 
| 284 | 
            +
                        meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
         | 
| 285 | 
            +
                        assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                resnet_0 = middle_blocks[0]
         | 
| 288 | 
            +
                attentions = middle_blocks[1]
         | 
| 289 | 
            +
                resnet_1 = middle_blocks[2]
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                resnet_0_paths = renew_resnet_paths(resnet_0)
         | 
| 292 | 
            +
                assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                resnet_1_paths = renew_resnet_paths(resnet_1)
         | 
| 295 | 
            +
                assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                attentions_paths = renew_attention_paths(attentions)
         | 
| 298 | 
            +
                meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
         | 
| 299 | 
            +
                assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                for i in range(num_output_blocks):
         | 
| 302 | 
            +
                    block_id = i // (config["layers_per_block"] + 1)
         | 
| 303 | 
            +
                    layer_in_block_id = i % (config["layers_per_block"] + 1)
         | 
| 304 | 
            +
                    output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
         | 
| 305 | 
            +
                    output_block_list = {}
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    for layer in output_block_layers:
         | 
| 308 | 
            +
                        layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
         | 
| 309 | 
            +
                        if layer_id in output_block_list:
         | 
| 310 | 
            +
                            output_block_list[layer_id].append(layer_name)
         | 
| 311 | 
            +
                        else:
         | 
| 312 | 
            +
                            output_block_list[layer_id] = [layer_name]
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    if len(output_block_list) > 1:
         | 
| 315 | 
            +
                        resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
         | 
| 316 | 
            +
                        attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                        resnet_0_paths = renew_resnet_paths(resnets)
         | 
| 319 | 
            +
                        paths = renew_resnet_paths(resnets)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                        meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
         | 
| 322 | 
            +
                        assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                        # オリジナル:
         | 
| 325 | 
            +
                        # if ["conv.weight", "conv.bias"] in output_block_list.values():
         | 
| 326 | 
            +
                        #   index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
         | 
| 329 | 
            +
                        for l in output_block_list.values():
         | 
| 330 | 
            +
                            l.sort()
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                        if ["conv.bias", "conv.weight"] in output_block_list.values():
         | 
| 333 | 
            +
                            index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
         | 
| 334 | 
            +
                            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
         | 
| 335 | 
            +
                                f"output_blocks.{i}.{index}.conv.bias"
         | 
| 336 | 
            +
                            ]
         | 
| 337 | 
            +
                            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
         | 
| 338 | 
            +
                                f"output_blocks.{i}.{index}.conv.weight"
         | 
| 339 | 
            +
                            ]
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                            # Clear attentions as they have been attributed above.
         | 
| 342 | 
            +
                            if len(attentions) == 2:
         | 
| 343 | 
            +
                                attentions = []
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                        if len(attentions):
         | 
| 346 | 
            +
                            paths = renew_attention_paths(attentions)
         | 
| 347 | 
            +
                            meta_path = {
         | 
| 348 | 
            +
                                "old": f"output_blocks.{i}.1",
         | 
| 349 | 
            +
                                "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
         | 
| 350 | 
            +
                            }
         | 
| 351 | 
            +
                            assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 352 | 
            +
                    else:
         | 
| 353 | 
            +
                        resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
         | 
| 354 | 
            +
                        for path in resnet_0_paths:
         | 
| 355 | 
            +
                            old_path = ".".join(["output_blocks", str(i), path["old"]])
         | 
| 356 | 
            +
                            new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                            new_checkpoint[new_path] = unet_state_dict[old_path]
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
         | 
| 361 | 
            +
                if v2:
         | 
| 362 | 
            +
                    linear_transformer_to_conv(new_checkpoint)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                return new_checkpoint
         | 
| 365 | 
            +
             | 
| 366 | 
            +
             | 
| 367 | 
            +
            def convert_ldm_vae_checkpoint(checkpoint, config):
         | 
| 368 | 
            +
                # extract state dict for VAE
         | 
| 369 | 
            +
                vae_state_dict = {}
         | 
| 370 | 
            +
                vae_key = "first_stage_model."
         | 
| 371 | 
            +
                keys = list(checkpoint.keys())
         | 
| 372 | 
            +
                for key in keys:
         | 
| 373 | 
            +
                    if key.startswith(vae_key):
         | 
| 374 | 
            +
                        vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
         | 
| 375 | 
            +
                # if len(vae_state_dict) == 0:
         | 
| 376 | 
            +
                #   # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
         | 
| 377 | 
            +
                #   vae_state_dict = checkpoint
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                new_checkpoint = {}
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
         | 
| 382 | 
            +
                new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
         | 
| 383 | 
            +
                new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
         | 
| 384 | 
            +
                new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
         | 
| 385 | 
            +
                new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
         | 
| 386 | 
            +
                new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
         | 
| 389 | 
            +
                new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
         | 
| 390 | 
            +
                new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
         | 
| 391 | 
            +
                new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
         | 
| 392 | 
            +
                new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
         | 
| 393 | 
            +
                new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
         | 
| 396 | 
            +
                new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
         | 
| 397 | 
            +
                new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
         | 
| 398 | 
            +
                new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                # Retrieves the keys for the encoder down blocks only
         | 
| 401 | 
            +
                num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
         | 
| 402 | 
            +
                down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                # Retrieves the keys for the decoder up blocks only
         | 
| 405 | 
            +
                num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
         | 
| 406 | 
            +
                up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                for i in range(num_down_blocks):
         | 
| 409 | 
            +
                    resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
         | 
| 412 | 
            +
                        new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
         | 
| 413 | 
            +
                            f"encoder.down.{i}.downsample.conv.weight"
         | 
| 414 | 
            +
                        )
         | 
| 415 | 
            +
                        new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
         | 
| 416 | 
            +
                            f"encoder.down.{i}.downsample.conv.bias"
         | 
| 417 | 
            +
                        )
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 420 | 
            +
                    meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
         | 
| 421 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
         | 
| 424 | 
            +
                num_mid_res_blocks = 2
         | 
| 425 | 
            +
                for i in range(1, num_mid_res_blocks + 1):
         | 
| 426 | 
            +
                    resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 429 | 
            +
                    meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
         | 
| 430 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
         | 
| 433 | 
            +
                paths = renew_vae_attention_paths(mid_attentions)
         | 
| 434 | 
            +
                meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
         | 
| 435 | 
            +
                assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 436 | 
            +
                conv_attn_to_linear(new_checkpoint)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                for i in range(num_up_blocks):
         | 
| 439 | 
            +
                    block_id = num_up_blocks - 1 - i
         | 
| 440 | 
            +
                    resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
         | 
| 443 | 
            +
                        new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
         | 
| 444 | 
            +
                            f"decoder.up.{block_id}.upsample.conv.weight"
         | 
| 445 | 
            +
                        ]
         | 
| 446 | 
            +
                        new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
         | 
| 447 | 
            +
                            f"decoder.up.{block_id}.upsample.conv.bias"
         | 
| 448 | 
            +
                        ]
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 451 | 
            +
                    meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
         | 
| 452 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
         | 
| 455 | 
            +
                num_mid_res_blocks = 2
         | 
| 456 | 
            +
                for i in range(1, num_mid_res_blocks + 1):
         | 
| 457 | 
            +
                    resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 460 | 
            +
                    meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
         | 
| 461 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
         | 
| 464 | 
            +
                paths = renew_vae_attention_paths(mid_attentions)
         | 
| 465 | 
            +
                meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
         | 
| 466 | 
            +
                assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 467 | 
            +
                conv_attn_to_linear(new_checkpoint)
         | 
| 468 | 
            +
                return new_checkpoint
         | 
| 469 | 
            +
             | 
| 470 | 
            +
             | 
| 471 | 
            +
            def create_unet_diffusers_config(v2):
         | 
| 472 | 
            +
                """
         | 
| 473 | 
            +
                Creates a config for the diffusers based on the config of the LDM model.
         | 
| 474 | 
            +
                """
         | 
| 475 | 
            +
                # unet_params = original_config.model.params.unet_config.params
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                down_block_types = []
         | 
| 480 | 
            +
                resolution = 1
         | 
| 481 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 482 | 
            +
                    block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
         | 
| 483 | 
            +
                    down_block_types.append(block_type)
         | 
| 484 | 
            +
                    if i != len(block_out_channels) - 1:
         | 
| 485 | 
            +
                        resolution *= 2
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                up_block_types = []
         | 
| 488 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 489 | 
            +
                    block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
         | 
| 490 | 
            +
                    up_block_types.append(block_type)
         | 
| 491 | 
            +
                    resolution //= 2
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                config = dict(
         | 
| 494 | 
            +
                    sample_size=UNET_PARAMS_IMAGE_SIZE,
         | 
| 495 | 
            +
                    in_channels=UNET_PARAMS_IN_CHANNELS,
         | 
| 496 | 
            +
                    out_channels=UNET_PARAMS_OUT_CHANNELS,
         | 
| 497 | 
            +
                    down_block_types=tuple(down_block_types),
         | 
| 498 | 
            +
                    up_block_types=tuple(up_block_types),
         | 
| 499 | 
            +
                    block_out_channels=tuple(block_out_channels),
         | 
| 500 | 
            +
                    layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
         | 
| 501 | 
            +
                    cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
         | 
| 502 | 
            +
                    attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
         | 
| 503 | 
            +
                )
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                return config
         | 
| 506 | 
            +
             | 
| 507 | 
            +
             | 
| 508 | 
            +
            def create_vae_diffusers_config():
         | 
| 509 | 
            +
                """
         | 
| 510 | 
            +
                Creates a config for the diffusers based on the config of the LDM model.
         | 
| 511 | 
            +
                """
         | 
| 512 | 
            +
                # vae_params = original_config.model.params.first_stage_config.params.ddconfig
         | 
| 513 | 
            +
                # _ = original_config.model.params.first_stage_config.params.embed_dim
         | 
| 514 | 
            +
                block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
         | 
| 515 | 
            +
                down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
         | 
| 516 | 
            +
                up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                config = dict(
         | 
| 519 | 
            +
                    sample_size=VAE_PARAMS_RESOLUTION,
         | 
| 520 | 
            +
                    in_channels=VAE_PARAMS_IN_CHANNELS,
         | 
| 521 | 
            +
                    out_channels=VAE_PARAMS_OUT_CH,
         | 
| 522 | 
            +
                    down_block_types=tuple(down_block_types),
         | 
| 523 | 
            +
                    up_block_types=tuple(up_block_types),
         | 
| 524 | 
            +
                    block_out_channels=tuple(block_out_channels),
         | 
| 525 | 
            +
                    latent_channels=VAE_PARAMS_Z_CHANNELS,
         | 
| 526 | 
            +
                    layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
         | 
| 527 | 
            +
                )
         | 
| 528 | 
            +
                return config
         | 
| 529 | 
            +
             | 
| 530 | 
            +
             | 
| 531 | 
            +
            def convert_ldm_clip_checkpoint_v1(checkpoint):
         | 
| 532 | 
            +
                keys = list(checkpoint.keys())
         | 
| 533 | 
            +
                text_model_dict = {}
         | 
| 534 | 
            +
                for key in keys:
         | 
| 535 | 
            +
                    if key.startswith("cond_stage_model.transformer"):
         | 
| 536 | 
            +
                        text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
         | 
| 537 | 
            +
                return text_model_dict
         | 
| 538 | 
            +
             | 
| 539 | 
            +
             | 
| 540 | 
            +
            def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
         | 
| 541 | 
            +
                # 嫌になるくらい違うぞ!
         | 
| 542 | 
            +
                def convert_key(key):
         | 
| 543 | 
            +
                    if not key.startswith("cond_stage_model"):
         | 
| 544 | 
            +
                        return None
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                    # common conversion
         | 
| 547 | 
            +
                    key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
         | 
| 548 | 
            +
                    key = key.replace("cond_stage_model.model.", "text_model.")
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    if "resblocks" in key:
         | 
| 551 | 
            +
                        # resblocks conversion
         | 
| 552 | 
            +
                        key = key.replace(".resblocks.", ".layers.")
         | 
| 553 | 
            +
                        if ".ln_" in key:
         | 
| 554 | 
            +
                            key = key.replace(".ln_", ".layer_norm")
         | 
| 555 | 
            +
                        elif ".mlp." in key:
         | 
| 556 | 
            +
                            key = key.replace(".c_fc.", ".fc1.")
         | 
| 557 | 
            +
                            key = key.replace(".c_proj.", ".fc2.")
         | 
| 558 | 
            +
                        elif ".attn.out_proj" in key:
         | 
| 559 | 
            +
                            key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
         | 
| 560 | 
            +
                        elif ".attn.in_proj" in key:
         | 
| 561 | 
            +
                            key = None  # 特殊なので後で処理する
         | 
| 562 | 
            +
                        else:
         | 
| 563 | 
            +
                            raise ValueError(f"unexpected key in SD: {key}")
         | 
| 564 | 
            +
                    elif ".positional_embedding" in key:
         | 
| 565 | 
            +
                        key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
         | 
| 566 | 
            +
                    elif ".text_projection" in key:
         | 
| 567 | 
            +
                        key = None  # 使われない???
         | 
| 568 | 
            +
                    elif ".logit_scale" in key:
         | 
| 569 | 
            +
                        key = None  # 使われない???
         | 
| 570 | 
            +
                    elif ".token_embedding" in key:
         | 
| 571 | 
            +
                        key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
         | 
| 572 | 
            +
                    elif ".ln_final" in key:
         | 
| 573 | 
            +
                        key = key.replace(".ln_final", ".final_layer_norm")
         | 
| 574 | 
            +
                    return key
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                keys = list(checkpoint.keys())
         | 
| 577 | 
            +
                new_sd = {}
         | 
| 578 | 
            +
                for key in keys:
         | 
| 579 | 
            +
                    # remove resblocks 23
         | 
| 580 | 
            +
                    if ".resblocks.23." in key:
         | 
| 581 | 
            +
                        continue
         | 
| 582 | 
            +
                    new_key = convert_key(key)
         | 
| 583 | 
            +
                    if new_key is None:
         | 
| 584 | 
            +
                        continue
         | 
| 585 | 
            +
                    new_sd[new_key] = checkpoint[key]
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                # attnの変換
         | 
| 588 | 
            +
                for key in keys:
         | 
| 589 | 
            +
                    if ".resblocks.23." in key:
         | 
| 590 | 
            +
                        continue
         | 
| 591 | 
            +
                    if ".resblocks" in key and ".attn.in_proj_" in key:
         | 
| 592 | 
            +
                        # 三つに分割
         | 
| 593 | 
            +
                        values = torch.chunk(checkpoint[key], 3)
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                        key_suffix = ".weight" if "weight" in key else ".bias"
         | 
| 596 | 
            +
                        key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
         | 
| 597 | 
            +
                        key_pfx = key_pfx.replace("_weight", "")
         | 
| 598 | 
            +
                        key_pfx = key_pfx.replace("_bias", "")
         | 
| 599 | 
            +
                        key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
         | 
| 600 | 
            +
                        new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
         | 
| 601 | 
            +
                        new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
         | 
| 602 | 
            +
                        new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                # rename or add position_ids
         | 
| 605 | 
            +
                ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
         | 
| 606 | 
            +
                if ANOTHER_POSITION_IDS_KEY in new_sd:
         | 
| 607 | 
            +
                    # waifu diffusion v1.4
         | 
| 608 | 
            +
                    position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
         | 
| 609 | 
            +
                    del new_sd[ANOTHER_POSITION_IDS_KEY]
         | 
| 610 | 
            +
                else:
         | 
| 611 | 
            +
                    position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                new_sd["text_model.embeddings.position_ids"] = position_ids
         | 
| 614 | 
            +
                return new_sd
         | 
| 615 | 
            +
             | 
| 616 | 
            +
             | 
| 617 | 
            +
            # endregion
         | 
| 618 | 
            +
             | 
| 619 | 
            +
             | 
| 620 | 
            +
            # region Diffusers->StableDiffusion の変換コード
         | 
| 621 | 
            +
            # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
             | 
| 624 | 
            +
            def conv_transformer_to_linear(checkpoint):
         | 
| 625 | 
            +
                keys = list(checkpoint.keys())
         | 
| 626 | 
            +
                tf_keys = ["proj_in.weight", "proj_out.weight"]
         | 
| 627 | 
            +
                for key in keys:
         | 
| 628 | 
            +
                    if ".".join(key.split(".")[-2:]) in tf_keys:
         | 
| 629 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 630 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0, 0]
         | 
| 631 | 
            +
             | 
| 632 | 
            +
             | 
| 633 | 
            +
            def convert_unet_state_dict_to_sd(v2, unet_state_dict):
         | 
| 634 | 
            +
                unet_conversion_map = [
         | 
| 635 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 636 | 
            +
                    ("time_embed.0.weight", "time_embedding.linear_1.weight"),
         | 
| 637 | 
            +
                    ("time_embed.0.bias", "time_embedding.linear_1.bias"),
         | 
| 638 | 
            +
                    ("time_embed.2.weight", "time_embedding.linear_2.weight"),
         | 
| 639 | 
            +
                    ("time_embed.2.bias", "time_embedding.linear_2.bias"),
         | 
| 640 | 
            +
                    ("input_blocks.0.0.weight", "conv_in.weight"),
         | 
| 641 | 
            +
                    ("input_blocks.0.0.bias", "conv_in.bias"),
         | 
| 642 | 
            +
                    ("out.0.weight", "conv_norm_out.weight"),
         | 
| 643 | 
            +
                    ("out.0.bias", "conv_norm_out.bias"),
         | 
| 644 | 
            +
                    ("out.2.weight", "conv_out.weight"),
         | 
| 645 | 
            +
                    ("out.2.bias", "conv_out.bias"),
         | 
| 646 | 
            +
                ]
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                unet_conversion_map_resnet = [
         | 
| 649 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 650 | 
            +
                    ("in_layers.0", "norm1"),
         | 
| 651 | 
            +
                    ("in_layers.2", "conv1"),
         | 
| 652 | 
            +
                    ("out_layers.0", "norm2"),
         | 
| 653 | 
            +
                    ("out_layers.3", "conv2"),
         | 
| 654 | 
            +
                    ("emb_layers.1", "time_emb_proj"),
         | 
| 655 | 
            +
                    ("skip_connection", "conv_shortcut"),
         | 
| 656 | 
            +
                ]
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                unet_conversion_map_layer = []
         | 
| 659 | 
            +
                for i in range(4):
         | 
| 660 | 
            +
                    # loop over downblocks/upblocks
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                    for j in range(2):
         | 
| 663 | 
            +
                        # loop over resnets/attentions for downblocks
         | 
| 664 | 
            +
                        hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
         | 
| 665 | 
            +
                        sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
         | 
| 666 | 
            +
                        unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                        if i < 3:
         | 
| 669 | 
            +
                            # no attention layers in down_blocks.3
         | 
| 670 | 
            +
                            hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
         | 
| 671 | 
            +
                            sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
         | 
| 672 | 
            +
                            unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    for j in range(3):
         | 
| 675 | 
            +
                        # loop over resnets/attentions for upblocks
         | 
| 676 | 
            +
                        hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
         | 
| 677 | 
            +
                        sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
         | 
| 678 | 
            +
                        unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                        if i > 0:
         | 
| 681 | 
            +
                            # no attention layers in up_blocks.0
         | 
| 682 | 
            +
                            hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
         | 
| 683 | 
            +
                            sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
         | 
| 684 | 
            +
                            unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                    if i < 3:
         | 
| 687 | 
            +
                        # no downsample in down_blocks.3
         | 
| 688 | 
            +
                        hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
         | 
| 689 | 
            +
                        sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
         | 
| 690 | 
            +
                        unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                        # no upsample in up_blocks.3
         | 
| 693 | 
            +
                        hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
         | 
| 694 | 
            +
                        sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
         | 
| 695 | 
            +
                        unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                hf_mid_atn_prefix = "mid_block.attentions.0."
         | 
| 698 | 
            +
                sd_mid_atn_prefix = "middle_block.1."
         | 
| 699 | 
            +
                unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                for j in range(2):
         | 
| 702 | 
            +
                    hf_mid_res_prefix = f"mid_block.resnets.{j}."
         | 
| 703 | 
            +
                    sd_mid_res_prefix = f"middle_block.{2*j}."
         | 
| 704 | 
            +
                    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                # buyer beware: this is a *brittle* function,
         | 
| 707 | 
            +
                # and correct output requires that all of these pieces interact in
         | 
| 708 | 
            +
                # the exact order in which I have arranged them.
         | 
| 709 | 
            +
                mapping = {k: k for k in unet_state_dict.keys()}
         | 
| 710 | 
            +
                for sd_name, hf_name in unet_conversion_map:
         | 
| 711 | 
            +
                    mapping[hf_name] = sd_name
         | 
| 712 | 
            +
                for k, v in mapping.items():
         | 
| 713 | 
            +
                    if "resnets" in k:
         | 
| 714 | 
            +
                        for sd_part, hf_part in unet_conversion_map_resnet:
         | 
| 715 | 
            +
                            v = v.replace(hf_part, sd_part)
         | 
| 716 | 
            +
                        mapping[k] = v
         | 
| 717 | 
            +
                for k, v in mapping.items():
         | 
| 718 | 
            +
                    for sd_part, hf_part in unet_conversion_map_layer:
         | 
| 719 | 
            +
                        v = v.replace(hf_part, sd_part)
         | 
| 720 | 
            +
                    mapping[k] = v
         | 
| 721 | 
            +
                new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                if v2:
         | 
| 724 | 
            +
                    conv_transformer_to_linear(new_state_dict)
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                return new_state_dict
         | 
| 727 | 
            +
             | 
| 728 | 
            +
             | 
| 729 | 
            +
            # ================#
         | 
| 730 | 
            +
            # VAE Conversion #
         | 
| 731 | 
            +
            # ================#
         | 
| 732 | 
            +
             | 
| 733 | 
            +
             | 
| 734 | 
            +
            def reshape_weight_for_sd(w):
         | 
| 735 | 
            +
                # convert HF linear weights to SD conv2d weights
         | 
| 736 | 
            +
                return w.reshape(*w.shape, 1, 1)
         | 
| 737 | 
            +
             | 
| 738 | 
            +
             | 
| 739 | 
            +
            def convert_vae_state_dict(vae_state_dict):
         | 
| 740 | 
            +
                vae_conversion_map = [
         | 
| 741 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 742 | 
            +
                    ("nin_shortcut", "conv_shortcut"),
         | 
| 743 | 
            +
                    ("norm_out", "conv_norm_out"),
         | 
| 744 | 
            +
                    ("mid.attn_1.", "mid_block.attentions.0."),
         | 
| 745 | 
            +
                ]
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                for i in range(4):
         | 
| 748 | 
            +
                    # down_blocks have two resnets
         | 
| 749 | 
            +
                    for j in range(2):
         | 
| 750 | 
            +
                        hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
         | 
| 751 | 
            +
                        sd_down_prefix = f"encoder.down.{i}.block.{j}."
         | 
| 752 | 
            +
                        vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                    if i < 3:
         | 
| 755 | 
            +
                        hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
         | 
| 756 | 
            +
                        sd_downsample_prefix = f"down.{i}.downsample."
         | 
| 757 | 
            +
                        vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                        hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
         | 
| 760 | 
            +
                        sd_upsample_prefix = f"up.{3-i}.upsample."
         | 
| 761 | 
            +
                        vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
         | 
| 762 | 
            +
             | 
| 763 | 
            +
                    # up_blocks have three resnets
         | 
| 764 | 
            +
                    # also, up blocks in hf are numbered in reverse from sd
         | 
| 765 | 
            +
                    for j in range(3):
         | 
| 766 | 
            +
                        hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
         | 
| 767 | 
            +
                        sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
         | 
| 768 | 
            +
                        vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                # this part accounts for mid blocks in both the encoder and the decoder
         | 
| 771 | 
            +
                for i in range(2):
         | 
| 772 | 
            +
                    hf_mid_res_prefix = f"mid_block.resnets.{i}."
         | 
| 773 | 
            +
                    sd_mid_res_prefix = f"mid.block_{i+1}."
         | 
| 774 | 
            +
                    vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                vae_conversion_map_attn = [
         | 
| 777 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 778 | 
            +
                    ("norm.", "group_norm."),
         | 
| 779 | 
            +
                    ("q.", "query."),
         | 
| 780 | 
            +
                    ("k.", "key."),
         | 
| 781 | 
            +
                    ("v.", "value."),
         | 
| 782 | 
            +
                    ("proj_out.", "proj_attn."),
         | 
| 783 | 
            +
                ]
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                mapping = {k: k for k in vae_state_dict.keys()}
         | 
| 786 | 
            +
                for k, v in mapping.items():
         | 
| 787 | 
            +
                    for sd_part, hf_part in vae_conversion_map:
         | 
| 788 | 
            +
                        v = v.replace(hf_part, sd_part)
         | 
| 789 | 
            +
                    mapping[k] = v
         | 
| 790 | 
            +
                for k, v in mapping.items():
         | 
| 791 | 
            +
                    if "attentions" in k:
         | 
| 792 | 
            +
                        for sd_part, hf_part in vae_conversion_map_attn:
         | 
| 793 | 
            +
                            v = v.replace(hf_part, sd_part)
         | 
| 794 | 
            +
                        mapping[k] = v
         | 
| 795 | 
            +
                new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
         | 
| 796 | 
            +
                weights_to_convert = ["q", "k", "v", "proj_out"]
         | 
| 797 | 
            +
                for k, v in new_state_dict.items():
         | 
| 798 | 
            +
                    for weight_name in weights_to_convert:
         | 
| 799 | 
            +
                        if f"mid.attn_1.{weight_name}.weight" in k:
         | 
| 800 | 
            +
                            # print(f"Reshaping {k} for SD format")
         | 
| 801 | 
            +
                            new_state_dict[k] = reshape_weight_for_sd(v)
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                return new_state_dict
         | 
| 804 | 
            +
             | 
| 805 | 
            +
             | 
| 806 | 
            +
            # endregion
         | 
| 807 | 
            +
             | 
| 808 | 
            +
            # region 自作のモデル読み書きなど
         | 
| 809 | 
            +
             | 
| 810 | 
            +
             | 
| 811 | 
            +
            def is_safetensors(path):
         | 
| 812 | 
            +
                return os.path.splitext(path)[1].lower() == ".safetensors"
         | 
| 813 | 
            +
             | 
| 814 | 
            +
             | 
| 815 | 
            +
            def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
         | 
| 816 | 
            +
                # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
         | 
| 817 | 
            +
                TEXT_ENCODER_KEY_REPLACEMENTS = [
         | 
| 818 | 
            +
                    ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
         | 
| 819 | 
            +
                    ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
         | 
| 820 | 
            +
                    ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
         | 
| 821 | 
            +
                ]
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                if is_safetensors(ckpt_path):
         | 
| 824 | 
            +
                    checkpoint = None
         | 
| 825 | 
            +
                    state_dict = load_file(ckpt_path)  # , device) # may causes error
         | 
| 826 | 
            +
                else:
         | 
| 827 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location=device)
         | 
| 828 | 
            +
                    if "state_dict" in checkpoint:
         | 
| 829 | 
            +
                        state_dict = checkpoint["state_dict"]
         | 
| 830 | 
            +
                    else:
         | 
| 831 | 
            +
                        state_dict = checkpoint
         | 
| 832 | 
            +
                        checkpoint = None
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                key_reps = []
         | 
| 835 | 
            +
                for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
         | 
| 836 | 
            +
                    for key in state_dict.keys():
         | 
| 837 | 
            +
                        if key.startswith(rep_from):
         | 
| 838 | 
            +
                            new_key = rep_to + key[len(rep_from) :]
         | 
| 839 | 
            +
                            key_reps.append((key, new_key))
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                for key, new_key in key_reps:
         | 
| 842 | 
            +
                    state_dict[new_key] = state_dict[key]
         | 
| 843 | 
            +
                    del state_dict[key]
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                return checkpoint, state_dict
         | 
| 846 | 
            +
             | 
| 847 | 
            +
             | 
| 848 | 
            +
            # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
         | 
| 849 | 
            +
            def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
         | 
| 850 | 
            +
                _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
         | 
| 851 | 
            +
             | 
| 852 | 
            +
                # Convert the UNet2DConditionModel model.
         | 
| 853 | 
            +
                unet_config = create_unet_diffusers_config(v2)
         | 
| 854 | 
            +
                converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
         | 
| 855 | 
            +
             | 
| 856 | 
            +
                unet = UNet2DConditionModel(**unet_config).to(device)
         | 
| 857 | 
            +
                info = unet.load_state_dict(converted_unet_checkpoint)
         | 
| 858 | 
            +
                print("loading u-net:", info)
         | 
| 859 | 
            +
             | 
| 860 | 
            +
                # Convert the VAE model.
         | 
| 861 | 
            +
                vae_config = create_vae_diffusers_config()
         | 
| 862 | 
            +
                converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                vae = AutoencoderKL(**vae_config).to(device)
         | 
| 865 | 
            +
                info = vae.load_state_dict(converted_vae_checkpoint)
         | 
| 866 | 
            +
                print("loading vae:", info)
         | 
| 867 | 
            +
             | 
| 868 | 
            +
                # convert text_model
         | 
| 869 | 
            +
                if v2:
         | 
| 870 | 
            +
                    converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
         | 
| 871 | 
            +
                    cfg = CLIPTextConfig(
         | 
| 872 | 
            +
                        vocab_size=49408,
         | 
| 873 | 
            +
                        hidden_size=1024,
         | 
| 874 | 
            +
                        intermediate_size=4096,
         | 
| 875 | 
            +
                        num_hidden_layers=23,
         | 
| 876 | 
            +
                        num_attention_heads=16,
         | 
| 877 | 
            +
                        max_position_embeddings=77,
         | 
| 878 | 
            +
                        hidden_act="gelu",
         | 
| 879 | 
            +
                        layer_norm_eps=1e-05,
         | 
| 880 | 
            +
                        dropout=0.0,
         | 
| 881 | 
            +
                        attention_dropout=0.0,
         | 
| 882 | 
            +
                        initializer_range=0.02,
         | 
| 883 | 
            +
                        initializer_factor=1.0,
         | 
| 884 | 
            +
                        pad_token_id=1,
         | 
| 885 | 
            +
                        bos_token_id=0,
         | 
| 886 | 
            +
                        eos_token_id=2,
         | 
| 887 | 
            +
                        model_type="clip_text_model",
         | 
| 888 | 
            +
                        projection_dim=512,
         | 
| 889 | 
            +
                        torch_dtype="float32",
         | 
| 890 | 
            +
                        transformers_version="4.25.0.dev0",
         | 
| 891 | 
            +
                    )
         | 
| 892 | 
            +
                    text_model = CLIPTextModel._from_config(cfg)
         | 
| 893 | 
            +
                    info = text_model.load_state_dict(converted_text_encoder_checkpoint)
         | 
| 894 | 
            +
                else:
         | 
| 895 | 
            +
                    converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                    logging.set_verbosity_error()  # don't show annoying warning
         | 
| 898 | 
            +
                    text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
         | 
| 899 | 
            +
                    logging.set_verbosity_warning()
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                    info = text_model.load_state_dict(converted_text_encoder_checkpoint)
         | 
| 902 | 
            +
                print("loading text encoder:", info)
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                return text_model, vae, unet
         | 
| 905 | 
            +
             | 
| 906 | 
            +
             | 
| 907 | 
            +
            def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
         | 
| 908 | 
            +
                def convert_key(key):
         | 
| 909 | 
            +
                    # position_idsの除去
         | 
| 910 | 
            +
                    if ".position_ids" in key:
         | 
| 911 | 
            +
                        return None
         | 
| 912 | 
            +
             | 
| 913 | 
            +
                    # common
         | 
| 914 | 
            +
                    key = key.replace("text_model.encoder.", "transformer.")
         | 
| 915 | 
            +
                    key = key.replace("text_model.", "")
         | 
| 916 | 
            +
                    if "layers" in key:
         | 
| 917 | 
            +
                        # resblocks conversion
         | 
| 918 | 
            +
                        key = key.replace(".layers.", ".resblocks.")
         | 
| 919 | 
            +
                        if ".layer_norm" in key:
         | 
| 920 | 
            +
                            key = key.replace(".layer_norm", ".ln_")
         | 
| 921 | 
            +
                        elif ".mlp." in key:
         | 
| 922 | 
            +
                            key = key.replace(".fc1.", ".c_fc.")
         | 
| 923 | 
            +
                            key = key.replace(".fc2.", ".c_proj.")
         | 
| 924 | 
            +
                        elif ".self_attn.out_proj" in key:
         | 
| 925 | 
            +
                            key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
         | 
| 926 | 
            +
                        elif ".self_attn." in key:
         | 
| 927 | 
            +
                            key = None  # 特殊なので後で処理する
         | 
| 928 | 
            +
                        else:
         | 
| 929 | 
            +
                            raise ValueError(f"unexpected key in DiffUsers model: {key}")
         | 
| 930 | 
            +
                    elif ".position_embedding" in key:
         | 
| 931 | 
            +
                        key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
         | 
| 932 | 
            +
                    elif ".token_embedding" in key:
         | 
| 933 | 
            +
                        key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
         | 
| 934 | 
            +
                    elif "final_layer_norm" in key:
         | 
| 935 | 
            +
                        key = key.replace("final_layer_norm", "ln_final")
         | 
| 936 | 
            +
                    return key
         | 
| 937 | 
            +
             | 
| 938 | 
            +
                keys = list(checkpoint.keys())
         | 
| 939 | 
            +
                new_sd = {}
         | 
| 940 | 
            +
                for key in keys:
         | 
| 941 | 
            +
                    new_key = convert_key(key)
         | 
| 942 | 
            +
                    if new_key is None:
         | 
| 943 | 
            +
                        continue
         | 
| 944 | 
            +
                    new_sd[new_key] = checkpoint[key]
         | 
| 945 | 
            +
             | 
| 946 | 
            +
                # attnの変換
         | 
| 947 | 
            +
                for key in keys:
         | 
| 948 | 
            +
                    if "layers" in key and "q_proj" in key:
         | 
| 949 | 
            +
                        # 三つを結合
         | 
| 950 | 
            +
                        key_q = key
         | 
| 951 | 
            +
                        key_k = key.replace("q_proj", "k_proj")
         | 
| 952 | 
            +
                        key_v = key.replace("q_proj", "v_proj")
         | 
| 953 | 
            +
             | 
| 954 | 
            +
                        value_q = checkpoint[key_q]
         | 
| 955 | 
            +
                        value_k = checkpoint[key_k]
         | 
| 956 | 
            +
                        value_v = checkpoint[key_v]
         | 
| 957 | 
            +
                        value = torch.cat([value_q, value_k, value_v])
         | 
| 958 | 
            +
             | 
| 959 | 
            +
                        new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
         | 
| 960 | 
            +
                        new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
         | 
| 961 | 
            +
                        new_sd[new_key] = value
         | 
| 962 | 
            +
             | 
| 963 | 
            +
                # 最後の層などを捏造するか
         | 
| 964 | 
            +
                if make_dummy_weights:
         | 
| 965 | 
            +
                    print("make dummy weights for resblock.23, text_projection and logit scale.")
         | 
| 966 | 
            +
                    keys = list(new_sd.keys())
         | 
| 967 | 
            +
                    for key in keys:
         | 
| 968 | 
            +
                        if key.startswith("transformer.resblocks.22."):
         | 
| 969 | 
            +
                            new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone()  # copyしないとsafetensorsの保存で落ちる
         | 
| 970 | 
            +
             | 
| 971 | 
            +
                    # Diffusersに含まれない重みを作っておく
         | 
| 972 | 
            +
                    new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
         | 
| 973 | 
            +
                    new_sd["logit_scale"] = torch.tensor(1)
         | 
| 974 | 
            +
             | 
| 975 | 
            +
                return new_sd
         | 
| 976 | 
            +
             | 
| 977 | 
            +
             | 
| 978 | 
            +
            def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
         | 
| 979 | 
            +
                if ckpt_path is not None:
         | 
| 980 | 
            +
                    # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
         | 
| 981 | 
            +
                    checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
         | 
| 982 | 
            +
                    if checkpoint is None:  # safetensors または state_dictのckpt
         | 
| 983 | 
            +
                        checkpoint = {}
         | 
| 984 | 
            +
                        strict = False
         | 
| 985 | 
            +
                    else:
         | 
| 986 | 
            +
                        strict = True
         | 
| 987 | 
            +
                    if "state_dict" in state_dict:
         | 
| 988 | 
            +
                        del state_dict["state_dict"]
         | 
| 989 | 
            +
                else:
         | 
| 990 | 
            +
                    # 新しく作る
         | 
| 991 | 
            +
                    assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
         | 
| 992 | 
            +
                    checkpoint = {}
         | 
| 993 | 
            +
                    state_dict = {}
         | 
| 994 | 
            +
                    strict = False
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                def update_sd(prefix, sd):
         | 
| 997 | 
            +
                    for k, v in sd.items():
         | 
| 998 | 
            +
                        key = prefix + k
         | 
| 999 | 
            +
                        assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
         | 
| 1000 | 
            +
                        if save_dtype is not None:
         | 
| 1001 | 
            +
                            v = v.detach().clone().to("cpu").to(save_dtype)
         | 
| 1002 | 
            +
                        state_dict[key] = v
         | 
| 1003 | 
            +
             | 
| 1004 | 
            +
                # Convert the UNet model
         | 
| 1005 | 
            +
                unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
         | 
| 1006 | 
            +
                update_sd("model.diffusion_model.", unet_state_dict)
         | 
| 1007 | 
            +
             | 
| 1008 | 
            +
                # Convert the text encoder model
         | 
| 1009 | 
            +
                if v2:
         | 
| 1010 | 
            +
                    make_dummy = ckpt_path is None  # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
         | 
| 1011 | 
            +
                    text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
         | 
| 1012 | 
            +
                    update_sd("cond_stage_model.model.", text_enc_dict)
         | 
| 1013 | 
            +
                else:
         | 
| 1014 | 
            +
                    text_enc_dict = text_encoder.state_dict()
         | 
| 1015 | 
            +
                    update_sd("cond_stage_model.transformer.", text_enc_dict)
         | 
| 1016 | 
            +
             | 
| 1017 | 
            +
                # Convert the VAE
         | 
| 1018 | 
            +
                if vae is not None:
         | 
| 1019 | 
            +
                    vae_dict = convert_vae_state_dict(vae.state_dict())
         | 
| 1020 | 
            +
                    update_sd("first_stage_model.", vae_dict)
         | 
| 1021 | 
            +
             | 
| 1022 | 
            +
                # Put together new checkpoint
         | 
| 1023 | 
            +
                key_count = len(state_dict.keys())
         | 
| 1024 | 
            +
                new_ckpt = {"state_dict": state_dict}
         | 
| 1025 | 
            +
             | 
| 1026 | 
            +
                # epoch and global_step are sometimes not int
         | 
| 1027 | 
            +
                try:
         | 
| 1028 | 
            +
                    if "epoch" in checkpoint:
         | 
| 1029 | 
            +
                        epochs += checkpoint["epoch"]
         | 
| 1030 | 
            +
                    if "global_step" in checkpoint:
         | 
| 1031 | 
            +
                        steps += checkpoint["global_step"]
         | 
| 1032 | 
            +
                except:
         | 
| 1033 | 
            +
                    pass
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
                new_ckpt["epoch"] = epochs
         | 
| 1036 | 
            +
                new_ckpt["global_step"] = steps
         | 
| 1037 | 
            +
             | 
| 1038 | 
            +
                if is_safetensors(output_file):
         | 
| 1039 | 
            +
                    # TODO Tensor以外のdictの値を削除したほうがいいか
         | 
| 1040 | 
            +
                    save_file(state_dict, output_file)
         | 
| 1041 | 
            +
                else:
         | 
| 1042 | 
            +
                    torch.save(new_ckpt, output_file)
         | 
| 1043 | 
            +
             | 
| 1044 | 
            +
                return key_count
         | 
| 1045 | 
            +
             | 
| 1046 | 
            +
             | 
| 1047 | 
            +
            def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
         | 
| 1048 | 
            +
                if pretrained_model_name_or_path is None:
         | 
| 1049 | 
            +
                    # load default settings for v1/v2
         | 
| 1050 | 
            +
                    if v2:
         | 
| 1051 | 
            +
                        pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
         | 
| 1052 | 
            +
                    else:
         | 
| 1053 | 
            +
                        pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
         | 
| 1054 | 
            +
             | 
| 1055 | 
            +
                scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
         | 
| 1056 | 
            +
                tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
         | 
| 1057 | 
            +
                if vae is None:
         | 
| 1058 | 
            +
                    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
         | 
| 1059 | 
            +
             | 
| 1060 | 
            +
                pipeline = StableDiffusionPipeline(
         | 
| 1061 | 
            +
                    unet=unet,
         | 
| 1062 | 
            +
                    text_encoder=text_encoder,
         | 
| 1063 | 
            +
                    vae=vae,
         | 
| 1064 | 
            +
                    scheduler=scheduler,
         | 
| 1065 | 
            +
                    tokenizer=tokenizer,
         | 
| 1066 | 
            +
                    safety_checker=None,
         | 
| 1067 | 
            +
                    feature_extractor=None,
         | 
| 1068 | 
            +
                    requires_safety_checker=None,
         | 
| 1069 | 
            +
                )
         | 
| 1070 | 
            +
                pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
         | 
| 1071 | 
            +
             | 
| 1072 | 
            +
             | 
| 1073 | 
            +
            VAE_PREFIX = "first_stage_model."
         | 
| 1074 | 
            +
             | 
| 1075 | 
            +
             | 
| 1076 | 
            +
            def load_vae(vae_id, dtype):
         | 
| 1077 | 
            +
                print(f"load VAE: {vae_id}")
         | 
| 1078 | 
            +
                if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
         | 
| 1079 | 
            +
                    # Diffusers local/remote
         | 
| 1080 | 
            +
                    try:
         | 
| 1081 | 
            +
                        vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
         | 
| 1082 | 
            +
                    except EnvironmentError as e:
         | 
| 1083 | 
            +
                        print(f"exception occurs in loading vae: {e}")
         | 
| 1084 | 
            +
                        print("retry with subfolder='vae'")
         | 
| 1085 | 
            +
                        vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
         | 
| 1086 | 
            +
                    return vae
         | 
| 1087 | 
            +
             | 
| 1088 | 
            +
                # local
         | 
| 1089 | 
            +
                vae_config = create_vae_diffusers_config()
         | 
| 1090 | 
            +
             | 
| 1091 | 
            +
                if vae_id.endswith(".bin"):
         | 
| 1092 | 
            +
                    # SD 1.5 VAE on Huggingface
         | 
| 1093 | 
            +
                    converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
         | 
| 1094 | 
            +
                else:
         | 
| 1095 | 
            +
                    # StableDiffusion
         | 
| 1096 | 
            +
                    vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
         | 
| 1097 | 
            +
                    vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
         | 
| 1098 | 
            +
             | 
| 1099 | 
            +
                    # vae only or full model
         | 
| 1100 | 
            +
                    full_model = False
         | 
| 1101 | 
            +
                    for vae_key in vae_sd:
         | 
| 1102 | 
            +
                        if vae_key.startswith(VAE_PREFIX):
         | 
| 1103 | 
            +
                            full_model = True
         | 
| 1104 | 
            +
                            break
         | 
| 1105 | 
            +
                    if not full_model:
         | 
| 1106 | 
            +
                        sd = {}
         | 
| 1107 | 
            +
                        for key, value in vae_sd.items():
         | 
| 1108 | 
            +
                            sd[VAE_PREFIX + key] = value
         | 
| 1109 | 
            +
                        vae_sd = sd
         | 
| 1110 | 
            +
                        del sd
         | 
| 1111 | 
            +
             | 
| 1112 | 
            +
                    # Convert the VAE model.
         | 
| 1113 | 
            +
                    converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
         | 
| 1114 | 
            +
             | 
| 1115 | 
            +
                vae = AutoencoderKL(**vae_config)
         | 
| 1116 | 
            +
                vae.load_state_dict(converted_vae_checkpoint)
         | 
| 1117 | 
            +
                return vae
         | 
| 1118 | 
            +
             | 
| 1119 | 
            +
             | 
| 1120 | 
            +
            # endregion
         | 
| 1121 | 
            +
             | 
| 1122 | 
            +
             | 
| 1123 | 
            +
            def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
         | 
| 1124 | 
            +
                max_width, max_height = max_reso
         | 
| 1125 | 
            +
                max_area = (max_width // divisible) * (max_height // divisible)
         | 
| 1126 | 
            +
             | 
| 1127 | 
            +
                resos = set()
         | 
| 1128 | 
            +
             | 
| 1129 | 
            +
                size = int(math.sqrt(max_area)) * divisible
         | 
| 1130 | 
            +
                resos.add((size, size))
         | 
| 1131 | 
            +
             | 
| 1132 | 
            +
                size = min_size
         | 
| 1133 | 
            +
                while size <= max_size:
         | 
| 1134 | 
            +
                    width = size
         | 
| 1135 | 
            +
                    height = min(max_size, (max_area // (width // divisible)) * divisible)
         | 
| 1136 | 
            +
                    resos.add((width, height))
         | 
| 1137 | 
            +
                    resos.add((height, width))
         | 
| 1138 | 
            +
             | 
| 1139 | 
            +
                    # # make additional resos
         | 
| 1140 | 
            +
                    # if width >= height and width - divisible >= min_size:
         | 
| 1141 | 
            +
                    #   resos.add((width - divisible, height))
         | 
| 1142 | 
            +
                    #   resos.add((height, width - divisible))
         | 
| 1143 | 
            +
                    # if height >= width and height - divisible >= min_size:
         | 
| 1144 | 
            +
                    #   resos.add((width, height - divisible))
         | 
| 1145 | 
            +
                    #   resos.add((height - divisible, width))
         | 
| 1146 | 
            +
             | 
| 1147 | 
            +
                    size += divisible
         | 
| 1148 | 
            +
             | 
| 1149 | 
            +
                resos = list(resos)
         | 
| 1150 | 
            +
                resos.sort()
         | 
| 1151 | 
            +
                return resos
         | 
| 1152 | 
            +
             | 
| 1153 | 
            +
             | 
| 1154 | 
            +
            if __name__ == "__main__":
         | 
| 1155 | 
            +
                resos = make_bucket_resolutions((512, 768))
         | 
| 1156 | 
            +
                print(len(resos))
         | 
| 1157 | 
            +
                print(resos)
         | 
| 1158 | 
            +
                aspect_ratios = [w / h for w, h in resos]
         | 
| 1159 | 
            +
                print(aspect_ratios)
         | 
| 1160 | 
            +
             | 
| 1161 | 
            +
                ars = set()
         | 
| 1162 | 
            +
                for ar in aspect_ratios:
         | 
| 1163 | 
            +
                    if ar in ars:
         | 
| 1164 | 
            +
                        print("error! duplicate ar:", ar)
         | 
| 1165 | 
            +
                    ars.add(ar)
         | 
    	
        library/resize_lora_gui.py
    ADDED
    
    | @@ -0,0 +1,173 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from .common_gui import get_saveasfilename_path, get_file_path
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 8 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 9 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 10 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 11 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def resize_lora(
         | 
| 15 | 
            +
                model,
         | 
| 16 | 
            +
                new_rank,
         | 
| 17 | 
            +
                save_to,
         | 
| 18 | 
            +
                save_precision,
         | 
| 19 | 
            +
                device,
         | 
| 20 | 
            +
                dynamic_method,
         | 
| 21 | 
            +
                dynamic_param,
         | 
| 22 | 
            +
                verbose,
         | 
| 23 | 
            +
            ):
         | 
| 24 | 
            +
                # Check for caption_text_input
         | 
| 25 | 
            +
                if model == '':
         | 
| 26 | 
            +
                    msgbox('Invalid model file')
         | 
| 27 | 
            +
                    return
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                # Check if source model exist
         | 
| 30 | 
            +
                if not os.path.isfile(model):
         | 
| 31 | 
            +
                    msgbox('The provided model is not a file')
         | 
| 32 | 
            +
                    return
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                if dynamic_method == 'sv_ratio':
         | 
| 35 | 
            +
                    if float(dynamic_param) < 2:
         | 
| 36 | 
            +
                        msgbox(
         | 
| 37 | 
            +
                            f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
         | 
| 38 | 
            +
                        )
         | 
| 39 | 
            +
                        return
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
         | 
| 42 | 
            +
                    if float(dynamic_param) < 0 or float(dynamic_param) > 1:
         | 
| 43 | 
            +
                        msgbox(
         | 
| 44 | 
            +
                            f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
         | 
| 45 | 
            +
                        )
         | 
| 46 | 
            +
                        return
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # Check if save_to end with one of the defines extension. If not add .safetensors.
         | 
| 49 | 
            +
                if not save_to.endswith(('.pt', '.safetensors')):
         | 
| 50 | 
            +
                    save_to += '.safetensors'
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                if device == '':
         | 
| 53 | 
            +
                    device = 'cuda'
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                run_cmd = f'{PYTHON} "{os.path.join("networks","resize_lora.py")}"'
         | 
| 56 | 
            +
                run_cmd += f' --save_precision {save_precision}'
         | 
| 57 | 
            +
                run_cmd += f' --save_to "{save_to}"'
         | 
| 58 | 
            +
                run_cmd += f' --model "{model}"'
         | 
| 59 | 
            +
                run_cmd += f' --new_rank {new_rank}'
         | 
| 60 | 
            +
                run_cmd += f' --device {device}'
         | 
| 61 | 
            +
                if not dynamic_method == 'None':
         | 
| 62 | 
            +
                    run_cmd += f' --dynamic_method {dynamic_method}'
         | 
| 63 | 
            +
                    run_cmd += f' --dynamic_param {dynamic_param}'
         | 
| 64 | 
            +
                if verbose:
         | 
| 65 | 
            +
                    run_cmd += f' --verbose'
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                print(run_cmd)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # Run the command
         | 
| 70 | 
            +
                if os.name == 'posix':
         | 
| 71 | 
            +
                    os.system(run_cmd)
         | 
| 72 | 
            +
                else:
         | 
| 73 | 
            +
                    subprocess.run(run_cmd)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            ###
         | 
| 77 | 
            +
            # Gradio UI
         | 
| 78 | 
            +
            ###
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def gradio_resize_lora_tab():
         | 
| 82 | 
            +
                with gr.Tab('Resize LoRA'):
         | 
| 83 | 
            +
                    gr.Markdown('This utility can resize a LoRA.')
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
         | 
| 86 | 
            +
                    lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    with gr.Row():
         | 
| 89 | 
            +
                        model = gr.Textbox(
         | 
| 90 | 
            +
                            label='Source LoRA',
         | 
| 91 | 
            +
                            placeholder='Path to the LoRA to resize',
         | 
| 92 | 
            +
                            interactive=True,
         | 
| 93 | 
            +
                        )
         | 
| 94 | 
            +
                        button_lora_a_model_file = gr.Button(
         | 
| 95 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
                        button_lora_a_model_file.click(
         | 
| 98 | 
            +
                            get_file_path,
         | 
| 99 | 
            +
                            inputs=[model, lora_ext, lora_ext_name],
         | 
| 100 | 
            +
                            outputs=model,
         | 
| 101 | 
            +
                            show_progress=False,
         | 
| 102 | 
            +
                        )
         | 
| 103 | 
            +
                    with gr.Row():
         | 
| 104 | 
            +
                        new_rank = gr.Slider(
         | 
| 105 | 
            +
                            label='Desired LoRA rank',
         | 
| 106 | 
            +
                            minimum=1,
         | 
| 107 | 
            +
                            maximum=1024,
         | 
| 108 | 
            +
                            step=1,
         | 
| 109 | 
            +
                            value=4,
         | 
| 110 | 
            +
                            interactive=True,
         | 
| 111 | 
            +
                        )
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    with gr.Row():
         | 
| 114 | 
            +
                        dynamic_method = gr.Dropdown(
         | 
| 115 | 
            +
                            choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'],
         | 
| 116 | 
            +
                            value='sv_fro',
         | 
| 117 | 
            +
                            label='Dynamic method',
         | 
| 118 | 
            +
                            interactive=True,
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
                        dynamic_param = gr.Textbox(
         | 
| 121 | 
            +
                            label='Dynamic parameter',
         | 
| 122 | 
            +
                            value='0.9',
         | 
| 123 | 
            +
                            interactive=True,
         | 
| 124 | 
            +
                            placeholder='Value for the dynamic method selected.',
         | 
| 125 | 
            +
                        )
         | 
| 126 | 
            +
                        verbose = gr.Checkbox(label='Verbose', value=False)
         | 
| 127 | 
            +
                    with gr.Row():
         | 
| 128 | 
            +
                        save_to = gr.Textbox(
         | 
| 129 | 
            +
                            label='Save to',
         | 
| 130 | 
            +
                            placeholder='path for the LoRA file to save...',
         | 
| 131 | 
            +
                            interactive=True,
         | 
| 132 | 
            +
                        )
         | 
| 133 | 
            +
                        button_save_to = gr.Button(
         | 
| 134 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
                        button_save_to.click(
         | 
| 137 | 
            +
                            get_saveasfilename_path,
         | 
| 138 | 
            +
                            inputs=[save_to, lora_ext, lora_ext_name],
         | 
| 139 | 
            +
                            outputs=save_to,
         | 
| 140 | 
            +
                            show_progress=False,
         | 
| 141 | 
            +
                        )
         | 
| 142 | 
            +
                        save_precision = gr.Dropdown(
         | 
| 143 | 
            +
                            label='Save precision',
         | 
| 144 | 
            +
                            choices=['fp16', 'bf16', 'float'],
         | 
| 145 | 
            +
                            value='fp16',
         | 
| 146 | 
            +
                            interactive=True,
         | 
| 147 | 
            +
                        )
         | 
| 148 | 
            +
                        device = gr.Dropdown(
         | 
| 149 | 
            +
                            label='Device',
         | 
| 150 | 
            +
                            choices=[
         | 
| 151 | 
            +
                                'cpu',
         | 
| 152 | 
            +
                                'cuda',
         | 
| 153 | 
            +
                            ],
         | 
| 154 | 
            +
                            value='cuda',
         | 
| 155 | 
            +
                            interactive=True,
         | 
| 156 | 
            +
                        )
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    convert_button = gr.Button('Resize model')
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    convert_button.click(
         | 
| 161 | 
            +
                        resize_lora,
         | 
| 162 | 
            +
                        inputs=[
         | 
| 163 | 
            +
                            model,
         | 
| 164 | 
            +
                            new_rank,
         | 
| 165 | 
            +
                            save_to,
         | 
| 166 | 
            +
                            save_precision,
         | 
| 167 | 
            +
                            device,
         | 
| 168 | 
            +
                            dynamic_method,
         | 
| 169 | 
            +
                            dynamic_param,
         | 
| 170 | 
            +
                            verbose,
         | 
| 171 | 
            +
                        ],
         | 
| 172 | 
            +
                        show_progress=False,
         | 
| 173 | 
            +
                    )
         | 
    	
        library/sampler_gui.py
    ADDED
    
    | @@ -0,0 +1,102 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import tempfile
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            from easygui import msgbox
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 7 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 8 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 9 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            ###
         | 
| 13 | 
            +
            ### Gradio common sampler GUI section
         | 
| 14 | 
            +
            ###
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def sample_gradio_config():
         | 
| 18 | 
            +
                with gr.Accordion('Sample images config', open=False):
         | 
| 19 | 
            +
                    with gr.Row():
         | 
| 20 | 
            +
                        sample_every_n_steps = gr.Number(
         | 
| 21 | 
            +
                            label='Sample every n steps',
         | 
| 22 | 
            +
                            value=0,
         | 
| 23 | 
            +
                            precision=0,
         | 
| 24 | 
            +
                            interactive=True,
         | 
| 25 | 
            +
                        )
         | 
| 26 | 
            +
                        sample_every_n_epochs = gr.Number(
         | 
| 27 | 
            +
                            label='Sample every n epochs',
         | 
| 28 | 
            +
                            value=0,
         | 
| 29 | 
            +
                            precision=0,
         | 
| 30 | 
            +
                            interactive=True,
         | 
| 31 | 
            +
                        )
         | 
| 32 | 
            +
                        sample_sampler = gr.Dropdown(
         | 
| 33 | 
            +
                            label='Sample sampler',
         | 
| 34 | 
            +
                            choices=[
         | 
| 35 | 
            +
                                'ddim',
         | 
| 36 | 
            +
                                'pndm',
         | 
| 37 | 
            +
                                'lms',
         | 
| 38 | 
            +
                                'euler',
         | 
| 39 | 
            +
                                'euler_a',
         | 
| 40 | 
            +
                                'heun',
         | 
| 41 | 
            +
                                'dpm_2',
         | 
| 42 | 
            +
                                'dpm_2_a',
         | 
| 43 | 
            +
                                'dpmsolver',
         | 
| 44 | 
            +
                                'dpmsolver++',
         | 
| 45 | 
            +
                                'dpmsingle',
         | 
| 46 | 
            +
                                'k_lms',
         | 
| 47 | 
            +
                                'k_euler',
         | 
| 48 | 
            +
                                'k_euler_a',
         | 
| 49 | 
            +
                                'k_dpm_2',
         | 
| 50 | 
            +
                                'k_dpm_2_a',
         | 
| 51 | 
            +
                            ],
         | 
| 52 | 
            +
                            value='euler_a',
         | 
| 53 | 
            +
                            interactive=True,
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
                    with gr.Row():
         | 
| 56 | 
            +
                        sample_prompts = gr.Textbox(
         | 
| 57 | 
            +
                            lines=5,
         | 
| 58 | 
            +
                            label='Sample prompts',
         | 
| 59 | 
            +
                            interactive=True,
         | 
| 60 | 
            +
                            placeholder='masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28',
         | 
| 61 | 
            +
                        )
         | 
| 62 | 
            +
                return (
         | 
| 63 | 
            +
                    sample_every_n_steps,
         | 
| 64 | 
            +
                    sample_every_n_epochs,
         | 
| 65 | 
            +
                    sample_sampler,
         | 
| 66 | 
            +
                    sample_prompts,
         | 
| 67 | 
            +
                )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def run_cmd_sample(
         | 
| 71 | 
            +
                sample_every_n_steps,
         | 
| 72 | 
            +
                sample_every_n_epochs,
         | 
| 73 | 
            +
                sample_sampler,
         | 
| 74 | 
            +
                sample_prompts,
         | 
| 75 | 
            +
                output_dir,
         | 
| 76 | 
            +
            ):
         | 
| 77 | 
            +
                output_dir = os.path.join(output_dir, 'sample')
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                if not os.path.exists(output_dir):
         | 
| 80 | 
            +
                    os.makedirs(output_dir)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                run_cmd = ''
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                if sample_every_n_epochs == 0 and sample_every_n_steps == 0:
         | 
| 85 | 
            +
                    return run_cmd
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                # Create the prompt file and get its path
         | 
| 88 | 
            +
                sample_prompts_path = os.path.join(output_dir, 'prompt.txt')
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                with open(sample_prompts_path, 'w') as f:
         | 
| 91 | 
            +
                    f.write(sample_prompts)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                run_cmd += f' --sample_sampler={sample_sampler}'
         | 
| 94 | 
            +
                run_cmd += f' --sample_prompts="{sample_prompts_path}"'
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                if not sample_every_n_epochs == 0:
         | 
| 97 | 
            +
                    run_cmd += f' --sample_every_n_epochs="{sample_every_n_epochs}"'
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                if not sample_every_n_steps == 0:
         | 
| 100 | 
            +
                    run_cmd += f' --sample_every_n_steps="{sample_every_n_steps}"'
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                return run_cmd
         | 
    	
        library/svd_merge_lora_gui.py
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from easygui import msgbox
         | 
| 3 | 
            +
            import subprocess
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from .common_gui import (
         | 
| 6 | 
            +
                get_saveasfilename_path,
         | 
| 7 | 
            +
                get_any_file_path,
         | 
| 8 | 
            +
                get_file_path,
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            folder_symbol = '\U0001f4c2'  # 📂
         | 
| 12 | 
            +
            refresh_symbol = '\U0001f504'  # 🔄
         | 
| 13 | 
            +
            save_style_symbol = '\U0001f4be'  # 💾
         | 
| 14 | 
            +
            document_symbol = '\U0001F4C4'   # 📄
         | 
| 15 | 
            +
            PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def svd_merge_lora(
         | 
| 19 | 
            +
                lora_a_model,
         | 
| 20 | 
            +
                lora_b_model,
         | 
| 21 | 
            +
                ratio,
         | 
| 22 | 
            +
                save_to,
         | 
| 23 | 
            +
                precision,
         | 
| 24 | 
            +
                save_precision,
         | 
| 25 | 
            +
                new_rank,
         | 
| 26 | 
            +
                new_conv_rank,
         | 
| 27 | 
            +
                device,
         | 
| 28 | 
            +
            ):
         | 
| 29 | 
            +
                # Check for caption_text_input
         | 
| 30 | 
            +
                if lora_a_model == '':
         | 
| 31 | 
            +
                    msgbox('Invalid model A file')
         | 
| 32 | 
            +
                    return
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                if lora_b_model == '':
         | 
| 35 | 
            +
                    msgbox('Invalid model B file')
         | 
| 36 | 
            +
                    return
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # Check if source model exist
         | 
| 39 | 
            +
                if not os.path.isfile(lora_a_model):
         | 
| 40 | 
            +
                    msgbox('The provided model A is not a file')
         | 
| 41 | 
            +
                    return
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                if not os.path.isfile(lora_b_model):
         | 
| 44 | 
            +
                    msgbox('The provided model B is not a file')
         | 
| 45 | 
            +
                    return
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                ratio_a = ratio
         | 
| 48 | 
            +
                ratio_b = 1 - ratio
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"'
         | 
| 51 | 
            +
                run_cmd += f' --save_precision {save_precision}'
         | 
| 52 | 
            +
                run_cmd += f' --precision {precision}'
         | 
| 53 | 
            +
                run_cmd += f' --save_to "{save_to}"'
         | 
| 54 | 
            +
                run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"'
         | 
| 55 | 
            +
                run_cmd += f' --ratios {ratio_a} {ratio_b}'
         | 
| 56 | 
            +
                run_cmd += f' --device {device}'
         | 
| 57 | 
            +
                run_cmd += f' --new_rank "{new_rank}"'
         | 
| 58 | 
            +
                run_cmd += f' --new_conv_rank "{new_conv_rank}"'
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                print(run_cmd)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # Run the command
         | 
| 63 | 
            +
                if os.name == 'posix':
         | 
| 64 | 
            +
                    os.system(run_cmd)
         | 
| 65 | 
            +
                else:
         | 
| 66 | 
            +
                    subprocess.run(run_cmd)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            ###
         | 
| 70 | 
            +
            # Gradio UI
         | 
| 71 | 
            +
            ###
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def gradio_svd_merge_lora_tab():
         | 
| 75 | 
            +
                with gr.Tab('Merge LoRA (SVD)'):
         | 
| 76 | 
            +
                    gr.Markdown('This utility can merge two LoRA networks together.')
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
         | 
| 79 | 
            +
                    lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    with gr.Row():
         | 
| 82 | 
            +
                        lora_a_model = gr.Textbox(
         | 
| 83 | 
            +
                            label='LoRA model "A"',
         | 
| 84 | 
            +
                            placeholder='Path to the LoRA A model',
         | 
| 85 | 
            +
                            interactive=True,
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
                        button_lora_a_model_file = gr.Button(
         | 
| 88 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
                        button_lora_a_model_file.click(
         | 
| 91 | 
            +
                            get_file_path,
         | 
| 92 | 
            +
                            inputs=[lora_a_model, lora_ext, lora_ext_name],
         | 
| 93 | 
            +
                            outputs=lora_a_model,
         | 
| 94 | 
            +
                            show_progress=False,
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        lora_b_model = gr.Textbox(
         | 
| 98 | 
            +
                            label='LoRA model "B"',
         | 
| 99 | 
            +
                            placeholder='Path to the LoRA B model',
         | 
| 100 | 
            +
                            interactive=True,
         | 
| 101 | 
            +
                        )
         | 
| 102 | 
            +
                        button_lora_b_model_file = gr.Button(
         | 
| 103 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
                        button_lora_b_model_file.click(
         | 
| 106 | 
            +
                            get_file_path,
         | 
| 107 | 
            +
                            inputs=[lora_b_model, lora_ext, lora_ext_name],
         | 
| 108 | 
            +
                            outputs=lora_b_model,
         | 
| 109 | 
            +
                            show_progress=False,
         | 
| 110 | 
            +
                        )
         | 
| 111 | 
            +
                    with gr.Row():
         | 
| 112 | 
            +
                        ratio = gr.Slider(
         | 
| 113 | 
            +
                            label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
         | 
| 114 | 
            +
                            minimum=0,
         | 
| 115 | 
            +
                            maximum=1,
         | 
| 116 | 
            +
                            step=0.01,
         | 
| 117 | 
            +
                            value=0.5,
         | 
| 118 | 
            +
                            interactive=True,
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
                        new_rank = gr.Slider(
         | 
| 121 | 
            +
                            label='New Rank',
         | 
| 122 | 
            +
                            minimum=1,
         | 
| 123 | 
            +
                            maximum=1024,
         | 
| 124 | 
            +
                            step=1,
         | 
| 125 | 
            +
                            value=128,
         | 
| 126 | 
            +
                            interactive=True,
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
                        new_conv_rank = gr.Slider(
         | 
| 129 | 
            +
                            label='New Conv Rank',
         | 
| 130 | 
            +
                            minimum=1,
         | 
| 131 | 
            +
                            maximum=1024,
         | 
| 132 | 
            +
                            step=1,
         | 
| 133 | 
            +
                            value=128,
         | 
| 134 | 
            +
                            interactive=True,
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    with gr.Row():
         | 
| 138 | 
            +
                        save_to = gr.Textbox(
         | 
| 139 | 
            +
                            label='Save to',
         | 
| 140 | 
            +
                            placeholder='path for the file to save...',
         | 
| 141 | 
            +
                            interactive=True,
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                        button_save_to = gr.Button(
         | 
| 144 | 
            +
                            folder_symbol, elem_id='open_folder_small'
         | 
| 145 | 
            +
                        )
         | 
| 146 | 
            +
                        button_save_to.click(
         | 
| 147 | 
            +
                            get_saveasfilename_path,
         | 
| 148 | 
            +
                            inputs=[save_to, lora_ext, lora_ext_name],
         | 
| 149 | 
            +
                            outputs=save_to,
         | 
| 150 | 
            +
                            show_progress=False,
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                        precision = gr.Dropdown(
         | 
| 153 | 
            +
                            label='Merge precision',
         | 
| 154 | 
            +
                            choices=['fp16', 'bf16', 'float'],
         | 
| 155 | 
            +
                            value='float',
         | 
| 156 | 
            +
                            interactive=True,
         | 
| 157 | 
            +
                        )
         | 
| 158 | 
            +
                        save_precision = gr.Dropdown(
         | 
| 159 | 
            +
                            label='Save precision',
         | 
| 160 | 
            +
                            choices=['fp16', 'bf16', 'float'],
         | 
| 161 | 
            +
                            value='float',
         | 
| 162 | 
            +
                            interactive=True,
         | 
| 163 | 
            +
                        )
         | 
| 164 | 
            +
                        device = gr.Dropdown(
         | 
| 165 | 
            +
                            label='Device',
         | 
| 166 | 
            +
                            choices=[
         | 
| 167 | 
            +
                                'cpu',
         | 
| 168 | 
            +
                                'cuda',
         | 
| 169 | 
            +
                            ],
         | 
| 170 | 
            +
                            value='cuda',
         | 
| 171 | 
            +
                            interactive=True,
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    convert_button = gr.Button('Merge model')
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    convert_button.click(
         | 
| 177 | 
            +
                        svd_merge_lora,
         | 
| 178 | 
            +
                        inputs=[
         | 
| 179 | 
            +
                            lora_a_model,
         | 
| 180 | 
            +
                            lora_b_model,
         | 
| 181 | 
            +
                            ratio,
         | 
| 182 | 
            +
                            save_to,
         | 
| 183 | 
            +
                            precision,
         | 
| 184 | 
            +
                            save_precision,
         | 
| 185 | 
            +
                            new_rank,
         | 
| 186 | 
            +
                            new_conv_rank,
         | 
| 187 | 
            +
                            device,
         | 
| 188 | 
            +
                        ],
         | 
| 189 | 
            +
                        show_progress=False,
         | 
| 190 | 
            +
                    )
         |