Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from modules.BlackForest import Flux | |
| from modules.Utilities import util | |
| from modules.Model import ModelBase | |
| from modules.SD15 import SDClip, SDToken | |
| from modules.Utilities import Latent | |
| from modules.clip import Clip | |
| class sm_SD15(ModelBase.BASE): | |
| """#### Class representing the SD15 model. | |
| #### Args: | |
| - `ModelBase.BASE` (ModelBase.BASE): The base model class. | |
| """ | |
| unet_config: dict = { | |
| "context_dim": 768, | |
| "model_channels": 320, | |
| "use_linear_in_transformer": False, | |
| "adm_in_channels": None, | |
| "use_temporal_attention": False, | |
| } | |
| unet_extra_config: dict = { | |
| "num_heads": 8, | |
| "num_head_channels": -1, | |
| } | |
| latent_format: Latent.SD15 = Latent.SD15 | |
| def process_clip_state_dict(self, state_dict: dict) -> dict: | |
| """#### Process the state dictionary for the CLIP model. | |
| #### Args: | |
| - `state_dict` (dict): The state dictionary. | |
| #### Returns: | |
| - `dict`: The processed state dictionary. | |
| """ | |
| k = list(state_dict.keys()) | |
| for x in k: | |
| if x.startswith("cond_stage_model.transformer.") and not x.startswith( | |
| "cond_stage_model.transformer.text_model." | |
| ): | |
| y = x.replace( | |
| "cond_stage_model.transformer.", | |
| "cond_stage_model.transformer.text_model.", | |
| ) | |
| state_dict[y] = state_dict.pop(x) | |
| if ( | |
| "cond_stage_model.transformer.text_model.embeddings.position_ids" | |
| in state_dict | |
| ): | |
| ids = state_dict[ | |
| "cond_stage_model.transformer.text_model.embeddings.position_ids" | |
| ] | |
| if ids.dtype == torch.float32: | |
| state_dict[ | |
| "cond_stage_model.transformer.text_model.embeddings.position_ids" | |
| ] = ids.round() | |
| replace_prefix = {} | |
| replace_prefix["cond_stage_model."] = "clip_l." | |
| state_dict = util.state_dict_prefix_replace( | |
| state_dict, replace_prefix, filter_keys=True | |
| ) | |
| return state_dict | |
| def clip_target(self) -> Clip.ClipTarget: | |
| """#### Get the target CLIP model. | |
| #### Returns: | |
| - `Clip.ClipTarget`: The target CLIP model. | |
| """ | |
| return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel) | |
| models = [ | |
| sm_SD15, Flux.Flux | |
| ] |