Spaces:
Running
Running
add peft, set_adapter
Browse files- model.py +2 -2
- pipeline.py +1 -1
- requirements.txt +1 -0
- unet.py +5 -3
model.py
CHANGED
|
@@ -44,7 +44,7 @@ class Model:
|
|
| 44 |
and self.pipe is not None
|
| 45 |
):
|
| 46 |
unet: UNet2DConditionModelEx = self.pipe.unet
|
| 47 |
-
unet.
|
| 48 |
return self.pipe
|
| 49 |
unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
|
| 50 |
base_model_id, subfolder="unet", torch_dtype=torch.float16
|
|
@@ -82,7 +82,7 @@ class Model:
|
|
| 82 |
if task_name == self.task_name:
|
| 83 |
return
|
| 84 |
unet: UNet2DConditionModelEx = self.pipe.unet
|
| 85 |
-
unet.
|
| 86 |
self.task_name = task_name
|
| 87 |
|
| 88 |
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
|
|
|
| 44 |
and self.pipe is not None
|
| 45 |
):
|
| 46 |
unet: UNet2DConditionModelEx = self.pipe.unet
|
| 47 |
+
unet.set_adapter(task_name)
|
| 48 |
return self.pipe
|
| 49 |
unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
|
| 50 |
base_model_id, subfolder="unet", torch_dtype=torch.float16
|
|
|
|
| 82 |
if task_name == self.task_name:
|
| 83 |
return
|
| 84 |
unet: UNet2DConditionModelEx = self.pipe.unet
|
| 85 |
+
unet.set_adapter(task_name)
|
| 86 |
self.task_name = task_name
|
| 87 |
|
| 88 |
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
pipeline.py
CHANGED
|
@@ -949,7 +949,7 @@ class StableDiffusionControlLoraV3Pipeline(
|
|
| 949 |
if adapter_name_ori is not None:
|
| 950 |
break
|
| 951 |
|
| 952 |
-
unet.
|
| 953 |
|
| 954 |
@torch.no_grad()
|
| 955 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
|
| 949 |
if adapter_name_ori is not None:
|
| 950 |
break
|
| 951 |
|
| 952 |
+
unet.activate_extra_condition_adapters()
|
| 953 |
|
| 954 |
@torch.no_grad()
|
| 955 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@ gradio>=4.26.0
|
|
| 6 |
huggingface-hub>=0.16.4
|
| 7 |
mediapipe>=0.10.1
|
| 8 |
opencv-python-headless>=4.8.0.74
|
|
|
|
| 9 |
safetensors>=0.3.1
|
| 10 |
torch>=2.0.1
|
| 11 |
torchvision>=0.15.2
|
|
|
|
| 6 |
huggingface-hub>=0.16.4
|
| 7 |
mediapipe>=0.10.1
|
| 8 |
opencv-python-headless>=4.8.0.74
|
| 9 |
+
peft>=0.11.1
|
| 10 |
safetensors>=0.3.1
|
| 11 |
torch>=2.0.1
|
| 12 |
torchvision>=0.15.2
|
unet.py
CHANGED
|
@@ -145,11 +145,13 @@ class UNet2DConditionModelEx(UNet2DConditionModel):
|
|
| 145 |
|
| 146 |
return self
|
| 147 |
|
| 148 |
-
def
|
| 149 |
lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
|
| 150 |
for lora_layer in lora_layers:
|
| 151 |
-
|
| 152 |
-
lora_layer.
|
|
|
|
|
|
|
| 153 |
|
| 154 |
def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0):
|
| 155 |
if isinstance(scale, float):
|
|
|
|
| 145 |
|
| 146 |
return self
|
| 147 |
|
| 148 |
+
def activate_extra_condition_adapters(self):
|
| 149 |
lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
|
| 150 |
for lora_layer in lora_layers:
|
| 151 |
+
adapter_names = [k for k in lora_layer.scaling.keys() if k in self.config.extra_condition_names]
|
| 152 |
+
adapter_names += lora_layer.active_adapters
|
| 153 |
+
adapter_names = list(set(adapter_names))
|
| 154 |
+
lora_layer.set_adapter(adapter_names)
|
| 155 |
|
| 156 |
def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0):
|
| 157 |
if isinstance(scale, float):
|