diff --git a/modules/__pycache__/call_queue.cpython-310.pyc b/modules/__pycache__/call_queue.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b2ccc988fc8e87652c577a2bdebbb4da9e5353c
Binary files /dev/null and b/modules/__pycache__/call_queue.cpython-310.pyc differ
diff --git a/modules/__pycache__/codeformer_model.cpython-310.pyc b/modules/__pycache__/codeformer_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ac51265399ac38b4e7dceb7b688243f983deeee
Binary files /dev/null and b/modules/__pycache__/codeformer_model.cpython-310.pyc differ
diff --git a/modules/__pycache__/deepbooru.cpython-310.pyc b/modules/__pycache__/deepbooru.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..670e0e10f53f870973d8a9214b85dbb2f093fac1
Binary files /dev/null and b/modules/__pycache__/deepbooru.cpython-310.pyc differ
diff --git a/modules/__pycache__/deepbooru_model.cpython-310.pyc b/modules/__pycache__/deepbooru_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb0657513026a633f8b76409d42dd4b6df98eb9e
Binary files /dev/null and b/modules/__pycache__/deepbooru_model.cpython-310.pyc differ
diff --git a/modules/__pycache__/devices.cpython-310.pyc b/modules/__pycache__/devices.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cae3334a22d6687741148f657fca9c99d1b5dbb
Binary files /dev/null and b/modules/__pycache__/devices.cpython-310.pyc differ
diff --git a/modules/__pycache__/errors.cpython-310.pyc b/modules/__pycache__/errors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbeaed671098ab280fbcdabac5621aec8b116695
Binary files /dev/null and b/modules/__pycache__/errors.cpython-310.pyc differ
diff --git a/modules/__pycache__/errors.cpython-311.pyc b/modules/__pycache__/errors.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96ce9c6bd5754c18355c4b0c033239ad923ca1fd
Binary files /dev/null and b/modules/__pycache__/errors.cpython-311.pyc differ
diff --git a/modules/__pycache__/esrgan_model.cpython-310.pyc b/modules/__pycache__/esrgan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3dee8362487293e588cc65dd90bd0cf05c8c50b
Binary files /dev/null and b/modules/__pycache__/esrgan_model.cpython-310.pyc differ
diff --git a/modules/__pycache__/esrgan_model_arch.cpython-310.pyc b/modules/__pycache__/esrgan_model_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a0cb4869fedf942afd501571b7affa75852c7ad
Binary files /dev/null and b/modules/__pycache__/esrgan_model_arch.cpython-310.pyc differ
diff --git a/modules/__pycache__/extensions.cpython-310.pyc b/modules/__pycache__/extensions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4948a28b11dec190b0cdce252d84988bf39cb1a
Binary files /dev/null and b/modules/__pycache__/extensions.cpython-310.pyc differ
diff --git a/modules/__pycache__/extra_networks.cpython-310.pyc b/modules/__pycache__/extra_networks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e146bb16667d2260bbef6d5b9b5656ee69a63456
Binary files /dev/null and b/modules/__pycache__/extra_networks.cpython-310.pyc differ
diff --git a/modules/__pycache__/extra_networks_hypernet.cpython-310.pyc b/modules/__pycache__/extra_networks_hypernet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2f96dd1cf4fd2667a9bc8be4af8a096c170cd4f
Binary files /dev/null and b/modules/__pycache__/extra_networks_hypernet.cpython-310.pyc differ
diff --git a/modules/__pycache__/extras.cpython-310.pyc b/modules/__pycache__/extras.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24cbd1726ffd567da1b59c4b5800a4b51b4f5e08
Binary files /dev/null and b/modules/__pycache__/extras.cpython-310.pyc differ
diff --git a/modules/__pycache__/face_restoration.cpython-310.pyc b/modules/__pycache__/face_restoration.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87372fed55828f1b096bee886b74534ddec4b2d2
Binary files /dev/null and b/modules/__pycache__/face_restoration.cpython-310.pyc differ
diff --git a/modules/__pycache__/generation_parameters_copypaste.cpython-310.pyc b/modules/__pycache__/generation_parameters_copypaste.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb1a5f08420059311a6012741cc8de87b73a807c
Binary files /dev/null and b/modules/__pycache__/generation_parameters_copypaste.cpython-310.pyc differ
diff --git a/modules/__pycache__/gfpgan_model.cpython-310.pyc b/modules/__pycache__/gfpgan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83f5f3035e60bee40e8118a10a5fbdd087caadb2
Binary files /dev/null and b/modules/__pycache__/gfpgan_model.cpython-310.pyc differ
diff --git a/modules/__pycache__/hashes.cpython-310.pyc b/modules/__pycache__/hashes.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..541001d19c1a2bc159cf134c4f76c40744e19698
Binary files /dev/null and b/modules/__pycache__/hashes.cpython-310.pyc differ
diff --git a/modules/__pycache__/images.cpython-310.pyc b/modules/__pycache__/images.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52f815cef182a732a1a9a0aeb7d74aa061f02ead
Binary files /dev/null and b/modules/__pycache__/images.cpython-310.pyc differ
diff --git a/modules/__pycache__/img2img.cpython-310.pyc b/modules/__pycache__/img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09ec4b64eb4f5e76af04183acb3e88ce1b0f24f5
Binary files /dev/null and b/modules/__pycache__/img2img.cpython-310.pyc differ
diff --git a/modules/__pycache__/import_hook.cpython-310.pyc b/modules/__pycache__/import_hook.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0b86adffa0cd3fcbb04b64bc7deb7f6feb036d0
Binary files /dev/null and b/modules/__pycache__/import_hook.cpython-310.pyc differ
diff --git a/modules/__pycache__/interrogate.cpython-310.pyc b/modules/__pycache__/interrogate.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0d6078f96e53322ccd95d1f4b246086327d160e
Binary files /dev/null and b/modules/__pycache__/interrogate.cpython-310.pyc differ
diff --git a/modules/__pycache__/localization.cpython-310.pyc b/modules/__pycache__/localization.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..929442984544c2ead95002b499eb5e83870a1090
Binary files /dev/null and b/modules/__pycache__/localization.cpython-310.pyc differ
diff --git a/modules/__pycache__/lowvram.cpython-310.pyc b/modules/__pycache__/lowvram.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04418a8f567464f17e8c8165bf1c6d92ad383a8d
Binary files /dev/null and b/modules/__pycache__/lowvram.cpython-310.pyc differ
diff --git a/modules/__pycache__/masking.cpython-310.pyc b/modules/__pycache__/masking.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c1b1511be360b5fce2678400f7530351c98b55d
Binary files /dev/null and b/modules/__pycache__/masking.cpython-310.pyc differ
diff --git a/modules/__pycache__/memmon.cpython-310.pyc b/modules/__pycache__/memmon.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f5222f91605cb4abfeb2b0176b1936d8850f9ef
Binary files /dev/null and b/modules/__pycache__/memmon.cpython-310.pyc differ
diff --git a/modules/__pycache__/modelloader.cpython-310.pyc b/modules/__pycache__/modelloader.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3789998bdc67347d200c6a74ee43e61b02a5da41
Binary files /dev/null and b/modules/__pycache__/modelloader.cpython-310.pyc differ
diff --git a/modules/__pycache__/paths.cpython-310.pyc b/modules/__pycache__/paths.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75299b95b6e2ef8850707934298a38665db48310
Binary files /dev/null and b/modules/__pycache__/paths.cpython-310.pyc differ
diff --git a/modules/__pycache__/postprocessing.cpython-310.pyc b/modules/__pycache__/postprocessing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29925fad8e0cab134b8153941b69688582fccbc0
Binary files /dev/null and b/modules/__pycache__/postprocessing.cpython-310.pyc differ
diff --git a/modules/__pycache__/processing.cpython-310.pyc b/modules/__pycache__/processing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07fc7e0eefc5ceefca40d304b03e9f971e631a4b
Binary files /dev/null and b/modules/__pycache__/processing.cpython-310.pyc differ
diff --git a/modules/__pycache__/progress.cpython-310.pyc b/modules/__pycache__/progress.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a16bd37a32e0dbe92a4908440167872b9a0c38a3
Binary files /dev/null and b/modules/__pycache__/progress.cpython-310.pyc differ
diff --git a/modules/__pycache__/prompt_parser.cpython-310.pyc b/modules/__pycache__/prompt_parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e47bce72155967fff015c175f05b4a7d9f8afbc
Binary files /dev/null and b/modules/__pycache__/prompt_parser.cpython-310.pyc differ
diff --git a/modules/__pycache__/realesrgan_model.cpython-310.pyc b/modules/__pycache__/realesrgan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3285eefc1c210486403de739dde1435b0871d370
Binary files /dev/null and b/modules/__pycache__/realesrgan_model.cpython-310.pyc differ
diff --git a/modules/__pycache__/safe.cpython-310.pyc b/modules/__pycache__/safe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83a61d5cb00746a9ad15413ee8036053d4d8806c
Binary files /dev/null and b/modules/__pycache__/safe.cpython-310.pyc differ
diff --git a/modules/__pycache__/script_callbacks.cpython-310.pyc b/modules/__pycache__/script_callbacks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2617cc7471d8a13266dbdfe54af443b6a50fd23b
Binary files /dev/null and b/modules/__pycache__/script_callbacks.cpython-310.pyc differ
diff --git a/modules/__pycache__/script_loading.cpython-310.pyc b/modules/__pycache__/script_loading.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7407b29e3d95f1ae2875a6e488440e5273524ea1
Binary files /dev/null and b/modules/__pycache__/script_loading.cpython-310.pyc differ
diff --git a/modules/__pycache__/scripts.cpython-310.pyc b/modules/__pycache__/scripts.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b0c951846470305b4822a29ad4f798b69671db8
Binary files /dev/null and b/modules/__pycache__/scripts.cpython-310.pyc differ
diff --git a/modules/__pycache__/scripts_auto_postprocessing.cpython-310.pyc b/modules/__pycache__/scripts_auto_postprocessing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61d267908149cfeab9a3c0d47b62c8ffadb3b841
Binary files /dev/null and b/modules/__pycache__/scripts_auto_postprocessing.cpython-310.pyc differ
diff --git a/modules/__pycache__/scripts_postprocessing.cpython-310.pyc b/modules/__pycache__/scripts_postprocessing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6d4c3522cecfefa90c59957d13c4855c6f9be33
Binary files /dev/null and b/modules/__pycache__/scripts_postprocessing.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_disable_initialization.cpython-310.pyc b/modules/__pycache__/sd_disable_initialization.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09fa0fe41d6bd513d6f4a34bf6d52e93a881bb60
Binary files /dev/null and b/modules/__pycache__/sd_disable_initialization.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack.cpython-310.pyc b/modules/__pycache__/sd_hijack.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9527a6041b7198ab885fef3520a1a38c9911d64c
Binary files /dev/null and b/modules/__pycache__/sd_hijack.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_checkpoint.cpython-310.pyc b/modules/__pycache__/sd_hijack_checkpoint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a4337801a7450f5e5eee8e292f951bd19b2995a
Binary files /dev/null and b/modules/__pycache__/sd_hijack_checkpoint.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_clip.cpython-310.pyc b/modules/__pycache__/sd_hijack_clip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07cf240ed22e89f8add0e4baa37d2fe953601c96
Binary files /dev/null and b/modules/__pycache__/sd_hijack_clip.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_inpainting.cpython-310.pyc b/modules/__pycache__/sd_hijack_inpainting.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57e5e57b55d793174a8f41f1abcbe15cd8929104
Binary files /dev/null and b/modules/__pycache__/sd_hijack_inpainting.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_open_clip.cpython-310.pyc b/modules/__pycache__/sd_hijack_open_clip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01a4d0999f9e64cbef003ecbd7743b975f906b66
Binary files /dev/null and b/modules/__pycache__/sd_hijack_open_clip.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_optimizations.cpython-310.pyc b/modules/__pycache__/sd_hijack_optimizations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b3778fd8f2e69660c06631ca20bf327cef0d4ec
Binary files /dev/null and b/modules/__pycache__/sd_hijack_optimizations.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_unet.cpython-310.pyc b/modules/__pycache__/sd_hijack_unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecdde8bf87a6cf4a54992c2f271b5e16a460285b
Binary files /dev/null and b/modules/__pycache__/sd_hijack_unet.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_utils.cpython-310.pyc b/modules/__pycache__/sd_hijack_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dfa34e3f884602cb55258c8a07258d81ebbb615
Binary files /dev/null and b/modules/__pycache__/sd_hijack_utils.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_hijack_xlmr.cpython-310.pyc b/modules/__pycache__/sd_hijack_xlmr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..926fa9a3e4e79c8468c32f6cbf0b2e29c6ed751c
Binary files /dev/null and b/modules/__pycache__/sd_hijack_xlmr.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_models.cpython-310.pyc b/modules/__pycache__/sd_models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6c1c6fbfc0c67552d4687776703531c75e5a5e4
Binary files /dev/null and b/modules/__pycache__/sd_models.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_models_config.cpython-310.pyc b/modules/__pycache__/sd_models_config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3bc73f30aca04f42ca2bcfd4ded294ff019e026
Binary files /dev/null and b/modules/__pycache__/sd_models_config.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_samplers.cpython-310.pyc b/modules/__pycache__/sd_samplers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fe0eb1d24c20d639e1fdf5320e136b20b5351d9
Binary files /dev/null and b/modules/__pycache__/sd_samplers.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_samplers_common.cpython-310.pyc b/modules/__pycache__/sd_samplers_common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c70f557d47ffb4ec7b2dff1a727b35b64b44c2a
Binary files /dev/null and b/modules/__pycache__/sd_samplers_common.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_samplers_compvis.cpython-310.pyc b/modules/__pycache__/sd_samplers_compvis.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e412e80eb9337e89470495f4495ec78927c217db
Binary files /dev/null and b/modules/__pycache__/sd_samplers_compvis.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_samplers_kdiffusion.cpython-310.pyc b/modules/__pycache__/sd_samplers_kdiffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5167d47836f94dba2798b18dcd76e9a040bc73f8
Binary files /dev/null and b/modules/__pycache__/sd_samplers_kdiffusion.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_vae.cpython-310.pyc b/modules/__pycache__/sd_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90eee90824e2020230b76d67bec78b3aca25a148
Binary files /dev/null and b/modules/__pycache__/sd_vae.cpython-310.pyc differ
diff --git a/modules/__pycache__/sd_vae_approx.cpython-310.pyc b/modules/__pycache__/sd_vae_approx.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce5acce701d2012c0d53f4a57c43bd64795e18c2
Binary files /dev/null and b/modules/__pycache__/sd_vae_approx.cpython-310.pyc differ
diff --git a/modules/__pycache__/shared.cpython-310.pyc b/modules/__pycache__/shared.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69643de0a0aa71843a7d17e3ece6c342279cf2f2
Binary files /dev/null and b/modules/__pycache__/shared.cpython-310.pyc differ
diff --git a/modules/__pycache__/shared_items.cpython-310.pyc b/modules/__pycache__/shared_items.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4765da45eb0877aeb1f3a63da77237962bdaa73a
Binary files /dev/null and b/modules/__pycache__/shared_items.cpython-310.pyc differ
diff --git a/modules/__pycache__/styles.cpython-310.pyc b/modules/__pycache__/styles.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f0400b2bebe0d88554ee65300f1c8397e1ab6bc
Binary files /dev/null and b/modules/__pycache__/styles.cpython-310.pyc differ
diff --git a/modules/__pycache__/sub_quadratic_attention.cpython-310.pyc b/modules/__pycache__/sub_quadratic_attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..853aba8ead147920fb972c16c533d7317a5cfc2a
Binary files /dev/null and b/modules/__pycache__/sub_quadratic_attention.cpython-310.pyc differ
diff --git a/modules/__pycache__/timer.cpython-310.pyc b/modules/__pycache__/timer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20b06bd1d475993f167955907d859218cc2a2d43
Binary files /dev/null and b/modules/__pycache__/timer.cpython-310.pyc differ
diff --git a/modules/__pycache__/txt2img.cpython-310.pyc b/modules/__pycache__/txt2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f180816bf8033178e766f4f229e4d47b5219508
Binary files /dev/null and b/modules/__pycache__/txt2img.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui.cpython-310.pyc b/modules/__pycache__/ui.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb13882f3ebb3036a286ebd0154f001cd11839fc
Binary files /dev/null and b/modules/__pycache__/ui.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_common.cpython-310.pyc b/modules/__pycache__/ui_common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68cebc39b404007e8879e9a51e48979ade90731e
Binary files /dev/null and b/modules/__pycache__/ui_common.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_components.cpython-310.pyc b/modules/__pycache__/ui_components.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6dae2435487e442a180b21ef7a98c13c6e804814
Binary files /dev/null and b/modules/__pycache__/ui_components.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_extensions.cpython-310.pyc b/modules/__pycache__/ui_extensions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37fa9cca577553c963b89b1710c8e256bfd01f02
Binary files /dev/null and b/modules/__pycache__/ui_extensions.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_extra_networks.cpython-310.pyc b/modules/__pycache__/ui_extra_networks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd097a20163f3fadfd9b9d6cc8dd04b5991d13d1
Binary files /dev/null and b/modules/__pycache__/ui_extra_networks.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_extra_networks_checkpoints.cpython-310.pyc b/modules/__pycache__/ui_extra_networks_checkpoints.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfe42048db4bd9ed7fbb80cbd3da8c0831da0c91
Binary files /dev/null and b/modules/__pycache__/ui_extra_networks_checkpoints.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_extra_networks_hypernets.cpython-310.pyc b/modules/__pycache__/ui_extra_networks_hypernets.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..faece134254e23ea820962fddf2b4cdcfb28fbb4
Binary files /dev/null and b/modules/__pycache__/ui_extra_networks_hypernets.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_extra_networks_textual_inversion.cpython-310.pyc b/modules/__pycache__/ui_extra_networks_textual_inversion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81c17121ef3f5be7b5bb04acc0702698cb8d205c
Binary files /dev/null and b/modules/__pycache__/ui_extra_networks_textual_inversion.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_postprocessing.cpython-310.pyc b/modules/__pycache__/ui_postprocessing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62eacb83f4654d2f3df06dcd7c2b55c45db03e62
Binary files /dev/null and b/modules/__pycache__/ui_postprocessing.cpython-310.pyc differ
diff --git a/modules/__pycache__/ui_tempdir.cpython-310.pyc b/modules/__pycache__/ui_tempdir.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..820a6a124b357b8a938ec052d8e3c649f3f4755d
Binary files /dev/null and b/modules/__pycache__/ui_tempdir.cpython-310.pyc differ
diff --git a/modules/__pycache__/upscaler.cpython-310.pyc b/modules/__pycache__/upscaler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d902911ab1364ca9e34388ed66b02cc474e593e
Binary files /dev/null and b/modules/__pycache__/upscaler.cpython-310.pyc differ
diff --git a/modules/__pycache__/xlmr.cpython-310.pyc b/modules/__pycache__/xlmr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5148f22aa4800ed9ed212ef46e418505763e30b5
Binary files /dev/null and b/modules/__pycache__/xlmr.cpython-310.pyc differ
diff --git a/modules/api/__pycache__/api.cpython-310.pyc b/modules/api/__pycache__/api.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33ad49a1205ac29e8e1a679615dc29d593fd2365
Binary files /dev/null and b/modules/api/__pycache__/api.cpython-310.pyc differ
diff --git a/modules/api/__pycache__/models.cpython-310.pyc b/modules/api/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..815b58fafd38f42a6c662c8d6ae8d82d3c706c64
Binary files /dev/null and b/modules/api/__pycache__/models.cpython-310.pyc differ
diff --git a/modules/api/api.py b/modules/api/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..35e17afc9dafec5ad369db3f76bcf5ddaf258acd
--- /dev/null
+++ b/modules/api/api.py
@@ -0,0 +1,615 @@
+import base64
+import io
+import time
+import datetime
+import uvicorn
+from threading import Lock
+from io import BytesIO
+from gradio.processing_utils import decode_base64_to_file
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from secrets import compare_digest
+
+import modules.shared as shared
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
+from modules.api.models import *
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
+from modules.textual_inversion.preprocess import preprocess
+from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
+from PIL import PngImagePlugin,Image
+from modules.sd_models import checkpoints_list
+from modules.sd_models_config import find_checkpoint_config_near_filename
+from modules.realesrgan_model import get_realesrgan_models
+from modules import devices
+from typing import List
+import piexif
+import piexif.helper
+
+def upscaler_to_index(name: str):
+ try:
+ return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
+ except:
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
+
+def script_name_to_index(name, scripts):
+ try:
+ return [script.title().lower() for script in scripts].index(name.lower())
+ except:
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
+
+def validate_sampler_name(name):
+ config = sd_samplers.all_samplers_map.get(name, None)
+ if config is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
+ return name
+
+def setUpscalers(req: dict):
+ reqDict = vars(req)
+ reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
+ reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
+ return reqDict
+
+def decode_base64_to_image(encoding):
+ if encoding.startswith("data:image/"):
+ encoding = encoding.split(";")[1].split(",")[1]
+ try:
+ image = Image.open(BytesIO(base64.b64decode(encoding)))
+ return image
+ except Exception as err:
+ raise HTTPException(status_code=500, detail="Invalid encoded image")
+
+def encode_pil_to_base64(image):
+ with io.BytesIO() as output_bytes:
+
+ if opts.samples_format.lower() == 'png':
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+ image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
+
+ elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+ parameters = image.info.get('parameters', None)
+ exif_bytes = piexif.dump({
+ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
+ })
+ if opts.samples_format.lower() in ("jpg", "jpeg"):
+ image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
+ else:
+ image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
+
+ else:
+ raise HTTPException(status_code=500, detail="Invalid image format")
+
+ bytes_data = output_bytes.getvalue()
+
+ return base64.b64encode(bytes_data)
+
+def api_middleware(app: FastAPI):
+ @app.middleware("http")
+ async def log_and_time(req: Request, call_next):
+ ts = time.time()
+ res: Response = await call_next(req)
+ duration = str(round(time.time() - ts, 4))
+ res.headers["X-Process-Time"] = duration
+ endpoint = req.scope.get('path', 'err')
+ if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
+ print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
+ t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code = res.status_code,
+ ver = req.scope.get('http_version', '0.0'),
+ cli = req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot = req.scope.get('scheme', 'err'),
+ method = req.scope.get('method', 'err'),
+ endpoint = endpoint,
+ duration = duration,
+ ))
+ return res
+
+
+class Api:
+ def __init__(self, app: FastAPI, queue_lock: Lock):
+ if shared.cmd_opts.api_auth:
+ self.credentials = dict()
+ for auth in shared.cmd_opts.api_auth.split(","):
+ user, password = auth.split(":")
+ self.credentials[user] = password
+
+ self.router = APIRouter()
+ self.app = app
+ self.queue_lock = queue_lock
+ api_middleware(self.app)
+ self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
+ self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
+ self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
+ self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
+ self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
+ self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+ self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
+ self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
+ self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
+ self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
+ self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
+
+ def add_api_route(self, path: str, endpoint, **kwargs):
+ if shared.cmd_opts.api_auth:
+ return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
+ return self.app.add_api_route(path, endpoint, **kwargs)
+
+ def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
+ if credentials.username in self.credentials:
+ if compare_digest(credentials.password, self.credentials[credentials.username]):
+ return True
+
+ raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+
+ def get_selectable_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
+ return None, None
+
+ script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
+ script = script_runner.selectable_scripts[script_idx]
+ return script, script_idx
+
+ def get_scripts_list(self):
+ t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
+ i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
+
+ return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
+
+ def get_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
+ return None, None
+
+ script_idx = script_name_to_index(script_name, script_runner.scripts)
+ return script_runner.scripts[script_idx]
+
+ def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
+ #find max idx from the scripts in runner and generate a none array to init script_args
+ last_arg_index = 1
+ for script in script_runner.scripts:
+ if last_arg_index < script.args_to:
+ last_arg_index = script.args_to
+ # None everywhere except position 0 to initialize script args
+ script_args = [None]*last_arg_index
+ # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
+ if selectable_scripts:
+ script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
+ script_args[0] = selectable_idx + 1
+ else:
+ # when [0] = 0 no selectable script to run
+ script_args[0] = 0
+
+ # Now check for always on scripts
+ if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
+ for alwayson_script_name in request.alwayson_scripts.keys():
+ alwayson_script = self.get_script(alwayson_script_name, script_runner)
+ if alwayson_script == None:
+ raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
+ # Selectable script in always on script param check
+ if alwayson_script.alwayson == False:
+ raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
+ # always on script with no arg should always run so you don't really need to add them to the requests
+ if "args" in request.alwayson_scripts[alwayson_script_name]:
+ script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
+ return script_args
+
+ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ script_runner = scripts.scripts_txt2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
+
+ populate = txt2imgreq.copy(update={ # Override __init__ params
+ "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
+ "do_not_save_samples": not txt2imgreq.save_images,
+ "do_not_save_grid": not txt2imgreq.save_images,
+ })
+ if populate.sampler_name:
+ populate.sampler_index = None # prevent a warning later on
+
+ args = vars(populate)
+ args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
+
+ with self.queue_lock:
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+
+ shared.state.begin()
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
+
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
+
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
+
+ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ init_images = img2imgreq.init_images
+ if init_images is None:
+ raise HTTPException(status_code=404, detail="Init image not found")
+
+ mask = img2imgreq.mask
+ if mask:
+ mask = decode_base64_to_image(mask)
+
+ script_runner = scripts.scripts_img2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(True)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
+ "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
+ "do_not_save_samples": not img2imgreq.save_images,
+ "do_not_save_grid": not img2imgreq.save_images,
+ "mask": mask,
+ })
+ if populate.sampler_name:
+ populate.sampler_index = None # prevent a warning later on
+
+ args = vars(populate)
+ args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+ args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
+
+ with self.queue_lock:
+ p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+
+ shared.state.begin()
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
+
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
+
+ if not img2imgreq.include_init_images:
+ img2imgreq.init_images = None
+ img2imgreq.mask = None
+
+ return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+
+ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+ reqDict = setUpscalers(req)
+
+ reqDict['image'] = decode_base64_to_image(reqDict['image'])
+
+ with self.queue_lock:
+ result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
+
+ return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+
+ def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+ reqDict = setUpscalers(req)
+
+ def prepareFiles(file):
+ file = decode_base64_to_file(file.data, file_path=file.name)
+ file.orig_name = file.name
+ return file
+
+ reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
+ reqDict.pop('imageList')
+
+ with self.queue_lock:
+ result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
+
+ def pnginfoapi(self, req: PNGInfoRequest):
+ if(not req.image.strip()):
+ return PNGInfoResponse(info="")
+
+ image = decode_base64_to_image(req.image.strip())
+ if image is None:
+ return PNGInfoResponse(info="")
+
+ geninfo, items = images.read_info_from_image(image)
+ if geninfo is None:
+ geninfo = ""
+
+ items = {**{'parameters': geninfo}, **items}
+
+ return PNGInfoResponse(info=geninfo, items=items)
+
+ def progressapi(self, req: ProgressRequest = Depends()):
+ # copy from check_progress_call of ui.py
+
+ if shared.state.job_count == 0:
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
+
+ # avoid dividing zero
+ progress = 0.01
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ eta_relative = eta-time_since_start
+
+ progress = min(progress, 1)
+
+ shared.state.set_current_image()
+
+ current_image = None
+ if shared.state.current_image and not req.skip_current_image:
+ current_image = encode_pil_to_base64(shared.state.current_image)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
+
+ def interrogateapi(self, interrogatereq: InterrogateRequest):
+ image_b64 = interrogatereq.image
+ if image_b64 is None:
+ raise HTTPException(status_code=404, detail="Image not found")
+
+ img = decode_base64_to_image(image_b64)
+ img = img.convert('RGB')
+
+ # Override object param
+ with self.queue_lock:
+ if interrogatereq.model == "clip":
+ processed = shared.interrogator.interrogate(img)
+ elif interrogatereq.model == "deepdanbooru":
+ processed = deepbooru.model.tag(img)
+ else:
+ raise HTTPException(status_code=404, detail="Model not found")
+
+ return InterrogateResponse(caption=processed)
+
+ def interruptapi(self):
+ shared.state.interrupt()
+
+ return {}
+
+ def skip(self):
+ shared.state.skip()
+
+ def get_config(self):
+ options = {}
+ for key in shared.opts.data.keys():
+ metadata = shared.opts.data_labels.get(key)
+ if(metadata is not None):
+ options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
+ else:
+ options.update({key: shared.opts.data.get(key, None)})
+
+ return options
+
+ def set_config(self, req: Dict[str, Any]):
+ for k, v in req.items():
+ shared.opts.set(k, v)
+
+ shared.opts.save(shared.config_filename)
+ return
+
+ def get_cmd_flags(self):
+ return vars(shared.cmd_opts)
+
+ def get_samplers(self):
+ return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
+
+ def get_upscalers(self):
+ return [
+ {
+ "name": upscaler.name,
+ "model_name": upscaler.scaler.model_name,
+ "model_path": upscaler.data_path,
+ "model_url": None,
+ "scale": upscaler.scale,
+ }
+ for upscaler in shared.sd_upscalers
+ ]
+
+ def get_sd_models(self):
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+
+ def get_hypernetworks(self):
+ return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
+
+ def get_face_restorers(self):
+ return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
+
+ def get_realesrgan_models(self):
+ return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
+
+ def get_prompt_styles(self):
+ styleList = []
+ for k in shared.prompt_styles.styles:
+ style = shared.prompt_styles.styles[k]
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
+
+ return styleList
+
+ def get_embeddings(self):
+ db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
+ return {
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
+ }
+
+ def refresh_checkpoints(self):
+ shared.refresh_checkpoints()
+
+ def create_embedding(self, args: dict):
+ try:
+ shared.state.begin()
+ filename = create_embedding(**args) # create empty embedding
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
+ shared.state.end()
+ return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
+ except AssertionError as e:
+ shared.state.end()
+ return TrainResponse(info = "create embedding error: {error}".format(error = e))
+
+ def create_hypernetwork(self, args: dict):
+ try:
+ shared.state.begin()
+ filename = create_hypernetwork(**args) # create empty embedding
+ shared.state.end()
+ return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
+ except AssertionError as e:
+ shared.state.end()
+ return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
+
+ def preprocess(self, args: dict):
+ try:
+ shared.state.begin()
+ preprocess(**args) # quick operation unless blip/booru interrogation is enabled
+ shared.state.end()
+ return PreprocessResponse(info = 'preprocess complete')
+ except KeyError as e:
+ shared.state.end()
+ return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
+ except AssertionError as e:
+ shared.state.end()
+ return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
+ except FileNotFoundError as e:
+ shared.state.end()
+ return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
+
+ def train_embedding(self, args: dict):
+ try:
+ shared.state.begin()
+ apply_optimizations = shared.opts.training_xattention_optimizations
+ error = None
+ filename = ''
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+ try:
+ embedding, filename = train_embedding(**args) # can take a long time to complete
+ except Exception as e:
+ error = e
+ finally:
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
+ shared.state.end()
+ return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ except AssertionError as msg:
+ shared.state.end()
+ return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
+
+ def train_hypernetwork(self, args: dict):
+ try:
+ shared.state.begin()
+ shared.loaded_hypernetworks = []
+ apply_optimizations = shared.opts.training_xattention_optimizations
+ error = None
+ filename = ''
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+ try:
+ hypernetwork, filename = train_hypernetwork(**args)
+ except Exception as e:
+ error = e
+ finally:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
+ shared.state.end()
+ return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
+ except AssertionError as msg:
+ shared.state.end()
+ return TrainResponse(info="train embedding error: {error}".format(error=error))
+
+ def get_memory(self):
+ try:
+ import os, psutil
+ process = psutil.Process(os.getpid())
+ res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
+ ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
+ ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
+ except Exception as err:
+ ram = { 'error': f'{err}' }
+ try:
+ import torch
+ if torch.cuda.is_available():
+ s = torch.cuda.mem_get_info()
+ system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
+ s = dict(torch.cuda.memory_stats(shared.device))
+ allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
+ reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
+ active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
+ inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
+ warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
+ cuda = {
+ 'system': system,
+ 'active': active,
+ 'allocated': allocated,
+ 'reserved': reserved,
+ 'inactive': inactive,
+ 'events': warnings,
+ }
+ else:
+ cuda = { 'error': 'unavailable' }
+ except Exception as err:
+ cuda = { 'error': f'{err}' }
+ return MemoryResponse(ram = ram, cuda = cuda)
+
+ def launch(self, server_name, port):
+ self.app.include_router(self.router)
+ uvicorn.run(self.app, host=server_name, port=port)
diff --git a/modules/api/models.py b/modules/api/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a70f440c8f07541e6838ed186066efbe49eff9f
--- /dev/null
+++ b/modules/api/models.py
@@ -0,0 +1,291 @@
+import inspect
+from pydantic import BaseModel, Field, create_model
+from typing import Any, Optional
+from typing_extensions import Literal
+from inflection import underscore
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
+from modules.shared import sd_upscalers, opts, parser
+from typing import Dict, List
+
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ # "do_not_save_samples",
+ # "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
+class ModelDef(BaseModel):
+ """Assistance Class for Pydantic Dynamic Model Generation"""
+
+ field: str
+ field_alias: str
+ field_type: Any
+ field_value: Any
+ field_exclude: bool = False
+
+
+class PydanticModelGenerator:
+ """
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+ source_data is a snapshot of the default values produced by the class
+ params are the names of the actual keys required by __init__
+ """
+
+ def __init__(
+ self,
+ model_name: str = None,
+ class_instance = None,
+ additional_fields = None,
+ ):
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
+
+ return Optional[field_type]
+
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
+ self._model_name = model_name
+ self._class_data = merge_class_params(class_instance)
+
+ self._model_def = [
+ ModelDef(
+ field=underscore(k),
+ field_alias=k,
+ field_type=field_type_generator(k, v),
+ field_value=v.default
+ )
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+ ]
+
+ for fields in additional_fields:
+ self._model_def.append(ModelDef(
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
+ field_type=fields["type"],
+ field_value=fields["default"],
+ field_exclude=fields["exclude"] if "exclude" in fields else False))
+
+ def generate_model(self):
+ """
+ Creates a pydantic BaseModel
+ from the json and overrides provided at initialization
+ """
+ fields = {
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
+ }
+ DynamicModel = create_model(self._model_name, **fields)
+ DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
+ return DynamicModel
+
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingTxt2Img",
+ StableDiffusionProcessingTxt2Img,
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "init_images", "type": list, "default": None},
+ {"key": "denoising_strength", "type": float, "default": 0.75},
+ {"key": "mask", "type": str, "default": None},
+ {"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
+).generate_model()
+
+class TextToImageResponse(BaseModel):
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
+ info: str
+
+class ImageToImageResponse(BaseModel):
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
+ info: str
+
+class ExtrasBaseRequest(BaseModel):
+ resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
+ show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
+ gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
+ codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
+ codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
+ upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
+ upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+ upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
+
+class ExtraBaseResponse(BaseModel):
+ html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
+
+class ExtrasSingleImageRequest(ExtrasBaseRequest):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+
+class ExtrasSingleImageResponse(ExtraBaseResponse):
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
+
+class FileData(BaseModel):
+ data: str = Field(title="File data", description="Base64 representation of the file")
+ name: str = Field(title="File name")
+
+class ExtrasBatchImagesRequest(ExtrasBaseRequest):
+ imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+
+class ExtrasBatchImagesResponse(ExtraBaseResponse):
+ images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class PNGInfoRequest(BaseModel):
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
+
+class PNGInfoResponse(BaseModel):
+ info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
+ items: dict = Field(title="Items", description="An object containing all the info the image had")
+
+class ProgressRequest(BaseModel):
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
+
+class ProgressResponse(BaseModel):
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
+ eta_relative: float = Field(title="ETA in secs")
+ state: dict = Field(title="State", description="The current state snapshot")
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+
+class InterrogateRequest(BaseModel):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+ model: str = Field(default="clip", title="Model", description="The interrogate model used.")
+
+class InterrogateResponse(BaseModel):
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
+
+class TrainResponse(BaseModel):
+ info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
+
+class CreateResponse(BaseModel):
+ info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
+
+class PreprocessResponse(BaseModel):
+ info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
+
+fields = {}
+for key, metadata in opts.data_labels.items():
+ value = opts.data.get(key)
+ optType = opts.typemap.get(type(metadata.default), type(value))
+
+ if (metadata is not None):
+ fields.update({key: (Optional[optType], Field(
+ default=metadata.default ,description=metadata.label))})
+ else:
+ fields.update({key: (Optional[optType], Field())})
+
+OptionsModel = create_model("Options", **fields)
+
+flags = {}
+_options = vars(parser)['_option_string_actions']
+for key in _options:
+ if(_options[key].dest != 'help'):
+ flag = _options[key]
+ _type = str
+ if _options[key].default is not None: _type = type(_options[key].default)
+ flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+
+FlagsModel = create_model("Flags", **flags)
+
+class SamplerItem(BaseModel):
+ name: str = Field(title="Name")
+ aliases: List[str] = Field(title="Aliases")
+ options: Dict[str, str] = Field(title="Options")
+
+class UpscalerItem(BaseModel):
+ name: str = Field(title="Name")
+ model_name: Optional[str] = Field(title="Model Name")
+ model_path: Optional[str] = Field(title="Path")
+ model_url: Optional[str] = Field(title="URL")
+ scale: Optional[float] = Field(title="Scale")
+
+class SDModelItem(BaseModel):
+ title: str = Field(title="Title")
+ model_name: str = Field(title="Model Name")
+ hash: Optional[str] = Field(title="Short hash")
+ sha256: Optional[str] = Field(title="sha256 hash")
+ filename: str = Field(title="Filename")
+ config: Optional[str] = Field(title="Config file")
+
+class HypernetworkItem(BaseModel):
+ name: str = Field(title="Name")
+ path: Optional[str] = Field(title="Path")
+
+class FaceRestorerItem(BaseModel):
+ name: str = Field(title="Name")
+ cmd_dir: Optional[str] = Field(title="Path")
+
+class RealesrganItem(BaseModel):
+ name: str = Field(title="Name")
+ path: Optional[str] = Field(title="Path")
+ scale: Optional[int] = Field(title="Scale")
+
+class PromptStyleItem(BaseModel):
+ name: str = Field(title="Name")
+ prompt: Optional[str] = Field(title="Prompt")
+ negative_prompt: Optional[str] = Field(title="Negative Prompt")
+
+class ArtistItem(BaseModel):
+ name: str = Field(title="Name")
+ score: float = Field(title="Score")
+ category: str = Field(title="Category")
+
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
+class EmbeddingsResponse(BaseModel):
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+
+class MemoryResponse(BaseModel):
+ ram: dict = Field(title="RAM", description="System memory stats")
+ cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
+
+class ScriptsList(BaseModel):
+ txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
+ img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
\ No newline at end of file
diff --git a/modules/call_queue.py b/modules/call_queue.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eabc06441b531f89e55f805096af97ede4547ee
--- /dev/null
+++ b/modules/call_queue.py
@@ -0,0 +1,109 @@
+import html
+import sys
+import threading
+import traceback
+import time
+
+from modules import shared, progress
+
+queue_lock = threading.Lock()
+
+
+def wrap_queued_call(func):
+ def f(*args, **kwargs):
+ with queue_lock:
+ res = func(*args, **kwargs)
+
+ return res
+
+ return f
+
+
+def wrap_gradio_gpu_call(func, extra_outputs=None):
+ def f(*args, **kwargs):
+
+ # if the first argument is a string that says "task(...)", it is treated as a job id
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ id_task = args[0]
+ progress.add_task_to_queue(id_task)
+ else:
+ id_task = None
+
+ with queue_lock:
+ shared.state.begin()
+ progress.start_task(id_task)
+
+ try:
+ res = func(*args, **kwargs)
+ finally:
+ progress.finish_task(id_task)
+
+ shared.state.end()
+
+ return res
+
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
+
+
+def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
+ run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
+ if run_memmon:
+ shared.mem_mon.monitor()
+ t = time.perf_counter()
+
+ try:
+ res = list(func(*args, **kwargs))
+ except Exception as e:
+ # When printing out our debug argument list, do not print out more than a MB of text
+ max_debug_str_len = 131072 # (1024*1024)/8
+
+ print("Error completing request", file=sys.stderr)
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
+ print(argStr[:max_debug_str_len], file=sys.stderr)
+ if len(argStr) > max_debug_str_len:
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
+
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.state.job = ""
+ shared.state.job_count = 0
+
+ if extra_outputs_array is None:
+ extra_outputs_array = [None, '']
+
+ res = extra_outputs_array + [f"
{html.escape(type(e).__name__+': '+str(e))}
"]
+
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.job_count = 0
+
+ if not add_stats:
+ return tuple(res)
+
+ elapsed = time.perf_counter() - t
+ elapsed_m = int(elapsed // 60)
+ elapsed_s = elapsed % 60
+ elapsed_text = f"{elapsed_s:.2f}s"
+ if elapsed_m > 0:
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
+
+ if run_memmon:
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
+ active_peak = mem_stats['active_peak']
+ reserved_peak = mem_stats['reserved_peak']
+ sys_peak = mem_stats['system_peak']
+ sys_total = mem_stats['total']
+ sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+
+ vram_html = f"Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)
"
+ else:
+ vram_html = ''
+
+ # last item is always HTML
+ res[-1] += f""
+
+ return tuple(res)
+
+ return f
+
diff --git a/modules/codeformer/__pycache__/codeformer_arch.cpython-310.pyc b/modules/codeformer/__pycache__/codeformer_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b8546163d59f83122a2a0a88dcf115395c1b6c4
Binary files /dev/null and b/modules/codeformer/__pycache__/codeformer_arch.cpython-310.pyc differ
diff --git a/modules/codeformer/__pycache__/vqgan_arch.cpython-310.pyc b/modules/codeformer/__pycache__/vqgan_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4424feef0c2881e6920a84da9e4522f2bf14d054
Binary files /dev/null and b/modules/codeformer/__pycache__/vqgan_arch.cpython-310.pyc differ
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..11dcc3ee76511218c64977c2ecbb306cecd892c3
--- /dev/null
+++ b/modules/codeformer/codeformer_arch.py
@@ -0,0 +1,278 @@
+# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
+
+import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+
+from modules.codeformer.vqgan_arch import *
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ codebook_size=1024, latent_size=256,
+ connect_list=['32', '64', '128', '256'],
+ fix_modules=['quantize','generator']):
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+ self.connect_list = connect_list
+ self.n_layers = n_layers
+ self.dim_embd = dim_embd
+ self.dim_mlp = dim_embd*2
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+ self.feat_emb = nn.Linear(256, self.dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ for _ in range(self.n_layers)])
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(
+ nn.LayerNorm(dim_embd),
+ nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ '16': 512,
+ '32': 256,
+ '64': 256,
+ '128': 128,
+ '256': 128,
+ '512': 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+ # fuse_convs_dict
+ self.fuse_convs_dict = nn.ModuleDict()
+ for f_size in self.connect_list:
+ in_ch = self.channels[f_size]
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+ lq_feat = x
+ # ################# Transformer ###################
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+ if code_only: # for training stage II
+ # logits doesn't need softmax before cross_entropy loss
+ return logits, lq_feat
+
+ # ################# Quantization ###################
+ # if self.training:
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+ # # b(hw)c -> bc(hw) -> bchw
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+ # ------------
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+ # preserve gradients
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w>0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+ out = x
+ # logits doesn't need softmax before cross_entropy loss
+ return out, logits, lq_feat
\ No newline at end of file
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e729368383aa2d8c224289284ec5489d554f9a33
--- /dev/null
+++ b/modules/codeformer/vqgan_arch.py
@@ -0,0 +1,437 @@
+# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
+
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x*torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "min_encoding_scores": min_encoding_scores,
+ "mean_distance": mean_distance
+ }
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1,1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return z_q, diff, {
+ "min_encoding_indices": min_encoding_indices
+ }
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h*w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta #0.25
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_ema' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
+ else:
+ raise ValueError('Wrong params!')
+
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_d' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ else:
+ raise ValueError('Wrong params!')
+
+ def forward(self, x):
+ return self.main(x)
\ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b492d12373a0486d73af8cba2321940aedd1c69
--- /dev/null
+++ b/modules/codeformer_model.py
@@ -0,0 +1,143 @@
+import os
+import sys
+import traceback
+
+import cv2
+import torch
+
+import modules.face_restoration
+import modules.shared
+from modules import shared, devices, modelloader
+from modules.paths import models_path
+
+# codeformer people made a choice to include modified basicsr library to their project which makes
+# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
+# I am making a choice to include some files from codeformer to work around this issue.
+model_dir = "Codeformer"
+model_path = os.path.join(models_path, model_dir)
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+
+have_codeformer = False
+codeformer = None
+
+
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
+ path = modules.paths.paths.get("CodeFormer", None)
+ if path is None:
+ return
+
+ try:
+ from torchvision.transforms.functional import normalize
+ from modules.codeformer.codeformer_arch import CodeFormer
+ from basicsr.utils.download_util import load_file_from_url
+ from basicsr.utils import imwrite, img2tensor, tensor2img
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
+ from facelib.detection.retinaface import retinaface
+ from modules.shared import cmd_opts
+
+ net_class = CodeFormer
+
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
+ def name(self):
+ return "CodeFormer"
+
+ def __init__(self, dirname):
+ self.net = None
+ self.face_helper = None
+ self.cmd_dir = dirname
+
+ def create_models(self):
+
+ if self.net is not None and self.face_helper is not None:
+ self.net.to(devices.device_codeformer)
+ return self.net, self.face_helper
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
+ if len(model_paths) != 0:
+ ckpt_path = model_paths[0]
+ else:
+ print("Unable to load codeformer model.")
+ return None, None
+ net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
+ checkpoint = torch.load(ckpt_path)['params_ema']
+ net.load_state_dict(checkpoint)
+ net.eval()
+
+ if hasattr(retinaface, 'device'):
+ retinaface.device = devices.device_codeformer
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
+
+ self.net = net
+ self.face_helper = face_helper
+
+ return net, face_helper
+
+ def send_model_to(self, device):
+ self.net.to(device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
+ def restore(self, np_image, w=None):
+ np_image = np_image[:, :, ::-1]
+
+ original_resolution = np_image.shape[0:2]
+
+ self.create_models()
+ if self.net is None or self.face_helper is None:
+ return np_image
+
+ self.send_model_to(devices.device_codeformer)
+
+ self.face_helper.clean_all()
+ self.face_helper.read_image(np_image)
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ self.face_helper.align_warp_face()
+
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+
+ try:
+ with torch.no_grad():
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as error:
+ print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ restored_face = restored_face.astype('uint8')
+ self.face_helper.add_restored_face(restored_face)
+
+ self.face_helper.get_inverse_affine(None)
+
+ restored_img = self.face_helper.paste_faces_to_input_image()
+ restored_img = restored_img[:, :, ::-1]
+
+ if original_resolution != restored_img.shape[0:2]:
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+
+ self.face_helper.clean_all()
+
+ if shared.opts.face_restoration_unload:
+ self.send_model_to(devices.cpu)
+
+ return restored_img
+
+ global have_codeformer
+ have_codeformer = True
+
+ global codeformer
+ codeformer = FaceRestorerCodeFormer(dirname)
+ shared.face_restorers.append(codeformer)
+
+ except Exception:
+ print("Error setting up CodeFormer:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ # sys.path = stored_sys_path
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
new file mode 100644
index 0000000000000000000000000000000000000000..122fce7f569dbd28f9c6d83af874bb3efed34a5e
--- /dev/null
+++ b/modules/deepbooru.py
@@ -0,0 +1,99 @@
+import os
+import re
+
+import torch
+from PIL import Image
+import numpy as np
+
+from modules import modelloader, paths, deepbooru_model, devices, images, shared
+
+re_special = re.compile(r'([\\()])')
+
+
+class DeepDanbooru:
+ def __init__(self):
+ self.model = None
+
+ def load(self):
+ if self.model is not None:
+ return
+
+ files = modelloader.load_models(
+ model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
+ model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
+ ext_filter=[".pt"],
+ download_name='model-resnet_custom_v3.pt',
+ )
+
+ self.model = deepbooru_model.DeepDanbooruModel()
+ self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
+
+ self.model.eval()
+ self.model.to(devices.cpu, devices.dtype)
+
+ def start(self):
+ self.load()
+ self.model.to(devices.device)
+
+ def stop(self):
+ if not shared.opts.interrogate_keep_models_in_memory:
+ self.model.to(devices.cpu)
+ devices.torch_gc()
+
+ def tag(self, pil_image):
+ self.start()
+ res = self.tag_multi(pil_image)
+ self.stop()
+
+ return res
+
+ def tag_multi(self, pil_image, force_disable_ranks=False):
+ threshold = shared.opts.interrogate_deepbooru_score_threshold
+ use_spaces = shared.opts.deepbooru_use_spaces
+ use_escape = shared.opts.deepbooru_escape
+ alpha_sort = shared.opts.deepbooru_sort_alpha
+ include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
+
+ pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
+ a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
+
+ with torch.no_grad(), devices.autocast():
+ x = torch.from_numpy(a).to(devices.device)
+ y = self.model(x)[0].detach().cpu().numpy()
+
+ probability_dict = {}
+
+ for tag, probability in zip(self.model.tags, y):
+ if probability < threshold:
+ continue
+
+ if tag.startswith("rating:"):
+ continue
+
+ probability_dict[tag] = probability
+
+ if alpha_sort:
+ tags = sorted(probability_dict)
+ else:
+ tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
+
+ res = []
+
+ filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
+
+ for tag in [x for x in tags if x not in filtertags]:
+ probability = probability_dict[tag]
+ tag_outformat = tag
+ if use_spaces:
+ tag_outformat = tag_outformat.replace('_', ' ')
+ if use_escape:
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+ if include_ranks:
+ tag_outformat = f"({tag_outformat}:{probability:.3f})"
+
+ res.append(tag_outformat)
+
+ return ", ".join(res)
+
+
+model = DeepDanbooru()
diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a53884624e96284c35214ce02b8a2891d92c3e8
--- /dev/null
+++ b/modules/deepbooru_model.py
@@ -0,0 +1,678 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modules import devices
+
+# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
+
+
+class DeepDanbooruModel(nn.Module):
+ def __init__(self):
+ super(DeepDanbooruModel, self).__init__()
+
+ self.tags = []
+
+ self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
+ self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
+ self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
+ self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+ self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+ self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
+ self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
+ self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
+ self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
+ self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
+ self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+ self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+ self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
+ self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
+ self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
+ self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
+ self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+ self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+ self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+ self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+ self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
+ self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
+ self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
+ self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+ self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+ self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+ self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+ self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
+
+ def forward(self, *inputs):
+ t_358, = inputs
+ t_359 = t_358.permute(*[0, 3, 1, 2])
+ t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
+ t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
+ t_361 = F.relu(t_360)
+ t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
+ t_362 = self.n_MaxPool_0(t_361)
+ t_363 = self.n_Conv_1(t_362)
+ t_364 = self.n_Conv_2(t_362)
+ t_365 = F.relu(t_364)
+ t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
+ t_366 = self.n_Conv_3(t_365_padded)
+ t_367 = F.relu(t_366)
+ t_368 = self.n_Conv_4(t_367)
+ t_369 = torch.add(t_368, t_363)
+ t_370 = F.relu(t_369)
+ t_371 = self.n_Conv_5(t_370)
+ t_372 = F.relu(t_371)
+ t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
+ t_373 = self.n_Conv_6(t_372_padded)
+ t_374 = F.relu(t_373)
+ t_375 = self.n_Conv_7(t_374)
+ t_376 = torch.add(t_375, t_370)
+ t_377 = F.relu(t_376)
+ t_378 = self.n_Conv_8(t_377)
+ t_379 = F.relu(t_378)
+ t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
+ t_380 = self.n_Conv_9(t_379_padded)
+ t_381 = F.relu(t_380)
+ t_382 = self.n_Conv_10(t_381)
+ t_383 = torch.add(t_382, t_377)
+ t_384 = F.relu(t_383)
+ t_385 = self.n_Conv_11(t_384)
+ t_386 = self.n_Conv_12(t_384)
+ t_387 = F.relu(t_386)
+ t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
+ t_388 = self.n_Conv_13(t_387_padded)
+ t_389 = F.relu(t_388)
+ t_390 = self.n_Conv_14(t_389)
+ t_391 = torch.add(t_390, t_385)
+ t_392 = F.relu(t_391)
+ t_393 = self.n_Conv_15(t_392)
+ t_394 = F.relu(t_393)
+ t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
+ t_395 = self.n_Conv_16(t_394_padded)
+ t_396 = F.relu(t_395)
+ t_397 = self.n_Conv_17(t_396)
+ t_398 = torch.add(t_397, t_392)
+ t_399 = F.relu(t_398)
+ t_400 = self.n_Conv_18(t_399)
+ t_401 = F.relu(t_400)
+ t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
+ t_402 = self.n_Conv_19(t_401_padded)
+ t_403 = F.relu(t_402)
+ t_404 = self.n_Conv_20(t_403)
+ t_405 = torch.add(t_404, t_399)
+ t_406 = F.relu(t_405)
+ t_407 = self.n_Conv_21(t_406)
+ t_408 = F.relu(t_407)
+ t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
+ t_409 = self.n_Conv_22(t_408_padded)
+ t_410 = F.relu(t_409)
+ t_411 = self.n_Conv_23(t_410)
+ t_412 = torch.add(t_411, t_406)
+ t_413 = F.relu(t_412)
+ t_414 = self.n_Conv_24(t_413)
+ t_415 = F.relu(t_414)
+ t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
+ t_416 = self.n_Conv_25(t_415_padded)
+ t_417 = F.relu(t_416)
+ t_418 = self.n_Conv_26(t_417)
+ t_419 = torch.add(t_418, t_413)
+ t_420 = F.relu(t_419)
+ t_421 = self.n_Conv_27(t_420)
+ t_422 = F.relu(t_421)
+ t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
+ t_423 = self.n_Conv_28(t_422_padded)
+ t_424 = F.relu(t_423)
+ t_425 = self.n_Conv_29(t_424)
+ t_426 = torch.add(t_425, t_420)
+ t_427 = F.relu(t_426)
+ t_428 = self.n_Conv_30(t_427)
+ t_429 = F.relu(t_428)
+ t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
+ t_430 = self.n_Conv_31(t_429_padded)
+ t_431 = F.relu(t_430)
+ t_432 = self.n_Conv_32(t_431)
+ t_433 = torch.add(t_432, t_427)
+ t_434 = F.relu(t_433)
+ t_435 = self.n_Conv_33(t_434)
+ t_436 = F.relu(t_435)
+ t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
+ t_437 = self.n_Conv_34(t_436_padded)
+ t_438 = F.relu(t_437)
+ t_439 = self.n_Conv_35(t_438)
+ t_440 = torch.add(t_439, t_434)
+ t_441 = F.relu(t_440)
+ t_442 = self.n_Conv_36(t_441)
+ t_443 = self.n_Conv_37(t_441)
+ t_444 = F.relu(t_443)
+ t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
+ t_445 = self.n_Conv_38(t_444_padded)
+ t_446 = F.relu(t_445)
+ t_447 = self.n_Conv_39(t_446)
+ t_448 = torch.add(t_447, t_442)
+ t_449 = F.relu(t_448)
+ t_450 = self.n_Conv_40(t_449)
+ t_451 = F.relu(t_450)
+ t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
+ t_452 = self.n_Conv_41(t_451_padded)
+ t_453 = F.relu(t_452)
+ t_454 = self.n_Conv_42(t_453)
+ t_455 = torch.add(t_454, t_449)
+ t_456 = F.relu(t_455)
+ t_457 = self.n_Conv_43(t_456)
+ t_458 = F.relu(t_457)
+ t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
+ t_459 = self.n_Conv_44(t_458_padded)
+ t_460 = F.relu(t_459)
+ t_461 = self.n_Conv_45(t_460)
+ t_462 = torch.add(t_461, t_456)
+ t_463 = F.relu(t_462)
+ t_464 = self.n_Conv_46(t_463)
+ t_465 = F.relu(t_464)
+ t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
+ t_466 = self.n_Conv_47(t_465_padded)
+ t_467 = F.relu(t_466)
+ t_468 = self.n_Conv_48(t_467)
+ t_469 = torch.add(t_468, t_463)
+ t_470 = F.relu(t_469)
+ t_471 = self.n_Conv_49(t_470)
+ t_472 = F.relu(t_471)
+ t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
+ t_473 = self.n_Conv_50(t_472_padded)
+ t_474 = F.relu(t_473)
+ t_475 = self.n_Conv_51(t_474)
+ t_476 = torch.add(t_475, t_470)
+ t_477 = F.relu(t_476)
+ t_478 = self.n_Conv_52(t_477)
+ t_479 = F.relu(t_478)
+ t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
+ t_480 = self.n_Conv_53(t_479_padded)
+ t_481 = F.relu(t_480)
+ t_482 = self.n_Conv_54(t_481)
+ t_483 = torch.add(t_482, t_477)
+ t_484 = F.relu(t_483)
+ t_485 = self.n_Conv_55(t_484)
+ t_486 = F.relu(t_485)
+ t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
+ t_487 = self.n_Conv_56(t_486_padded)
+ t_488 = F.relu(t_487)
+ t_489 = self.n_Conv_57(t_488)
+ t_490 = torch.add(t_489, t_484)
+ t_491 = F.relu(t_490)
+ t_492 = self.n_Conv_58(t_491)
+ t_493 = F.relu(t_492)
+ t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
+ t_494 = self.n_Conv_59(t_493_padded)
+ t_495 = F.relu(t_494)
+ t_496 = self.n_Conv_60(t_495)
+ t_497 = torch.add(t_496, t_491)
+ t_498 = F.relu(t_497)
+ t_499 = self.n_Conv_61(t_498)
+ t_500 = F.relu(t_499)
+ t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
+ t_501 = self.n_Conv_62(t_500_padded)
+ t_502 = F.relu(t_501)
+ t_503 = self.n_Conv_63(t_502)
+ t_504 = torch.add(t_503, t_498)
+ t_505 = F.relu(t_504)
+ t_506 = self.n_Conv_64(t_505)
+ t_507 = F.relu(t_506)
+ t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
+ t_508 = self.n_Conv_65(t_507_padded)
+ t_509 = F.relu(t_508)
+ t_510 = self.n_Conv_66(t_509)
+ t_511 = torch.add(t_510, t_505)
+ t_512 = F.relu(t_511)
+ t_513 = self.n_Conv_67(t_512)
+ t_514 = F.relu(t_513)
+ t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
+ t_515 = self.n_Conv_68(t_514_padded)
+ t_516 = F.relu(t_515)
+ t_517 = self.n_Conv_69(t_516)
+ t_518 = torch.add(t_517, t_512)
+ t_519 = F.relu(t_518)
+ t_520 = self.n_Conv_70(t_519)
+ t_521 = F.relu(t_520)
+ t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
+ t_522 = self.n_Conv_71(t_521_padded)
+ t_523 = F.relu(t_522)
+ t_524 = self.n_Conv_72(t_523)
+ t_525 = torch.add(t_524, t_519)
+ t_526 = F.relu(t_525)
+ t_527 = self.n_Conv_73(t_526)
+ t_528 = F.relu(t_527)
+ t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
+ t_529 = self.n_Conv_74(t_528_padded)
+ t_530 = F.relu(t_529)
+ t_531 = self.n_Conv_75(t_530)
+ t_532 = torch.add(t_531, t_526)
+ t_533 = F.relu(t_532)
+ t_534 = self.n_Conv_76(t_533)
+ t_535 = F.relu(t_534)
+ t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
+ t_536 = self.n_Conv_77(t_535_padded)
+ t_537 = F.relu(t_536)
+ t_538 = self.n_Conv_78(t_537)
+ t_539 = torch.add(t_538, t_533)
+ t_540 = F.relu(t_539)
+ t_541 = self.n_Conv_79(t_540)
+ t_542 = F.relu(t_541)
+ t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
+ t_543 = self.n_Conv_80(t_542_padded)
+ t_544 = F.relu(t_543)
+ t_545 = self.n_Conv_81(t_544)
+ t_546 = torch.add(t_545, t_540)
+ t_547 = F.relu(t_546)
+ t_548 = self.n_Conv_82(t_547)
+ t_549 = F.relu(t_548)
+ t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
+ t_550 = self.n_Conv_83(t_549_padded)
+ t_551 = F.relu(t_550)
+ t_552 = self.n_Conv_84(t_551)
+ t_553 = torch.add(t_552, t_547)
+ t_554 = F.relu(t_553)
+ t_555 = self.n_Conv_85(t_554)
+ t_556 = F.relu(t_555)
+ t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
+ t_557 = self.n_Conv_86(t_556_padded)
+ t_558 = F.relu(t_557)
+ t_559 = self.n_Conv_87(t_558)
+ t_560 = torch.add(t_559, t_554)
+ t_561 = F.relu(t_560)
+ t_562 = self.n_Conv_88(t_561)
+ t_563 = F.relu(t_562)
+ t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
+ t_564 = self.n_Conv_89(t_563_padded)
+ t_565 = F.relu(t_564)
+ t_566 = self.n_Conv_90(t_565)
+ t_567 = torch.add(t_566, t_561)
+ t_568 = F.relu(t_567)
+ t_569 = self.n_Conv_91(t_568)
+ t_570 = F.relu(t_569)
+ t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
+ t_571 = self.n_Conv_92(t_570_padded)
+ t_572 = F.relu(t_571)
+ t_573 = self.n_Conv_93(t_572)
+ t_574 = torch.add(t_573, t_568)
+ t_575 = F.relu(t_574)
+ t_576 = self.n_Conv_94(t_575)
+ t_577 = F.relu(t_576)
+ t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
+ t_578 = self.n_Conv_95(t_577_padded)
+ t_579 = F.relu(t_578)
+ t_580 = self.n_Conv_96(t_579)
+ t_581 = torch.add(t_580, t_575)
+ t_582 = F.relu(t_581)
+ t_583 = self.n_Conv_97(t_582)
+ t_584 = F.relu(t_583)
+ t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
+ t_585 = self.n_Conv_98(t_584_padded)
+ t_586 = F.relu(t_585)
+ t_587 = self.n_Conv_99(t_586)
+ t_588 = self.n_Conv_100(t_582)
+ t_589 = torch.add(t_587, t_588)
+ t_590 = F.relu(t_589)
+ t_591 = self.n_Conv_101(t_590)
+ t_592 = F.relu(t_591)
+ t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
+ t_593 = self.n_Conv_102(t_592_padded)
+ t_594 = F.relu(t_593)
+ t_595 = self.n_Conv_103(t_594)
+ t_596 = torch.add(t_595, t_590)
+ t_597 = F.relu(t_596)
+ t_598 = self.n_Conv_104(t_597)
+ t_599 = F.relu(t_598)
+ t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
+ t_600 = self.n_Conv_105(t_599_padded)
+ t_601 = F.relu(t_600)
+ t_602 = self.n_Conv_106(t_601)
+ t_603 = torch.add(t_602, t_597)
+ t_604 = F.relu(t_603)
+ t_605 = self.n_Conv_107(t_604)
+ t_606 = F.relu(t_605)
+ t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
+ t_607 = self.n_Conv_108(t_606_padded)
+ t_608 = F.relu(t_607)
+ t_609 = self.n_Conv_109(t_608)
+ t_610 = torch.add(t_609, t_604)
+ t_611 = F.relu(t_610)
+ t_612 = self.n_Conv_110(t_611)
+ t_613 = F.relu(t_612)
+ t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
+ t_614 = self.n_Conv_111(t_613_padded)
+ t_615 = F.relu(t_614)
+ t_616 = self.n_Conv_112(t_615)
+ t_617 = torch.add(t_616, t_611)
+ t_618 = F.relu(t_617)
+ t_619 = self.n_Conv_113(t_618)
+ t_620 = F.relu(t_619)
+ t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
+ t_621 = self.n_Conv_114(t_620_padded)
+ t_622 = F.relu(t_621)
+ t_623 = self.n_Conv_115(t_622)
+ t_624 = torch.add(t_623, t_618)
+ t_625 = F.relu(t_624)
+ t_626 = self.n_Conv_116(t_625)
+ t_627 = F.relu(t_626)
+ t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
+ t_628 = self.n_Conv_117(t_627_padded)
+ t_629 = F.relu(t_628)
+ t_630 = self.n_Conv_118(t_629)
+ t_631 = torch.add(t_630, t_625)
+ t_632 = F.relu(t_631)
+ t_633 = self.n_Conv_119(t_632)
+ t_634 = F.relu(t_633)
+ t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
+ t_635 = self.n_Conv_120(t_634_padded)
+ t_636 = F.relu(t_635)
+ t_637 = self.n_Conv_121(t_636)
+ t_638 = torch.add(t_637, t_632)
+ t_639 = F.relu(t_638)
+ t_640 = self.n_Conv_122(t_639)
+ t_641 = F.relu(t_640)
+ t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
+ t_642 = self.n_Conv_123(t_641_padded)
+ t_643 = F.relu(t_642)
+ t_644 = self.n_Conv_124(t_643)
+ t_645 = torch.add(t_644, t_639)
+ t_646 = F.relu(t_645)
+ t_647 = self.n_Conv_125(t_646)
+ t_648 = F.relu(t_647)
+ t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
+ t_649 = self.n_Conv_126(t_648_padded)
+ t_650 = F.relu(t_649)
+ t_651 = self.n_Conv_127(t_650)
+ t_652 = torch.add(t_651, t_646)
+ t_653 = F.relu(t_652)
+ t_654 = self.n_Conv_128(t_653)
+ t_655 = F.relu(t_654)
+ t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
+ t_656 = self.n_Conv_129(t_655_padded)
+ t_657 = F.relu(t_656)
+ t_658 = self.n_Conv_130(t_657)
+ t_659 = torch.add(t_658, t_653)
+ t_660 = F.relu(t_659)
+ t_661 = self.n_Conv_131(t_660)
+ t_662 = F.relu(t_661)
+ t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
+ t_663 = self.n_Conv_132(t_662_padded)
+ t_664 = F.relu(t_663)
+ t_665 = self.n_Conv_133(t_664)
+ t_666 = torch.add(t_665, t_660)
+ t_667 = F.relu(t_666)
+ t_668 = self.n_Conv_134(t_667)
+ t_669 = F.relu(t_668)
+ t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
+ t_670 = self.n_Conv_135(t_669_padded)
+ t_671 = F.relu(t_670)
+ t_672 = self.n_Conv_136(t_671)
+ t_673 = torch.add(t_672, t_667)
+ t_674 = F.relu(t_673)
+ t_675 = self.n_Conv_137(t_674)
+ t_676 = F.relu(t_675)
+ t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
+ t_677 = self.n_Conv_138(t_676_padded)
+ t_678 = F.relu(t_677)
+ t_679 = self.n_Conv_139(t_678)
+ t_680 = torch.add(t_679, t_674)
+ t_681 = F.relu(t_680)
+ t_682 = self.n_Conv_140(t_681)
+ t_683 = F.relu(t_682)
+ t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
+ t_684 = self.n_Conv_141(t_683_padded)
+ t_685 = F.relu(t_684)
+ t_686 = self.n_Conv_142(t_685)
+ t_687 = torch.add(t_686, t_681)
+ t_688 = F.relu(t_687)
+ t_689 = self.n_Conv_143(t_688)
+ t_690 = F.relu(t_689)
+ t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
+ t_691 = self.n_Conv_144(t_690_padded)
+ t_692 = F.relu(t_691)
+ t_693 = self.n_Conv_145(t_692)
+ t_694 = torch.add(t_693, t_688)
+ t_695 = F.relu(t_694)
+ t_696 = self.n_Conv_146(t_695)
+ t_697 = F.relu(t_696)
+ t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
+ t_698 = self.n_Conv_147(t_697_padded)
+ t_699 = F.relu(t_698)
+ t_700 = self.n_Conv_148(t_699)
+ t_701 = torch.add(t_700, t_695)
+ t_702 = F.relu(t_701)
+ t_703 = self.n_Conv_149(t_702)
+ t_704 = F.relu(t_703)
+ t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
+ t_705 = self.n_Conv_150(t_704_padded)
+ t_706 = F.relu(t_705)
+ t_707 = self.n_Conv_151(t_706)
+ t_708 = torch.add(t_707, t_702)
+ t_709 = F.relu(t_708)
+ t_710 = self.n_Conv_152(t_709)
+ t_711 = F.relu(t_710)
+ t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
+ t_712 = self.n_Conv_153(t_711_padded)
+ t_713 = F.relu(t_712)
+ t_714 = self.n_Conv_154(t_713)
+ t_715 = torch.add(t_714, t_709)
+ t_716 = F.relu(t_715)
+ t_717 = self.n_Conv_155(t_716)
+ t_718 = F.relu(t_717)
+ t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
+ t_719 = self.n_Conv_156(t_718_padded)
+ t_720 = F.relu(t_719)
+ t_721 = self.n_Conv_157(t_720)
+ t_722 = torch.add(t_721, t_716)
+ t_723 = F.relu(t_722)
+ t_724 = self.n_Conv_158(t_723)
+ t_725 = self.n_Conv_159(t_723)
+ t_726 = F.relu(t_725)
+ t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
+ t_727 = self.n_Conv_160(t_726_padded)
+ t_728 = F.relu(t_727)
+ t_729 = self.n_Conv_161(t_728)
+ t_730 = torch.add(t_729, t_724)
+ t_731 = F.relu(t_730)
+ t_732 = self.n_Conv_162(t_731)
+ t_733 = F.relu(t_732)
+ t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
+ t_734 = self.n_Conv_163(t_733_padded)
+ t_735 = F.relu(t_734)
+ t_736 = self.n_Conv_164(t_735)
+ t_737 = torch.add(t_736, t_731)
+ t_738 = F.relu(t_737)
+ t_739 = self.n_Conv_165(t_738)
+ t_740 = F.relu(t_739)
+ t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
+ t_741 = self.n_Conv_166(t_740_padded)
+ t_742 = F.relu(t_741)
+ t_743 = self.n_Conv_167(t_742)
+ t_744 = torch.add(t_743, t_738)
+ t_745 = F.relu(t_744)
+ t_746 = self.n_Conv_168(t_745)
+ t_747 = self.n_Conv_169(t_745)
+ t_748 = F.relu(t_747)
+ t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
+ t_749 = self.n_Conv_170(t_748_padded)
+ t_750 = F.relu(t_749)
+ t_751 = self.n_Conv_171(t_750)
+ t_752 = torch.add(t_751, t_746)
+ t_753 = F.relu(t_752)
+ t_754 = self.n_Conv_172(t_753)
+ t_755 = F.relu(t_754)
+ t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
+ t_756 = self.n_Conv_173(t_755_padded)
+ t_757 = F.relu(t_756)
+ t_758 = self.n_Conv_174(t_757)
+ t_759 = torch.add(t_758, t_753)
+ t_760 = F.relu(t_759)
+ t_761 = self.n_Conv_175(t_760)
+ t_762 = F.relu(t_761)
+ t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
+ t_763 = self.n_Conv_176(t_762_padded)
+ t_764 = F.relu(t_763)
+ t_765 = self.n_Conv_177(t_764)
+ t_766 = torch.add(t_765, t_760)
+ t_767 = F.relu(t_766)
+ t_768 = self.n_Conv_178(t_767)
+ t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
+ t_770 = torch.squeeze(t_769, 3)
+ t_770 = torch.squeeze(t_770, 2)
+ t_771 = torch.sigmoid(t_770)
+ return t_771
+
+ def load_state_dict(self, state_dict, **kwargs):
+ self.tags = state_dict.get('tags', [])
+
+ super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
+
diff --git a/modules/devices.py b/modules/devices.py
new file mode 100644
index 0000000000000000000000000000000000000000..52c3e7cd773f9c89857dfce14b37d63cb6329fac
--- /dev/null
+++ b/modules/devices.py
@@ -0,0 +1,152 @@
+import sys
+import contextlib
+import torch
+from modules import errors
+
+if sys.platform == "darwin":
+ from modules import mac_specific
+
+
+def has_mps() -> bool:
+ if sys.platform != "darwin":
+ return False
+ else:
+ return mac_specific.has_mps
+
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]:
+ return args[x + 1]
+
+ return None
+
+
+def get_cuda_device_string():
+ from modules import shared
+
+ if shared.cmd_opts.device_id is not None:
+ return f"cuda:{shared.cmd_opts.device_id}"
+
+ return "cuda"
+
+
+def get_optimal_device_name():
+ if torch.cuda.is_available():
+ return get_cuda_device_string()
+
+ if has_mps():
+ return "mps"
+
+ return "cpu"
+
+
+def get_optimal_device():
+ return torch.device(get_optimal_device_name())
+
+
+def get_device_for(task):
+ from modules import shared
+
+ if task in shared.cmd_opts.use_cpu:
+ return cpu
+
+ return get_optimal_device()
+
+
+def torch_gc():
+ if torch.cuda.is_available():
+ with torch.cuda.device(get_cuda_device_string()):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+
+def enable_tf32():
+ if torch.cuda.is_available():
+
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+ if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+ torch.backends.cudnn.benchmark = True
+
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+
+
+errors.run(enable_tf32, "Enabling TF32")
+
+cpu = torch.device("cpu")
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
+dtype = torch.float16
+dtype_vae = torch.float16
+dtype_unet = torch.float16
+unet_needs_upcast = False
+
+
+def cond_cast_unet(input):
+ return input.to(dtype_unet) if unet_needs_upcast else input
+
+
+def cond_cast_float(input):
+ return input.float() if unet_needs_upcast else input
+
+
+def randn(seed, shape):
+ torch.manual_seed(seed)
+ if device.type == 'mps':
+ return torch.randn(shape, device=cpu).to(device)
+ return torch.randn(shape, device=device)
+
+
+def randn_without_seed(shape):
+ if device.type == 'mps':
+ return torch.randn(shape, device=cpu).to(device)
+ return torch.randn(shape, device=device)
+
+
+def autocast(disable=False):
+ from modules import shared
+
+ if disable:
+ return contextlib.nullcontext()
+
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ return contextlib.nullcontext()
+
+ return torch.autocast("cuda")
+
+
+def without_autocast(disable=False):
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
+
+
+class NansException(Exception):
+ pass
+
+
+def test_for_nans(x, where):
+ from modules import shared
+
+ if shared.cmd_opts.disable_nan_check:
+ return
+
+ if not torch.all(torch.isnan(x)).item():
+ return
+
+ if where == "unet":
+ message = "A tensor with all NaNs was produced in Unet."
+
+ if not shared.cmd_opts.no_half:
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
+
+ elif where == "vae":
+ message = "A tensor with all NaNs was produced in VAE."
+
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
+ else:
+ message = "A tensor with all NaNs was produced."
+
+ message += " Use --disable-nan-check commandline argument to disable this check."
+
+ raise NansException(message)
diff --git a/modules/errors.py b/modules/errors.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c9c44497221eb814b402aa5859a3e6aaeaac00
--- /dev/null
+++ b/modules/errors.py
@@ -0,0 +1,43 @@
+import sys
+import traceback
+
+
+def print_error_explanation(message):
+ lines = message.strip().split("\n")
+ max_len = max([len(x) for x in lines])
+
+ print('=' * max_len, file=sys.stderr)
+ for line in lines:
+ print(line, file=sys.stderr)
+ print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ message = str(e)
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+ print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+ """)
+
+
+already_displayed = {}
+
+
+def display_once(e: Exception, task):
+ if task in already_displayed:
+ return
+
+ display(e, task)
+
+ already_displayed[task] = 1
+
+
+def run(code, task):
+ try:
+ code()
+ except Exception as e:
+ display(task, e)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..80131c62cfeaa7f95455df55c45d6e62591adeee
--- /dev/null
+++ b/modules/esrgan_model.py
@@ -0,0 +1,233 @@
+import os
+
+import numpy as np
+import torch
+from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
+
+import modules.esrgan_model_arch as arch
+from modules import shared, modelloader, images, devices
+from modules.upscaler import Upscaler, UpscalerData
+from modules.shared import opts
+
+
+
+def mod2normal(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if 'conv_first.weight' in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def resrgan2normal(state_dict, nb=23):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ re8x = 0
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if "rdb" in k:
+ ori_k = k.replace('body.', 'model.1.sub.')
+ ori_k = ori_k.replace('.rdb', '.RDB')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
+
+ if 'conv_up3.weight' in state_dict:
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
+ re8x = 3
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
+
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
+
+ state_dict = crt_net
+ return state_dict
+
+
+def infer_params(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ scale2x = 0
+ scalemin = 6
+ n_uplayer = 0
+ plus = False
+
+ for block in list(state_dict):
+ parts = block.split(".")
+ n_parts = len(parts)
+ if n_parts == 5 and parts[2] == "sub":
+ nb = int(parts[3])
+ elif n_parts == 3:
+ part_num = int(parts[1])
+ if (part_num > scalemin
+ and parts[0] == "model"
+ and parts[2] == "weight"):
+ scale2x += 1
+ if part_num > n_uplayer:
+ n_uplayer = part_num
+ out_nc = state_dict[block].shape[0]
+ if not plus and "conv1x1" in block:
+ plus = True
+
+ nf = state_dict["model.0.weight"].shape[0]
+ in_nc = state_dict["model.0.weight"].shape[1]
+ out_nc = out_nc
+ scale = 2 ** scale2x
+
+ return in_nc, out_nc, nf, nb, plus, scale
+
+
+class UpscalerESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "ESRGAN"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
+ self.model_name = "ESRGAN_4x"
+ self.scalers = []
+ self.user_path = dirname
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
+ scalers = []
+ if len(model_paths) == 0:
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
+ scalers.append(scaler_data)
+ for file in model_paths:
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+
+ scaler_data = UpscalerData(name, file, self, 4)
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ model = self.load_model(selected_model)
+ if model is None:
+ return img
+ model.to(devices.device_esrgan)
+ img = esrgan_upscale(model, img)
+ return img
+
+ def load_model(self, path: str):
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="%s.pth" % self.model_name,
+ progress=True)
+ else:
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (self.model_path, filename))
+ return None
+
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
+
+ if "params_ema" in state_dict:
+ state_dict = state_dict["params_ema"]
+ elif "params" in state_dict:
+ state_dict = state_dict["params"]
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
+ state_dict = resrgan2normal(state_dict, nb)
+ elif "conv_first.weight" in state_dict:
+ state_dict = mod2normal(state_dict)
+ elif "model.0.weight" not in state_dict:
+ raise Exception("The file is not a recognized ESRGAN model.")
+
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
+
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ return model
+
+
+def upscale_without_tiling(model, img):
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(devices.device_esrgan)
+ with torch.no_grad():
+ output = model(img)
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ return Image.fromarray(output, 'RGB')
+
+
+def esrgan_upscale(model, img):
+ if opts.ESRGAN_tile == 0:
+ return upscale_without_tiling(model, img)
+
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
+ newtiles = []
+ scale_factor = 1
+
+ for y, h, row in grid.tiles:
+ newrow = []
+ for tiledata in row:
+ x, w, tile = tiledata
+
+ output = upscale_without_tiling(model, tile)
+ scale_factor = output.width // tile.width
+
+ newrow.append([x * scale_factor, w * scale_factor, output])
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
+
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = images.combine_grid(newgrid)
+ return output
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec5962b595cfc0e6c52d916275538c9c4252068
--- /dev/null
+++ b/modules/esrgan_model_arch.py
@@ -0,0 +1,464 @@
+# this file is adapted from https://github.com/victorca25/iNNfer
+
+from collections import OrderedDict
+import math
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+####################
+# RRDBNet Generator
+####################
+
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
+ finalact=None, gaussian_noise=False, plus=False):
+ super(RRDBNet, self).__init__()
+ n_upscale = int(math.log(upscale, 2))
+ if upscale == 3:
+ n_upscale = 1
+
+ self.resrgan_scale = 0
+ if in_nc % 16 == 0:
+ self.resrgan_scale = 1
+ elif in_nc != 4 and in_nc % 4 == 0:
+ self.resrgan_scale = 2
+
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
+
+ if upsample_mode == 'upconv':
+ upsample_block = upconv_block
+ elif upsample_mode == 'pixelshuffle':
+ upsample_block = pixelshuffle_block
+ else:
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+ if upscale == 3:
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
+ else:
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+
+ outact = act(finalact) if finalact else None
+
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
+ *upsampler, HR_conv0, HR_conv1, outact)
+
+ def forward(self, x, outm=None):
+ if self.resrgan_scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ elif self.resrgan_scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ else:
+ feat = x
+
+ return self.model(feat)
+
+
+class RRDB(nn.Module):
+ """
+ Residual in Residual Dense Block
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+ """
+
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(RRDB, self).__init__()
+ # This is for backwards compatibility with existing models
+ if nr == 3:
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ else:
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
+ self.RDBs = nn.Sequential(*RDB_list)
+
+ def forward(self, x):
+ if hasattr(self, 'RDB1'):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ else:
+ out = self.RDBs(x)
+ return out * 0.2 + x
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ """
+ Residual Dense Block
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+ Modified options that can be used:
+ - "Partial Convolution based Padding" arXiv:1811.11718
+ - "Spectral normalization" arXiv:1802.05957
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ {Rakotonirina} and A. {Rasoanaivo}
+ """
+
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(ResidualDenseBlock_5C, self).__init__()
+
+ self.noise = GaussianNoise() if gaussian_noise else None
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
+
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ if mode == 'CNA':
+ last_act = None
+ else:
+ last_act = act_type
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.conv2(torch.cat((x, x1), 1))
+ if self.conv1x1:
+ x2 = x2 + self.conv1x1(x)
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+ if self.conv1x1:
+ x4 = x4 + x2
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ if self.noise:
+ return self.noise(x5.mul(0.2) + x)
+ else:
+ return x5 * 0.2 + x
+
+
+####################
+# ESRGANplus
+####################
+
+class GaussianNoise(nn.Module):
+ def __init__(self, sigma=0.1, is_relative_detach=False):
+ super().__init__()
+ self.sigma = sigma
+ self.is_relative_detach = is_relative_detach
+ self.noise = torch.tensor(0, dtype=torch.float)
+
+ def forward(self, x):
+ if self.training and self.sigma != 0:
+ self.noise = self.noise.to(x.device)
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
+ x = x + sampled_noise
+ return x
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# SRVGGNetCompact
+####################
+
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
+ return out
+
+
+####################
+# Upsampler
+####################
+
+class Upsample(nn.Module):
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+ The input data is assumed to be of the form
+ `minibatch x channels x [optional depth] x [optional height] x width`.
+ """
+
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ super(Upsample, self).__init__()
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.size = size
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+
+ def extra_repr(self):
+ if self.scale_factor is not None:
+ info = 'scale_factor=' + str(self.scale_factor)
+ else:
+ info = 'size=' + str(self.size)
+ info += ', mode=' + self.mode
+ return info
+
+
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
+ """
+ Pixel shuffle layer
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+ Neural Network, CVPR17)
+ """
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ n = norm(norm_type, out_nc) if norm_type else None
+ a = act(act_type) if act_type else None
+ return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
+ """ Upconv layer """
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
+ return sequential(upsample, conv)
+
+
+
+
+
+
+
+
+####################
+# Basic blocks
+####################
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+ Args:
+ basic_block (nn.module): nn.module class for basic block. (block)
+ num_basic_block (int): number of blocks. (n_layers)
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
+ """ activation helper """
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ layer = nn.ReLU(inplace)
+ elif act_type in ('leakyrelu', 'lrelu'):
+ layer = nn.LeakyReLU(neg_slope, inplace)
+ elif act_type == 'prelu':
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+ elif act_type == 'tanh': # [-1, 1] range output
+ layer = nn.Tanh()
+ elif act_type == 'sigmoid': # [0, 1] range output
+ layer = nn.Sigmoid()
+ else:
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+ return layer
+
+
+class Identity(nn.Module):
+ def __init__(self, *kwargs):
+ super(Identity, self).__init__()
+
+ def forward(self, x, *kwargs):
+ return x
+
+
+def norm(norm_type, nc):
+ """ Return a normalization layer """
+ norm_type = norm_type.lower()
+ if norm_type == 'batch':
+ layer = nn.BatchNorm2d(nc, affine=True)
+ elif norm_type == 'instance':
+ layer = nn.InstanceNorm2d(nc, affine=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+ return layer
+
+
+def pad(pad_type, padding):
+ """ padding layer helper """
+ pad_type = pad_type.lower()
+ if padding == 0:
+ return None
+ if pad_type == 'reflect':
+ layer = nn.ReflectionPad2d(padding)
+ elif pad_type == 'replicate':
+ layer = nn.ReplicationPad2d(padding)
+ elif pad_type == 'zero':
+ layer = nn.ZeroPad2d(padding)
+ else:
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+ return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ padding = (kernel_size - 1) // 2
+ return padding
+
+
+class ShortcutBlock(nn.Module):
+ """ Elementwise sum the output of a submodule to its input """
+ def __init__(self, submodule):
+ super(ShortcutBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = x + self.sub(x)
+ return output
+
+ def __repr__(self):
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
+
+
+def sequential(*args):
+ """ Flatten Sequential. It unwraps nn.Sequential. """
+ if len(args) == 1:
+ if isinstance(args[0], OrderedDict):
+ raise NotImplementedError('sequential does not support OrderedDict input.')
+ return args[0] # No sequential is needed.
+ modules = []
+ for module in args:
+ if isinstance(module, nn.Sequential):
+ for submodule in module.children():
+ modules.append(submodule)
+ elif isinstance(module, nn.Module):
+ modules.append(module)
+ return nn.Sequential(*modules)
+
+
+def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False):
+ """ Conv layer with padding, normalization, activation """
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+ padding = get_valid_padding(kernel_size, dilation)
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
+ padding = padding if pad_type == 'zero' else 0
+
+ if convtype=='PartialConv2D':
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='DeformConv2D':
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='Conv3D':
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ else:
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+
+ if spectral_norm:
+ c = nn.utils.spectral_norm(c)
+
+ a = act(act_type) if act_type else None
+ if 'CNA' in mode:
+ n = norm(norm_type, out_nc) if norm_type else None
+ return sequential(p, c, n, a)
+ elif mode == 'NAC':
+ if norm_type is None and act_type is not None:
+ a = act(act_type, inplace=False)
+ n = norm(norm_type, in_nc) if norm_type else None
+ return sequential(n, a, p, c)
diff --git a/modules/extensions.py b/modules/extensions.py
new file mode 100644
index 0000000000000000000000000000000000000000..5918b9c4ec578421e70f09f754f2c9b2dafef03a
--- /dev/null
+++ b/modules/extensions.py
@@ -0,0 +1,107 @@
+import os
+import sys
+import traceback
+
+import time
+import git
+
+from modules import paths, shared
+
+extensions = []
+extensions_dir = os.path.join(paths.data_path, "extensions")
+extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
+
+if not os.path.exists(extensions_dir):
+ os.makedirs(extensions_dir)
+
+def active():
+ return [x for x in extensions if x.enabled]
+
+
+class Extension:
+ def __init__(self, name, path, enabled=True, is_builtin=False):
+ self.name = name
+ self.path = path
+ self.enabled = enabled
+ self.status = ''
+ self.can_update = False
+ self.is_builtin = is_builtin
+ self.version = ''
+
+ repo = None
+ try:
+ if os.path.exists(os.path.join(path, ".git")):
+ repo = git.Repo(path)
+ except Exception:
+ print(f"Error reading github repository info from {path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if repo is None or repo.bare:
+ self.remote = None
+ else:
+ try:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+ head = repo.head.commit
+ ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
+ self.version = f'{head.hexsha[:8]} ({ts})'
+
+ except Exception:
+ self.remote = None
+
+ def list_files(self, subdir, extension):
+ from modules import scripts
+
+ dirpath = os.path.join(self.path, subdir)
+ if not os.path.isdir(dirpath):
+ return []
+
+ res = []
+ for filename in sorted(os.listdir(dirpath)):
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
+
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+ return res
+
+ def check_updates(self):
+ repo = git.Repo(self.path)
+ for fetch in repo.remote().fetch(dry_run=True):
+ if fetch.flags != fetch.HEAD_UPTODATE:
+ self.can_update = True
+ self.status = "behind"
+ return
+
+ self.can_update = False
+ self.status = "latest"
+
+ def fetch_and_reset_hard(self):
+ repo = git.Repo(self.path)
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
+ repo.git.fetch(all=True)
+ repo.git.reset('origin', hard=True)
+
+
+def list_extensions():
+ extensions.clear()
+
+ if not os.path.isdir(extensions_dir):
+ return
+
+ paths = []
+ for dirname in [extensions_dir, extensions_builtin_dir]:
+ if not os.path.isdir(dirname):
+ return
+
+ for extension_dirname in sorted(os.listdir(dirname)):
+ path = os.path.join(dirname, extension_dirname)
+ if not os.path.isdir(path):
+ continue
+
+ paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+
+ for dirname, path, is_builtin in paths:
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
+ extensions.append(extension)
+
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf6aa7907156aefa62f183fdcd25fea3e17252b4
--- /dev/null
+++ b/modules/extra_networks.py
@@ -0,0 +1,147 @@
+import re
+from collections import defaultdict
+
+from modules import errors
+
+extra_network_registry = {}
+
+
+def initialize():
+ extra_network_registry.clear()
+
+
+def register_extra_network(extra_network):
+ extra_network_registry[extra_network.name] = extra_network
+
+
+class ExtraNetworkParams:
+ def __init__(self, items=None):
+ self.items = items or []
+
+
+class ExtraNetwork:
+ def __init__(self, name):
+ self.name = name
+
+ def activate(self, p, params_list):
+ """
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
+ Passes arguments related to this extra network in params_list.
+ User passes arguments by specifying this in his prompt:
+
+
+
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
+ separated by colon.
+
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
+ in this case, all effects of this extra networks should be disabled.
+
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
+
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
+
+ > "1girl, "
+
+ params_list will be:
+
+ [
+ ExtraNetworkParams(items=["agm", "1.1"]),
+ ExtraNetworkParams(items=["ray"])
+ ]
+
+ """
+ raise NotImplementedError
+
+ def deactivate(self, p):
+ """
+ Called at the end of processing for housekeeping. No need to do anything here.
+ """
+
+ raise NotImplementedError
+
+
+def activate(p, extra_network_data):
+ """call activate for extra networks in extra_network_data in specified order, then call
+ activate for all remaining registered networks with an empty argument list"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ print(f"Skipping unknown extra network: {extra_network_name}")
+ continue
+
+ try:
+ extra_network.activate(p, extra_network_args)
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.activate(p, [])
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name}")
+
+
+def deactivate(p, extra_network_data):
+ """call deactivate for extra networks in extra_network_data in specified order, then call
+ deactivate for all remaining registered networks"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating extra network {extra_network_name}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
+
+
+re_extra_net = re.compile(r"<(\w+):([^>]+)>")
+
+
+def parse_prompt(prompt):
+ res = defaultdict(list)
+
+ def found(m):
+ name = m.group(1)
+ args = m.group(2)
+
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
+
+ return ""
+
+ prompt = re.sub(re_extra_net, found, prompt)
+
+ return prompt, res
+
+
+def parse_prompts(prompts):
+ res = []
+ extra_data = None
+
+ for prompt in prompts:
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
+
+ if extra_data is None:
+ extra_data = parsed_extra_data
+
+ res.append(updated_prompt)
+
+ return res, extra_data
+
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
new file mode 100644
index 0000000000000000000000000000000000000000..207343daa673c14a362d4bd2399982d9ad86fe22
--- /dev/null
+++ b/modules/extra_networks_hypernet.py
@@ -0,0 +1,27 @@
+from modules import extra_networks, shared, extra_networks
+from modules.hypernetworks import hypernetwork
+
+
+class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('hypernet')
+
+ def activate(self, p, params_list):
+ additional = shared.opts.sd_hypernetwork
+
+ if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ p.all_prompts = [x + f"" for x in p.all_prompts]
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
+
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ hypernetwork.load_hypernetworks(names, multipliers)
+
+ def deactivate(self, p):
+ pass
diff --git a/modules/extras.py b/modules/extras.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a9af2d8e641fdf1ebd29045078d29b5aeae3d6f
--- /dev/null
+++ b/modules/extras.py
@@ -0,0 +1,258 @@
+import os
+import re
+import shutil
+
+
+import torch
+import tqdm
+
+from modules import shared, images, sd_models, sd_vae, sd_models_config
+from modules.ui_common import plaintext_to_html
+import gradio as gr
+import safetensors.torch
+
+
+def run_pnginfo(image):
+ if image is None:
+ return '', '', ''
+
+ geninfo, items = images.read_info_from_image(image)
+ items = {**{'parameters': geninfo}, **items}
+
+ info = ''
+ for key, text in items.items():
+ info += f"""
+
+
{plaintext_to_html(str(key))}
+
{plaintext_to_html(str(text))}
+
+""".strip()+"\n"
+
+ if len(info) == 0:
+ message = "Nothing found in the image."
+ info = f""
+
+ return '', geninfo, info
+
+
+def create_config(ckpt_result, config_source, a, b, c):
+ def config(x):
+ res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
+ return res if res != shared.sd_default_config else None
+
+ if config_source == 0:
+ cfg = config(a) or config(b) or config(c)
+ elif config_source == 1:
+ cfg = config(b)
+ elif config_source == 2:
+ cfg = config(c)
+ else:
+ cfg = None
+
+ if cfg is None:
+ return
+
+ filename, _ = os.path.splitext(ckpt_result)
+ checkpoint_filename = filename + ".yaml"
+
+ print("Copying config:")
+ print(" from:", cfg)
+ print(" to:", checkpoint_filename)
+ shutil.copyfile(cfg, checkpoint_filename)
+
+
+checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
+
+def to_half(tensor, enable):
+ if enable and tensor.dtype == torch.float:
+ return tensor.half()
+
+ return tensor
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+ shared.state.begin()
+ shared.state.job = 'model-merge'
+
+ def fail(message):
+ shared.state.textinfo = message
+ shared.state.end()
+ return [*[gr.update() for _ in range(4)], message]
+
+ def weighted_sum(theta0, theta1, alpha):
+ return ((1 - alpha) * theta0) + (alpha * theta1)
+
+ def get_difference(theta1, theta2):
+ return theta1 - theta2
+
+ def add_difference(theta0, theta1_2_diff, alpha):
+ return theta0 + (alpha * theta1_2_diff)
+
+ def filename_weighted_sum():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ Ma = round(1 - multiplier, 2)
+ Mb = round(multiplier, 2)
+
+ return f"{Ma}({a}) + {Mb}({b})"
+
+ def filename_add_difference():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ c = tertiary_model_info.model_name
+ M = round(multiplier, 2)
+
+ return f"{a} + {M}({b} - {c})"
+
+ def filename_nothing():
+ return primary_model_info.model_name
+
+ theta_funcs = {
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
+ "Add difference": (filename_add_difference, get_difference, add_difference),
+ "No interpolation": (filename_nothing, None, None),
+ }
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
+
+ if not primary_model_name:
+ return fail("Failed: Merging requires a primary model.")
+
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
+
+ if theta_func2 and not secondary_model_name:
+ return fail("Failed: Merging requires a secondary model.")
+
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
+
+ if theta_func1 and not tertiary_model_name:
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
+
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
+
+ result_is_inpainting_model = False
+ result_is_instruct_pix2pix_model = False
+
+ if theta_func2:
+ shared.state.textinfo = f"Loading B"
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ else:
+ theta_1 = None
+
+ if theta_func1:
+ shared.state.textinfo = f"Loading C"
+ print(f"Loading {tertiary_model_info.filename}...")
+ theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+
+ shared.state.textinfo = 'Merging B and C'
+ shared.state.sampling_steps = len(theta_1.keys())
+ for key in tqdm.tqdm(theta_1.keys()):
+ if key in checkpoint_dict_skip_on_merge:
+ continue
+
+ if 'model' in key:
+ if key in theta_2:
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
+ theta_1[key] = theta_func1(theta_1[key], t2)
+ else:
+ theta_1[key] = torch.zeros_like(theta_1[key])
+
+ shared.state.sampling_step += 1
+ del theta_2
+
+ shared.state.nextjob()
+
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
+ print(f"Loading {primary_model_info.filename}...")
+ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
+
+ print("Merging...")
+ shared.state.textinfo = 'Merging A and B'
+ shared.state.sampling_steps = len(theta_0.keys())
+ for key in tqdm.tqdm(theta_0.keys()):
+ if theta_1 and 'model' in key and key in theta_1:
+
+ if key in checkpoint_dict_skip_on_merge:
+ continue
+
+ a = theta_0[key]
+ b = theta_1[key]
+
+ # this enables merging an inpainting model (A) with another one (B);
+ # where normal model would have 4 channels, for latenst space, inpainting model would
+ # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
+ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
+ if a.shape[1] == 4 and b.shape[1] == 9:
+ raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
+ if a.shape[1] == 4 and b.shape[1] == 8:
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
+
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
+ result_is_instruct_pix2pix_model = True
+ else:
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
+ result_is_inpainting_model = True
+ else:
+ theta_0[key] = theta_func2(a, b, multiplier)
+
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
+ shared.state.sampling_step += 1
+
+ del theta_1
+
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
+ if bake_in_vae_filename is not None:
+ print(f"Baking in VAE from {bake_in_vae_filename}")
+ shared.state.textinfo = 'Baking in VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
+
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
+
+ del vae_dict
+
+ if save_as_half and not theta_func2:
+ for key in theta_0.keys():
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
+ if discard_weights:
+ regex = re.compile(discard_weights)
+ for key in list(theta_0):
+ if re.search(regex, key):
+ theta_0.pop(key, None)
+
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+
+ filename = filename_generator() if custom_name == '' else custom_name
+ filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
+ filename += "." + checkpoint_format
+
+ output_modelname = os.path.join(ckpt_dir, filename)
+
+ shared.state.nextjob()
+ shared.state.textinfo = "Saving"
+ print(f"Saving to {output_modelname}...")
+
+ _, extension = os.path.splitext(output_modelname)
+ if extension.lower() == ".safetensors":
+ safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ else:
+ torch.save(theta_0, output_modelname)
+
+ sd_models.list_models()
+
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
+
+ print(f"Checkpoint saved to {output_modelname}.")
+ shared.state.textinfo = "Checkpoint saved"
+ shared.state.end()
+
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/face_restoration.py b/modules/face_restoration.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c86c6ccce338a1411f4367a0bc6e4046ad67cae
--- /dev/null
+++ b/modules/face_restoration.py
@@ -0,0 +1,19 @@
+from modules import shared
+
+
+class FaceRestoration:
+ def name(self):
+ return "None"
+
+ def restore(self, np_image):
+ return np_image
+
+
+def restore_faces(np_image):
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
+ if len(face_restorers) == 0:
+ return np_image
+
+ face_restorer = face_restorers[0]
+
+ return face_restorer.restore(np_image)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7bc7001967a0fc0b0d39a668675dd6b6199dda4
--- /dev/null
+++ b/modules/generation_parameters_copypaste.py
@@ -0,0 +1,409 @@
+import base64
+import html
+import io
+import math
+import os
+import re
+from pathlib import Path
+
+import gradio as gr
+from modules.paths import data_path
+from modules import shared, ui_tempdir, script_callbacks
+import tempfile
+from PIL import Image
+
+re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
+re_param = re.compile(re_param_code)
+re_imagesize = re.compile(r"^(\d+)x(\d+)$")
+re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
+type_of_gr_update = type(gr.update())
+
+paste_fields = {}
+registered_param_bindings = []
+
+
+class ParamBinding:
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
+ self.paste_button = paste_button
+ self.tabname = tabname
+ self.source_text_component = source_text_component
+ self.source_image_component = source_image_component
+ self.source_tabname = source_tabname
+ self.override_settings_component = override_settings_component
+ self.paste_field_names = paste_field_names
+
+
+def reset():
+ paste_fields.clear()
+
+
+def quote(text):
+ if ',' not in str(text):
+ return text
+
+ text = str(text)
+ text = text.replace('\\', '\\\\')
+ text = text.replace('"', '\\"')
+ return f'"{text}"'
+
+
+def image_from_url_text(filedata):
+ if filedata is None:
+ return None
+
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ filedata = filedata[0]
+
+ if type(filedata) == dict and filedata.get("is_file", False):
+ filename = filedata["name"]
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
+ assert is_in_right_dir, 'trying to open image file outside of allowed directories'
+
+ return Image.open(filename)
+
+ if type(filedata) == list:
+ if len(filedata) == 0:
+ return None
+
+ filedata = filedata[0]
+
+ if filedata.startswith("data:image/png;base64,"):
+ filedata = filedata[len("data:image/png;base64,"):]
+
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
+ image = Image.open(io.BytesIO(filedata))
+ return image
+
+
+def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
+
+ # backwards compatibility for existing extensions
+ import modules.ui
+ if tabname == 'txt2img':
+ modules.ui.txt2img_paste_fields = fields
+ elif tabname == 'img2img':
+ modules.ui.img2img_paste_fields = fields
+
+
+def create_buttons(tabs_list):
+ buttons = {}
+ for tab in tabs_list:
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
+ return buttons
+
+
+def bind_buttons(buttons, send_image, send_generate_info):
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
+ for tabname, button in buttons.items():
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
+
+ register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
+
+
+def register_paste_params_button(binding: ParamBinding):
+ registered_param_bindings.append(binding)
+
+
+def connect_paste_params_buttons():
+ binding: ParamBinding
+ for binding in registered_param_bindings:
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
+ fields = paste_fields[binding.tabname]["fields"]
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
+
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if binding.source_image_component and destination_image_component:
+ if isinstance(binding.source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
+ else:
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+
+ binding.paste_button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[binding.source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ )
+
+ if binding.source_text_component is not None and fields is not None:
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
+
+ if binding.source_tabname is not None and fields is not None:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
+ binding.paste_button.click(
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
+ )
+
+ binding.paste_button.click(
+ fn=None,
+ _js=f"switch_to_{binding.tabname}",
+ inputs=None,
+ outputs=None,
+ )
+
+
+def send_image_and_dimensions(x):
+ if isinstance(x, Image.Image):
+ img = x
+ else:
+ img = image_from_url_text(x)
+
+ if shared.opts.send_size and isinstance(img, Image.Image):
+ w = img.width
+ h = img.height
+ else:
+ w = gr.update()
+ h = gr.update()
+
+ return img, w, h
+
+
+
+def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
+ """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
+
+ Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
+ parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
+
+ If the infotext has no hash, then a hypernet with the same name will be selected instead.
+ """
+ hypernet_name = hypernet_name.lower()
+ if hypernet_hash is not None:
+ # Try to match the hash in the name
+ for hypernet_key in shared.hypernetworks.keys():
+ result = re_hypernet_hash.search(hypernet_key)
+ if result is not None and result[1] == hypernet_hash:
+ return hypernet_key
+ else:
+ # Fall back to a hypernet with the same name
+ for hypernet_key in shared.hypernetworks.keys():
+ if hypernet_key.lower().startswith(hypernet_name):
+ return hypernet_key
+
+ return None
+
+
+def restore_old_hires_fix_params(res):
+ """for infotexts that specify old First pass size parameter, convert it into
+ width, height, and hr scale"""
+
+ firstpass_width = res.get('First pass size-1', None)
+ firstpass_height = res.get('First pass size-2', None)
+
+ if shared.opts.use_old_hires_fix_width_height:
+ hires_width = int(res.get("Hires resize-1", 0))
+ hires_height = int(res.get("Hires resize-2", 0))
+
+ if hires_width and hires_height:
+ res['Size-1'] = hires_width
+ res['Size-2'] = hires_height
+ return
+
+ if firstpass_width is None or firstpass_height is None:
+ return
+
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
+ width = int(res.get("Size-1", 512))
+ height = int(res.get("Size-2", 512))
+
+ if firstpass_width == 0 or firstpass_height == 0:
+ from modules import processing
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
+
+ res['Size-1'] = firstpass_width
+ res['Size-2'] = firstpass_height
+ res['Hires resize-1'] = width
+ res['Hires resize-2'] = height
+
+
+def parse_generation_parameters(x: str):
+ """parses generation parameters string, the one you see in text field under the picture in UI:
+```
+girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
+Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
+Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
+```
+
+ returns a dict with field values
+ """
+
+ res = {}
+
+ prompt = ""
+ negative_prompt = ""
+
+ done_with_prompt = False
+
+ *lines, lastline = x.strip().split("\n")
+ if len(re_param.findall(lastline)) < 3:
+ lines.append(lastline)
+ lastline = ''
+
+ for i, line in enumerate(lines):
+ line = line.strip()
+ if line.startswith("Negative prompt:"):
+ done_with_prompt = True
+ line = line[16:].strip()
+
+ if done_with_prompt:
+ negative_prompt += ("" if negative_prompt == "" else "\n") + line
+ else:
+ prompt += ("" if prompt == "" else "\n") + line
+
+ res["Prompt"] = prompt
+ res["Negative prompt"] = negative_prompt
+
+ for k, v in re_param.findall(lastline):
+ v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
+ m = re_imagesize.match(v)
+ if m is not None:
+ res[k+"-1"] = m.group(1)
+ res[k+"-2"] = m.group(2)
+ else:
+ res[k] = v
+
+ # Missing CLIP skip means it was set to 1 (the default)
+ if "Clip skip" not in res:
+ res["Clip skip"] = "1"
+
+ hypernet = res.get("Hypernet", None)
+ if hypernet is not None:
+ res["Prompt"] += f""""""
+
+ if "Hires resize-1" not in res:
+ res["Hires resize-1"] = 0
+ res["Hires resize-2"] = 0
+
+ restore_old_hires_fix_params(res)
+
+ return res
+
+
+settings_map = {}
+
+
+
+infotext_to_setting_name_mapping = [
+ ('Clip skip', 'CLIP_stop_at_last_layers', ),
+ ('Conditional mask weight', 'inpainting_mask_weight'),
+ ('Model hash', 'sd_model_checkpoint'),
+ ('ENSD', 'eta_noise_seed_delta'),
+ ('Noise multiplier', 'initial_noise_multiplier'),
+ ('Eta', 'eta_ancestral'),
+ ('Eta DDIM', 'eta_ddim'),
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
+ ('UniPC variant', 'uni_pc_variant'),
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ('UniPC order', 'uni_pc_order'),
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
+]
+
+
+def create_override_settings_dict(text_pairs):
+ """creates processing's override_settings parameters from gradio's multiselect
+
+ Example input:
+ ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
+
+ Example output:
+ {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
+ """
+
+ res = {}
+
+ params = {}
+ for pair in text_pairs:
+ k, v = pair.split(":", maxsplit=1)
+
+ params[k] = v.strip()
+
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ value = params.get(param_name, None)
+
+ if value is None:
+ continue
+
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
+
+ return res
+
+
+def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
+ def paste_func(prompt):
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(data_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+
+ params = parse_generation_parameters(prompt)
+ script_callbacks.infotext_pasted_callback(prompt, params)
+ res = []
+
+ for output, key in paste_fields:
+ if callable(key):
+ v = key(params)
+ else:
+ v = params.get(key, None)
+
+ if v is None:
+ res.append(gr.update())
+ elif isinstance(v, type_of_gr_update):
+ res.append(v)
+ else:
+ try:
+ valtype = type(output.value)
+
+ if valtype == bool and v == "False":
+ val = False
+ else:
+ val = valtype(v)
+
+ res.append(gr.update(value=val))
+ except Exception:
+ res.append(gr.update())
+
+ return res
+
+ if override_settings_component is not None:
+ def paste_settings(params):
+ vals = {}
+
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
+ continue
+
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+
+ if v == current_value:
+ continue
+
+ vals[param_name] = v
+
+ vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
+
+ paste_fields = paste_fields + [(override_settings_component, paste_settings)]
+
+ button.click(
+ fn=paste_func,
+ _js=f"recalculate_prompts_{tabname}",
+ inputs=[input_comp],
+ outputs=[x[0] for x in paste_fields],
+ )
+
+
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc0c5f738e086225505af9738862fde4eecfa4a9
--- /dev/null
+++ b/modules/gfpgan_model.py
@@ -0,0 +1,116 @@
+import os
+import sys
+import traceback
+
+import facexlib
+import gfpgan
+
+import modules.face_restoration
+from modules import paths, shared, devices, modelloader
+
+model_dir = "GFPGAN"
+user_path = None
+model_path = os.path.join(paths.models_path, model_dir)
+model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
+have_gfpgan = False
+loaded_gfpgan_model = None
+
+
+def gfpgann():
+ global loaded_gfpgan_model
+ global model_path
+ if loaded_gfpgan_model is not None:
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
+ return loaded_gfpgan_model
+
+ if gfpgan_constructor is None:
+ return None
+
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ if len(models) == 1 and "http" in models[0]:
+ model_file = models[0]
+ elif len(models) != 0:
+ latest_file = max(models, key=os.path.getctime)
+ model_file = latest_file
+ else:
+ print("Unable to load gfpgan model!")
+ return None
+ if hasattr(facexlib.detection.retinaface, 'device'):
+ facexlib.detection.retinaface.device = devices.device_gfpgan
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
+ loaded_gfpgan_model = model
+
+ return model
+
+
+def send_model_to(model, device):
+ model.gfpgan.to(device)
+ model.face_helper.face_det.to(device)
+ model.face_helper.face_parse.to(device)
+
+
+def gfpgan_fix_faces(np_image):
+ model = gfpgann()
+ if model is None:
+ return np_image
+
+ send_model_to(model, devices.device_gfpgan)
+
+ np_image_bgr = np_image[:, :, ::-1]
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ np_image = gfpgan_output_bgr[:, :, ::-1]
+
+ model.face_helper.clean_all()
+
+ if shared.opts.face_restoration_unload:
+ send_model_to(model, devices.cpu)
+
+ return np_image
+
+
+gfpgan_constructor = None
+
+
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
+ try:
+ from gfpgan import GFPGANer
+ from facexlib import detection, parsing
+ global user_path
+ global have_gfpgan
+ global gfpgan_constructor
+
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
+
+ def my_load_file_from_url(**kwargs):
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
+
+ def facex_load_file_from_url(**kwargs):
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ def facex_load_file_from_url2(**kwargs):
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
+ user_path = dirname
+ have_gfpgan = True
+ gfpgan_constructor = GFPGANer
+
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def restore(self, np_image):
+ return gfpgan_fix_faces(np_image)
+
+ shared.face_restorers.append(FaceRestorerGFPGAN())
+ except Exception:
+ print("Error setting up GFPGAN:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/hashes.py b/modules/hashes.py
new file mode 100644
index 0000000000000000000000000000000000000000..46abf99c304b23bf8e3e394e07c2209d4130afef
--- /dev/null
+++ b/modules/hashes.py
@@ -0,0 +1,91 @@
+import hashlib
+import json
+import os.path
+
+import filelock
+
+from modules import shared
+from modules.paths import data_path
+
+
+cache_filename = os.path.join(data_path, "cache.json")
+cache_data = None
+
+
+def dump_cache():
+ with filelock.FileLock(cache_filename+".lock"):
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ global cache_data
+
+ if cache_data is None:
+ with filelock.FileLock(cache_filename+".lock"):
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def calculate_sha256(filename):
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def sha256_from_cache(filename, title):
+ hashes = cache("hashes")
+ ondisk_mtime = os.path.getmtime(filename)
+
+ if title not in hashes:
+ return None
+
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
+ return None
+
+ return cached_sha256
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
+
+ sha256_value = sha256_from_cache(filename, title)
+ if sha256_value is not None:
+ return sha256_value
+
+ if shared.cmd_opts.no_hashing:
+ return None
+
+ print(f"Calculating sha256 for {filename}: ", end='')
+ sha256_value = calculate_sha256(filename)
+ print(f"{sha256_value}")
+
+ hashes[title] = {
+ "mtime": os.path.getmtime(filename),
+ "sha256": sha256_value,
+ }
+
+ dump_cache()
+
+ return sha256_value
+
+
+
+
+
diff --git a/modules/hypernetworks/__pycache__/hypernetwork.cpython-310.pyc b/modules/hypernetworks/__pycache__/hypernetwork.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cbb06155ef24639dc08cfb51511685f497d590d
Binary files /dev/null and b/modules/hypernetworks/__pycache__/hypernetwork.cpython-310.pyc differ
diff --git a/modules/hypernetworks/__pycache__/ui.cpython-310.pyc b/modules/hypernetworks/__pycache__/ui.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78d37fa77c893fc5c6eb2c782d94de55601572b4
Binary files /dev/null and b/modules/hypernetworks/__pycache__/ui.cpython-310.pyc differ
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d6dcc2281daeb30f7f0f380540dd0a650248af
--- /dev/null
+++ b/modules/hypernetworks/hypernetwork.py
@@ -0,0 +1,811 @@
+import csv
+import datetime
+import glob
+import html
+import os
+import sys
+import traceback
+import inspect
+
+import modules.textual_inversion.dataset
+import torch
+import tqdm
+from einops import rearrange, repeat
+from ldm.util import default
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
+from modules.textual_inversion import textual_inversion, logging
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
+from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
+
+from collections import defaultdict, deque
+from statistics import stdev, mean
+
+
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
+class HypernetworkModule(torch.nn.Module):
+ activation_dict = {
+ "linear": torch.nn.Identity,
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ "tanh": torch.nn.Tanh,
+ "sigmoid": torch.nn.Sigmoid,
+ }
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
+ super().__init__()
+
+ self.multiplier = 1.0
+
+ assert layer_structure is not None, "layer_structure must not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+
+ linears = []
+ for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
+ # Add an activation func except last layer
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
+ # Add layer normalization
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+
+ # Everything should be now parsed into dropout structure, and applied here.
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
+
+ self.linear = torch.nn.Sequential(*linears)
+
+ if state_dict is not None:
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
+ else:
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ w, b = layer.weight.data, layer.bias.data
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
+ normal_(w, mean=0.0, std=0.01)
+ normal_(b, mean=0.0, std=0)
+ elif weight_init == 'XavierUniform':
+ xavier_uniform_(w)
+ zeros_(b)
+ elif weight_init == 'XavierNormal':
+ xavier_normal_(w)
+ zeros_(b)
+ elif weight_init == 'KaimingUniform':
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ elif weight_init == 'KaimingNormal':
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ else:
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
+ self.to(devices.device)
+
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
+
+ def forward(self, x):
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
+
+ def trainables(self):
+ layer_structure = []
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer_structure += [layer.weight, layer.bias]
+ return layer_structure
+
+
+#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
+def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
+ if layer_structure is None:
+ layer_structure = [1, 2, 1]
+ if not use_dropout:
+ return [0] * len(layer_structure)
+ dropout_values = [0]
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
+ if last_layer_dropout:
+ dropout_values.append(0.3)
+ else:
+ dropout_values.append(0)
+ dropout_values.append(0)
+ return dropout_values
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+ self.layer_structure = layer_structure
+ self.activation_func = activation_func
+ self.weight_init = weight_init
+ self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
+ self.activate_output = activate_output
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
+ self.dropout_structure = kwargs.get('dropout_structure', None)
+ if self.dropout_structure is None:
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
+ self.optimizer_name = None
+ self.optimizer_state_dict = None
+ self.optional_info = None
+
+ for size in enable_sizes or []:
+ self.layers[size] = (
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
+ )
+ self.eval()
+
+ def weights(self):
+ res = []
+ for k, layers in self.layers.items():
+ for layer in layers:
+ res += layer.parameters()
+ return res
+
+ def train(self, mode=True):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train(mode=mode)
+ for param in layer.parameters():
+ param.requires_grad = mode
+
+ def to(self, device):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.to(device)
+
+ return self
+
+ def set_multiplier(self, multiplier):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.multiplier = multiplier
+
+ return self
+
+ def eval(self):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def save(self, filename):
+ state_dict = {}
+ optimizer_saved_dict = {}
+
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['layer_structure'] = self.layer_structure
+ state_dict['activation_func'] = self.activation_func
+ state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['weight_initialization'] = self.weight_init
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+ state_dict['activate_output'] = self.activate_output
+ state_dict['use_dropout'] = self.use_dropout
+ state_dict['dropout_structure'] = self.dropout_structure
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
+
+ if self.optimizer_name is not None:
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
+
+ torch.save(state_dict, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
+ optimizer_saved_dict['hash'] = self.shorthash()
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
+ def load(self, filename):
+ self.filename = filename
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+
+ state_dict = torch.load(filename, map_location='cpu')
+
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.optional_info = state_dict.get('optional_info', None)
+ self.activation_func = state_dict.get('activation_func', None)
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.dropout_structure = state_dict.get('dropout_structure', None)
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
+ self.activate_output = state_dict.get('activate_output', True)
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
+ if self.dropout_structure is None:
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
+
+ if shared.opts.print_hypernet_extra:
+ if self.optional_info is not None:
+ print(f" INFO:\n {self.optional_info}\n")
+
+ print(f" Layer structure: {self.layer_structure}")
+ print(f" Activation function: {self.activation_func}")
+ print(f" Weight initialization: {self.weight_init}")
+ print(f" Layer norm: {self.add_layer_norm}")
+ print(f" Dropout usage: {self.use_dropout}" )
+ print(f" Activate last layer: {self.activate_output}")
+ print(f" Dropout structure: {self.dropout_structure}")
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
+
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ else:
+ self.optimizer_state_dict = None
+ if self.optimizer_state_dict:
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
+ if shared.opts.print_hypernet_extra:
+ print("Loaded existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
+ else:
+ self.optimizer_name = "AdamW"
+ if shared.opts.print_hypernet_extra:
+ print("No saved optimizer exists in checkpoint")
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
+ )
+
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+ self.eval()
+
+ def shorthash(self):
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
+
+ return sha256[0:10] if sha256 else None
+
+
+def list_hypernetworks(path):
+ res = {}
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ # Prevent a hypothetical "None.pt" from being listed.
+ if name != "None":
+ res[name] = filename
+ return res
+
+
+def load_hypernetwork(name):
+ path = shared.hypernetworks.get(name, None)
+
+ if path is None:
+ return None
+
+ hypernetwork = Hypernetwork()
+
+ try:
+ hypernetwork.load(path)
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+ already_loaded = {}
+
+ for hypernetwork in shared.loaded_hypernetworks:
+ if hypernetwork.name in names:
+ already_loaded[hypernetwork.name] = hypernetwork
+
+ shared.loaded_hypernetworks.clear()
+
+ for i, name in enumerate(names):
+ hypernetwork = already_loaded.get(name, None)
+ if hypernetwork is None:
+ hypernetwork = load_hypernetwork(name)
+
+ if hypernetwork is None:
+ continue
+
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+ shared.loaded_hypernetworks.append(hypernetwork)
+
+
+def find_closest_hypernetwork_name(search: str):
+ if not search:
+ return None
+ search = search.lower()
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
+ if not applicable:
+ return None
+ applicable = sorted(applicable, key=lambda name: len(name))
+ return applicable[0]
+
+
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context_k, context_v
+
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
+ return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+ context_k = context
+ context_v = context
+ for hypernetwork in hypernetworks:
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+
+ return context_k, context_v
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def stack_conds(conds):
+ if len(conds) == 1:
+ return torch.stack(conds)
+
+ # same as in reconstruct_multicond_batch
+ token_count = max([x.shape[0] for x in conds])
+ for i in range(len(conds)):
+ if conds[i].shape[0] != token_count:
+ last_vector = conds[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
+
+ return torch.stack(conds)
+
+
+def statistics(data):
+ if len(data) < 2:
+ std = 0
+ else:
+ std = stdev(data)
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
+ recent_data = data[-32:]
+ if len(recent_data) < 2:
+ std = 0
+ else:
+ std = stdev(recent_data)
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(list(loss_info[key]))
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
+
+
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+ assert name, "Name cannot be empty!"
+
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
+ else:
+ dropout_structure = [0] * len(layer_structure)
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ weight_init=weight_init,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ dropout_structure=dropout_structure
+ )
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
+ from modules import images
+
+ save_hypernetwork_every = save_hypernetwork_every or 0
+ create_image_every = create_image_every or 0
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = template_file.path
+
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ hypernetwork = Hypernetwork()
+ hypernetwork.load(path)
+ shared.loaded_hypernetworks = [hypernetwork]
+
+ shared.state.job = "train-hypernetwork"
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ checkpoint = sd_models.select_checkpoint()
+
+ initial_step = hypernetwork.step or 0
+ if initial_step >= steps:
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+
+ pin_memory = shared.opts.pin_memory
+
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
+
+ if shared.opts.save_training_settings_to_txt:
+ saved_params = dict(
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
+ )
+ logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
+
+ latent_sampling_method = ds.latent_sampling_method
+
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
+ if unload:
+ shared.parallel_processing_allowed = False
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ weights = hypernetwork.weights()
+ hypernetwork.train()
+
+ # Here we use optimizer from saved HN, or we can specify as UI option.
+ if hypernetwork.optimizer_name in optimizer_dict:
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+ optimizer_name = hypernetwork.optimizer_name
+ else:
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
+ optimizer_name = 'AdamW'
+
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
+ try:
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
+ except RuntimeError as e:
+ print("Cannot resume from saved optimizer!")
+ print(e)
+
+ scaler = torch.cuda.amp.GradScaler()
+
+ batch_size = ds.batch_size
+ gradient_step = ds.gradient_step
+ # n steps = batch_size * gradient_step * n image processed
+ steps_per_epoch = len(ds) // batch_size // gradient_step
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+ loss_step = 0
+ _loss_step = 0 #internal
+ # size = len(ds.indexes)
+ # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
+ # losses = torch.zeros((size,))
+ # previous_mean_losses = [0]
+ # previous_mean_loss = 0
+ # print("Mean loss of {} elements".format(size))
+
+ steps_without_grad = 0
+
+ last_saved_file = ""
+ last_saved_image = ""
+ forced_filename = ""
+
+ pbar = tqdm.tqdm(total=steps - initial_step)
+ try:
+ sd_hijack_checkpoint.add()
+
+ for i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+
+ if clip_grad:
+ clip_grad_sched.step(hypernetwork.step)
+
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
+ if tag_drop_out != 0 or shuffle_tags:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ else:
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
+ del w
+ else:
+ loss = shared.sd_model.forward(x, c)[0] / gradient_step
+ del x
+ del c
+
+ _loss_step += loss.item()
+ scaler.scale(loss).backward()
+
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+ loss_logging.append(_loss_step)
+ if clip_grad:
+ clip_grad(weights, clip_grad_sched.learn_rate)
+
+ scaler.step(optimizer)
+ scaler.update()
+ hypernetwork.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+
+ steps_done = hypernetwork.step + 1
+
+ epoch_num = hypernetwork.step // steps_per_epoch
+ epoch_step = hypernetwork.step % steps_per_epoch
+
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
+
+
+ if shared.opts.training_enable_tensorboard:
+ epoch_num = hypernetwork.step // len(ds)
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
+ mean_loss = sum(loss_logging) / len(loss_logging)
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
+ "loss": f"{loss_step:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ hypernetwork.eval()
+ rng_state = torch.get_rng_state()
+ cuda_rng_state = None
+ if torch.cuda.is_available():
+ cuda_rng_state = torch.cuda.get_rng_state_all()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ p.disable_extra_networks = True
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+ torch.set_rng_state(rng_state)
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state_all(cuda_rng_state)
+ hypernetwork.train()
+ if image is not None:
+ shared.state.assign_current_image(image)
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
+ f"Validation at epoch {epoch_num}", image,
+ hypernetwork.step)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+
+Loss: {loss_step:.7f}
+Step: {steps_done}
+Last prompt: {html.escape(batch.cond_text[0])}
+Last saved hypernetwork: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
+"""
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr)
+ finally:
+ pbar.leave = False
+ pbar.close()
+ hypernetwork.eval()
+ #report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
+
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+
+ del optimizer
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
+
+ return hypernetwork, filename
+
+def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+ old_hypernetwork_name = hypernetwork.name
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
+ try:
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.name = hypernetwork_name
+ hypernetwork.save(filename)
+ except:
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
+ hypernetwork.name = old_hypernetwork_name
+ raise
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..be2fd77cc76a24d0e7932c6b1fb26efcb18edcc5
--- /dev/null
+++ b/modules/hypernetworks/ui.py
@@ -0,0 +1,40 @@
+import html
+import os
+import re
+
+import gradio as gr
+import modules.hypernetworks.hypernetwork
+from modules import devices, sd_hijack, shared
+
+not_available = ["hardswish", "multiheadattention"]
+keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+
+
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
+
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
+
+
+def train_hypernetwork(*args):
+ shared.loaded_hypernetworks = []
+
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/images.py b/modules/images.py
new file mode 100644
index 0000000000000000000000000000000000000000..97d17de8cada6573f4100353d82db108bb373d32
--- /dev/null
+++ b/modules/images.py
@@ -0,0 +1,674 @@
+import datetime
+import sys
+import traceback
+
+import pytz
+import io
+import math
+import os
+from collections import namedtuple
+import re
+
+import numpy as np
+import piexif
+import piexif.helper
+from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+from fonts.ttf import Roboto
+import string
+import json
+import hashlib
+
+from modules import sd_samplers, shared, script_callbacks, errors
+from modules.shared import opts, cmd_opts
+
+LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+
+
+def image_grid(imgs, batch_size=1, rows=None):
+ if rows is None:
+ if opts.n_rows > 0:
+ rows = opts.n_rows
+ elif opts.n_rows == 0:
+ rows = batch_size
+ elif opts.grid_prevent_empty_spots:
+ rows = math.floor(math.sqrt(len(imgs)))
+ while len(imgs) % rows != 0:
+ rows -= 1
+ else:
+ rows = math.sqrt(len(imgs))
+ rows = round(rows)
+ if rows > len(imgs):
+ rows = len(imgs)
+
+ cols = math.ceil(len(imgs) / rows)
+
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
+ script_callbacks.image_grid_callback(params)
+
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
+
+ for i, img in enumerate(params.imgs):
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
+
+ return grid
+
+
+Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
+
+
+def split_grid(image, tile_w=512, tile_h=512, overlap=64):
+ w = image.width
+ h = image.height
+
+ non_overlap_width = tile_w - overlap
+ non_overlap_height = tile_h - overlap
+
+ cols = math.ceil((w - overlap) / non_overlap_width)
+ rows = math.ceil((h - overlap) / non_overlap_height)
+
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
+
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
+ for row in range(rows):
+ row_images = []
+
+ y = int(row * dy)
+
+ if y + tile_h >= h:
+ y = h - tile_h
+
+ for col in range(cols):
+ x = int(col * dx)
+
+ if x + tile_w >= w:
+ x = w - tile_w
+
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
+
+ row_images.append([x, tile_w, tile])
+
+ grid.tiles.append([y, tile_h, row_images])
+
+ return grid
+
+
+def combine_grid(grid):
+ def make_mask_image(r):
+ r = r * 255 / grid.overlap
+ r = r.astype(np.uint8)
+ return Image.fromarray(r, 'L')
+
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
+
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
+ for y, h, row in grid.tiles:
+ combined_row = Image.new("RGB", (grid.image_w, h))
+ for x, w, tile in row:
+ if x == 0:
+ combined_row.paste(tile, (0, 0))
+ continue
+
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
+
+ if y == 0:
+ combined_image.paste(combined_row, (0, 0))
+ continue
+
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
+
+ return combined_image
+
+
+class GridAnnotation:
+ def __init__(self, text='', is_active=True):
+ self.text = text
+ self.is_active = is_active
+ self.size = None
+
+
+def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+ def wrap(drawing, text, font, line_length):
+ lines = ['']
+ for word in text.split():
+ line = f'{lines[-1]} {word}'.strip()
+ if drawing.textlength(line, font=font) <= line_length:
+ lines[-1] = line
+ else:
+ lines.append(word)
+ return lines
+
+ def get_font(fontsize):
+ try:
+ return ImageFont.truetype(opts.font or Roboto, fontsize)
+ except Exception:
+ return ImageFont.truetype(Roboto, fontsize)
+
+ def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
+ for i, line in enumerate(lines):
+ fnt = initial_fnt
+ fontsize = initial_fontsize
+ while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
+ fontsize -= 1
+ fnt = get_font(fontsize)
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
+
+ if not line.is_active:
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
+
+ draw_y += line.size[1] + line_spacing
+
+ fontsize = (width + height) // 25
+ line_spacing = fontsize // 2
+
+ fnt = get_font(fontsize)
+
+ color_active = (0, 0, 0)
+ color_inactive = (153, 153, 153)
+
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
+
+ cols = im.width // width
+ rows = im.height // height
+
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
+
+ calc_img = Image.new("RGB", (1, 1), "white")
+ calc_d = ImageDraw.Draw(calc_img)
+
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
+ items = [] + texts
+ texts.clear()
+
+ for line in items:
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
+
+ for line in texts:
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
+ line.allowed_width = allowed_width
+
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
+
+ pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
+
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+
+ for row in range(rows):
+ for col in range(cols):
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
+
+ d = ImageDraw.Draw(result)
+
+ for col in range(cols):
+ x = pad_left + (width + margin) * col + width / 2
+ y = pad_top / 2 - hor_text_heights[col] / 2
+
+ draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
+
+ for row in range(rows):
+ x = pad_left / 2
+ y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
+
+ draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
+
+ return result
+
+
+def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
+ prompts = all_prompts[1:]
+ boundary = math.ceil(len(prompts) / 2)
+
+ prompts_horiz = prompts[:boundary]
+ prompts_vert = prompts[boundary:]
+
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
+
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
+
+
+def resize_image(resize_mode, im, width, height, upscaler_name=None):
+ """
+ Resizes an image with the specified resize_mode, width, and height.
+
+ Args:
+ resize_mode: The mode to use when resizing the image.
+ 0: Resize the image to the specified width and height.
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+ im: The image to resize.
+ width: The width to resize the image to.
+ height: The height to resize the image to.
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
+ """
+
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
+
+ def resize(im, w, h):
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
+ return im.resize((w, h), resample=LANCZOS)
+
+ scale = max(w / im.width, h / im.height)
+
+ if scale > 1.0:
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
+ assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
+
+ upscaler = upscalers[0]
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
+
+ if im.width != w or im.height != h:
+ im = im.resize((w, h), resample=LANCZOS)
+
+ return im
+
+ if resize_mode == 0:
+ res = resize(im, width, height)
+
+ elif resize_mode == 1:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio > src_ratio else im.width * height // im.height
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
+
+ resized = resize(im, src_w, src_h)
+ res = Image.new("RGB", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ else:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio < src_ratio else im.width * height // im.height
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
+
+ resized = resize(im, src_w, src_h)
+ res = Image.new("RGB", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+
+ return res
+
+
+invalid_filename_chars = '<>:"/\\|?*\n'
+invalid_filename_prefix = ' '
+invalid_filename_postfix = ' .'
+re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
+re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
+re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
+max_filename_part_length = 128
+
+
+def sanitize_filename_part(text, replace_spaces=True):
+ if text is None:
+ return None
+
+ if replace_spaces:
+ text = text.replace(' ', '_')
+
+ text = text.translate({ord(x): '_' for x in invalid_filename_chars})
+ text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
+ text = text.rstrip(invalid_filename_postfix)
+ return text
+
+
+class FilenameGenerator:
+ replacements = {
+ 'seed': lambda self: self.seed if self.seed is not None else '',
+ 'steps': lambda self: self.p and self.p.steps,
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
+ 'width': lambda self: self.image.width,
+ 'height': lambda self: self.image.height,
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
+ 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime