Spaces:
Configuration error
Configuration error
Upload 8 files
Browse files- .gitattributes +35 -0
- LICENSE +437 -0
- README.md +201 -0
- environment.yaml +197 -0
- main.py +738 -0
- test.py +447 -0
- test.sh +13 -0
- train.sh +1 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
|
| 55 |
+
=======================================================================
|
| 56 |
+
|
| 57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
| 58 |
+
Public License
|
| 59 |
+
|
| 60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 61 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
| 63 |
+
("Public License"). To the extent this Public License may be
|
| 64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
| 65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
| 66 |
+
Licensor grants You such rights in consideration of benefits the
|
| 67 |
+
Licensor receives from making the Licensed Material available under
|
| 68 |
+
these terms and conditions.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Section 1 -- Definitions.
|
| 72 |
+
|
| 73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 74 |
+
Rights that is derived from or based upon the Licensed Material
|
| 75 |
+
and in which the Licensed Material is translated, altered,
|
| 76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 77 |
+
permission under the Copyright and Similar Rights held by the
|
| 78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 79 |
+
Material is a musical work, performance, or sound recording,
|
| 80 |
+
Adapted Material is always produced where the Licensed Material is
|
| 81 |
+
synched in timed relation with a moving image.
|
| 82 |
+
|
| 83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 85 |
+
accordance with the terms and conditions of this Public License.
|
| 86 |
+
|
| 87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
| 88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
| 89 |
+
Commons as essentially the equivalent of this Public License.
|
| 90 |
+
|
| 91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
| 92 |
+
closely related to copyright including, without limitation,
|
| 93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 94 |
+
Rights, without regard to how the rights are labeled or
|
| 95 |
+
categorized. For purposes of this Public License, the rights
|
| 96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 97 |
+
Rights.
|
| 98 |
+
|
| 99 |
+
e. Effective Technological Measures means those measures that, in the
|
| 100 |
+
absence of proper authority, may not be circumvented under laws
|
| 101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 103 |
+
agreements.
|
| 104 |
+
|
| 105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 106 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 107 |
+
that applies to Your use of the Licensed Material.
|
| 108 |
+
|
| 109 |
+
g. License Elements means the license attributes listed in the name
|
| 110 |
+
of a Creative Commons Public License. The License Elements of this
|
| 111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
| 112 |
+
|
| 113 |
+
h. Licensed Material means the artistic or literary work, database,
|
| 114 |
+
or other material to which the Licensor applied this Public
|
| 115 |
+
License.
|
| 116 |
+
|
| 117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
| 118 |
+
terms and conditions of this Public License, which are limited to
|
| 119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 120 |
+
Licensed Material and that the Licensor has authority to license.
|
| 121 |
+
|
| 122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
| 123 |
+
under this Public License.
|
| 124 |
+
|
| 125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
| 126 |
+
commercial advantage or monetary compensation. For purposes of
|
| 127 |
+
this Public License, the exchange of the Licensed Material for
|
| 128 |
+
other material subject to Copyright and Similar Rights by digital
|
| 129 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 130 |
+
no payment of monetary compensation in connection with the
|
| 131 |
+
exchange.
|
| 132 |
+
|
| 133 |
+
l. Share means to provide material to the public by any means or
|
| 134 |
+
process that requires permission under the Licensed Rights, such
|
| 135 |
+
as reproduction, public display, public performance, distribution,
|
| 136 |
+
dissemination, communication, or importation, and to make material
|
| 137 |
+
available to the public including in ways that members of the
|
| 138 |
+
public may access the material from a place and at a time
|
| 139 |
+
individually chosen by them.
|
| 140 |
+
|
| 141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
| 142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 144 |
+
as amended and/or succeeded, as well as other essentially
|
| 145 |
+
equivalent rights anywhere in the world.
|
| 146 |
+
|
| 147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
| 148 |
+
under this Public License. Your has a corresponding meaning.
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
Section 2 -- Scope.
|
| 152 |
+
|
| 153 |
+
a. License grant.
|
| 154 |
+
|
| 155 |
+
1. Subject to the terms and conditions of this Public License,
|
| 156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 159 |
+
|
| 160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 161 |
+
in part, for NonCommercial purposes only; and
|
| 162 |
+
|
| 163 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 164 |
+
NonCommercial purposes only.
|
| 165 |
+
|
| 166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 167 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 168 |
+
License does not apply, and You do not need to comply with
|
| 169 |
+
its terms and conditions.
|
| 170 |
+
|
| 171 |
+
3. Term. The term of this Public License is specified in Section
|
| 172 |
+
6(a).
|
| 173 |
+
|
| 174 |
+
4. Media and formats; technical modifications allowed. The
|
| 175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 176 |
+
all media and formats whether now known or hereafter created,
|
| 177 |
+
and to make technical modifications necessary to do so. The
|
| 178 |
+
Licensor waives and/or agrees not to assert any right or
|
| 179 |
+
authority to forbid You from making technical modifications
|
| 180 |
+
necessary to exercise the Licensed Rights, including
|
| 181 |
+
technical modifications necessary to circumvent Effective
|
| 182 |
+
Technological Measures. For purposes of this Public License,
|
| 183 |
+
simply making modifications authorized by this Section 2(a)
|
| 184 |
+
(4) never produces Adapted Material.
|
| 185 |
+
|
| 186 |
+
5. Downstream recipients.
|
| 187 |
+
|
| 188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 189 |
+
recipient of the Licensed Material automatically
|
| 190 |
+
receives an offer from the Licensor to exercise the
|
| 191 |
+
Licensed Rights under the terms and conditions of this
|
| 192 |
+
Public License.
|
| 193 |
+
|
| 194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
| 195 |
+
Every recipient of Adapted Material from You
|
| 196 |
+
automatically receives an offer from the Licensor to
|
| 197 |
+
exercise the Licensed Rights in the Adapted Material
|
| 198 |
+
under the conditions of the Adapter's License You apply.
|
| 199 |
+
|
| 200 |
+
c. No downstream restrictions. You may not offer or impose
|
| 201 |
+
any additional or different terms or conditions on, or
|
| 202 |
+
apply any Effective Technological Measures to, the
|
| 203 |
+
Licensed Material if doing so restricts exercise of the
|
| 204 |
+
Licensed Rights by any recipient of the Licensed
|
| 205 |
+
Material.
|
| 206 |
+
|
| 207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 208 |
+
may be construed as permission to assert or imply that You
|
| 209 |
+
are, or that Your use of the Licensed Material is, connected
|
| 210 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 211 |
+
the Licensor or others designated to receive attribution as
|
| 212 |
+
provided in Section 3(a)(1)(A)(i).
|
| 213 |
+
|
| 214 |
+
b. Other rights.
|
| 215 |
+
|
| 216 |
+
1. Moral rights, such as the right of integrity, are not
|
| 217 |
+
licensed under this Public License, nor are publicity,
|
| 218 |
+
privacy, and/or other similar personality rights; however, to
|
| 219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 220 |
+
assert any such rights held by the Licensor to the limited
|
| 221 |
+
extent necessary to allow You to exercise the Licensed
|
| 222 |
+
Rights, but not otherwise.
|
| 223 |
+
|
| 224 |
+
2. Patent and trademark rights are not licensed under this
|
| 225 |
+
Public License.
|
| 226 |
+
|
| 227 |
+
3. To the extent possible, the Licensor waives any right to
|
| 228 |
+
collect royalties from You for the exercise of the Licensed
|
| 229 |
+
Rights, whether directly or through a collecting society
|
| 230 |
+
under any voluntary or waivable statutory or compulsory
|
| 231 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 232 |
+
reserves any right to collect such royalties, including when
|
| 233 |
+
the Licensed Material is used other than for NonCommercial
|
| 234 |
+
purposes.
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
Section 3 -- License Conditions.
|
| 238 |
+
|
| 239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 240 |
+
following conditions.
|
| 241 |
+
|
| 242 |
+
a. Attribution.
|
| 243 |
+
|
| 244 |
+
1. If You Share the Licensed Material (including in modified
|
| 245 |
+
form), You must:
|
| 246 |
+
|
| 247 |
+
a. retain the following if it is supplied by the Licensor
|
| 248 |
+
with the Licensed Material:
|
| 249 |
+
|
| 250 |
+
i. identification of the creator(s) of the Licensed
|
| 251 |
+
Material and any others designated to receive
|
| 252 |
+
attribution, in any reasonable manner requested by
|
| 253 |
+
the Licensor (including by pseudonym if
|
| 254 |
+
designated);
|
| 255 |
+
|
| 256 |
+
ii. a copyright notice;
|
| 257 |
+
|
| 258 |
+
iii. a notice that refers to this Public License;
|
| 259 |
+
|
| 260 |
+
iv. a notice that refers to the disclaimer of
|
| 261 |
+
warranties;
|
| 262 |
+
|
| 263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 264 |
+
extent reasonably practicable;
|
| 265 |
+
|
| 266 |
+
b. indicate if You modified the Licensed Material and
|
| 267 |
+
retain an indication of any previous modifications; and
|
| 268 |
+
|
| 269 |
+
c. indicate the Licensed Material is licensed under this
|
| 270 |
+
Public License, and include the text of, or the URI or
|
| 271 |
+
hyperlink to, this Public License.
|
| 272 |
+
|
| 273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 274 |
+
reasonable manner based on the medium, means, and context in
|
| 275 |
+
which You Share the Licensed Material. For example, it may be
|
| 276 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 277 |
+
hyperlink to a resource that includes the required
|
| 278 |
+
information.
|
| 279 |
+
3. If requested by the Licensor, You must remove any of the
|
| 280 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 281 |
+
reasonably practicable.
|
| 282 |
+
|
| 283 |
+
b. ShareAlike.
|
| 284 |
+
|
| 285 |
+
In addition to the conditions in Section 3(a), if You Share
|
| 286 |
+
Adapted Material You produce, the following conditions also apply.
|
| 287 |
+
|
| 288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
| 289 |
+
license with the same License Elements, this version or
|
| 290 |
+
later, or a BY-NC-SA Compatible License.
|
| 291 |
+
|
| 292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
| 293 |
+
Adapter's License You apply. You may satisfy this condition
|
| 294 |
+
in any reasonable manner based on the medium, means, and
|
| 295 |
+
context in which You Share Adapted Material.
|
| 296 |
+
|
| 297 |
+
3. You may not offer or impose any additional or different terms
|
| 298 |
+
or conditions on, or apply any Effective Technological
|
| 299 |
+
Measures to, Adapted Material that restrict exercise of the
|
| 300 |
+
rights granted under the Adapter's License You apply.
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
Section 4 -- Sui Generis Database Rights.
|
| 304 |
+
|
| 305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 306 |
+
apply to Your use of the Licensed Material:
|
| 307 |
+
|
| 308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 310 |
+
portion of the contents of the database for NonCommercial purposes
|
| 311 |
+
only;
|
| 312 |
+
|
| 313 |
+
b. if You include all or a substantial portion of the database
|
| 314 |
+
contents in a database in which You have Sui Generis Database
|
| 315 |
+
Rights, then the database in which You have Sui Generis Database
|
| 316 |
+
Rights (but not its individual contents) is Adapted Material,
|
| 317 |
+
including for purposes of Section 3(b); and
|
| 318 |
+
|
| 319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 320 |
+
all or a substantial portion of the contents of the database.
|
| 321 |
+
|
| 322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 323 |
+
replace Your obligations under this Public License where the Licensed
|
| 324 |
+
Rights include other Copyright and Similar Rights.
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 328 |
+
|
| 329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 339 |
+
|
| 340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 349 |
+
|
| 350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 351 |
+
above shall be interpreted in a manner that, to the extent
|
| 352 |
+
possible, most closely approximates an absolute disclaimer and
|
| 353 |
+
waiver of all liability.
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
Section 6 -- Term and Termination.
|
| 357 |
+
|
| 358 |
+
a. This Public License applies for the term of the Copyright and
|
| 359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 360 |
+
this Public License, then Your rights under this Public License
|
| 361 |
+
terminate automatically.
|
| 362 |
+
|
| 363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 364 |
+
Section 6(a), it reinstates:
|
| 365 |
+
|
| 366 |
+
1. automatically as of the date the violation is cured, provided
|
| 367 |
+
it is cured within 30 days of Your discovery of the
|
| 368 |
+
violation; or
|
| 369 |
+
|
| 370 |
+
2. upon express reinstatement by the Licensor.
|
| 371 |
+
|
| 372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 373 |
+
right the Licensor may have to seek remedies for Your violations
|
| 374 |
+
of this Public License.
|
| 375 |
+
|
| 376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 377 |
+
Licensed Material under separate terms or conditions or stop
|
| 378 |
+
distributing the Licensed Material at any time; however, doing so
|
| 379 |
+
will not terminate this Public License.
|
| 380 |
+
|
| 381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 382 |
+
License.
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
Section 7 -- Other Terms and Conditions.
|
| 386 |
+
|
| 387 |
+
a. The Licensor shall not be bound by any additional or different
|
| 388 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 389 |
+
|
| 390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 391 |
+
Licensed Material not stated herein are separate from and
|
| 392 |
+
independent of the terms and conditions of this Public License.
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
Section 8 -- Interpretation.
|
| 396 |
+
|
| 397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 399 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 400 |
+
be made without permission under this Public License.
|
| 401 |
+
|
| 402 |
+
b. To the extent possible, if any provision of this Public License is
|
| 403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 404 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 405 |
+
cannot be reformed, it shall be severed from this Public License
|
| 406 |
+
without affecting the enforceability of the remaining terms and
|
| 407 |
+
conditions.
|
| 408 |
+
|
| 409 |
+
c. No term or condition of this Public License will be waived and no
|
| 410 |
+
failure to comply consented to unless expressly agreed to by the
|
| 411 |
+
Licensor.
|
| 412 |
+
|
| 413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 415 |
+
that apply to the Licensor or You, including from the legal
|
| 416 |
+
processes of any jurisdiction or authority.
|
| 417 |
+
|
| 418 |
+
=======================================================================
|
| 419 |
+
|
| 420 |
+
Creative Commons is not a party to its public
|
| 421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 422 |
+
its public licenses to material it publishes and in those instances
|
| 423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 426 |
+
material is shared under a Creative Commons public license or as
|
| 427 |
+
otherwise permitted by the Creative Commons policies published at
|
| 428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 430 |
+
of Creative Commons without its prior written consent including,
|
| 431 |
+
without limitation, in connection with any unauthorized modifications
|
| 432 |
+
to any of its public licenses or any other arrangements,
|
| 433 |
+
understandings, or agreements concerning use of licensed material. For
|
| 434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 435 |
+
public licenses.
|
| 436 |
+
|
| 437 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<<<<<<< HEAD
|
| 2 |
+
# MV-VTON
|
| 3 |
+
|
| 4 |
+
PyTorch implementation of **MV-VTON: Multi-View Virtual Try-On with Diffusion Models**
|
| 5 |
+
|
| 6 |
+
[](https://arxiv.org/abs/2404.17364)
|
| 7 |
+
[](https://hywang2002.github.io/MV-VTON/)
|
| 8 |
+

|
| 9 |
+
[](https://creativecommons.org/licenses/by-nc-sa/4.0/)
|
| 10 |
+
|
| 11 |
+
## News
|
| 12 |
+
- 🔥The first multi-view virtual try-on dataset MVG is now available.
|
| 13 |
+
- 🔥Checkpoints on both frontal-view and multi-view virtual try-on tasks are released.
|
| 14 |
+
|
| 15 |
+
## Overview
|
| 16 |
+
|
| 17 |
+

|
| 18 |
+
> **Abstract:**
|
| 19 |
+
> The goal of image-based virtual try-on is to generate an image of the target person naturally wearing the given
|
| 20 |
+
> clothing. However, most existing methods solely focus on the frontal try-on using the frontal clothing. When the views
|
| 21 |
+
> of the clothing and person are significantly inconsistent, particularly when the person’s view is non-frontal, the
|
| 22 |
+
> results are unsatisfactory. To address this challenge, we introduce Multi-View Virtual Try-ON (MV-VTON), which aims to
|
| 23 |
+
> reconstruct the dressing results of a person from multiple views using the given clothes. On the one hand, given that
|
| 24 |
+
> single-view clothes provide insufficient information for MV-VTON, we instead employ two images, i.e., the frontal and
|
| 25 |
+
> back views of the clothing, to encompass the complete view as much as possible. On the other hand, the diffusion
|
| 26 |
+
> models
|
| 27 |
+
> that have demonstrated superior abilities are adopted to perform our MV-VTON. In particular, we propose a
|
| 28 |
+
> view-adaptive
|
| 29 |
+
> selection method where hard-selection and soft-selection are applied to the global and local clothing feature
|
| 30 |
+
> extraction, respectively. This ensures that the clothing features are roughly fit to the person’s view. Subsequently,
|
| 31 |
+
> we
|
| 32 |
+
> suggest a joint attention block to align and fuse clothing features with person features. Additionally, we collect a
|
| 33 |
+
> MV-VTON dataset, i.e., Multi-View Garment (MVG), in which each person has multiple photos with diverse views and
|
| 34 |
+
> poses.
|
| 35 |
+
> Experiments show that the proposed method not only achieves state-of-the-art results on MV-VTON task using our MVG
|
| 36 |
+
> dataset, but also has superiority on frontal-view virtual try-on task using VITON-HD and DressCode datasets.
|
| 37 |
+
|
| 38 |
+
## Getting Started
|
| 39 |
+
|
| 40 |
+
### Installation
|
| 41 |
+
|
| 42 |
+
1. Clone the repository
|
| 43 |
+
|
| 44 |
+
```shell
|
| 45 |
+
git clone https://github.com/hywang2002/MV-VTON.git
|
| 46 |
+
cd MV-VTON
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
2. Install Python dependencies
|
| 50 |
+
|
| 51 |
+
```shell
|
| 52 |
+
conda env create -f environment.yaml
|
| 53 |
+
conda activate mv-vton
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
3. Download the pretrained [vgg](https://drive.google.com/file/d/1rvow8jStPt8t2prDcSRlnf8yzXhrYeGo/view?usp=sharing)
|
| 57 |
+
checkpoint and put it in `models/vgg/` for Multi-View VTON and `Frontal-View VTON/models/vgg/` for Frontal-View VTON.
|
| 58 |
+
4. Download the pretrained models `mvg.ckpt` via [Baidu Cloud](https://pan.baidu.com/s/17SC8fHE5w2g7gEtzJgRRew?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/1J91PoT8A9yqHWNxkgRe6ZCnDEhN-H9O6/view?usp=sharing),
|
| 59 |
+
and `vitonhd.ckpt` via [Baidu Cloud](https://pan.baidu.com/s/1R2yGgm35UwTpnXPEU6-tlA?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/13A0uzUY6PuvitLOqzyHzWASOh0dNXdem/view?usp=sharing), and put `mvg.ckpt` in `checkpoint/` and
|
| 60 |
+
put `vitonhd.ckpt`
|
| 61 |
+
in `Frontal-View VTON/checkpoint/`.
|
| 62 |
+
|
| 63 |
+
### Datasets
|
| 64 |
+
|
| 65 |
+
#### MVG
|
| 66 |
+
|
| 67 |
+
1. Fill `Dataset Request Form` via [Baidu Cloud](https://pan.baidu.com/s/12HAq0V4FfgpU_q8AeyZzwA?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/1zWt6HYBz7Vzaxu8rp1bwkhRoBkxbwQjw/view?usp=sharing), and
|
| 68 |
+
contact `cshy2mvvton@outlook.com` with this form to get MVG dataset (
|
| 69 |
+
Non-institutional emails (e.g. gmail.com) are not allowed. Please provide your institutional
|
| 70 |
+
email address.).
|
| 71 |
+
|
| 72 |
+
After these, the folder structure should look like this (the warp_feat_unpair* only included in test directory):
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
├── MVG
|
| 76 |
+
| ├── unpaired.txt
|
| 77 |
+
│ ├── [train | test]
|
| 78 |
+
| | ├── image-wo-bg
|
| 79 |
+
│ │ ├── cloth
|
| 80 |
+
│ │ ├── cloth-mask
|
| 81 |
+
│ │ ├── warp_feat
|
| 82 |
+
│ │ ├── warp_feat_unpair
|
| 83 |
+
│ │ ├── ...
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
#### VITON-HD
|
| 87 |
+
|
| 88 |
+
1. Download [VITON-HD](https://github.com/shadow2496/VITON-HD) dataset
|
| 89 |
+
2. Download pre-warped cloth image/mask via [Baidu Cloud](https://pan.baidu.com/s/1uQM0IOltOmbeqwdOKX5kCw?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/18DTWfhxUnfg41nnwwpCKN--akC4eT9DM/view?usp=sharing) and
|
| 90 |
+
put
|
| 91 |
+
it under VITON-HD dataset.
|
| 92 |
+
|
| 93 |
+
After these, the folder structure should look like this (the unpaired-cloth* only included in test directory):
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
├── VITON-HD
|
| 97 |
+
| ├── test_pairs.txt
|
| 98 |
+
| ├── train_pairs.txt
|
| 99 |
+
│ ├── [train | test]
|
| 100 |
+
| | ├── image
|
| 101 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 102 |
+
│ │ ├── cloth
|
| 103 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 104 |
+
│ │ ├── cloth-mask
|
| 105 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 106 |
+
│ │ ├── cloth-warp
|
| 107 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 108 |
+
│ │ ├── cloth-warp-mask
|
| 109 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 110 |
+
│ │ ├── unpaired-cloth-warp
|
| 111 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 112 |
+
│ │ ├── unpaired-cloth-warp-mask
|
| 113 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Inference
|
| 117 |
+
|
| 118 |
+
#### MVG
|
| 119 |
+
|
| 120 |
+
To test on paired settings (using `cp_dataset_mv_paired.py`), you can modify the `configs/viton512.yaml` and `main.py`,
|
| 121 |
+
or directly rename `cp_dataset_mv_paired.py` to `cp_dataset.py` (recommended). Then run:
|
| 122 |
+
|
| 123 |
+
```shell
|
| 124 |
+
sh test.sh
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
To test on unpaired settings, rename `cp_dataset_mv_unpaired.py` to `cp_dataset.py`, and do the same operation.
|
| 128 |
+
|
| 129 |
+
#### VITON-HD
|
| 130 |
+
|
| 131 |
+
To test on paired settings, input command `cd Frontal-View\ VTON/`, then directly run:
|
| 132 |
+
|
| 133 |
+
```shell
|
| 134 |
+
sh test.sh
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
To test on unpaired settings, input command `cd Frontal-View\ VTON/`, add `--unpaired` to `test.sh`, add then run:
|
| 138 |
+
|
| 139 |
+
```shell
|
| 140 |
+
sh test.sh
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
#### Metrics
|
| 144 |
+
|
| 145 |
+
We compute `LPIPS`, `SSIM`, `FID`, `KID` using the same tools in [LaDI-VTON](https://github.com/miccunifi/ladi-vton).
|
| 146 |
+
|
| 147 |
+
### Training
|
| 148 |
+
|
| 149 |
+
#### MVG
|
| 150 |
+
|
| 151 |
+
We use Paint-by-Example as initialization, please download the pretrained model
|
| 152 |
+
from [Google Drive](https://drive.google.com/file/d/15QzaTWsvZonJcXsNv-ilMRCYaQLhzR_i/view) and save the model to
|
| 153 |
+
directory `checkpoints`. Rename `cp_dataset_mv_paired.py` to `cp_dataset.py`, then run:
|
| 154 |
+
|
| 155 |
+
```shell
|
| 156 |
+
sh train.sh
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
#### VITON-HD
|
| 160 |
+
|
| 161 |
+
Input command `cd Frontal-View\ VTON/`, then directly run:
|
| 162 |
+
|
| 163 |
+
```shell
|
| 164 |
+
sh train.sh
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## Acknowledgements
|
| 168 |
+
|
| 169 |
+
Our code is heavily borrowed from [Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example)
|
| 170 |
+
and [DCI-VTON](https://github.com/bcmi/DCI-VTON-Virtual-Try-On). We also
|
| 171 |
+
thank previous work [PF-AFN](https://github.com/geyuying/PF-AFN), [GP-VTON](https://github.com/xiezhy6/GP-VTON),
|
| 172 |
+
[LaDI-VTON](https://github.com/miccunifi/ladi-vton)
|
| 173 |
+
and [StableVITON](https://github.com/rlawjdghek/StableVITON).
|
| 174 |
+
|
| 175 |
+
## LICENSE
|
| 176 |
+
MV-VTON: Multi-View Virtual Try-On with Diffusion Models © 2024 by Haoyu Wang, Zhilu Zhang, Donglin Di, Shiliang Zhang, Wangmeng Zuo is licensed under CC BY-NC-SA 4.0
|
| 177 |
+
|
| 178 |
+
## Citation
|
| 179 |
+
|
| 180 |
+
```
|
| 181 |
+
@article{wang2024mv,
|
| 182 |
+
title={MV-VTON: Multi-View Virtual Try-On with Diffusion Models},
|
| 183 |
+
author={Wang, Haoyu and Zhang, Zhilu and Di, Donglin and Zhang, Shiliang and Zuo, Wangmeng},
|
| 184 |
+
journal={arXiv preprint arXiv:2404.17364},
|
| 185 |
+
year={2024}
|
| 186 |
+
}
|
| 187 |
+
```
|
| 188 |
+
=======
|
| 189 |
+
---
|
| 190 |
+
title: Mv Vton Demo
|
| 191 |
+
emoji: 👁
|
| 192 |
+
colorFrom: gray
|
| 193 |
+
colorTo: pink
|
| 194 |
+
sdk: gradio
|
| 195 |
+
sdk_version: 5.26.0
|
| 196 |
+
app_file: app.py
|
| 197 |
+
pinned: false
|
| 198 |
+
---
|
| 199 |
+
|
| 200 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 201 |
+
>>>>>>> 2a4541c57faf075fa9e813ae2777dfaa55fc0306
|
environment.yaml
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mv-vton
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _openmp_mutex=5.1=1_gnu
|
| 8 |
+
- blas=1.0=mkl
|
| 9 |
+
- brotli-python=1.0.9=py38h6a678d5_7
|
| 10 |
+
- bzip2=1.0.8=h7b6447c_0
|
| 11 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
| 12 |
+
- certifi=2023.11.17=py38h06a4308_0
|
| 13 |
+
- cffi=1.15.1=py38h74dc2b5_0
|
| 14 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
| 15 |
+
- cryptography=41.0.3=py38h130f0dd_0
|
| 16 |
+
- cudatoolkit=11.3.1=h2bc3f7f_2
|
| 17 |
+
- ffmpeg=4.3=hf484d3e_0
|
| 18 |
+
- freetype=2.12.1=h4a9f257_0
|
| 19 |
+
- giflib=5.2.1=h5eee18b_3
|
| 20 |
+
- gmp=6.2.1=h295c915_3
|
| 21 |
+
- gnutls=3.6.15=he1e5248_0
|
| 22 |
+
- idna=3.4=py38h06a4308_0
|
| 23 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
| 24 |
+
- jpeg=9e=h5eee18b_1
|
| 25 |
+
- lame=3.100=h7b6447c_0
|
| 26 |
+
- lcms2=2.12=h3be6417_0
|
| 27 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 28 |
+
- lerc=3.0=h295c915_0
|
| 29 |
+
- libdeflate=1.17=h5eee18b_1
|
| 30 |
+
- libffi=3.3=he6710b0_2
|
| 31 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 32 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
| 33 |
+
- libgfortran5=11.2.0=h1234567_1
|
| 34 |
+
- libgomp=11.2.0=h1234567_1
|
| 35 |
+
- libiconv=1.16=h7f8727e_2
|
| 36 |
+
- libidn2=2.3.4=h5eee18b_0
|
| 37 |
+
- libpng=1.6.39=h5eee18b_0
|
| 38 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 39 |
+
- libtasn1=4.19.0=h5eee18b_0
|
| 40 |
+
- libtiff=4.5.1=h6a678d5_0
|
| 41 |
+
- libunistring=0.9.10=h27cfd23_0
|
| 42 |
+
- libuv=1.44.2=h5eee18b_0
|
| 43 |
+
- libwebp=1.3.2=h11a3e52_0
|
| 44 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
| 45 |
+
- lz4-c=1.9.4=h6a678d5_0
|
| 46 |
+
- mkl=2021.4.0=h06a4308_640
|
| 47 |
+
- mkl-service=2.4.0=py38h7f8727e_0
|
| 48 |
+
- mkl_fft=1.3.1=py38hd3c417c_0
|
| 49 |
+
- mkl_random=1.2.2=py38h51133e4_0
|
| 50 |
+
- ncurses=6.4=h6a678d5_0
|
| 51 |
+
- nettle=3.7.3=hbbd107a_1
|
| 52 |
+
- openh264=2.1.1=h4ff587b_0
|
| 53 |
+
- openjpeg=2.4.0=h3ad879b_0
|
| 54 |
+
- openssl=1.1.1w=h7f8727e_0
|
| 55 |
+
- pillow=10.0.1=py38ha6cbd5a_0
|
| 56 |
+
- pip=20.3.3=py38h06a4308_0
|
| 57 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
| 58 |
+
- pyopenssl=23.2.0=py38h06a4308_0
|
| 59 |
+
- pysocks=1.7.1=py38h06a4308_0
|
| 60 |
+
- python=3.8.5=h7579374_1
|
| 61 |
+
- pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
|
| 62 |
+
- pytorch-mutex=1.0=cuda
|
| 63 |
+
- readline=8.2=h5eee18b_0
|
| 64 |
+
- requests=2.31.0=py38h06a4308_0
|
| 65 |
+
- setuptools=68.0.0=py38h06a4308_0
|
| 66 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 67 |
+
- sqlite=3.41.2=h5eee18b_0
|
| 68 |
+
- tk=8.6.12=h1ccaba5_0
|
| 69 |
+
- torchvision=0.12.0=py38_cu113
|
| 70 |
+
- typing_extensions=4.7.1=py38h06a4308_0
|
| 71 |
+
- urllib3=1.26.18=py38h06a4308_0
|
| 72 |
+
- wheel=0.41.2=py38h06a4308_0
|
| 73 |
+
- xz=5.4.5=h5eee18b_0
|
| 74 |
+
- zlib=1.2.13=h5eee18b_0
|
| 75 |
+
- zstd=1.5.5=hc292b87_0
|
| 76 |
+
- pip:
|
| 77 |
+
- absl-py==2.0.0
|
| 78 |
+
- aiohttp==3.9.1
|
| 79 |
+
- aiosignal==1.3.1
|
| 80 |
+
- albumentations==0.4.3
|
| 81 |
+
- altair==5.2.0
|
| 82 |
+
- antlr4-python3-runtime==4.9.3
|
| 83 |
+
- async-timeout==4.0.3
|
| 84 |
+
- attrs==23.1.0
|
| 85 |
+
- av==12.0.0
|
| 86 |
+
- backports-zoneinfo==0.2.1
|
| 87 |
+
- bezier==2023.7.28
|
| 88 |
+
- black==24.2.0
|
| 89 |
+
- blinker==1.7.0
|
| 90 |
+
- cachetools==5.3.2
|
| 91 |
+
- click==8.1.7
|
| 92 |
+
- clip==0.2.0
|
| 93 |
+
- cloudpickle==3.0.0
|
| 94 |
+
- contourpy==1.1.1
|
| 95 |
+
- cupy==12.3.0
|
| 96 |
+
- cycler==0.12.1
|
| 97 |
+
- diffusers==0.20.0
|
| 98 |
+
- einops==0.3.0
|
| 99 |
+
- fastrlock==0.8.2
|
| 100 |
+
- filelock==3.13.1
|
| 101 |
+
- fonttools==4.45.1
|
| 102 |
+
- frozenlist==1.4.0
|
| 103 |
+
- fsspec==2023.10.0
|
| 104 |
+
- future==0.18.3
|
| 105 |
+
- fvcore==0.1.5.post20221221
|
| 106 |
+
- gitdb==4.0.11
|
| 107 |
+
- gitpython==3.1.40
|
| 108 |
+
- google-auth==2.23.4
|
| 109 |
+
- google-auth-oauthlib==1.0.0
|
| 110 |
+
- grpcio==1.59.3
|
| 111 |
+
- huggingface-hub==0.19.4
|
| 112 |
+
- hydra-core==1.3.2
|
| 113 |
+
- imageio==2.9.0
|
| 114 |
+
- imageio-ffmpeg==0.4.2
|
| 115 |
+
- imgaug==0.2.6
|
| 116 |
+
- importlib-metadata==6.8.0
|
| 117 |
+
- importlib-resources==6.1.1
|
| 118 |
+
- invisible-watermark==0.2.0
|
| 119 |
+
- iopath==0.1.9
|
| 120 |
+
- jinja2==3.1.2
|
| 121 |
+
- jsonschema==4.20.0
|
| 122 |
+
- jsonschema-specifications==2023.11.1
|
| 123 |
+
- kiwisolver==1.4.5
|
| 124 |
+
- kornia==0.6.0
|
| 125 |
+
- lazy-loader==0.3
|
| 126 |
+
- markdown==3.5.1
|
| 127 |
+
- markdown-it-py==3.0.0
|
| 128 |
+
- markupsafe==2.1.3
|
| 129 |
+
- matplotlib==3.7.4
|
| 130 |
+
- mdurl==0.1.2
|
| 131 |
+
- multidict==6.0.4
|
| 132 |
+
- mypy-extensions==1.0.0
|
| 133 |
+
- networkx==3.1
|
| 134 |
+
- numpy==1.24.4
|
| 135 |
+
- oauthlib==3.2.2
|
| 136 |
+
- omegaconf==2.3.0
|
| 137 |
+
- opencv-python==4.1.2.30
|
| 138 |
+
- opencv-python-headless==4.8.1.78
|
| 139 |
+
- packaging==23.2
|
| 140 |
+
- pandas==2.0.3
|
| 141 |
+
- pathspec==0.12.1
|
| 142 |
+
- pkgutil-resolve-name==1.3.10
|
| 143 |
+
- platformdirs==4.2.0
|
| 144 |
+
- portalocker==2.8.2
|
| 145 |
+
- protobuf==4.25.1
|
| 146 |
+
- pudb==2019.2
|
| 147 |
+
- pyarrow==14.0.1
|
| 148 |
+
- pyasn1==0.5.1
|
| 149 |
+
- pyasn1-modules==0.3.0
|
| 150 |
+
- pycocotools==2.0.7
|
| 151 |
+
- pydeck==0.8.1b0
|
| 152 |
+
- pydeprecate==0.3.1
|
| 153 |
+
- pygments==2.17.2
|
| 154 |
+
- pyparsing==3.1.1
|
| 155 |
+
- python-dateutil==2.8.2
|
| 156 |
+
- pytorch-lightning==1.4.2
|
| 157 |
+
- pytz==2023.3.post1
|
| 158 |
+
- pywavelets==1.4.1
|
| 159 |
+
- pyyaml==6.0.1
|
| 160 |
+
- referencing==0.31.1
|
| 161 |
+
- regex==2023.10.3
|
| 162 |
+
- requests-oauthlib==1.3.1
|
| 163 |
+
- rich==13.7.0
|
| 164 |
+
- rpds-py==0.13.2
|
| 165 |
+
- rsa==4.9
|
| 166 |
+
- safetensors==0.4.1
|
| 167 |
+
- scikit-image==0.20.0
|
| 168 |
+
- scipy==1.9.1
|
| 169 |
+
- smmap==5.0.1
|
| 170 |
+
- streamlit==1.28.2
|
| 171 |
+
- tabulate==0.9.0
|
| 172 |
+
- taming-transformers==0.0.1
|
| 173 |
+
- tenacity==8.2.3
|
| 174 |
+
- tensorboard==2.14.0
|
| 175 |
+
- tensorboard-data-server==0.7.2
|
| 176 |
+
- termcolor==2.4.0
|
| 177 |
+
- test-tube==0.7.5
|
| 178 |
+
- tifffile==2023.7.10
|
| 179 |
+
- tokenizers==0.12.1
|
| 180 |
+
- toml==0.10.2
|
| 181 |
+
- tomli==2.0.1
|
| 182 |
+
- toolz==0.12.0
|
| 183 |
+
- torch-fidelity==0.3.0
|
| 184 |
+
- torchmetrics==0.6.0
|
| 185 |
+
- tornado==6.4
|
| 186 |
+
- tqdm==4.66.1
|
| 187 |
+
- transformers==4.27.3
|
| 188 |
+
- tzdata==2023.3
|
| 189 |
+
- tzlocal==5.2
|
| 190 |
+
- urwid==2.2.3
|
| 191 |
+
- validators==0.22.0
|
| 192 |
+
- watchdog==3.0.0
|
| 193 |
+
- werkzeug==3.0.1
|
| 194 |
+
- yacs==0.1.8
|
| 195 |
+
- yarl==1.9.3
|
| 196 |
+
- zipp==3.17.0
|
| 197 |
+
prefix: /mnt/pfs-mc0p4k/cvg/team/didonglin/conda_envs/mv-vton
|
main.py
ADDED
|
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, os, sys, datetime, glob, importlib, csv
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
|
| 8 |
+
sys.setrecursionlimit(10000)
|
| 9 |
+
from packaging import version
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
| 12 |
+
from functools import partial
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from pytorch_lightning import seed_everything
|
| 16 |
+
from pytorch_lightning.trainer import Trainer
|
| 17 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
| 18 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
| 19 |
+
from pytorch_lightning.utilities import rank_zero_info
|
| 20 |
+
|
| 21 |
+
from ldm.data.base import Txt2ImgIterableBaseDataset
|
| 22 |
+
from ldm.util import instantiate_from_config
|
| 23 |
+
import socket
|
| 24 |
+
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_parser(**parser_kwargs):
|
| 28 |
+
def str2bool(v):
|
| 29 |
+
if isinstance(v, bool):
|
| 30 |
+
return v
|
| 31 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 32 |
+
return True
|
| 33 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 34 |
+
return False
|
| 35 |
+
else:
|
| 36 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 37 |
+
|
| 38 |
+
parser = argparse.ArgumentParser(**parser_kwargs)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"-n",
|
| 41 |
+
"--name",
|
| 42 |
+
type=str,
|
| 43 |
+
const=True,
|
| 44 |
+
default="",
|
| 45 |
+
nargs="?",
|
| 46 |
+
help="postfix for logdir",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"-r",
|
| 50 |
+
"--resume",
|
| 51 |
+
type=str,
|
| 52 |
+
const=True,
|
| 53 |
+
default="",
|
| 54 |
+
nargs="?",
|
| 55 |
+
help="resume from logdir or checkpoint in logdir",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"-b",
|
| 59 |
+
"--base",
|
| 60 |
+
nargs="*",
|
| 61 |
+
metavar="base_config.yaml",
|
| 62 |
+
help="paths to base configs. Loaded from left-to-right. "
|
| 63 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
| 64 |
+
default=["configs/stable-diffusion/v1-inference-inpaint.yaml"],
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"-t",
|
| 68 |
+
"--train",
|
| 69 |
+
type=str2bool,
|
| 70 |
+
const=True,
|
| 71 |
+
default=True,
|
| 72 |
+
nargs="?",
|
| 73 |
+
help="train",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--no-test",
|
| 77 |
+
type=str2bool,
|
| 78 |
+
const=True,
|
| 79 |
+
default=False,
|
| 80 |
+
nargs="?",
|
| 81 |
+
help="disable test",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"-p",
|
| 85 |
+
"--project",
|
| 86 |
+
help="name of new or path to existing project"
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"-d",
|
| 90 |
+
"--debug",
|
| 91 |
+
type=str2bool,
|
| 92 |
+
nargs="?",
|
| 93 |
+
const=True,
|
| 94 |
+
default=False,
|
| 95 |
+
help="enable post-mortem debugging",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"-s",
|
| 99 |
+
"--seed",
|
| 100 |
+
type=int,
|
| 101 |
+
default=23,
|
| 102 |
+
help="seed for seed_everything",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"-f",
|
| 106 |
+
"--postfix",
|
| 107 |
+
type=str,
|
| 108 |
+
default="",
|
| 109 |
+
help="post-postfix for default name",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"-l",
|
| 113 |
+
"--logdir",
|
| 114 |
+
type=str,
|
| 115 |
+
default="logs",
|
| 116 |
+
help="directory for logging dat shit",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--pretrained_model",
|
| 120 |
+
type=str,
|
| 121 |
+
default="",
|
| 122 |
+
help="path to pretrained model",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--scale_lr",
|
| 126 |
+
type=str2bool,
|
| 127 |
+
nargs="?",
|
| 128 |
+
const=True,
|
| 129 |
+
default=True,
|
| 130 |
+
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--train_from_scratch",
|
| 134 |
+
type=str2bool,
|
| 135 |
+
nargs="?",
|
| 136 |
+
const=True,
|
| 137 |
+
default=False,
|
| 138 |
+
help="Train from scratch",
|
| 139 |
+
)
|
| 140 |
+
return parser
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def nondefault_trainer_args(opt):
|
| 144 |
+
parser = argparse.ArgumentParser()
|
| 145 |
+
parser = Trainer.add_argparse_args(parser)
|
| 146 |
+
args = parser.parse_args([])
|
| 147 |
+
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class WrappedDataset(Dataset):
|
| 151 |
+
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
| 152 |
+
|
| 153 |
+
def __init__(self, dataset):
|
| 154 |
+
self.data = dataset
|
| 155 |
+
|
| 156 |
+
def __len__(self):
|
| 157 |
+
return len(self.data)
|
| 158 |
+
|
| 159 |
+
def __getitem__(self, idx):
|
| 160 |
+
return self.data[idx]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def worker_init_fn(_):
|
| 164 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 165 |
+
|
| 166 |
+
dataset = worker_info.dataset
|
| 167 |
+
worker_id = worker_info.id
|
| 168 |
+
|
| 169 |
+
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
| 170 |
+
split_size = dataset.num_records // worker_info.num_workers
|
| 171 |
+
# reset num_records to the true number to retain reliable length information
|
| 172 |
+
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
| 173 |
+
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
| 174 |
+
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
| 175 |
+
else:
|
| 176 |
+
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
| 180 |
+
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
| 181 |
+
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
| 182 |
+
shuffle_val_dataloader=False):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.batch_size = batch_size
|
| 185 |
+
self.dataset_configs = dict()
|
| 186 |
+
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
| 187 |
+
self.use_worker_init_fn = use_worker_init_fn
|
| 188 |
+
if train is not None:
|
| 189 |
+
self.dataset_configs["train"] = train
|
| 190 |
+
self.train_dataloader = self._train_dataloader
|
| 191 |
+
if validation is not None:
|
| 192 |
+
self.dataset_configs["validation"] = validation
|
| 193 |
+
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
| 194 |
+
if test is not None:
|
| 195 |
+
self.dataset_configs["test"] = test
|
| 196 |
+
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
| 197 |
+
if predict is not None:
|
| 198 |
+
self.dataset_configs["predict"] = predict
|
| 199 |
+
self.predict_dataloader = self._predict_dataloader
|
| 200 |
+
self.wrap = wrap
|
| 201 |
+
|
| 202 |
+
def prepare_data(self):
|
| 203 |
+
for data_cfg in self.dataset_configs.values():
|
| 204 |
+
instantiate_from_config(data_cfg)
|
| 205 |
+
|
| 206 |
+
def setup(self, stage=None):
|
| 207 |
+
self.datasets = dict(
|
| 208 |
+
(k, instantiate_from_config(self.dataset_configs[k]))
|
| 209 |
+
for k in self.dataset_configs)
|
| 210 |
+
if self.wrap:
|
| 211 |
+
for k in self.datasets:
|
| 212 |
+
self.datasets[k] = WrappedDataset(self.datasets[k])
|
| 213 |
+
|
| 214 |
+
def _train_dataloader(self):
|
| 215 |
+
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
| 216 |
+
if is_iterable_dataset or self.use_worker_init_fn:
|
| 217 |
+
init_fn = worker_init_fn
|
| 218 |
+
else:
|
| 219 |
+
init_fn = None
|
| 220 |
+
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
| 221 |
+
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
| 222 |
+
worker_init_fn=init_fn)
|
| 223 |
+
|
| 224 |
+
def _val_dataloader(self, shuffle=False):
|
| 225 |
+
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
| 226 |
+
init_fn = worker_init_fn
|
| 227 |
+
else:
|
| 228 |
+
init_fn = None
|
| 229 |
+
return DataLoader(self.datasets["validation"],
|
| 230 |
+
batch_size=self.batch_size,
|
| 231 |
+
num_workers=self.num_workers,
|
| 232 |
+
worker_init_fn=init_fn,
|
| 233 |
+
shuffle=shuffle)
|
| 234 |
+
|
| 235 |
+
def _test_dataloader(self, shuffle=False):
|
| 236 |
+
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
| 237 |
+
if is_iterable_dataset or self.use_worker_init_fn:
|
| 238 |
+
init_fn = worker_init_fn
|
| 239 |
+
else:
|
| 240 |
+
init_fn = None
|
| 241 |
+
|
| 242 |
+
# do not shuffle dataloader for iterable dataset
|
| 243 |
+
shuffle = shuffle and (not is_iterable_dataset)
|
| 244 |
+
|
| 245 |
+
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
| 246 |
+
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
|
| 247 |
+
|
| 248 |
+
def _predict_dataloader(self, shuffle=False):
|
| 249 |
+
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
| 250 |
+
init_fn = worker_init_fn
|
| 251 |
+
else:
|
| 252 |
+
init_fn = None
|
| 253 |
+
return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
|
| 254 |
+
num_workers=self.num_workers, worker_init_fn=init_fn)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class SetupCallback(Callback):
|
| 258 |
+
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.resume = resume
|
| 261 |
+
self.now = now
|
| 262 |
+
self.logdir = logdir
|
| 263 |
+
self.ckptdir = ckptdir
|
| 264 |
+
self.cfgdir = cfgdir
|
| 265 |
+
self.config = config
|
| 266 |
+
self.lightning_config = lightning_config
|
| 267 |
+
|
| 268 |
+
def on_keyboard_interrupt(self, trainer, pl_module):
|
| 269 |
+
if trainer.global_rank == 0:
|
| 270 |
+
print("Summoning checkpoint.")
|
| 271 |
+
if hasattr(self.config, 'lora_config'):
|
| 272 |
+
ckpt_path = os.path.join(self.ckptdir, "lora_last.ckpt")
|
| 273 |
+
from lora.lora import save_lora_weight
|
| 274 |
+
save_lora_weight(trainer.model, path=ckpt_path)
|
| 275 |
+
else:
|
| 276 |
+
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
| 277 |
+
trainer.save_checkpoint(ckpt_path)
|
| 278 |
+
|
| 279 |
+
def on_pretrain_routine_start(self, trainer, pl_module):
|
| 280 |
+
if trainer.global_rank == 0:
|
| 281 |
+
# Create logdirs and save configs
|
| 282 |
+
os.makedirs(self.logdir, exist_ok=True)
|
| 283 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
| 284 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
| 285 |
+
|
| 286 |
+
if "callbacks" in self.lightning_config:
|
| 287 |
+
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
| 288 |
+
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
| 289 |
+
print("Project config")
|
| 290 |
+
print(OmegaConf.to_yaml(self.config))
|
| 291 |
+
OmegaConf.save(self.config,
|
| 292 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
| 293 |
+
|
| 294 |
+
print("Lightning config")
|
| 295 |
+
print(OmegaConf.to_yaml(self.lightning_config))
|
| 296 |
+
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
| 297 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
# ModelCheckpoint callback created log directory --- remove it
|
| 301 |
+
if not self.resume and os.path.exists(self.logdir):
|
| 302 |
+
dst, name = os.path.split(self.logdir)
|
| 303 |
+
dst = os.path.join(dst, "child_runs", name)
|
| 304 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
| 305 |
+
try:
|
| 306 |
+
os.rename(self.logdir, dst)
|
| 307 |
+
except FileNotFoundError:
|
| 308 |
+
pass
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class ImageLogger(Callback):
|
| 312 |
+
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
| 313 |
+
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
| 314 |
+
log_images_kwargs=None):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.rescale = rescale
|
| 317 |
+
self.batch_freq = batch_frequency
|
| 318 |
+
self.max_images = max_images
|
| 319 |
+
self.logger_log_images = {
|
| 320 |
+
pl.loggers.TestTubeLogger: self._testtube,
|
| 321 |
+
}
|
| 322 |
+
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
| 323 |
+
if not increase_log_steps:
|
| 324 |
+
self.log_steps = [self.batch_freq]
|
| 325 |
+
self.clamp = clamp
|
| 326 |
+
self.disabled = disabled
|
| 327 |
+
self.log_on_batch_idx = log_on_batch_idx
|
| 328 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
| 329 |
+
self.log_first_step = log_first_step
|
| 330 |
+
|
| 331 |
+
@rank_zero_only
|
| 332 |
+
def _testtube(self, pl_module, images, batch_idx, split):
|
| 333 |
+
for k in images:
|
| 334 |
+
grid = torchvision.utils.make_grid(images[k])
|
| 335 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
| 336 |
+
|
| 337 |
+
tag = f"{split}/{k}"
|
| 338 |
+
pl_module.logger.experiment.add_image(
|
| 339 |
+
tag, grid,
|
| 340 |
+
global_step=pl_module.global_step)
|
| 341 |
+
|
| 342 |
+
@rank_zero_only
|
| 343 |
+
def log_local(self, save_dir, split, images,
|
| 344 |
+
global_step, current_epoch, batch_idx):
|
| 345 |
+
root = os.path.join(save_dir, "images", split)
|
| 346 |
+
for k in images:
|
| 347 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
| 348 |
+
if self.rescale:
|
| 349 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
| 350 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 351 |
+
grid = grid.numpy()
|
| 352 |
+
grid = (grid * 255).astype(np.uint8)
|
| 353 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
| 354 |
+
k,
|
| 355 |
+
global_step,
|
| 356 |
+
current_epoch,
|
| 357 |
+
batch_idx)
|
| 358 |
+
path = os.path.join(root, filename)
|
| 359 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
| 360 |
+
Image.fromarray(grid).save(path)
|
| 361 |
+
|
| 362 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
| 363 |
+
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
| 364 |
+
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
| 365 |
+
hasattr(pl_module, "log_images") and
|
| 366 |
+
callable(pl_module.log_images) and
|
| 367 |
+
self.max_images > 0):
|
| 368 |
+
logger = type(pl_module.logger)
|
| 369 |
+
|
| 370 |
+
is_train = pl_module.training
|
| 371 |
+
if is_train:
|
| 372 |
+
pl_module.eval()
|
| 373 |
+
|
| 374 |
+
with torch.no_grad():
|
| 375 |
+
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
| 376 |
+
|
| 377 |
+
for k in images:
|
| 378 |
+
N = min(images[k].shape[0], self.max_images)
|
| 379 |
+
images[k] = images[k][:N]
|
| 380 |
+
if isinstance(images[k], torch.Tensor):
|
| 381 |
+
images[k] = images[k].detach().cpu()
|
| 382 |
+
if self.clamp:
|
| 383 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
| 384 |
+
|
| 385 |
+
self.log_local(pl_module.logger.save_dir, split, images,
|
| 386 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
| 387 |
+
|
| 388 |
+
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
| 389 |
+
logger_log_images(pl_module, images, pl_module.global_step, split)
|
| 390 |
+
|
| 391 |
+
if is_train:
|
| 392 |
+
pl_module.train()
|
| 393 |
+
|
| 394 |
+
def check_frequency(self, check_idx):
|
| 395 |
+
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
| 396 |
+
check_idx > 0 or self.log_first_step):
|
| 397 |
+
try:
|
| 398 |
+
self.log_steps.pop(0)
|
| 399 |
+
except IndexError as e:
|
| 400 |
+
print(e)
|
| 401 |
+
pass
|
| 402 |
+
return True
|
| 403 |
+
return False
|
| 404 |
+
|
| 405 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
| 406 |
+
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
| 407 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
| 408 |
+
|
| 409 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
| 410 |
+
if not self.disabled and pl_module.global_step > 0:
|
| 411 |
+
self.log_img(pl_module, batch, batch_idx, split="val")
|
| 412 |
+
if hasattr(pl_module, 'calibrate_grad_norm'):
|
| 413 |
+
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
| 414 |
+
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class CUDACallback(Callback):
|
| 418 |
+
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
| 419 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
| 420 |
+
# Reset the memory use counter
|
| 421 |
+
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
| 422 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
| 423 |
+
self.start_time = time.time()
|
| 424 |
+
|
| 425 |
+
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
| 426 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
| 427 |
+
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
|
| 428 |
+
epoch_time = time.time() - self.start_time
|
| 429 |
+
|
| 430 |
+
try:
|
| 431 |
+
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
| 432 |
+
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
| 433 |
+
|
| 434 |
+
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
| 435 |
+
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
| 436 |
+
except AttributeError:
|
| 437 |
+
pass
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
|
| 442 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
| 443 |
+
sys.path.append(os.getcwd())
|
| 444 |
+
|
| 445 |
+
parser = get_parser()
|
| 446 |
+
parser = Trainer.add_argparse_args(parser)
|
| 447 |
+
|
| 448 |
+
opt, unknown = parser.parse_known_args()
|
| 449 |
+
if opt.name and opt.resume:
|
| 450 |
+
raise ValueError(
|
| 451 |
+
"-n/--name and -r/--resume cannot be specified both."
|
| 452 |
+
"If you want to resume training in a new log folder, "
|
| 453 |
+
"use -n/--name in combination with --resume_from_checkpoint"
|
| 454 |
+
)
|
| 455 |
+
if opt.resume:
|
| 456 |
+
if not os.path.exists(opt.resume):
|
| 457 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
| 458 |
+
if os.path.isfile(opt.resume):
|
| 459 |
+
paths = opt.resume.split("/")
|
| 460 |
+
# idx = len(paths)-paths[::-1].index("logs")+1
|
| 461 |
+
# logdir = "/".join(paths[:idx])
|
| 462 |
+
logdir = "/".join(paths[:-2])
|
| 463 |
+
ckpt = opt.resume
|
| 464 |
+
else:
|
| 465 |
+
assert os.path.isdir(opt.resume), opt.resume
|
| 466 |
+
logdir = opt.resume.rstrip("/")
|
| 467 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
| 468 |
+
|
| 469 |
+
opt.resume_from_checkpoint = ckpt
|
| 470 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
| 471 |
+
opt.base = base_configs + opt.base
|
| 472 |
+
_tmp = logdir.split("/")
|
| 473 |
+
nowname = _tmp[-1]
|
| 474 |
+
else:
|
| 475 |
+
if opt.name:
|
| 476 |
+
name = "_" + opt.name
|
| 477 |
+
elif opt.base:
|
| 478 |
+
cfg_fname = os.path.split(opt.base[0])[-1]
|
| 479 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
| 480 |
+
name = "_" + cfg_name
|
| 481 |
+
else:
|
| 482 |
+
name = ""
|
| 483 |
+
nowname = now + name + opt.postfix
|
| 484 |
+
logdir = os.path.join(opt.logdir, nowname)
|
| 485 |
+
|
| 486 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
| 487 |
+
cfgdir = os.path.join(logdir, "configs")
|
| 488 |
+
seed_everything(opt.seed)
|
| 489 |
+
|
| 490 |
+
# try:
|
| 491 |
+
# init and save configs
|
| 492 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
| 493 |
+
cli = OmegaConf.from_dotlist(unknown)
|
| 494 |
+
config = OmegaConf.merge(*configs, cli)
|
| 495 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
| 496 |
+
# merge trainer cli with config
|
| 497 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
| 498 |
+
# default to ddp
|
| 499 |
+
trainer_config["accelerator"] = "ddp"
|
| 500 |
+
for k in nondefault_trainer_args(opt):
|
| 501 |
+
trainer_config[k] = getattr(opt, k)
|
| 502 |
+
if not "gpus" in trainer_config:
|
| 503 |
+
del trainer_config["accelerator"]
|
| 504 |
+
cpu = True
|
| 505 |
+
else:
|
| 506 |
+
gpuinfo = trainer_config["gpus"]
|
| 507 |
+
print(f"Running on GPUs {gpuinfo}")
|
| 508 |
+
cpu = False
|
| 509 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
| 510 |
+
lightning_config.trainer = trainer_config
|
| 511 |
+
|
| 512 |
+
# model
|
| 513 |
+
model = instantiate_from_config(config.model)
|
| 514 |
+
if not opt.resume:
|
| 515 |
+
if opt.train_from_scratch:
|
| 516 |
+
ckpt_file = torch.load(opt.pretrained_model, map_location='cpu')['state_dict']
|
| 517 |
+
ckpt_file = {key: value for key, value in ckpt_file.items() if not (key[:6] == 'model.')}
|
| 518 |
+
model.load_state_dict(ckpt_file, strict=False)
|
| 519 |
+
print("Train from scratch!")
|
| 520 |
+
else:
|
| 521 |
+
model.load_state_dict(torch.load(opt.pretrained_model, map_location='cpu')['state_dict'], strict=False)
|
| 522 |
+
print("Load Stable Diffusion v1-4!")
|
| 523 |
+
|
| 524 |
+
# lora
|
| 525 |
+
if hasattr(config, 'lora_config'):
|
| 526 |
+
model.eval()
|
| 527 |
+
model._requires_grad = False
|
| 528 |
+
from lora.lora import inject_trainable_lora_extended
|
| 529 |
+
|
| 530 |
+
params, names = inject_trainable_lora_extended(model, r=config.lora_config.rank)
|
| 531 |
+
|
| 532 |
+
model.requires_grad_(False)
|
| 533 |
+
for name, param in model.named_parameters():
|
| 534 |
+
if "diffusion_model.output_blocks" in name and "transformer_blocks" in name:
|
| 535 |
+
param.requires_grad = True
|
| 536 |
+
if "local_controlnet" in name or "pose" in name:
|
| 537 |
+
param.requires_grad = True
|
| 538 |
+
# 打开一个文件来写入模块名称
|
| 539 |
+
with open("module_names.txt", "w") as file:
|
| 540 |
+
# 遍历模型的所有模块并将名称写入文件
|
| 541 |
+
for name, param in model.named_parameters():
|
| 542 |
+
if param.requires_grad == True:
|
| 543 |
+
file.write(name + "\n")
|
| 544 |
+
|
| 545 |
+
# trainer and callbacks
|
| 546 |
+
trainer_kwargs = dict()
|
| 547 |
+
|
| 548 |
+
# default logger configs
|
| 549 |
+
default_logger_cfgs = {
|
| 550 |
+
"wandb": {
|
| 551 |
+
"target": "pytorch_lightning.loggers.WandbLogger",
|
| 552 |
+
"params": {
|
| 553 |
+
"name": nowname,
|
| 554 |
+
"save_dir": logdir,
|
| 555 |
+
"offline": opt.debug,
|
| 556 |
+
"id": nowname,
|
| 557 |
+
}
|
| 558 |
+
},
|
| 559 |
+
"testtube": {
|
| 560 |
+
"target": "pytorch_lightning.loggers.TestTubeLogger",
|
| 561 |
+
"params": {
|
| 562 |
+
"name": "testtube",
|
| 563 |
+
"save_dir": logdir,
|
| 564 |
+
}
|
| 565 |
+
},
|
| 566 |
+
}
|
| 567 |
+
default_logger_cfg = default_logger_cfgs["testtube"]
|
| 568 |
+
if "logger" in lightning_config:
|
| 569 |
+
logger_cfg = lightning_config.logger
|
| 570 |
+
else:
|
| 571 |
+
logger_cfg = OmegaConf.create()
|
| 572 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
| 573 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
| 574 |
+
|
| 575 |
+
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
| 576 |
+
# specify which metric is used to determine best models
|
| 577 |
+
default_modelckpt_cfg = {
|
| 578 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
| 579 |
+
"params": {
|
| 580 |
+
"dirpath": ckptdir,
|
| 581 |
+
"filename": "{epoch:06}",
|
| 582 |
+
"verbose": True,
|
| 583 |
+
"save_last": False,
|
| 584 |
+
"every_n_epochs": 1
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
if hasattr(model, "monitor"):
|
| 588 |
+
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
| 589 |
+
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
| 590 |
+
default_modelckpt_cfg["params"]["save_top_k"] = 30
|
| 591 |
+
|
| 592 |
+
if "modelcheckpoint" in lightning_config:
|
| 593 |
+
modelckpt_cfg = lightning_config.modelcheckpoint
|
| 594 |
+
else:
|
| 595 |
+
modelckpt_cfg = OmegaConf.create()
|
| 596 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
| 597 |
+
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
| 598 |
+
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
| 599 |
+
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
| 600 |
+
|
| 601 |
+
# add callback which sets up log directory
|
| 602 |
+
default_callbacks_cfg = {
|
| 603 |
+
"setup_callback": {
|
| 604 |
+
"target": "main.SetupCallback",
|
| 605 |
+
"params": {
|
| 606 |
+
"resume": opt.resume,
|
| 607 |
+
"now": now,
|
| 608 |
+
"logdir": logdir,
|
| 609 |
+
"ckptdir": ckptdir,
|
| 610 |
+
"cfgdir": cfgdir,
|
| 611 |
+
"config": config,
|
| 612 |
+
"lightning_config": lightning_config,
|
| 613 |
+
}
|
| 614 |
+
},
|
| 615 |
+
"image_logger": {
|
| 616 |
+
"target": "main.ImageLogger",
|
| 617 |
+
"params": {
|
| 618 |
+
"batch_frequency": 500,
|
| 619 |
+
"max_images": 4,
|
| 620 |
+
"clamp": True
|
| 621 |
+
}
|
| 622 |
+
},
|
| 623 |
+
"learning_rate_logger": {
|
| 624 |
+
"target": "main.LearningRateMonitor",
|
| 625 |
+
"params": {
|
| 626 |
+
"logging_interval": "step",
|
| 627 |
+
# "log_momentum": True
|
| 628 |
+
}
|
| 629 |
+
},
|
| 630 |
+
"cuda_callback": {
|
| 631 |
+
"target": "main.CUDACallback"
|
| 632 |
+
},
|
| 633 |
+
}
|
| 634 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
| 635 |
+
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
| 636 |
+
|
| 637 |
+
if "callbacks" in lightning_config:
|
| 638 |
+
callbacks_cfg = lightning_config.callbacks
|
| 639 |
+
else:
|
| 640 |
+
callbacks_cfg = OmegaConf.create()
|
| 641 |
+
|
| 642 |
+
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
| 643 |
+
print(
|
| 644 |
+
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
| 645 |
+
default_metrics_over_trainsteps_ckpt_dict = {
|
| 646 |
+
'metrics_over_trainsteps_checkpoint':
|
| 647 |
+
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
| 648 |
+
'params': {
|
| 649 |
+
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
| 650 |
+
"filename": "{epoch:06}-{step:09}",
|
| 651 |
+
"verbose": True,
|
| 652 |
+
'save_top_k': -1,
|
| 653 |
+
'every_n_train_steps': 10000,
|
| 654 |
+
'save_weights_only': True
|
| 655 |
+
}
|
| 656 |
+
}
|
| 657 |
+
}
|
| 658 |
+
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
| 659 |
+
|
| 660 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
| 661 |
+
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
| 662 |
+
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
| 663 |
+
elif 'ignore_keys_callback' in callbacks_cfg:
|
| 664 |
+
del callbacks_cfg['ignore_keys_callback']
|
| 665 |
+
|
| 666 |
+
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
| 667 |
+
|
| 668 |
+
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
| 669 |
+
# trainer.plugins = [MyCluster()]
|
| 670 |
+
trainer.logdir = logdir ###
|
| 671 |
+
|
| 672 |
+
# data
|
| 673 |
+
data = instantiate_from_config(config.data)
|
| 674 |
+
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
| 675 |
+
# calling these ourselves should not be necessary but it is.
|
| 676 |
+
# lightning still takes care of proper multiprocessing though
|
| 677 |
+
data.prepare_data()
|
| 678 |
+
data.setup()
|
| 679 |
+
print("#### Data #####")
|
| 680 |
+
for k in data.datasets:
|
| 681 |
+
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
| 682 |
+
|
| 683 |
+
# configure learning rate
|
| 684 |
+
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
| 685 |
+
if not cpu:
|
| 686 |
+
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
|
| 687 |
+
else:
|
| 688 |
+
ngpu = 1
|
| 689 |
+
if 'accumulate_grad_batches' in lightning_config.trainer:
|
| 690 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
| 691 |
+
else:
|
| 692 |
+
accumulate_grad_batches = 1
|
| 693 |
+
# if 'num_nodes' in lightning_config.trainer:
|
| 694 |
+
# num_nodes = lightning_config.trainer.num_nodes
|
| 695 |
+
# else:
|
| 696 |
+
num_nodes = 1
|
| 697 |
+
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
| 698 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
| 699 |
+
if opt.scale_lr:
|
| 700 |
+
model.learning_rate = accumulate_grad_batches * num_nodes * ngpu * bs * base_lr
|
| 701 |
+
print(
|
| 702 |
+
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_nodes) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
| 703 |
+
model.learning_rate, accumulate_grad_batches, num_nodes, ngpu, bs, base_lr))
|
| 704 |
+
else:
|
| 705 |
+
model.learning_rate = base_lr
|
| 706 |
+
print("++++ NOT USING LR SCALING ++++")
|
| 707 |
+
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
# allow checkpointing via USR1
|
| 711 |
+
def melk(*args, **kwargs):
|
| 712 |
+
# run all checkpoint hooks
|
| 713 |
+
if trainer.global_rank == 0:
|
| 714 |
+
print("Summoning checkpoint.")
|
| 715 |
+
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
| 716 |
+
trainer.save_checkpoint(ckpt_path)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def divein(*args, **kwargs):
|
| 720 |
+
if trainer.global_rank == 0:
|
| 721 |
+
import pudb
|
| 722 |
+
pudb.set_trace()
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
import signal
|
| 726 |
+
|
| 727 |
+
signal.signal(signal.SIGUSR1, melk)
|
| 728 |
+
signal.signal(signal.SIGUSR2, divein)
|
| 729 |
+
|
| 730 |
+
# run
|
| 731 |
+
if opt.train:
|
| 732 |
+
try:
|
| 733 |
+
trainer.fit(model, data)
|
| 734 |
+
except Exception:
|
| 735 |
+
melk()
|
| 736 |
+
raise
|
| 737 |
+
if not opt.no_test and not trainer.interrupted:
|
| 738 |
+
trainer.test(model, data)
|
test.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, os, sys, glob
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from tqdm import tqdm, trange
|
| 10 |
+
from itertools import islice
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from torchvision.utils import make_grid
|
| 13 |
+
import time
|
| 14 |
+
from pytorch_lightning import seed_everything
|
| 15 |
+
from torch import autocast
|
| 16 |
+
from contextlib import contextmanager, nullcontext
|
| 17 |
+
import torchvision
|
| 18 |
+
|
| 19 |
+
from ldm.data.cp_dataset import CPDataset
|
| 20 |
+
from ldm.resizer import Resizer
|
| 21 |
+
from ldm.util import instantiate_from_config
|
| 22 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 23 |
+
from ldm.models.diffusion.plms import PLMSSampler
|
| 24 |
+
from ldm.data.deepfashions import DFPairDataset
|
| 25 |
+
|
| 26 |
+
import clip
|
| 27 |
+
from torchvision.transforms import Resize
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def chunk(it, size):
|
| 31 |
+
it = iter(it)
|
| 32 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_tensor_clip(normalize=True, toTensor=True):
|
| 36 |
+
transform_list = []
|
| 37 |
+
if toTensor:
|
| 38 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
| 39 |
+
|
| 40 |
+
if normalize:
|
| 41 |
+
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
| 42 |
+
(0.26862954, 0.26130258, 0.27577711))]
|
| 43 |
+
return torchvision.transforms.Compose(transform_list)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def numpy_to_pil(images):
|
| 47 |
+
"""
|
| 48 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 49 |
+
"""
|
| 50 |
+
if images.ndim == 3:
|
| 51 |
+
images = images[None, ...]
|
| 52 |
+
images = (images * 255).round().astype("uint8")
|
| 53 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 54 |
+
|
| 55 |
+
return pil_images
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_model_from_config(config, ckpt, verbose=False):
|
| 59 |
+
print(f"Loading model from {ckpt}")
|
| 60 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 61 |
+
if "global_step" in pl_sd:
|
| 62 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
| 63 |
+
sd = pl_sd["state_dict"]
|
| 64 |
+
model = instantiate_from_config(config.model)
|
| 65 |
+
m, u = model.load_state_dict(sd, strict=False)
|
| 66 |
+
if len(m) > 0 and verbose:
|
| 67 |
+
print("missing keys:")
|
| 68 |
+
print(m)
|
| 69 |
+
if len(u) > 0 and verbose:
|
| 70 |
+
print("unexpected keys:")
|
| 71 |
+
print(u)
|
| 72 |
+
|
| 73 |
+
model.cuda()
|
| 74 |
+
model.eval()
|
| 75 |
+
return model
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def put_watermark(img, wm_encoder=None):
|
| 79 |
+
if wm_encoder is not None:
|
| 80 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 81 |
+
img = wm_encoder.encode(img, 'dwtDct')
|
| 82 |
+
img = Image.fromarray(img[:, :, ::-1])
|
| 83 |
+
return img
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def load_replacement(x):
|
| 87 |
+
try:
|
| 88 |
+
hwc = x.shape
|
| 89 |
+
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
| 90 |
+
y = (np.array(y) / 255.0).astype(x.dtype)
|
| 91 |
+
assert y.shape == x.shape
|
| 92 |
+
return y
|
| 93 |
+
except Exception:
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_tensor(normalize=True, toTensor=True):
|
| 98 |
+
transform_list = []
|
| 99 |
+
if toTensor:
|
| 100 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
| 101 |
+
|
| 102 |
+
if normalize:
|
| 103 |
+
transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
|
| 104 |
+
(0.5, 0.5, 0.5))]
|
| 105 |
+
return torchvision.transforms.Compose(transform_list)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_tensor_clip(normalize=True, toTensor=True):
|
| 109 |
+
transform_list = []
|
| 110 |
+
if toTensor:
|
| 111 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
| 112 |
+
|
| 113 |
+
if normalize:
|
| 114 |
+
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
| 115 |
+
(0.26862954, 0.26130258, 0.27577711))]
|
| 116 |
+
return torchvision.transforms.Compose(transform_list)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def main():
|
| 120 |
+
parser = argparse.ArgumentParser()
|
| 121 |
+
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--outdir",
|
| 124 |
+
type=str,
|
| 125 |
+
nargs="?",
|
| 126 |
+
help="dir to write results to",
|
| 127 |
+
default="outputs/txt2img-samples"
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--skip_grid",
|
| 131 |
+
action='store_true',
|
| 132 |
+
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--skip_save",
|
| 136 |
+
action='store_true',
|
| 137 |
+
help="do not save individual samples. For speed measurements.",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--gpu_id",
|
| 141 |
+
type=int,
|
| 142 |
+
default=0,
|
| 143 |
+
help="which gpu to use",
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--ddim_steps",
|
| 147 |
+
type=int,
|
| 148 |
+
default=30,
|
| 149 |
+
help="number of ddim sampling steps",
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--plms",
|
| 153 |
+
action='store_true',
|
| 154 |
+
help="use plms sampling",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--fixed_code",
|
| 158 |
+
action='store_true',
|
| 159 |
+
help="if enabled, uses the same starting code across samples ",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--ddim_eta",
|
| 163 |
+
type=float,
|
| 164 |
+
default=0.0,
|
| 165 |
+
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--n_iter",
|
| 169 |
+
type=int,
|
| 170 |
+
default=2,
|
| 171 |
+
help="sample this often",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--H",
|
| 175 |
+
type=int,
|
| 176 |
+
default=512,
|
| 177 |
+
help="image height, in pixel space",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--W",
|
| 181 |
+
type=int,
|
| 182 |
+
default=512,
|
| 183 |
+
help="image width, in pixel space",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--n_imgs",
|
| 187 |
+
type=int,
|
| 188 |
+
default=100,
|
| 189 |
+
help="image width, in pixel space",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--C",
|
| 193 |
+
type=int,
|
| 194 |
+
default=4,
|
| 195 |
+
help="latent channels",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--f",
|
| 199 |
+
type=int,
|
| 200 |
+
default=8,
|
| 201 |
+
help="downsampling factor",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--n_samples",
|
| 205 |
+
type=int,
|
| 206 |
+
default=1,
|
| 207 |
+
help="how many samples to produce for each given reference image. A.k.a. batch size",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--n_rows",
|
| 211 |
+
type=int,
|
| 212 |
+
default=0,
|
| 213 |
+
help="rows in the grid (default: n_samples)",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--scale",
|
| 217 |
+
type=float,
|
| 218 |
+
default=1,
|
| 219 |
+
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--config",
|
| 223 |
+
type=str,
|
| 224 |
+
default="",
|
| 225 |
+
help="path to config which constructs model",
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--ckpt",
|
| 229 |
+
type=str,
|
| 230 |
+
default="",
|
| 231 |
+
help="path to checkpoint of model",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--seed",
|
| 235 |
+
type=int,
|
| 236 |
+
default=42,
|
| 237 |
+
help="the seed (for reproducible sampling)",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--precision",
|
| 241 |
+
type=str,
|
| 242 |
+
help="evaluate at this precision",
|
| 243 |
+
choices=["full", "autocast"],
|
| 244 |
+
default="autocast"
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--unpaired",
|
| 248 |
+
action='store_true',
|
| 249 |
+
help="if enabled, uses the same starting code across samples "
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--dataroot",
|
| 253 |
+
type=str,
|
| 254 |
+
help="path to dataroot of the dataset",
|
| 255 |
+
default=""
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
opt = parser.parse_args()
|
| 259 |
+
|
| 260 |
+
seed_everything(opt.seed)
|
| 261 |
+
|
| 262 |
+
device = torch.device("cuda:{}".format(opt.gpu_id)) if torch.cuda.is_available() else torch.device("cpu")
|
| 263 |
+
torch.cuda.set_device(device)
|
| 264 |
+
|
| 265 |
+
config = OmegaConf.load(f"{opt.config}")
|
| 266 |
+
version = opt.config.split('/')[-1].split('.')[0]
|
| 267 |
+
model = load_model_from_config(config, f"{opt.ckpt}")
|
| 268 |
+
|
| 269 |
+
# model = model.to(device)
|
| 270 |
+
dataset = CPDataset(opt.dataroot, opt.H, mode='test', unpaired=opt.unpaired)
|
| 271 |
+
loader = DataLoader(dataset, batch_size=opt.n_samples, shuffle=False, num_workers=4, pin_memory=True)
|
| 272 |
+
if opt.plms:
|
| 273 |
+
sampler = PLMSSampler(model)
|
| 274 |
+
else:
|
| 275 |
+
sampler = DDIMSampler(model)
|
| 276 |
+
|
| 277 |
+
os.makedirs(opt.outdir, exist_ok=True)
|
| 278 |
+
outpath = opt.outdir
|
| 279 |
+
|
| 280 |
+
result_path = os.path.join(outpath, "upper_body")
|
| 281 |
+
os.makedirs(result_path, exist_ok=True)
|
| 282 |
+
|
| 283 |
+
start_code = None
|
| 284 |
+
if opt.fixed_code:
|
| 285 |
+
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
| 286 |
+
|
| 287 |
+
iterator = tqdm(loader, desc='Test Dataset', total=len(loader))
|
| 288 |
+
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
with precision_scope("cuda"):
|
| 291 |
+
with model.ema_scope():
|
| 292 |
+
for data in iterator:
|
| 293 |
+
mask_tensor = data['inpaint_mask']
|
| 294 |
+
inpaint_image = data['inpaint_image']
|
| 295 |
+
ref_tensor_f = data['ref_imgs_f']
|
| 296 |
+
ref_tensor_b = data['ref_imgs_b']
|
| 297 |
+
skeleton_cf = data['skeleton_cf']
|
| 298 |
+
skeleton_cb = data['skeleton_cb']
|
| 299 |
+
skeleton_p = data['skeleton_p']
|
| 300 |
+
order = data['order']
|
| 301 |
+
feat_tensor = data['warp_feat']
|
| 302 |
+
image_tensor = data['GT']
|
| 303 |
+
|
| 304 |
+
controlnet_cond_f = data['controlnet_cond_f']
|
| 305 |
+
controlnet_cond_b = data['controlnet_cond_b']
|
| 306 |
+
|
| 307 |
+
ref_tensor = ref_tensor_f
|
| 308 |
+
for i in range(len(order)):
|
| 309 |
+
if order[i] == "1" or order[i] == "2":
|
| 310 |
+
continue
|
| 311 |
+
elif order[i] == "3":
|
| 312 |
+
ref_tensor[i] = ref_tensor_b[i]
|
| 313 |
+
else:
|
| 314 |
+
raise ValueError("Invalid order")
|
| 315 |
+
|
| 316 |
+
# filename = data['file_name']
|
| 317 |
+
|
| 318 |
+
test_model_kwargs = {}
|
| 319 |
+
test_model_kwargs['inpaint_mask'] = mask_tensor.to(device)
|
| 320 |
+
test_model_kwargs['inpaint_image'] = inpaint_image.to(device)
|
| 321 |
+
feat_tensor = feat_tensor.to(device)
|
| 322 |
+
ref_tensor = ref_tensor.to(device)
|
| 323 |
+
|
| 324 |
+
controlnet_cond_f = controlnet_cond_f.to(device)
|
| 325 |
+
controlnet_cond_b = controlnet_cond_b.to(device)
|
| 326 |
+
skeleton_cf = skeleton_cf.to(device)
|
| 327 |
+
skeleton_cb = skeleton_cb.to(device)
|
| 328 |
+
skeleton_p = skeleton_p.to(device)
|
| 329 |
+
|
| 330 |
+
uc = None
|
| 331 |
+
if opt.scale != 1.0:
|
| 332 |
+
uc = model.learnable_vector
|
| 333 |
+
uc = uc.repeat(ref_tensor.size(0), 1, 1)
|
| 334 |
+
c = model.get_learned_conditioning(ref_tensor.to(torch.float16))
|
| 335 |
+
c = model.proj_out(c)
|
| 336 |
+
|
| 337 |
+
# z_gt = model.encode_first_stage(image_tensor.to(device))
|
| 338 |
+
# z_gt = model.get_first_stage_encoding(z_gt).detach()
|
| 339 |
+
|
| 340 |
+
z_inpaint = model.encode_first_stage(test_model_kwargs['inpaint_image'])
|
| 341 |
+
z_inpaint = model.get_first_stage_encoding(z_inpaint).detach()
|
| 342 |
+
test_model_kwargs['inpaint_image'] = z_inpaint
|
| 343 |
+
test_model_kwargs['inpaint_mask'] = Resize([z_inpaint.shape[-2], z_inpaint.shape[-1]])(
|
| 344 |
+
test_model_kwargs['inpaint_mask'])
|
| 345 |
+
|
| 346 |
+
warp_feat = model.encode_first_stage(feat_tensor)
|
| 347 |
+
warp_feat = model.get_first_stage_encoding(warp_feat).detach()
|
| 348 |
+
|
| 349 |
+
ts = torch.full((1,), 999, device=device, dtype=torch.long)
|
| 350 |
+
start_code = model.q_sample(warp_feat, ts)
|
| 351 |
+
|
| 352 |
+
# local_controlnet
|
| 353 |
+
ehs_cf = model.pose_model(skeleton_cf)
|
| 354 |
+
ehs_cb = model.pose_model(skeleton_cb)
|
| 355 |
+
ehs_p = model.pose_model(skeleton_p)
|
| 356 |
+
ehs_text = torch.zeros((c.shape[0], 1, 768)).to("cuda")
|
| 357 |
+
# controlnet_cond = torch.cat((controlnet_cond_f, controlnet_cond_b, ehs_cf, ehs_cb, ehs_p), dim=1)
|
| 358 |
+
x_noisy = torch.cat(
|
| 359 |
+
(start_code, test_model_kwargs['inpaint_image'], test_model_kwargs['inpaint_mask']), dim=1)
|
| 360 |
+
|
| 361 |
+
down_samples_f, mid_samples_f = model.local_controlnet(x_noisy, ts,
|
| 362 |
+
encoder_hidden_states=ehs_text.to("cuda"), controlnet_cond=controlnet_cond_f, ehs_c=ehs_cf, ehs_p=ehs_p)
|
| 363 |
+
down_samples_b, mid_samples_b = model.local_controlnet(x_noisy, ts,
|
| 364 |
+
encoder_hidden_states=ehs_text.to("cuda"), controlnet_cond=controlnet_cond_b, ehs_c=ehs_cb, ehs_p=ehs_p)
|
| 365 |
+
|
| 366 |
+
# print(torch.max(down_samples_f[0]))
|
| 367 |
+
# print(torch.min(down_samples_f[0]))
|
| 368 |
+
|
| 369 |
+
# normalized_tensor = (down_samples_f[0] + 1) / 2
|
| 370 |
+
|
| 371 |
+
# # 将张量值范围从[0,1]转换到[0,255]
|
| 372 |
+
# scaled_tensor = normalized_tensor * 255
|
| 373 |
+
|
| 374 |
+
# # 将张量转换为NumPy数组
|
| 375 |
+
# numpy_array = scaled_tensor.squeeze().cpu().numpy().astype(np.uint8)
|
| 376 |
+
|
| 377 |
+
# # 将NumPy数组转换为PIL图像
|
| 378 |
+
# image = Image.fromarray(numpy_array)
|
| 379 |
+
|
| 380 |
+
# # 保存图像
|
| 381 |
+
# image.save("down_samples_f.jpg")
|
| 382 |
+
|
| 383 |
+
# normalized_tensor = (down_samples_b[0] + 1) / 2
|
| 384 |
+
|
| 385 |
+
# # 将张量值范围从[0,1]转换到[0,255]
|
| 386 |
+
# scaled_tensor = normalized_tensor * 255
|
| 387 |
+
|
| 388 |
+
# # 将张量转换为NumPy数组
|
| 389 |
+
# numpy_array = scaled_tensor.squeeze().cpu().numpy().astype(np.uint8)
|
| 390 |
+
|
| 391 |
+
# # 将NumPy数组转换为PIL图像
|
| 392 |
+
# image = Image.fromarray(numpy_array)
|
| 393 |
+
|
| 394 |
+
# # 保存图像
|
| 395 |
+
# image.save("down_samples_b.jpg")
|
| 396 |
+
|
| 397 |
+
mid_samples = mid_samples_f + mid_samples_b
|
| 398 |
+
down_samples = ()
|
| 399 |
+
for ds in range(len(down_samples_f)):
|
| 400 |
+
tmp = torch.cat((down_samples_f[ds], down_samples_b[ds]), dim=1)
|
| 401 |
+
down_samples = down_samples + (tmp,)
|
| 402 |
+
|
| 403 |
+
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
| 404 |
+
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
| 405 |
+
conditioning=c,
|
| 406 |
+
batch_size=opt.n_samples,
|
| 407 |
+
shape=shape,
|
| 408 |
+
verbose=False,
|
| 409 |
+
unconditional_guidance_scale=opt.scale,
|
| 410 |
+
unconditional_conditioning=uc,
|
| 411 |
+
eta=opt.ddim_eta,
|
| 412 |
+
x_T=start_code,
|
| 413 |
+
down_samples=down_samples,
|
| 414 |
+
test_model_kwargs=test_model_kwargs)
|
| 415 |
+
|
| 416 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
| 417 |
+
x_sample_result = x_samples_ddim
|
| 418 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
| 419 |
+
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
| 420 |
+
|
| 421 |
+
x_checked_image = x_samples_ddim
|
| 422 |
+
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
| 423 |
+
x_source = torch.clamp((image_tensor + 1.0) / 2.0, min=0.0, max=1.0)
|
| 424 |
+
x_result = x_checked_image_torch * (1 - mask_tensor) + mask_tensor * x_source
|
| 425 |
+
# x_result = x_checked_image_torch
|
| 426 |
+
|
| 427 |
+
resize = transforms.Resize((opt.H, int(opt.H / 256 * 192)))
|
| 428 |
+
|
| 429 |
+
if not opt.skip_save:
|
| 430 |
+
|
| 431 |
+
def un_norm(x):
|
| 432 |
+
return (x + 1.0) / 2.0
|
| 433 |
+
|
| 434 |
+
for i, x_sample in enumerate(x_result):
|
| 435 |
+
filename = data['file_name'][i]
|
| 436 |
+
# filename = data['file_name']
|
| 437 |
+
save_x = resize(x_sample)
|
| 438 |
+
save_x = 255. * rearrange(save_x.cpu().numpy(), 'c h w -> h w c')
|
| 439 |
+
img = Image.fromarray(save_x.astype(np.uint8))
|
| 440 |
+
img.save(os.path.join(result_path, filename[:-4] + ".png"))
|
| 441 |
+
|
| 442 |
+
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
| 443 |
+
f" \nEnjoy.")
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
main()
|
test.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=3 python test.py --gpu_id 0 \
|
| 2 |
+
--ddim_steps 50 \
|
| 3 |
+
--outdir results/try/ \
|
| 4 |
+
--config configs/viton512.yaml \
|
| 5 |
+
--dataroot /datasets/NVG \
|
| 6 |
+
--ckpt checkpoints/mvg.ckpt \
|
| 7 |
+
--n_samples 1 \
|
| 8 |
+
--seed 23 \
|
| 9 |
+
--scale 1 \
|
| 10 |
+
--H 512 \
|
| 11 |
+
--W 384
|
| 12 |
+
|
| 13 |
+
#!/bin/bash
|
train.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=4,5 python -u main.py --logdir models/oc --pretrained_model checkpoints/model.ckpt --base configs/viton512.yaml --scale_lr False
|