Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
1d5bb62
0
Parent(s):
init space
Browse files- .gitattributes +1 -0
- app.py +231 -0
- examples/birdhouse.glb +3 -0
- examples/mario.glb +3 -0
- utils/__init__.py +0 -0
- utils/controlnet_union.py +957 -0
- utils/image_generation.py +299 -0
- utils/mesh_utils.py +500 -0
- utils/pipeline_controlnet_union_sd_xl.py +1397 -0
- utils/pipeline_stable_diffusion_switcher.py +1240 -0
- utils/rasterize.py +166 -0
- utils/render_utils.py +352 -0
- utils/texture_generation.py +309 -0
- wan/__init__.py +0 -0
- wan/pipeline_wan_t2tex_extra.py +366 -0
- wan/wan_t2tex_transformer_3d_extra.py +634 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from utils.image_generation import generate_image_condition
|
| 6 |
+
from utils.mesh_utils import Mesh
|
| 7 |
+
from utils.render_utils import render_views
|
| 8 |
+
from utils.texture_generation import generate_texture
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
from gradio_litmodel3d import LitModel3D
|
| 12 |
+
|
| 13 |
+
EXAMPLES = [
|
| 14 |
+
["examples/birdhouse.glb", True, False, False, False, 42, "First View", "SDXL", False, "A rustic birdhouse featuring a snow-covered roof, wood textures, and two decorative cardinal birds. It has a circular entryway and conveys a winter-themed aesthetic."],
|
| 15 |
+
["examples/mario.glb", False, False, False, True, 6666, "Third View", "FLUX", True, "Mario, a cartoon character wearing a red cap and blue overalls, with brown hair and a mustache, and white gloves, in a fighting pose. The clothes he wears are not in a reflection mode."],
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
def tensor_to_pil(tensor, mask=None, normalize: bool = True):
|
| 19 |
+
"""
|
| 20 |
+
Convert tensor to PIL Image.
|
| 21 |
+
:param tensor: torch.Tensor, shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W)
|
| 22 |
+
:param mask: torch.Tensor, shape same as tensor, effective when C=3
|
| 23 |
+
:return: PIL.Image
|
| 24 |
+
"""
|
| 25 |
+
# Move to cpu
|
| 26 |
+
tensor = tensor.detach()
|
| 27 |
+
if tensor.is_cuda:
|
| 28 |
+
tensor = tensor.cpu()
|
| 29 |
+
if mask is not None and mask.is_cuda:
|
| 30 |
+
mask = mask.cpu()
|
| 31 |
+
|
| 32 |
+
# Convert to float32
|
| 33 |
+
tensor = tensor.float()
|
| 34 |
+
if mask is not None:
|
| 35 |
+
mask = mask.float()
|
| 36 |
+
|
| 37 |
+
if normalize:
|
| 38 |
+
tensor = (tensor + 1.0) / 2.0
|
| 39 |
+
tensor = torch.clamp(tensor, 0.0, 1.0)
|
| 40 |
+
if mask is not None:
|
| 41 |
+
if mask.shape[-1] not in [1, 3]:
|
| 42 |
+
mask = mask.unsqueeze(-1)
|
| 43 |
+
tensor = torch.cat([tensor, mask], dim=-1)
|
| 44 |
+
|
| 45 |
+
shape = tensor.shape
|
| 46 |
+
# 4D: (Nv, H, W, C) or (Nv, C, H, W)
|
| 47 |
+
if len(shape) == 4:
|
| 48 |
+
Nv = shape[0]
|
| 49 |
+
if shape[-1] in [3, 4]: # (Nv, H, W, C)
|
| 50 |
+
tensor = rearrange(tensor, 'nv h w c -> h (nv w) c')
|
| 51 |
+
else: # (Nv, C, H, W)
|
| 52 |
+
tensor = rearrange(tensor, 'nv c h w -> h (nv w) c')
|
| 53 |
+
# 3D: (H, W, C) or (C, H, W)
|
| 54 |
+
elif len(shape) == 3:
|
| 55 |
+
if shape[-1] in [3, 4]: # (H, W, C)
|
| 56 |
+
tensor = rearrange(tensor, 'h w c -> h w c')
|
| 57 |
+
else: # (C, H, W)
|
| 58 |
+
tensor = rearrange(tensor, 'c h w -> h w c')
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Unsupported tensor shape: {shape}")
|
| 61 |
+
|
| 62 |
+
# Convert to numpy
|
| 63 |
+
np_img = (tensor.numpy() * 255).round().astype(np.uint8)
|
| 64 |
+
|
| 65 |
+
# Create PIL Image
|
| 66 |
+
if np_img.shape[2] == 3:
|
| 67 |
+
return Image.fromarray(np_img, mode="RGB")
|
| 68 |
+
elif np_img.shape[2] == 4:
|
| 69 |
+
return Image.fromarray(np_img, mode="RGBA")
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError("Only support 3 or 4 channel images.")
|
| 72 |
+
|
| 73 |
+
if __name__ == '__main__':
|
| 74 |
+
with gr.Blocks() as demo:
|
| 75 |
+
gr.Markdown("# 🎨 SeqTex: Generate Mesh Textures in Video Sequence")
|
| 76 |
+
|
| 77 |
+
gr.Markdown("""
|
| 78 |
+
## 🚀 Welcome to SeqTex!
|
| 79 |
+
**SeqTex** is a cutting-edge AI system that generates high-quality textures for 3D meshes using image prompts (here we use image generator to get them from textual prompts).
|
| 80 |
+
|
| 81 |
+
Choose to either **try our example models** below or **upload your own 3D mesh** to create stunning textures.
|
| 82 |
+
""")
|
| 83 |
+
|
| 84 |
+
gr.Markdown("---")
|
| 85 |
+
|
| 86 |
+
gr.Markdown("## 🔧 Step 1: Upload & Process 3D Mesh")
|
| 87 |
+
gr.Markdown("""
|
| 88 |
+
**📋 How to prepare your 3D mesh:**
|
| 89 |
+
- Upload your 3D mesh in **.obj** or **.glb** format
|
| 90 |
+
- **💡 Pro Tip**:
|
| 91 |
+
- For optimal results, ensure your mesh includes only one part with <span style="color:#e74c3c; font-weight:bold;">UV parameterization</span>
|
| 92 |
+
- Otherwise, we'll combine all parts and generate UV parameterization using *xAtlas* (may take longer for high-poly meshes; may also fail for certain meshes)
|
| 93 |
+
- **⚠️ Important**: We recommend adjusting your model using *Mesh Orientation Adjustments* to be **Z-UP oriented** for best results
|
| 94 |
+
""")
|
| 95 |
+
position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, mesh, mvp_matrix = gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State()
|
| 96 |
+
|
| 97 |
+
# fixed_texture_map = Image.open("image.webp").convert("RGB")
|
| 98 |
+
# Step 1
|
| 99 |
+
with gr.Row():
|
| 100 |
+
with gr.Column():
|
| 101 |
+
mesh_upload = gr.File(label="📁 Upload 3D Mesh", file_types=[".obj", ".glb"])
|
| 102 |
+
# uv_tool = gr.Radio(["xAtlas", "UVAtlas"], label="UV parameterizer", value="xAtlas")
|
| 103 |
+
|
| 104 |
+
gr.Markdown("**🔄 Mesh Orientation Adjustments** (if needed):")
|
| 105 |
+
y2z = gr.Checkbox(label="Y → Z Transform", value=False, info="Rotate: Y becomes Z, -Z becomes Y")
|
| 106 |
+
y2x = gr.Checkbox(label="Y → X Transform", value=False, info="Rotate: Y becomes X, -X becomes Y")
|
| 107 |
+
z2x = gr.Checkbox(label="Z → X Transform", value=False, info="Rotate: Z becomes X, -X becomes Z")
|
| 108 |
+
upside_down = gr.Checkbox(label="🔃 Flip Vertically", value=False, info="Fix upside-down mesh orientation")
|
| 109 |
+
|
| 110 |
+
with gr.Column():
|
| 111 |
+
step1_button = gr.Button("🔄 Process Mesh & Generate Views", variant="primary")
|
| 112 |
+
step1_progress = gr.Textbox(label="📊 Processing Status", interactive=False)
|
| 113 |
+
model_input = gr.Model3D(label="📐 Processed 3D Model", height=500)
|
| 114 |
+
|
| 115 |
+
with gr.Row(equal_height=True):
|
| 116 |
+
rgb_views = gr.Image(label="📷 Generated Views (Front, Back, Left, Right)", type="pil", scale=3)
|
| 117 |
+
position_map = gr.Image(label="🗺️ Position Map", type="pil", scale=1)
|
| 118 |
+
normal_map = gr.Image(label="🧭 Normal Map", type="pil", scale=1)
|
| 119 |
+
|
| 120 |
+
step1_button.click(
|
| 121 |
+
Mesh.process,
|
| 122 |
+
inputs=[mesh_upload, gr.State("xAtlas"), y2z, y2x, z2x, upside_down],
|
| 123 |
+
outputs=[position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, mesh, mvp_matrix, step1_progress]
|
| 124 |
+
).then(
|
| 125 |
+
tensor_to_pil,
|
| 126 |
+
inputs=[normal_images_tensor, mask_images_tensor],
|
| 127 |
+
outputs=[rgb_views]
|
| 128 |
+
).then(
|
| 129 |
+
tensor_to_pil,
|
| 130 |
+
inputs=[position_map_tensor],
|
| 131 |
+
outputs=[position_map]
|
| 132 |
+
).then(
|
| 133 |
+
tensor_to_pil,
|
| 134 |
+
inputs=[normal_map_tensor],
|
| 135 |
+
outputs=[normal_map]
|
| 136 |
+
).then(
|
| 137 |
+
Mesh.export,
|
| 138 |
+
inputs=[mesh],
|
| 139 |
+
outputs=[model_input]
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Step 2
|
| 143 |
+
gr.Markdown("---")
|
| 144 |
+
gr.Markdown("## 👁️ Step 2: Select View & Generate Image Condition")
|
| 145 |
+
gr.Markdown("""
|
| 146 |
+
**📋 How to generate image condition:**
|
| 147 |
+
- Your mesh will be rendered from **four viewpoints** (front, back, left, right)
|
| 148 |
+
- Choose **one view** as your image condition
|
| 149 |
+
- Enter a **descriptive text prompt** for the desired texture
|
| 150 |
+
- Select your preferred AI model:
|
| 151 |
+
- <span style="color:#27ae60; font-weight:bold;">🎯 SDXL</span>: Fast generation with depth + normal control, better details
|
| 152 |
+
- <span style="color:#3498db; font-weight:bold;">⚡ FLUX</span>: High-quality generation with depth control (slower due to CPU offloading). Better work with **Edge Refinement**
|
| 153 |
+
""")
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column():
|
| 156 |
+
img_condition_seed = gr.Number(label="🎲 Random Seed", minimum=0, maximum=9999, step=1, value=42, info="Change for different results")
|
| 157 |
+
selected_view = gr.Radio(["First View", "Second View", "Third View", "Fourth View"], label="📐 Camera View", value="First View", info="Choose which viewpoint to use as reference")
|
| 158 |
+
with gr.Row():
|
| 159 |
+
model_choice = gr.Radio(["SDXL", "FLUX"], label="🤖 AI Model", value="SDXL", info="SDXL: Fast, depth+normal control | FLUX: High-quality, slower processing")
|
| 160 |
+
edge_refinement = gr.Checkbox(label="✨ Edge Refinement", value=True, info="Smooth boundary artifacts (recommended for cleaner results)")
|
| 161 |
+
text_prompt = gr.Textbox(label="💬 Texture Description", placeholder="Describe the desired texture appearance (e.g., 'rustic wooden surface with weathered paint')", lines=2)
|
| 162 |
+
step2_button = gr.Button("🎯 Generate Image Condition", variant="primary")
|
| 163 |
+
step2_progress = gr.Textbox(label="📊 Generation Status", interactive=False)
|
| 164 |
+
|
| 165 |
+
with gr.Column():
|
| 166 |
+
condition_image = gr.Image(label="🖼️ Generated Image Condition", type="pil") # , interactive=False
|
| 167 |
+
|
| 168 |
+
step2_button.click(
|
| 169 |
+
generate_image_condition,
|
| 170 |
+
inputs=[position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, text_prompt, selected_view, img_condition_seed, model_choice, edge_refinement],
|
| 171 |
+
outputs=[condition_image, step2_progress],
|
| 172 |
+
concurrency_id="gpu_intensive"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Step 3
|
| 176 |
+
gr.Markdown("---")
|
| 177 |
+
gr.Markdown("## 🎨 Step 3: Generate Final Texture")
|
| 178 |
+
gr.Markdown("""
|
| 179 |
+
**📋 How to generate final texture:**
|
| 180 |
+
- The **SeqTex pipeline** will create a complete texture map for your model
|
| 181 |
+
- View the results from multiple angles and download your textured 3D model (the viewport is a little bit dark)
|
| 182 |
+
""")
|
| 183 |
+
texture_map_tensor, mv_out_tensor = gr.State(), gr.State()
|
| 184 |
+
with gr.Row():
|
| 185 |
+
with gr.Column(scale=1):
|
| 186 |
+
step3_button = gr.Button("🎨 Generate Final Texture", variant="primary")
|
| 187 |
+
step3_progress = gr.Textbox(label="📊 Texture Generation Status", interactive=False)
|
| 188 |
+
texture_map = gr.Image(label="🏆 Generated Texture Map", interactive=False)
|
| 189 |
+
with gr.Column(scale=2):
|
| 190 |
+
rendered_imgs = gr.Image(label="🖼️ Final Rendered Views")
|
| 191 |
+
mv_branch_imgs = gr.Image(label="🖼️ SeqTex Direct Output")
|
| 192 |
+
with gr.Column(scale=1.5):
|
| 193 |
+
# model_display = gr.Model3D(label="🏆 Final Textured Model", height=500)
|
| 194 |
+
model_display = LitModel3D(label="Model with Texture",
|
| 195 |
+
exposure=30.0,
|
| 196 |
+
height=500)
|
| 197 |
+
|
| 198 |
+
step3_button.click(
|
| 199 |
+
generate_texture,
|
| 200 |
+
inputs=[position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, condition_image, text_prompt, selected_view],
|
| 201 |
+
outputs=[texture_map_tensor, mv_out_tensor, step3_progress],
|
| 202 |
+
concurrency_id="gpu_intensive"
|
| 203 |
+
).then(
|
| 204 |
+
tensor_to_pil,
|
| 205 |
+
inputs=[texture_map_tensor, gr.State(None), gr.State(False)],
|
| 206 |
+
outputs=[texture_map]
|
| 207 |
+
).then(
|
| 208 |
+
tensor_to_pil,
|
| 209 |
+
inputs=[mv_out_tensor, gr.State(None), gr.State(False)],
|
| 210 |
+
outputs=[mv_branch_imgs]
|
| 211 |
+
).then(
|
| 212 |
+
render_views,
|
| 213 |
+
inputs=[mesh, texture_map_tensor, mvp_matrix],
|
| 214 |
+
outputs=[rendered_imgs]
|
| 215 |
+
).then(
|
| 216 |
+
Mesh.export,
|
| 217 |
+
inputs=[mesh, gr.State(None), texture_map],
|
| 218 |
+
outputs=[model_display]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Add example inputs for user convenience
|
| 222 |
+
gr.Markdown("---")
|
| 223 |
+
gr.Markdown("## 🚀 Try Our Examples")
|
| 224 |
+
gr.Markdown("**Quick Start**: Click on any example below to see SeqTex in action with pre-configured settings!")
|
| 225 |
+
gr.Examples(
|
| 226 |
+
examples=EXAMPLES,
|
| 227 |
+
inputs=[mesh_upload, y2z, y2x, z2x, upside_down, img_condition_seed, selected_view, model_choice, edge_refinement, text_prompt],
|
| 228 |
+
cache_examples=False
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
demo.launch(server_name="0.0.0.0", server_port=52424)
|
examples/birdhouse.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30a006774b35531831aaf4ba0dd1c7b8a5b5b58433af17ebc52c816cfbd654b9
|
| 3 |
+
size 10043504
|
examples/mario.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbe06e0ad2fc52811ba343dcaeccacb0b9cee1705b6f33bcd222d20de770b80c
|
| 3 |
+
size 1970408
|
utils/__init__.py
ADDED
|
File without changes
|
utils/controlnet_union.py
ADDED
|
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
|
| 21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 23 |
+
from diffusers.utils import BaseOutput, logging
|
| 24 |
+
from diffusers.models.attention_processor import (
|
| 25 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 26 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 27 |
+
AttentionProcessor,
|
| 28 |
+
AttnAddedKVProcessor,
|
| 29 |
+
AttnProcessor,
|
| 30 |
+
)
|
| 31 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.unets.unet_2d_blocks import (
|
| 34 |
+
CrossAttnDownBlock2D,
|
| 35 |
+
DownBlock2D,
|
| 36 |
+
UNetMidBlock2DCrossAttn,
|
| 37 |
+
get_down_block,
|
| 38 |
+
)
|
| 39 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
from collections import OrderedDict
|
| 46 |
+
|
| 47 |
+
# Transformer Block
|
| 48 |
+
# Used to exchange info between different conditions and input image
|
| 49 |
+
# With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
|
| 50 |
+
class QuickGELU(nn.Module):
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor):
|
| 53 |
+
return x * torch.sigmoid(1.702 * x)
|
| 54 |
+
|
| 55 |
+
class LayerNorm(nn.LayerNorm):
|
| 56 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 57 |
+
|
| 58 |
+
def forward(self, x: torch.Tensor):
|
| 59 |
+
orig_type = x.dtype
|
| 60 |
+
ret = super().forward(x)
|
| 61 |
+
return ret.type(orig_type)
|
| 62 |
+
|
| 63 |
+
class ResidualAttentionBlock(nn.Module):
|
| 64 |
+
|
| 65 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 69 |
+
self.ln_1 = LayerNorm(d_model)
|
| 70 |
+
self.mlp = nn.Sequential(
|
| 71 |
+
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
|
| 72 |
+
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
| 73 |
+
self.ln_2 = LayerNorm(d_model)
|
| 74 |
+
self.attn_mask = attn_mask
|
| 75 |
+
|
| 76 |
+
def attention(self, x: torch.Tensor):
|
| 77 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 78 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 79 |
+
|
| 80 |
+
def forward(self, x: torch.Tensor):
|
| 81 |
+
x = x + self.attention(self.ln_1(x))
|
| 82 |
+
x = x + self.mlp(self.ln_2(x))
|
| 83 |
+
return x
|
| 84 |
+
#-----------------------------------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class ControlNetOutput(BaseOutput):
|
| 88 |
+
"""
|
| 89 |
+
The output of [`ControlNetModel`].
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
| 93 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
| 94 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
| 95 |
+
used to condition the original UNet's downsampling activations.
|
| 96 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
| 97 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
| 98 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
| 99 |
+
Output can be used to condition the original UNet's middle block activation.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
| 103 |
+
mid_block_res_sample: torch.Tensor
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
| 109 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
| 110 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
| 111 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
| 112 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
| 113 |
+
model) to encode image-space conditions ... into feature maps ..."
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
# original setting is (16, 32, 96, 256)
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
conditioning_embedding_channels: int,
|
| 120 |
+
conditioning_channels: int = 3,
|
| 121 |
+
block_out_channels: Tuple[int] = (48, 96, 192, 384),
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
|
| 125 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
| 126 |
+
|
| 127 |
+
self.blocks = nn.ModuleList([])
|
| 128 |
+
|
| 129 |
+
for i in range(len(block_out_channels) - 1):
|
| 130 |
+
channel_in = block_out_channels[i]
|
| 131 |
+
channel_out = block_out_channels[i + 1]
|
| 132 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
| 133 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
| 134 |
+
|
| 135 |
+
self.conv_out = zero_module(
|
| 136 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def forward(self, conditioning):
|
| 140 |
+
embedding = self.conv_in(conditioning)
|
| 141 |
+
embedding = F.silu(embedding)
|
| 142 |
+
|
| 143 |
+
for block in self.blocks:
|
| 144 |
+
embedding = block(embedding)
|
| 145 |
+
embedding = F.silu(embedding)
|
| 146 |
+
|
| 147 |
+
embedding = self.conv_out(embedding)
|
| 148 |
+
|
| 149 |
+
return embedding
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 153 |
+
"""
|
| 154 |
+
A ControlNet model.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
in_channels (`int`, defaults to 4):
|
| 158 |
+
The number of channels in the input sample.
|
| 159 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 160 |
+
Whether to flip the sin to cos in the time embedding.
|
| 161 |
+
freq_shift (`int`, defaults to 0):
|
| 162 |
+
The frequency shift to apply to the time embedding.
|
| 163 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 164 |
+
The tuple of downsample blocks to use.
|
| 165 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
| 166 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
| 167 |
+
The tuple of output channels for each block.
|
| 168 |
+
layers_per_block (`int`, defaults to 2):
|
| 169 |
+
The number of layers per block.
|
| 170 |
+
downsample_padding (`int`, defaults to 1):
|
| 171 |
+
The padding to use for the downsampling convolution.
|
| 172 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
| 173 |
+
The scale factor to use for the mid block.
|
| 174 |
+
act_fn (`str`, defaults to "silu"):
|
| 175 |
+
The activation function to use.
|
| 176 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 177 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
| 178 |
+
in post-processing.
|
| 179 |
+
norm_eps (`float`, defaults to 1e-5):
|
| 180 |
+
The epsilon to use for the normalization.
|
| 181 |
+
cross_attention_dim (`int`, defaults to 1280):
|
| 182 |
+
The dimension of the cross attention features.
|
| 183 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
| 184 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 185 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 186 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 187 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
| 188 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
| 189 |
+
dimension to `cross_attention_dim`.
|
| 190 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
| 191 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
| 192 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
| 193 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
| 194 |
+
The dimension of the attention heads.
|
| 195 |
+
use_linear_projection (`bool`, defaults to `False`):
|
| 196 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 197 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
| 198 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 199 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
| 200 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
| 201 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
| 202 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
| 203 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 204 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 205 |
+
upcast_attention (`bool`, defaults to `False`):
|
| 206 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
| 207 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
| 208 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
| 209 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
| 210 |
+
`class_embed_type="projection"`.
|
| 211 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
| 212 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
| 213 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
| 214 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
| 215 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
_supports_gradient_checkpointing = True
|
| 219 |
+
|
| 220 |
+
@register_to_config
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
in_channels: int = 4,
|
| 224 |
+
conditioning_channels: int = 3,
|
| 225 |
+
flip_sin_to_cos: bool = True,
|
| 226 |
+
freq_shift: int = 0,
|
| 227 |
+
down_block_types: Tuple[str] = (
|
| 228 |
+
"CrossAttnDownBlock2D",
|
| 229 |
+
"CrossAttnDownBlock2D",
|
| 230 |
+
"CrossAttnDownBlock2D",
|
| 231 |
+
"DownBlock2D",
|
| 232 |
+
),
|
| 233 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 234 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 235 |
+
layers_per_block: int = 2,
|
| 236 |
+
downsample_padding: int = 1,
|
| 237 |
+
mid_block_scale_factor: float = 1,
|
| 238 |
+
act_fn: str = "silu",
|
| 239 |
+
norm_num_groups: Optional[int] = 32,
|
| 240 |
+
norm_eps: float = 1e-5,
|
| 241 |
+
cross_attention_dim: int = 1280,
|
| 242 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
| 243 |
+
encoder_hid_dim: Optional[int] = None,
|
| 244 |
+
encoder_hid_dim_type: Optional[str] = None,
|
| 245 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 246 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
| 247 |
+
use_linear_projection: bool = False,
|
| 248 |
+
class_embed_type: Optional[str] = None,
|
| 249 |
+
addition_embed_type: Optional[str] = None,
|
| 250 |
+
addition_time_embed_dim: Optional[int] = None,
|
| 251 |
+
num_class_embeds: Optional[int] = None,
|
| 252 |
+
upcast_attention: bool = False,
|
| 253 |
+
resnet_time_scale_shift: str = "default",
|
| 254 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 255 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
| 256 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
| 257 |
+
global_pool_conditions: bool = False,
|
| 258 |
+
addition_embed_type_num_heads=64,
|
| 259 |
+
num_control_type = 6,
|
| 260 |
+
):
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
| 264 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
| 265 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
| 266 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
| 267 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
| 268 |
+
# which is why we correct for the naming here.
|
| 269 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
| 270 |
+
|
| 271 |
+
# Check inputs
|
| 272 |
+
if len(block_out_channels) != len(down_block_types):
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
| 278 |
+
raise ValueError(
|
| 279 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if isinstance(transformer_layers_per_block, int):
|
| 288 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 289 |
+
|
| 290 |
+
# input
|
| 291 |
+
conv_in_kernel = 3
|
| 292 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 293 |
+
self.conv_in = nn.Conv2d(
|
| 294 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# time
|
| 298 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 299 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 300 |
+
timestep_input_dim = block_out_channels[0]
|
| 301 |
+
self.time_embedding = TimestepEmbedding(
|
| 302 |
+
timestep_input_dim,
|
| 303 |
+
time_embed_dim,
|
| 304 |
+
act_fn=act_fn,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
| 308 |
+
encoder_hid_dim_type = "text_proj"
|
| 309 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
| 310 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
| 311 |
+
|
| 312 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
| 313 |
+
raise ValueError(
|
| 314 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if encoder_hid_dim_type == "text_proj":
|
| 318 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
| 319 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
| 320 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 321 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 322 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
| 323 |
+
self.encoder_hid_proj = TextImageProjection(
|
| 324 |
+
text_embed_dim=encoder_hid_dim,
|
| 325 |
+
image_embed_dim=cross_attention_dim,
|
| 326 |
+
cross_attention_dim=cross_attention_dim,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
elif encoder_hid_dim_type is not None:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
self.encoder_hid_proj = None
|
| 335 |
+
|
| 336 |
+
# class embedding
|
| 337 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 338 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 339 |
+
elif class_embed_type == "timestep":
|
| 340 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 341 |
+
elif class_embed_type == "identity":
|
| 342 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 343 |
+
elif class_embed_type == "projection":
|
| 344 |
+
if projection_class_embeddings_input_dim is None:
|
| 345 |
+
raise ValueError(
|
| 346 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
| 347 |
+
)
|
| 348 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
| 349 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
| 350 |
+
# 2. it projects from an arbitrary input dimension.
|
| 351 |
+
#
|
| 352 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
| 353 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
| 354 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
| 355 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 356 |
+
else:
|
| 357 |
+
self.class_embedding = None
|
| 358 |
+
|
| 359 |
+
if addition_embed_type == "text":
|
| 360 |
+
if encoder_hid_dim is not None:
|
| 361 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
| 362 |
+
else:
|
| 363 |
+
text_time_embedding_from_dim = cross_attention_dim
|
| 364 |
+
|
| 365 |
+
self.add_embedding = TextTimeEmbedding(
|
| 366 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
| 367 |
+
)
|
| 368 |
+
elif addition_embed_type == "text_image":
|
| 369 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 370 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 371 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
| 372 |
+
self.add_embedding = TextImageTimeEmbedding(
|
| 373 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
| 374 |
+
)
|
| 375 |
+
elif addition_embed_type == "text_time":
|
| 376 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
| 377 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 378 |
+
|
| 379 |
+
elif addition_embed_type is not None:
|
| 380 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
| 381 |
+
|
| 382 |
+
# control net conditioning embedding
|
| 383 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
| 384 |
+
conditioning_embedding_channels=block_out_channels[0],
|
| 385 |
+
block_out_channels=conditioning_embedding_out_channels,
|
| 386 |
+
conditioning_channels=conditioning_channels,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Copyright by Qi Xin(2024/07/06)
|
| 390 |
+
# Condition Transformer(fuse single/multi conditions with input image)
|
| 391 |
+
# The Condition Transformer augment the feature representation of conditions
|
| 392 |
+
# The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
|
| 393 |
+
# num_control_type = 6
|
| 394 |
+
num_trans_channel = 320
|
| 395 |
+
num_trans_head = 8
|
| 396 |
+
num_trans_layer = 1
|
| 397 |
+
num_proj_channel = 320
|
| 398 |
+
task_scale_factor = num_trans_channel ** 0.5
|
| 399 |
+
|
| 400 |
+
self.task_embedding = nn.Parameter(task_scale_factor * torch.randn(num_control_type, num_trans_channel))
|
| 401 |
+
self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)])
|
| 402 |
+
self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
|
| 403 |
+
#-----------------------------------------------------------------------------------------------------
|
| 404 |
+
|
| 405 |
+
# Copyright by Qi Xin(2024/07/06)
|
| 406 |
+
# Control Encoder to distinguish different control conditions
|
| 407 |
+
# A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
|
| 408 |
+
self.control_type_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
| 409 |
+
self.control_add_embedding = TimestepEmbedding(addition_time_embed_dim * num_control_type, time_embed_dim)
|
| 410 |
+
#-----------------------------------------------------------------------------------------------------
|
| 411 |
+
|
| 412 |
+
self.down_blocks = nn.ModuleList([])
|
| 413 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
| 414 |
+
|
| 415 |
+
if isinstance(only_cross_attention, bool):
|
| 416 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 417 |
+
|
| 418 |
+
if isinstance(attention_head_dim, int):
|
| 419 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 420 |
+
|
| 421 |
+
if isinstance(num_attention_heads, int):
|
| 422 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 423 |
+
|
| 424 |
+
# down
|
| 425 |
+
output_channel = block_out_channels[0]
|
| 426 |
+
|
| 427 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 428 |
+
controlnet_block = zero_module(controlnet_block)
|
| 429 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
| 430 |
+
|
| 431 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 432 |
+
input_channel = output_channel
|
| 433 |
+
output_channel = block_out_channels[i]
|
| 434 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 435 |
+
|
| 436 |
+
down_block = get_down_block(
|
| 437 |
+
down_block_type,
|
| 438 |
+
num_layers=layers_per_block,
|
| 439 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 440 |
+
in_channels=input_channel,
|
| 441 |
+
out_channels=output_channel,
|
| 442 |
+
temb_channels=time_embed_dim,
|
| 443 |
+
add_downsample=not is_final_block,
|
| 444 |
+
resnet_eps=norm_eps,
|
| 445 |
+
resnet_act_fn=act_fn,
|
| 446 |
+
resnet_groups=norm_num_groups,
|
| 447 |
+
cross_attention_dim=cross_attention_dim,
|
| 448 |
+
num_attention_heads=num_attention_heads[i],
|
| 449 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 450 |
+
downsample_padding=downsample_padding,
|
| 451 |
+
use_linear_projection=use_linear_projection,
|
| 452 |
+
only_cross_attention=only_cross_attention[i],
|
| 453 |
+
upcast_attention=upcast_attention,
|
| 454 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 455 |
+
)
|
| 456 |
+
self.down_blocks.append(down_block)
|
| 457 |
+
|
| 458 |
+
for _ in range(layers_per_block):
|
| 459 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 460 |
+
controlnet_block = zero_module(controlnet_block)
|
| 461 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
| 462 |
+
|
| 463 |
+
if not is_final_block:
|
| 464 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 465 |
+
controlnet_block = zero_module(controlnet_block)
|
| 466 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
| 467 |
+
|
| 468 |
+
# mid
|
| 469 |
+
mid_block_channel = block_out_channels[-1]
|
| 470 |
+
|
| 471 |
+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
| 472 |
+
controlnet_block = zero_module(controlnet_block)
|
| 473 |
+
self.controlnet_mid_block = controlnet_block
|
| 474 |
+
|
| 475 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
| 476 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 477 |
+
in_channels=mid_block_channel,
|
| 478 |
+
temb_channels=time_embed_dim,
|
| 479 |
+
resnet_eps=norm_eps,
|
| 480 |
+
resnet_act_fn=act_fn,
|
| 481 |
+
output_scale_factor=mid_block_scale_factor,
|
| 482 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 483 |
+
cross_attention_dim=cross_attention_dim,
|
| 484 |
+
num_attention_heads=num_attention_heads[-1],
|
| 485 |
+
resnet_groups=norm_num_groups,
|
| 486 |
+
use_linear_projection=use_linear_projection,
|
| 487 |
+
upcast_attention=upcast_attention,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
@classmethod
|
| 491 |
+
def from_unet(
|
| 492 |
+
cls,
|
| 493 |
+
unet: UNet2DConditionModel,
|
| 494 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
| 495 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
| 496 |
+
load_weights_from_unet: bool = True,
|
| 497 |
+
):
|
| 498 |
+
r"""
|
| 499 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
| 500 |
+
|
| 501 |
+
Parameters:
|
| 502 |
+
unet (`UNet2DConditionModel`):
|
| 503 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
| 504 |
+
where applicable.
|
| 505 |
+
"""
|
| 506 |
+
transformer_layers_per_block = (
|
| 507 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
| 508 |
+
)
|
| 509 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
| 510 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
| 511 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
| 512 |
+
addition_time_embed_dim = (
|
| 513 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
controlnet = cls(
|
| 517 |
+
encoder_hid_dim=encoder_hid_dim,
|
| 518 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
| 519 |
+
addition_embed_type=addition_embed_type,
|
| 520 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
| 521 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 522 |
+
# transformer_layers_per_block=[1, 2, 5],
|
| 523 |
+
in_channels=unet.config.in_channels,
|
| 524 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
| 525 |
+
freq_shift=unet.config.freq_shift,
|
| 526 |
+
down_block_types=unet.config.down_block_types,
|
| 527 |
+
only_cross_attention=unet.config.only_cross_attention,
|
| 528 |
+
block_out_channels=unet.config.block_out_channels,
|
| 529 |
+
layers_per_block=unet.config.layers_per_block,
|
| 530 |
+
downsample_padding=unet.config.downsample_padding,
|
| 531 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
| 532 |
+
act_fn=unet.config.act_fn,
|
| 533 |
+
norm_num_groups=unet.config.norm_num_groups,
|
| 534 |
+
norm_eps=unet.config.norm_eps,
|
| 535 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
| 536 |
+
attention_head_dim=unet.config.attention_head_dim,
|
| 537 |
+
num_attention_heads=unet.config.num_attention_heads,
|
| 538 |
+
use_linear_projection=unet.config.use_linear_projection,
|
| 539 |
+
class_embed_type=unet.config.class_embed_type,
|
| 540 |
+
num_class_embeds=unet.config.num_class_embeds,
|
| 541 |
+
upcast_attention=unet.config.upcast_attention,
|
| 542 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
| 543 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
| 544 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
| 545 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
if load_weights_from_unet:
|
| 549 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
| 550 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
| 551 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
| 552 |
+
|
| 553 |
+
if controlnet.class_embedding:
|
| 554 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
| 555 |
+
|
| 556 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
| 557 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
| 558 |
+
|
| 559 |
+
return controlnet
|
| 560 |
+
|
| 561 |
+
@property
|
| 562 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 563 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 564 |
+
r"""
|
| 565 |
+
Returns:
|
| 566 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 567 |
+
indexed by its weight name.
|
| 568 |
+
"""
|
| 569 |
+
# set recursively
|
| 570 |
+
processors = {}
|
| 571 |
+
|
| 572 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 573 |
+
if hasattr(module, "get_processor"):
|
| 574 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
| 575 |
+
|
| 576 |
+
for sub_name, child in module.named_children():
|
| 577 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 578 |
+
|
| 579 |
+
return processors
|
| 580 |
+
|
| 581 |
+
for name, module in self.named_children():
|
| 582 |
+
fn_recursive_add_processors(name, module, processors)
|
| 583 |
+
|
| 584 |
+
return processors
|
| 585 |
+
|
| 586 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 587 |
+
def set_attn_processor(
|
| 588 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
| 589 |
+
):
|
| 590 |
+
r"""
|
| 591 |
+
Sets the attention processor to use to compute attention.
|
| 592 |
+
|
| 593 |
+
Parameters:
|
| 594 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 595 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 596 |
+
for **all** `Attention` layers.
|
| 597 |
+
|
| 598 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 599 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 600 |
+
|
| 601 |
+
"""
|
| 602 |
+
count = len(self.attn_processors.keys())
|
| 603 |
+
|
| 604 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 605 |
+
raise ValueError(
|
| 606 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 607 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 611 |
+
if hasattr(module, "set_processor"):
|
| 612 |
+
if not isinstance(processor, dict):
|
| 613 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
| 614 |
+
else:
|
| 615 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
| 616 |
+
|
| 617 |
+
for sub_name, child in module.named_children():
|
| 618 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 619 |
+
|
| 620 |
+
for name, module in self.named_children():
|
| 621 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 622 |
+
|
| 623 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 624 |
+
def set_default_attn_processor(self):
|
| 625 |
+
"""
|
| 626 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 627 |
+
"""
|
| 628 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 629 |
+
processor = AttnAddedKVProcessor()
|
| 630 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 631 |
+
processor = AttnProcessor()
|
| 632 |
+
else:
|
| 633 |
+
raise ValueError(
|
| 634 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
| 638 |
+
|
| 639 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
| 640 |
+
def set_attention_slice(self, slice_size):
|
| 641 |
+
r"""
|
| 642 |
+
Enable sliced attention computation.
|
| 643 |
+
|
| 644 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
| 645 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 649 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
| 650 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
| 651 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 652 |
+
must be a multiple of `slice_size`.
|
| 653 |
+
"""
|
| 654 |
+
sliceable_head_dims = []
|
| 655 |
+
|
| 656 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
| 657 |
+
if hasattr(module, "set_attention_slice"):
|
| 658 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 659 |
+
|
| 660 |
+
for child in module.children():
|
| 661 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
| 662 |
+
|
| 663 |
+
# retrieve number of attention layers
|
| 664 |
+
for module in self.children():
|
| 665 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
| 666 |
+
|
| 667 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
| 668 |
+
|
| 669 |
+
if slice_size == "auto":
|
| 670 |
+
# half the attention head size is usually a good trade-off between
|
| 671 |
+
# speed and memory
|
| 672 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 673 |
+
elif slice_size == "max":
|
| 674 |
+
# make smallest slice possible
|
| 675 |
+
slice_size = num_sliceable_layers * [1]
|
| 676 |
+
|
| 677 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 678 |
+
|
| 679 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 680 |
+
raise ValueError(
|
| 681 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 682 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
for i in range(len(slice_size)):
|
| 686 |
+
size = slice_size[i]
|
| 687 |
+
dim = sliceable_head_dims[i]
|
| 688 |
+
if size is not None and size > dim:
|
| 689 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 690 |
+
|
| 691 |
+
# Recursively walk through all the children.
|
| 692 |
+
# Any children which exposes the set_attention_slice method
|
| 693 |
+
# gets the message
|
| 694 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 695 |
+
if hasattr(module, "set_attention_slice"):
|
| 696 |
+
module.set_attention_slice(slice_size.pop())
|
| 697 |
+
|
| 698 |
+
for child in module.children():
|
| 699 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 700 |
+
|
| 701 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 702 |
+
for module in self.children():
|
| 703 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 707 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
| 708 |
+
module.gradient_checkpointing = value
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def forward(
|
| 712 |
+
self,
|
| 713 |
+
sample: torch.FloatTensor,
|
| 714 |
+
timestep: Union[torch.Tensor, float, int],
|
| 715 |
+
encoder_hidden_states: torch.Tensor,
|
| 716 |
+
controlnet_cond_list: torch.FloatTensor,
|
| 717 |
+
conditioning_scale: float = 1.0,
|
| 718 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 719 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 720 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 721 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 722 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 723 |
+
guess_mode: bool = False,
|
| 724 |
+
return_dict: bool = True,
|
| 725 |
+
) -> Union[ControlNetOutput, Tuple]:
|
| 726 |
+
"""
|
| 727 |
+
The [`ControlNetModel`] forward method.
|
| 728 |
+
|
| 729 |
+
Args:
|
| 730 |
+
sample (`torch.FloatTensor`):
|
| 731 |
+
The noisy input tensor.
|
| 732 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
| 733 |
+
The number of timesteps to denoise an input.
|
| 734 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 735 |
+
The encoder hidden states.
|
| 736 |
+
controlnet_cond (`torch.FloatTensor`):
|
| 737 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
| 738 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
| 739 |
+
The scale factor for ControlNet outputs.
|
| 740 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
| 741 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
| 742 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
| 743 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
| 744 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
| 745 |
+
embeddings.
|
| 746 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
| 747 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 748 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 749 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 750 |
+
added_cond_kwargs (`dict`):
|
| 751 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
| 752 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
| 753 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
| 754 |
+
guess_mode (`bool`, defaults to `False`):
|
| 755 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
| 756 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
| 757 |
+
return_dict (`bool`, defaults to `True`):
|
| 758 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
| 759 |
+
|
| 760 |
+
Returns:
|
| 761 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
| 762 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
| 763 |
+
returned where the first element is the sample tensor.
|
| 764 |
+
"""
|
| 765 |
+
# check channel order
|
| 766 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
| 767 |
+
|
| 768 |
+
if channel_order == "rgb":
|
| 769 |
+
# in rgb order by default
|
| 770 |
+
...
|
| 771 |
+
# elif channel_order == "bgr":
|
| 772 |
+
# controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
| 773 |
+
else:
|
| 774 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
| 775 |
+
|
| 776 |
+
# prepare attention_mask
|
| 777 |
+
if attention_mask is not None:
|
| 778 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 779 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 780 |
+
|
| 781 |
+
# 1. time
|
| 782 |
+
timesteps = timestep
|
| 783 |
+
if not torch.is_tensor(timesteps):
|
| 784 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 785 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 786 |
+
is_mps = sample.device.type == "mps"
|
| 787 |
+
if isinstance(timestep, float):
|
| 788 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 789 |
+
else:
|
| 790 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 791 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 792 |
+
elif len(timesteps.shape) == 0:
|
| 793 |
+
timesteps = timesteps[None].to(sample.device)
|
| 794 |
+
|
| 795 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 796 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 797 |
+
|
| 798 |
+
t_emb = self.time_proj(timesteps)
|
| 799 |
+
|
| 800 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 801 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 802 |
+
# there might be better ways to encapsulate this.
|
| 803 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 804 |
+
|
| 805 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 806 |
+
aug_emb = None
|
| 807 |
+
|
| 808 |
+
if self.class_embedding is not None:
|
| 809 |
+
if class_labels is None:
|
| 810 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 811 |
+
|
| 812 |
+
if self.config.class_embed_type == "timestep":
|
| 813 |
+
class_labels = self.time_proj(class_labels)
|
| 814 |
+
|
| 815 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 816 |
+
emb = emb + class_emb
|
| 817 |
+
|
| 818 |
+
if self.config.addition_embed_type is not None:
|
| 819 |
+
if self.config.addition_embed_type == "text":
|
| 820 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 821 |
+
|
| 822 |
+
elif self.config.addition_embed_type == "text_time":
|
| 823 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 824 |
+
raise ValueError(
|
| 825 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 826 |
+
)
|
| 827 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 828 |
+
if "time_ids" not in added_cond_kwargs:
|
| 829 |
+
raise ValueError(
|
| 830 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 831 |
+
)
|
| 832 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 833 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 834 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 835 |
+
|
| 836 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 837 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 838 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 839 |
+
|
| 840 |
+
# Copyright by Qi Xin(2024/07/06)
|
| 841 |
+
# inject control type info to time embedding to distinguish different control conditions
|
| 842 |
+
control_type = added_cond_kwargs.get('control_type')
|
| 843 |
+
control_embeds = self.control_type_proj(control_type.flatten())
|
| 844 |
+
control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
|
| 845 |
+
control_embeds = control_embeds.to(emb.dtype)
|
| 846 |
+
control_emb = self.control_add_embedding(control_embeds)
|
| 847 |
+
emb = emb + control_emb
|
| 848 |
+
#---------------------------------------------------------------------------------
|
| 849 |
+
|
| 850 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 851 |
+
|
| 852 |
+
# 2. pre-process
|
| 853 |
+
sample = self.conv_in(sample)
|
| 854 |
+
indices = torch.nonzero(control_type[0])
|
| 855 |
+
|
| 856 |
+
# Copyright by Qi Xin(2024/07/06)
|
| 857 |
+
# add single/multi conditons to input image.
|
| 858 |
+
# Condition Transformer provides an easy and effective way to fuse different features naturally
|
| 859 |
+
inputs = []
|
| 860 |
+
condition_list = []
|
| 861 |
+
|
| 862 |
+
for idx in range(indices.shape[0] + 1):
|
| 863 |
+
if idx == indices.shape[0]:
|
| 864 |
+
controlnet_cond = sample
|
| 865 |
+
feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
|
| 866 |
+
else:
|
| 867 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond_list[indices[idx][0]])
|
| 868 |
+
feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
|
| 869 |
+
feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
|
| 870 |
+
|
| 871 |
+
inputs.append(feat_seq.unsqueeze(1))
|
| 872 |
+
condition_list.append(controlnet_cond)
|
| 873 |
+
|
| 874 |
+
x = torch.cat(inputs, dim=1) # NxLxC
|
| 875 |
+
x = self.transformer_layes(x)
|
| 876 |
+
|
| 877 |
+
controlnet_cond_fuser = sample * 0.0
|
| 878 |
+
for idx in range(indices.shape[0]):
|
| 879 |
+
alpha = self.spatial_ch_projs(x[:, idx])
|
| 880 |
+
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
| 881 |
+
controlnet_cond_fuser += condition_list[idx] + alpha
|
| 882 |
+
|
| 883 |
+
sample = sample + controlnet_cond_fuser
|
| 884 |
+
#-------------------------------------------------------------------------------------------
|
| 885 |
+
|
| 886 |
+
# 3. down
|
| 887 |
+
down_block_res_samples = (sample,)
|
| 888 |
+
for downsample_block in self.down_blocks:
|
| 889 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 890 |
+
sample, res_samples = downsample_block(
|
| 891 |
+
hidden_states=sample,
|
| 892 |
+
temb=emb,
|
| 893 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 894 |
+
attention_mask=attention_mask,
|
| 895 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 896 |
+
)
|
| 897 |
+
else:
|
| 898 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 899 |
+
|
| 900 |
+
down_block_res_samples += res_samples
|
| 901 |
+
|
| 902 |
+
# 4. mid
|
| 903 |
+
if self.mid_block is not None:
|
| 904 |
+
sample = self.mid_block(
|
| 905 |
+
sample,
|
| 906 |
+
emb,
|
| 907 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 908 |
+
attention_mask=attention_mask,
|
| 909 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
# 5. Control net blocks
|
| 913 |
+
|
| 914 |
+
controlnet_down_block_res_samples = ()
|
| 915 |
+
|
| 916 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
| 917 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
| 918 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
| 919 |
+
|
| 920 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
| 921 |
+
|
| 922 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
| 923 |
+
|
| 924 |
+
# 6. scaling
|
| 925 |
+
if guess_mode and not self.config.global_pool_conditions:
|
| 926 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
| 927 |
+
scales = scales * conditioning_scale
|
| 928 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
| 929 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
| 930 |
+
else:
|
| 931 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
| 932 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
| 933 |
+
|
| 934 |
+
if self.config.global_pool_conditions:
|
| 935 |
+
down_block_res_samples = [
|
| 936 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
| 937 |
+
]
|
| 938 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
| 939 |
+
|
| 940 |
+
if not return_dict:
|
| 941 |
+
return (down_block_res_samples, mid_block_res_sample)
|
| 942 |
+
|
| 943 |
+
return ControlNetOutput(
|
| 944 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def zero_module(module):
|
| 950 |
+
for p in module.parameters():
|
| 951 |
+
nn.init.zeros_(p)
|
| 952 |
+
return module
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
|
utils/image_generation.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import spaces
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
# Add FLUX imports
|
| 9 |
+
from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler,
|
| 10 |
+
FluxControlNetModel, FluxControlNetPipeline)
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torchvision.transforms import ToPILImage
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
|
| 17 |
+
from .controlnet_union import ControlNetModel_Union
|
| 18 |
+
from .pipeline_controlnet_union_sd_xl import \
|
| 19 |
+
StableDiffusionXLControlNetUnionPipeline
|
| 20 |
+
from .render_utils import get_silhouette_image
|
| 21 |
+
|
| 22 |
+
IMG_PIPE = None
|
| 23 |
+
IMG_PIPE_LOCK = threading.Lock()
|
| 24 |
+
# Add FLUX pipeline variables
|
| 25 |
+
FLUX_PIPE = None
|
| 26 |
+
FLUX_PIPE_LOCK = threading.Lock()
|
| 27 |
+
FLUX_SUFFIX = None
|
| 28 |
+
FLUX_NEGATIVE = None
|
| 29 |
+
|
| 30 |
+
def lazy_get_flux_pipe():
|
| 31 |
+
"""
|
| 32 |
+
Lazy load the FLUX pipeline with ControlNet for image generation.
|
| 33 |
+
"""
|
| 34 |
+
global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE
|
| 35 |
+
if FLUX_PIPE is not None:
|
| 36 |
+
return FLUX_PIPE
|
| 37 |
+
gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.")
|
| 38 |
+
with FLUX_PIPE_LOCK:
|
| 39 |
+
if FLUX_PIPE is not None:
|
| 40 |
+
return FLUX_PIPE
|
| 41 |
+
FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture."
|
| 42 |
+
FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field."
|
| 43 |
+
base_model = 'black-forest-labs/FLUX.1-dev'
|
| 44 |
+
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
|
| 45 |
+
|
| 46 |
+
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
|
| 47 |
+
FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
|
| 48 |
+
base_model,
|
| 49 |
+
controlnet=controlnet,
|
| 50 |
+
torch_dtype=torch.bfloat16
|
| 51 |
+
)
|
| 52 |
+
# Use model CPU offload for better GPU utilization during inference
|
| 53 |
+
FLUX_PIPE.enable_model_cpu_offload()
|
| 54 |
+
return FLUX_PIPE
|
| 55 |
+
|
| 56 |
+
def lazy_get_sdxl_pipe():
|
| 57 |
+
"""
|
| 58 |
+
Lazy load the SDXL pipeline with ControlNet for image generation.
|
| 59 |
+
"""
|
| 60 |
+
global IMG_PIPE
|
| 61 |
+
if IMG_PIPE is not None:
|
| 62 |
+
return IMG_PIPE
|
| 63 |
+
gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.")
|
| 64 |
+
with IMG_PIPE_LOCK:
|
| 65 |
+
if IMG_PIPE is not None:
|
| 66 |
+
return IMG_PIPE
|
| 67 |
+
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
|
| 68 |
+
# when test with other base model, you need to change the vae also.
|
| 69 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
| 70 |
+
controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True)
|
| 71 |
+
IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
|
| 72 |
+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model,
|
| 73 |
+
vae=vae,
|
| 74 |
+
torch_dtype=torch.float16,
|
| 75 |
+
scheduler=eulera_scheduler,
|
| 76 |
+
)
|
| 77 |
+
# Move pipeline to CUDA device
|
| 78 |
+
IMG_PIPE = IMG_PIPE.to("cuda")
|
| 79 |
+
return IMG_PIPE
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
|
| 83 |
+
"""
|
| 84 |
+
Generate image condition using SDXL model with ControlNet based on depth and normal images.
|
| 85 |
+
:param depth_img: Depth image from the selected view.
|
| 86 |
+
:param normal_img: Normal image (Camera Coordinate System) from the selected view.
|
| 87 |
+
:param text_prompt: Text prompt for image generation.
|
| 88 |
+
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
|
| 89 |
+
:param seed: Random seed for image generation.
|
| 90 |
+
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
|
| 91 |
+
:param image_height: Height of the output image.
|
| 92 |
+
:param image_width: Width of the output image.
|
| 93 |
+
:param progress: Progress callback for Gradio.
|
| 94 |
+
:return: Generated image condition (e.g., PIL Image).
|
| 95 |
+
"""
|
| 96 |
+
progress(0.1, desc="Loading SDXL pipeline...")
|
| 97 |
+
pipeline = lazy_get_sdxl_pipe()
|
| 98 |
+
progress(0.3, desc="SDXL pipeline loaded successfully.")
|
| 99 |
+
|
| 100 |
+
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft lightning, uniform color, foreground"
|
| 101 |
+
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
|
| 102 |
+
|
| 103 |
+
img_generation_resolution = 1024 # SDXL performs better at 1024x1024
|
| 104 |
+
image = pipeline(prompt=[positive_prompt]*1,
|
| 105 |
+
image_list=[0, depth_img, 0, 0, normal_img, 0],
|
| 106 |
+
negative_prompt=[negative_prompt]*1,
|
| 107 |
+
generator=torch.Generator(device="cuda").manual_seed(seed),
|
| 108 |
+
width=img_generation_resolution,
|
| 109 |
+
height=img_generation_resolution,
|
| 110 |
+
num_inference_steps=50,
|
| 111 |
+
union_control=True,
|
| 112 |
+
union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images
|
| 113 |
+
progress=progress,
|
| 114 |
+
).images[0]
|
| 115 |
+
progress(0.9, desc="Condition tensor generated successfully.")
|
| 116 |
+
|
| 117 |
+
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device)
|
| 118 |
+
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape
|
| 119 |
+
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
|
| 120 |
+
|
| 121 |
+
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
| 122 |
+
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
| 123 |
+
|
| 124 |
+
# Apply edge refinement if enabled
|
| 125 |
+
if edge_refinement:
|
| 126 |
+
# Convert to CUDA device for edge refinement
|
| 127 |
+
rgb_tensor_cuda = rgb_tensor.to("cuda")
|
| 128 |
+
mask_tensor_cuda = mask_tensor.to("cuda")
|
| 129 |
+
rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda)
|
| 130 |
+
rgb_tensor = rgb_tensor_cuda.to(pipeline.device)
|
| 131 |
+
|
| 132 |
+
background_tensor = torch.zeros_like(rgb_tensor)
|
| 133 |
+
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
|
| 134 |
+
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
|
| 135 |
+
rgb_tensor = rgb_tensor / 255.
|
| 136 |
+
to_img = ToPILImage()
|
| 137 |
+
condition_image = to_img(rgb_tensor.cpu())
|
| 138 |
+
|
| 139 |
+
progress(1, desc="Condition image generated successfully.")
|
| 140 |
+
return condition_image
|
| 141 |
+
|
| 142 |
+
def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
|
| 143 |
+
"""
|
| 144 |
+
Generate image condition using FLUX model with ControlNet based on depth image only.
|
| 145 |
+
Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth.
|
| 146 |
+
:param depth_img: Depth image from the selected view.
|
| 147 |
+
:param text_prompt: Text prompt for image generation.
|
| 148 |
+
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
|
| 149 |
+
:param seed: Random seed for image generation.
|
| 150 |
+
:param image_height: Height of the output image.
|
| 151 |
+
:param image_width: Width of the output image.
|
| 152 |
+
:param progress: Progress callback for Gradio.
|
| 153 |
+
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
|
| 154 |
+
:return: Generated image condition (PIL Image).
|
| 155 |
+
"""
|
| 156 |
+
progress(0.1, desc="Loading FLUX pipeline...")
|
| 157 |
+
pipeline = lazy_get_flux_pipe()
|
| 158 |
+
progress(0.3, desc="FLUX pipeline loaded successfully.")
|
| 159 |
+
|
| 160 |
+
# Enhanced prompt for better results
|
| 161 |
+
positive_prompt = text_prompt + FLUX_SUFFIX
|
| 162 |
+
negative_prompt = FLUX_NEGATIVE
|
| 163 |
+
|
| 164 |
+
# Get image dimensions
|
| 165 |
+
width, height = depth_img.size
|
| 166 |
+
|
| 167 |
+
progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...")
|
| 168 |
+
|
| 169 |
+
# Generate image using FLUX ControlNet with depth control
|
| 170 |
+
# model_cpu_offload handles GPU loading automatically
|
| 171 |
+
image = pipeline(
|
| 172 |
+
prompt=positive_prompt,
|
| 173 |
+
negative_prompt=negative_prompt,
|
| 174 |
+
control_image=depth_img,
|
| 175 |
+
width=width,
|
| 176 |
+
height=height,
|
| 177 |
+
controlnet_conditioning_scale=0.8, # Recommended for depth
|
| 178 |
+
control_guidance_end=0.8,
|
| 179 |
+
num_inference_steps=30,
|
| 180 |
+
guidance_scale=3.5,
|
| 181 |
+
generator=torch.Generator(device="cuda").manual_seed(seed),
|
| 182 |
+
).images[0]
|
| 183 |
+
|
| 184 |
+
progress(0.9, desc="Applying mask and resizing...")
|
| 185 |
+
|
| 186 |
+
# Convert to tensor and apply mask
|
| 187 |
+
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
|
| 188 |
+
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda")
|
| 189 |
+
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
|
| 190 |
+
|
| 191 |
+
# Resize to target dimensions
|
| 192 |
+
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
| 193 |
+
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
| 194 |
+
|
| 195 |
+
# Apply mask (blend with black background)
|
| 196 |
+
background_tensor = torch.zeros_like(rgb_tensor)
|
| 197 |
+
if edge_refinement:
|
| 198 |
+
# replace edge with inner values
|
| 199 |
+
rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor)
|
| 200 |
+
|
| 201 |
+
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
|
| 202 |
+
|
| 203 |
+
# Convert back to PIL Image
|
| 204 |
+
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
|
| 205 |
+
rgb_tensor = rgb_tensor / 255.0
|
| 206 |
+
to_img = ToPILImage()
|
| 207 |
+
condition_image = to_img(rgb_tensor.cpu())
|
| 208 |
+
|
| 209 |
+
progress(1, desc="FLUX condition image generated successfully.")
|
| 210 |
+
return condition_image
|
| 211 |
+
|
| 212 |
+
def refine_image_edges(rgb_tensor, mask_tensor):
|
| 213 |
+
"""
|
| 214 |
+
Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries.
|
| 215 |
+
|
| 216 |
+
Algorithm:
|
| 217 |
+
1. Erode mask to get eroded_mask
|
| 218 |
+
2. Double erode mask to get double_eroded_mask
|
| 219 |
+
3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask
|
| 220 |
+
4. Use circle_valid_mask to extract circle_rgb (clean edge values)
|
| 221 |
+
5. Dilate circle_rgb to cover the edge region
|
| 222 |
+
6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
|
| 223 |
+
|
| 224 |
+
:param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device
|
| 225 |
+
:param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1]
|
| 226 |
+
:return: refined_rgb_tensor
|
| 227 |
+
"""
|
| 228 |
+
# Convert tensors to numpy for OpenCV processing
|
| 229 |
+
rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C)
|
| 230 |
+
mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions
|
| 231 |
+
original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range
|
| 232 |
+
|
| 233 |
+
# Create morphological kernel (3x3 as requested)
|
| 234 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 235 |
+
|
| 236 |
+
# Step 1: Erode mask to get eroded_mask
|
| 237 |
+
eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3)
|
| 238 |
+
|
| 239 |
+
# Step 2: Double erode mask to get double_eroded_mask
|
| 240 |
+
double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5)
|
| 241 |
+
|
| 242 |
+
# Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask
|
| 243 |
+
circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np)
|
| 244 |
+
|
| 245 |
+
# Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values)
|
| 246 |
+
circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
|
| 247 |
+
circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8)
|
| 248 |
+
|
| 249 |
+
# Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly)
|
| 250 |
+
dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8)
|
| 251 |
+
|
| 252 |
+
# Step 6: Final composition
|
| 253 |
+
# Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
|
| 254 |
+
double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
|
| 255 |
+
|
| 256 |
+
# Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere
|
| 257 |
+
refined_rgb_np = (rgb_np * double_eroded_mask_3c +
|
| 258 |
+
dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8)
|
| 259 |
+
|
| 260 |
+
# Convert refined RGB back to tensor
|
| 261 |
+
refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
|
| 262 |
+
|
| 263 |
+
return refined_rgb_tensor
|
| 264 |
+
|
| 265 |
+
@spaces.GPU(duration=120)
|
| 266 |
+
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
|
| 267 |
+
"""
|
| 268 |
+
Generate the image condition based on the selected view's silhouette and text prompt.
|
| 269 |
+
:param position_imgs: Position images from different views.
|
| 270 |
+
:param normal_imgs: Normal images from different views.
|
| 271 |
+
:param mask_imgs: Mask images from different views.
|
| 272 |
+
:param w2c: World-to-camera transformation matrices.
|
| 273 |
+
:param text_prompt: The text prompt for image generation.
|
| 274 |
+
:param selected_view: The selected view for image generation.
|
| 275 |
+
:param seed: Random seed for image generation.
|
| 276 |
+
:param model: The image generation model type, supports "SDXL" and "FLUX".
|
| 277 |
+
:param progress: Progress callback for Gradio.
|
| 278 |
+
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
|
| 279 |
+
:return: Generated condition image and status message.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
progress(0, desc="Handling geometry information...")
|
| 283 |
+
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
|
| 284 |
+
depth_img = silhouette[0]
|
| 285 |
+
normal_img = silhouette[1]
|
| 286 |
+
mask = silhouette[2]
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
if model == "SDXL":
|
| 290 |
+
condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
|
| 291 |
+
return condition, "SDXL condition generated successfully."
|
| 292 |
+
elif model == "FLUX":
|
| 293 |
+
# FLUX only supports depth control, not normal
|
| 294 |
+
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
|
| 295 |
+
return condition, "FLUX condition generated successfully (depth-only control)."
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.")
|
| 298 |
+
finally:
|
| 299 |
+
torch.cuda.empty_cache()
|
utils/mesh_utils.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import trimesh
|
| 7 |
+
import xatlas
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map,
|
| 13 |
+
render_geo_views_tensor, render_views, setup_lights)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mesh:
|
| 17 |
+
def __init__(self, mesh_path=None, uv_tool="xAtlas", device='cuda', progress=gr.Progress()):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the Mesh object with a mesh file path.
|
| 20 |
+
:param mesh_path: Path to the mesh file (e.g., .obj or .glb).
|
| 21 |
+
"""
|
| 22 |
+
self.device = device
|
| 23 |
+
if mesh_path is not None:
|
| 24 |
+
# Initialize _parts dictionary to store all parts
|
| 25 |
+
self._parts = {}
|
| 26 |
+
|
| 27 |
+
if mesh_path.endswith('.obj'):
|
| 28 |
+
progress(0., f"Loading mesh in .obj format...")
|
| 29 |
+
mesh_data = trimesh.load(mesh_path, process=False)
|
| 30 |
+
|
| 31 |
+
# Check if it's a mesh list (multi-part obj)
|
| 32 |
+
if isinstance(mesh_data, list):
|
| 33 |
+
progress(0.1, f"Handling part list...")
|
| 34 |
+
for i, mesh_part in enumerate(mesh_data):
|
| 35 |
+
self._add_part_to_parts(f"part_{i}", mesh_part)
|
| 36 |
+
# Check if it's a Scene (another multi-part format)
|
| 37 |
+
elif isinstance(mesh_data, trimesh.Scene):
|
| 38 |
+
progress(0.1, f"Handling Scenes...")
|
| 39 |
+
geometry = mesh_data.geometry
|
| 40 |
+
if len(geometry) > 0:
|
| 41 |
+
for key, mesh_part in geometry.items():
|
| 42 |
+
self._add_part_to_parts(key, mesh_part)
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError("Empty scene, no mesh data found.")
|
| 45 |
+
else:
|
| 46 |
+
# Single part obj
|
| 47 |
+
progress(0.1, f"Handling single part...")
|
| 48 |
+
self._add_part_to_parts("part_0", mesh_data)
|
| 49 |
+
|
| 50 |
+
elif mesh_path.endswith('.glb'):
|
| 51 |
+
progress(0., f"Loading mesh in .glb format...")
|
| 52 |
+
mesh_loaded = trimesh.load(mesh_path)
|
| 53 |
+
|
| 54 |
+
# Check if it's a Scene (multi-part glb)
|
| 55 |
+
if isinstance(mesh_loaded, trimesh.Scene):
|
| 56 |
+
progress(0.1, f"Handling Scenes...")
|
| 57 |
+
geometry = mesh_loaded.geometry
|
| 58 |
+
if len(geometry) > 0:
|
| 59 |
+
for key, mesh_part in geometry.items():
|
| 60 |
+
self._add_part_to_parts(key, mesh_part)
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError("Empty scene, no mesh data found.")
|
| 63 |
+
else:
|
| 64 |
+
# Single part glb
|
| 65 |
+
progress(0.1, f"Handling single part...")
|
| 66 |
+
self._add_part_to_parts("part_0", mesh_loaded)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unsupported file format: {mesh_path}")
|
| 69 |
+
|
| 70 |
+
# Automatically merge all parts during initialization
|
| 71 |
+
progress(0.2, f"Merging if the mesh have multiple parts.")
|
| 72 |
+
self._merge_parts_internal()
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError("Mesh path cannot be None.")
|
| 75 |
+
self.to(self.device) # Move to the specified device
|
| 76 |
+
|
| 77 |
+
# Initialize transformation flags
|
| 78 |
+
self._upside_down_applied = False
|
| 79 |
+
|
| 80 |
+
# UV parameterization
|
| 81 |
+
if self.has_multi_parts or not self.has_uv:
|
| 82 |
+
progress(0.4, f"Using {uv_tool} for UV parameterization. It may take quite a while (several minutes), if there are many faces. We STRONLY recommend using a mesh with UV parameterization.")
|
| 83 |
+
if uv_tool == "xAtlas":
|
| 84 |
+
self.uv_xatlas_mapping() # Use default parameters
|
| 85 |
+
elif uv_tool == "UVAtlas":
|
| 86 |
+
raise NotImplementedError("UVAtlas parameterization is not implemented yet.")
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("Unsupported UV parameterization tool.")
|
| 89 |
+
print("UV parameterization completed.")
|
| 90 |
+
else:
|
| 91 |
+
progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.")
|
| 92 |
+
self._vmapping = None # No vmapping needed when not reparameterizing
|
| 93 |
+
|
| 94 |
+
def to(self, device):
|
| 95 |
+
"""
|
| 96 |
+
Move the mesh data to the specified device.
|
| 97 |
+
:param device: The target device (e.g., 'cuda' or 'cpu').
|
| 98 |
+
"""
|
| 99 |
+
self._v_pos = self._v_pos.to(device)
|
| 100 |
+
self._t_pos_idx = self._t_pos_idx.to(device)
|
| 101 |
+
if self._v_tex is not None:
|
| 102 |
+
self._v_tex = self._v_tex.to(device)
|
| 103 |
+
self._t_tex_idx = self._t_tex_idx.to(device)
|
| 104 |
+
if hasattr(self, '_vmapping') and self._vmapping is not None:
|
| 105 |
+
self._vmapping = self._vmapping.to(device)
|
| 106 |
+
self._v_normal = self._v_normal.to(device)
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def has_multi_parts(self):
|
| 110 |
+
"""
|
| 111 |
+
Check if the mesh has multiple parts.
|
| 112 |
+
:return: Boolean indicating whether the mesh has multiple parts.
|
| 113 |
+
"""
|
| 114 |
+
# If _parts is None, it means already merged, not multi-part
|
| 115 |
+
if self._parts is None:
|
| 116 |
+
return False
|
| 117 |
+
return len(self._parts) > 1
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def v_pos(self):
|
| 121 |
+
"""Vertex positions property."""
|
| 122 |
+
return self._v_pos
|
| 123 |
+
|
| 124 |
+
@v_pos.setter
|
| 125 |
+
def v_pos(self, value):
|
| 126 |
+
self._v_pos = value
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def t_pos_idx(self):
|
| 130 |
+
"""Triangle position indices property."""
|
| 131 |
+
return self._t_pos_idx
|
| 132 |
+
|
| 133 |
+
@t_pos_idx.setter
|
| 134 |
+
def t_pos_idx(self, value):
|
| 135 |
+
self._t_pos_idx = value
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def v_tex(self):
|
| 139 |
+
"""Vertex texture coordinates property."""
|
| 140 |
+
return self._v_tex
|
| 141 |
+
|
| 142 |
+
@v_tex.setter
|
| 143 |
+
def v_tex(self, value):
|
| 144 |
+
self._v_tex = value
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def t_tex_idx(self):
|
| 148 |
+
"""Triangle texture indices property."""
|
| 149 |
+
return self._t_tex_idx
|
| 150 |
+
|
| 151 |
+
@t_tex_idx.setter
|
| 152 |
+
def t_tex_idx(self, value):
|
| 153 |
+
self._t_tex_idx = value
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def v_normal(self):
|
| 157 |
+
"""Vertex normals property."""
|
| 158 |
+
return self._v_normal
|
| 159 |
+
|
| 160 |
+
@v_normal.setter
|
| 161 |
+
def v_normal(self, value):
|
| 162 |
+
self._v_normal = value
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def has_uv(self):
|
| 166 |
+
"""
|
| 167 |
+
Check if the mesh has a valid UV mapping.
|
| 168 |
+
:return: Boolean indicating whether the mesh has UV mapping.
|
| 169 |
+
"""
|
| 170 |
+
return self.v_tex is not None
|
| 171 |
+
|
| 172 |
+
def uv_xatlas_mapping(self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}):
|
| 173 |
+
# Merged mesh, directly add_mesh as a whole
|
| 174 |
+
atlas = xatlas.Atlas()
|
| 175 |
+
v_pos_np = self.v_pos.detach().cpu().numpy()
|
| 176 |
+
t_pos_idx_np = self.t_pos_idx.cpu().numpy()
|
| 177 |
+
atlas.add_mesh(v_pos_np, t_pos_idx_np)
|
| 178 |
+
|
| 179 |
+
# Set reasonable pack parameters to avoid overlap
|
| 180 |
+
co = xatlas.ChartOptions()
|
| 181 |
+
po = xatlas.PackOptions()
|
| 182 |
+
# Recommended default parameters
|
| 183 |
+
if 'resolution' not in xatlas_pack_options:
|
| 184 |
+
po.resolution = 1024 # or larger
|
| 185 |
+
if 'padding' not in xatlas_pack_options:
|
| 186 |
+
po.padding = 2
|
| 187 |
+
for k, v in xatlas_chart_options.items():
|
| 188 |
+
setattr(co, k, v)
|
| 189 |
+
for k, v in xatlas_pack_options.items():
|
| 190 |
+
setattr(po, k, v)
|
| 191 |
+
atlas.generate(co, po)
|
| 192 |
+
|
| 193 |
+
# Get unpacked data
|
| 194 |
+
vmapping, indices, uvs = atlas.get_mesh(0)
|
| 195 |
+
# vmapping: new UV vertex -> original mesh vertex
|
| 196 |
+
# indices: new triangle face indices (based on new UV vertices)
|
| 197 |
+
# uvs: new UV vertex coordinates
|
| 198 |
+
device = self.v_pos.device
|
| 199 |
+
vmapping = torch.from_numpy(vmapping.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long()
|
| 200 |
+
uvs = torch.from_numpy(uvs).to(device).float()
|
| 201 |
+
indices = torch.from_numpy(indices.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long()
|
| 202 |
+
|
| 203 |
+
self.v_tex = uvs # new UV vertices
|
| 204 |
+
self.t_tex_idx = indices # new triangle face indices (based on UV vertices)
|
| 205 |
+
self._vmapping = vmapping # save UV vertex to original vertex mapping for export
|
| 206 |
+
|
| 207 |
+
def normalize(self):
|
| 208 |
+
"""
|
| 209 |
+
Normalize mesh vertices to [-1, 1] range.
|
| 210 |
+
"""
|
| 211 |
+
vertices = self.v_pos
|
| 212 |
+
bounding_box_max = vertices.max(0)[0]
|
| 213 |
+
bounding_box_min = vertices.min(0)[0]
|
| 214 |
+
mesh_scale = 2.0 # Scale to [-1, 1]
|
| 215 |
+
scale = mesh_scale / ((bounding_box_max - bounding_box_min).max() + 1e-6)
|
| 216 |
+
center_offset = (bounding_box_max + bounding_box_min) * 0.5
|
| 217 |
+
self.v_pos = (vertices - center_offset) * scale
|
| 218 |
+
|
| 219 |
+
def vertex_transform(self):
|
| 220 |
+
"""
|
| 221 |
+
Apply coordinate transformation to mesh vertices and normals.
|
| 222 |
+
"""
|
| 223 |
+
# Transform normals
|
| 224 |
+
pre_normals = self.v_normal
|
| 225 |
+
normals = torch.clone(pre_normals)
|
| 226 |
+
normals[:, 1] = -pre_normals[:, 2] # -z --> y
|
| 227 |
+
normals[:, 2] = pre_normals[:, 1] # y --> z
|
| 228 |
+
|
| 229 |
+
# Transform vertices
|
| 230 |
+
pre_vertices = self.v_pos
|
| 231 |
+
vertices = torch.clone(pre_vertices)
|
| 232 |
+
vertices[:, 1] = -pre_vertices[:, 2] # -z --> y
|
| 233 |
+
vertices[:, 2] = pre_vertices[:, 1] # y --> z
|
| 234 |
+
|
| 235 |
+
# Update mesh
|
| 236 |
+
self.v_normal = normals
|
| 237 |
+
self.v_pos = vertices
|
| 238 |
+
|
| 239 |
+
def vertex_transform_y2x(self):
|
| 240 |
+
"""
|
| 241 |
+
Apply coordinate transformation to mesh vertices and normals.
|
| 242 |
+
"""
|
| 243 |
+
# Transform normals
|
| 244 |
+
pre_normals = self.v_normal
|
| 245 |
+
normals = torch.clone(pre_normals)
|
| 246 |
+
normals[:, 1] = -pre_normals[:, 0] # -x --> y
|
| 247 |
+
normals[:, 0] = pre_normals[:, 1] # y --> x
|
| 248 |
+
|
| 249 |
+
# Transform vertices
|
| 250 |
+
pre_vertices = self.v_pos
|
| 251 |
+
vertices = torch.clone(pre_vertices)
|
| 252 |
+
vertices[:, 1] = -pre_vertices[:, 0] # -z --> y
|
| 253 |
+
vertices[:, 0] = pre_vertices[:, 1] # y --> z
|
| 254 |
+
|
| 255 |
+
# 更新网格
|
| 256 |
+
self.v_normal = normals
|
| 257 |
+
self.v_pos = vertices
|
| 258 |
+
|
| 259 |
+
def vertex_transform_z2x(self):
|
| 260 |
+
"""
|
| 261 |
+
Apply coordinate transformation to mesh vertices and normals.
|
| 262 |
+
"""
|
| 263 |
+
# 变换法向量
|
| 264 |
+
pre_normals = self.v_normal
|
| 265 |
+
normals = torch.clone(pre_normals)
|
| 266 |
+
normals[:, 2] = -pre_normals[:, 0] # -x --> z
|
| 267 |
+
normals[:, 0] = pre_normals[:, 2] # z --> x
|
| 268 |
+
|
| 269 |
+
# 变换顶点
|
| 270 |
+
pre_vertices = self.v_pos
|
| 271 |
+
vertices = torch.clone(pre_vertices)
|
| 272 |
+
vertices[:, 2] = -pre_vertices[:, 0] # -z --> y
|
| 273 |
+
vertices[:, 0] = pre_vertices[:, 2] # y --> z
|
| 274 |
+
|
| 275 |
+
# 更新网格
|
| 276 |
+
self.v_normal = normals
|
| 277 |
+
self.v_pos = vertices
|
| 278 |
+
|
| 279 |
+
def vertex_transform_upsidedown(self):
|
| 280 |
+
"""
|
| 281 |
+
Apply upside-down transformation to mesh vertices and normals.
|
| 282 |
+
"""
|
| 283 |
+
# 变换法向量
|
| 284 |
+
pre_normals = self.v_normal
|
| 285 |
+
normals = torch.clone(pre_normals)
|
| 286 |
+
normals[:, 2] = -pre_normals[:, 2]
|
| 287 |
+
|
| 288 |
+
# 变换顶点
|
| 289 |
+
pre_vertices = self.v_pos
|
| 290 |
+
vertices = torch.clone(pre_vertices)
|
| 291 |
+
vertices[:, 2] = -pre_vertices[:, 2]
|
| 292 |
+
|
| 293 |
+
# 更新网格
|
| 294 |
+
self.v_normal = normals
|
| 295 |
+
self.v_pos = vertices
|
| 296 |
+
# self.t_pos_idx = faces
|
| 297 |
+
|
| 298 |
+
# 标记已应用上下翻转变换
|
| 299 |
+
self._upside_down_applied = True
|
| 300 |
+
|
| 301 |
+
def _add_part_to_parts(self, key, mesh_part):
|
| 302 |
+
"""
|
| 303 |
+
将单个mesh部分添加到_parts字典中
|
| 304 |
+
:param key: 部分的键名
|
| 305 |
+
:param mesh_part: trimesh对象
|
| 306 |
+
"""
|
| 307 |
+
# exclude PointCloud parts and empty parts
|
| 308 |
+
if hasattr(mesh_part, 'vertices') and hasattr(mesh_part, 'faces') and len(mesh_part.vertices) > 0 and len(mesh_part.faces) > 0:
|
| 309 |
+
raw_uv = getattr(mesh_part.visual, 'uv', None)
|
| 310 |
+
processed_v_tex = None
|
| 311 |
+
processed_t_tex_idx = None
|
| 312 |
+
|
| 313 |
+
# 仅当UV数据存在且不为空时才处理
|
| 314 |
+
if raw_uv is not None and np.asarray(raw_uv).size > 0 and np.asarray(raw_uv).shape[0] > 0:
|
| 315 |
+
processed_v_tex = torch.tensor(raw_uv, dtype=torch.float32)
|
| 316 |
+
# 假设当源数据提供UV时,t_tex_idx 与 t_pos_idx 使用相同的面索引
|
| 317 |
+
# trimesh 通常提供每个顶点的UV
|
| 318 |
+
processed_t_tex_idx = torch.tensor(mesh_part.faces, dtype=torch.int32)
|
| 319 |
+
|
| 320 |
+
self._parts[key] = {
|
| 321 |
+
'v_pos': torch.tensor(mesh_part.vertices, dtype=torch.float32),
|
| 322 |
+
't_pos_idx': torch.tensor(mesh_part.faces, dtype=torch.int32),
|
| 323 |
+
'v_tex': processed_v_tex,
|
| 324 |
+
't_tex_idx': processed_t_tex_idx,
|
| 325 |
+
'v_normal': torch.tensor(mesh_part.vertex_normals, dtype=torch.float32)
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
def _merge_parts_internal(self):
|
| 329 |
+
"""
|
| 330 |
+
内部使用的合并函数,在初始化时自动调用
|
| 331 |
+
将_parts中的所有部分合并为单一的mesh表示
|
| 332 |
+
"""
|
| 333 |
+
# 如果没有部分或只有一个部分,简化处理
|
| 334 |
+
if not self._parts:
|
| 335 |
+
raise ValueError("No mesh parts.")
|
| 336 |
+
elif len(self._parts) == 1:
|
| 337 |
+
key = next(iter(self._parts))
|
| 338 |
+
part = self._parts[key]
|
| 339 |
+
self._v_pos = part['v_pos']
|
| 340 |
+
self._t_pos_idx = part['t_pos_idx']
|
| 341 |
+
self._v_tex = part['v_tex']
|
| 342 |
+
self._t_tex_idx = part['t_tex_idx']
|
| 343 |
+
self._v_normal = part['v_normal']
|
| 344 |
+
self._parts = None # 清理_parts字典,释放内存
|
| 345 |
+
return
|
| 346 |
+
|
| 347 |
+
# 初始化合并后的数据
|
| 348 |
+
vertices = []
|
| 349 |
+
faces = []
|
| 350 |
+
normals = []
|
| 351 |
+
|
| 352 |
+
# Record vertex count for each part, used to adjust face indices
|
| 353 |
+
v_count = 0
|
| 354 |
+
|
| 355 |
+
# Iterate through all parts
|
| 356 |
+
for key, part in self._parts.items():
|
| 357 |
+
# Add vertices
|
| 358 |
+
vertices.append(part['v_pos'])
|
| 359 |
+
|
| 360 |
+
# Adjust face indices and add
|
| 361 |
+
if len(faces) > 0:
|
| 362 |
+
adjusted_faces = part['t_pos_idx'] + v_count
|
| 363 |
+
faces.append(adjusted_faces)
|
| 364 |
+
else:
|
| 365 |
+
faces.append(part['t_pos_idx'])
|
| 366 |
+
|
| 367 |
+
# Add normals
|
| 368 |
+
normals.append(part['v_normal'])
|
| 369 |
+
|
| 370 |
+
# Update vertex count
|
| 371 |
+
v_count += part['v_pos'].shape[0]
|
| 372 |
+
|
| 373 |
+
self._parts = None # Clear _parts dictionary to free memory
|
| 374 |
+
|
| 375 |
+
# Merge all data
|
| 376 |
+
self._v_pos = torch.cat(vertices, dim=0)
|
| 377 |
+
self._t_pos_idx = torch.cat(faces, dim=0)
|
| 378 |
+
self._v_normal = torch.cat(normals, dim=0)
|
| 379 |
+
self._v_tex = None # multi-parts mesh must be reparameterized
|
| 380 |
+
self._t_tex_idx = None # multi-parts mesh must be reparameterized
|
| 381 |
+
self._vmapping = None # multi-parts mesh must be reparameterized
|
| 382 |
+
|
| 383 |
+
@classmethod
|
| 384 |
+
def export(cls, mesh, save_path=None, texture_map: Image.Image = None):
|
| 385 |
+
"""
|
| 386 |
+
Exports the mesh to a GLB file.
|
| 387 |
+
:param mesh: Mesh instance to export
|
| 388 |
+
:param save_path: Optional path to save the GLB file. If None, a temporary file will be created.
|
| 389 |
+
:param texture_map: Optional PIL.Image to use as the texture. If None, a default texture will be used.
|
| 390 |
+
:return: Path to the exported GLB file.
|
| 391 |
+
"""
|
| 392 |
+
# 由于传入的mesh一定是process过的,所以断言确保是单个part且有UV
|
| 393 |
+
assert not mesh.has_multi_parts, "Mesh should be processed and merged to single part"
|
| 394 |
+
assert mesh.has_uv, "Mesh should have UV mapping after processing"
|
| 395 |
+
|
| 396 |
+
if save_path is None:
|
| 397 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
| 398 |
+
save_path = temp_file.name
|
| 399 |
+
temp_file.close()
|
| 400 |
+
|
| 401 |
+
# 创建材质
|
| 402 |
+
if texture_map is not None:
|
| 403 |
+
if type(texture_map) is np.ndarray:
|
| 404 |
+
texture_map = Image.fromarray(texture_map)
|
| 405 |
+
assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image"
|
| 406 |
+
texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB")
|
| 407 |
+
material = trimesh.visual.texture.SimpleMaterial(image=texture_map)
|
| 408 |
+
else:
|
| 409 |
+
default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200))
|
| 410 |
+
material = trimesh.visual.texture.SimpleMaterial(image=default_texture)
|
| 411 |
+
|
| 412 |
+
# If vmapping exists (processed by xatlas), need to rebuild vertices to match UV layout
|
| 413 |
+
if hasattr(mesh, '_vmapping') and mesh._vmapping is not None:
|
| 414 |
+
# Use xatlas-generated UV layout to rebuild mesh
|
| 415 |
+
vertices = mesh.v_pos[mesh._vmapping].cpu().numpy()
|
| 416 |
+
faces = mesh.t_tex_idx.cpu().numpy()
|
| 417 |
+
uvs = mesh.v_tex.cpu().numpy()
|
| 418 |
+
else:
|
| 419 |
+
# Original UV mapping, directly use original vertices and faces
|
| 420 |
+
vertices = mesh.v_pos.cpu().numpy()
|
| 421 |
+
faces = mesh.t_pos_idx.cpu().numpy()
|
| 422 |
+
uvs = mesh.v_tex.cpu().numpy()
|
| 423 |
+
|
| 424 |
+
# If upside_down transformation was applied, need to apply face orientation correction
|
| 425 |
+
if hasattr(mesh, '_upside_down_applied') and mesh._upside_down_applied:
|
| 426 |
+
faces_corrected = faces.copy()
|
| 427 |
+
faces_corrected[:, [1, 2]] = faces[:, [2, 1]] # (0,1,2) -> (0,2,1)
|
| 428 |
+
faces = faces_corrected
|
| 429 |
+
|
| 430 |
+
# Apply inverse transformation to convert vertices from rendering coordinate system back to GLB coordinate system
|
| 431 |
+
# This is the inverse of vertex_transform:
|
| 432 |
+
# vertex_transform: y = -z, z = y
|
| 433 |
+
# inverse transformation: y = z, z = -y
|
| 434 |
+
vertices_export = vertices.copy()
|
| 435 |
+
vertices_export[:, 1] = vertices[:, 2] # z → y
|
| 436 |
+
vertices_export[:, 2] = -vertices[:, 1] # -y → z
|
| 437 |
+
|
| 438 |
+
# Create Trimesh object and set texture
|
| 439 |
+
mesh_export = trimesh.Trimesh(vertices=vertices_export, faces=faces, process=False)
|
| 440 |
+
mesh_export.visual = trimesh.visual.TextureVisuals(uv=uvs, material=material)
|
| 441 |
+
|
| 442 |
+
# Export GLB file
|
| 443 |
+
mesh_export.export(file_obj=save_path, file_type='glb')
|
| 444 |
+
|
| 445 |
+
return save_path
|
| 446 |
+
|
| 447 |
+
@classmethod
|
| 448 |
+
def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()):
|
| 449 |
+
"""
|
| 450 |
+
Handle the mesh processing, which includes normalization, parts merging, and UV mapping.
|
| 451 |
+
Then render the untextured mesh from four views.
|
| 452 |
+
:param mesh_file: uploaded mesh file.
|
| 453 |
+
:param uv_tool: the UV parameterization tool, default is "xAtlas".
|
| 454 |
+
:return: rendered clay model images from four views.
|
| 455 |
+
"""
|
| 456 |
+
# load mesh (automatically merge multiple parts)
|
| 457 |
+
mesh: Mesh = cls(mesh_file, uv_tool, device, progress=progress)
|
| 458 |
+
|
| 459 |
+
progress(0.7, f"Handling transformation and normalization...")
|
| 460 |
+
# normalize mesh
|
| 461 |
+
if y2z:
|
| 462 |
+
mesh.vertex_transform() # transform vertices and normals
|
| 463 |
+
if y2x:
|
| 464 |
+
mesh.vertex_transform_y2x()
|
| 465 |
+
if z2x:
|
| 466 |
+
mesh.vertex_transform_z2x()
|
| 467 |
+
if upside_down:
|
| 468 |
+
mesh.vertex_transform_upsidedown()
|
| 469 |
+
mesh.normalize()
|
| 470 |
+
|
| 471 |
+
# render preparation
|
| 472 |
+
texture = get_pure_texture(uv_size).to(device) # tensor of shape (3, height, width)
|
| 473 |
+
# lights = setup_lights()
|
| 474 |
+
lights = None
|
| 475 |
+
mvp_matrix, w2c = get_mvp_matrix(mesh)
|
| 476 |
+
mvp_matrix = mvp_matrix.to(device)
|
| 477 |
+
w2c = w2c.to(device)
|
| 478 |
+
|
| 479 |
+
# render untextured mesh from four views
|
| 480 |
+
# images = render_views(mesh, texture, mvp_matrix, lights, img_size) # PIL.Image
|
| 481 |
+
progress(0.8, f"Rendering clay model views...")
|
| 482 |
+
print(f"Rendering geometry views...")
|
| 483 |
+
position_images, normal_images, mask_images = render_geo_views_tensor(mesh, mvp_matrix, img_size) # torch.Tensor # [batch_size, height, width, 3]
|
| 484 |
+
progress(0.9, f"Rendering geometry maps...")
|
| 485 |
+
print(f"Rendering geometry maps...")
|
| 486 |
+
position_map, normal_map = render_geo_map(mesh)
|
| 487 |
+
|
| 488 |
+
progress(1, f"Mesh processing completed.")
|
| 489 |
+
return position_map, normal_map, position_images, normal_images, mask_images.squeeze(-1), w2c, mesh, mvp_matrix, "Mesh processing completed."
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
if __name__ == '__main__':
|
| 493 |
+
glb_path = "/mnt/pfs/users/yuanze/projects/clean_seqtex/gradio/examples/multi_parts.glb"
|
| 494 |
+
position_map, normal_map, position_images, normal_images, w2c = Mesh.process(glb_path)
|
| 495 |
+
position_map.save("position_map.png")
|
| 496 |
+
normal_map.save("normal_map.png")
|
| 497 |
+
|
| 498 |
+
# 将 [-1, 1] 范围的normal_images save PIL
|
| 499 |
+
# normal_images = rearrange(normal_images, "B H W C -> B C H W")
|
| 500 |
+
# save_image(normal_images, "normal_images.png", normalize=True, value_range=(-1, 1))
|
utils/pipeline_controlnet_union_sd_xl.py
ADDED
|
@@ -0,0 +1,1397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import os
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import PIL.Image
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import gradio as gr
|
| 25 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer,CLIPImageProcessor,CLIPVisionModelWithProjection
|
| 26 |
+
|
| 27 |
+
from diffusers.utils.import_utils import is_invisible_watermark_available
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 30 |
+
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin,IPAdapterMixin
|
| 31 |
+
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel,ImageProjection
|
| 32 |
+
from .controlnet_union import ControlNetModel_Union
|
| 33 |
+
from diffusers.models.attention_processor import (
|
| 34 |
+
AttnProcessor2_0,
|
| 35 |
+
LoRAAttnProcessor2_0,
|
| 36 |
+
LoRAXFormersAttnProcessor,
|
| 37 |
+
XFormersAttnProcessor,
|
| 38 |
+
)
|
| 39 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 40 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 41 |
+
from diffusers.utils import (
|
| 42 |
+
is_accelerate_available,
|
| 43 |
+
is_accelerate_version,
|
| 44 |
+
logging,
|
| 45 |
+
replace_example_docstring,
|
| 46 |
+
)
|
| 47 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
| 48 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 49 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
| 50 |
+
|
| 51 |
+
if is_invisible_watermark_available():
|
| 52 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
| 53 |
+
|
| 54 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
EXAMPLE_DOC_STRING = """
|
| 61 |
+
Examples:
|
| 62 |
+
```py
|
| 63 |
+
>>> # !pip install opencv-python transformers accelerate
|
| 64 |
+
>>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
|
| 65 |
+
>>> from diffusers.utils import load_image
|
| 66 |
+
>>> import numpy as np
|
| 67 |
+
>>> import torch
|
| 68 |
+
|
| 69 |
+
>>> import cv2
|
| 70 |
+
>>> from PIL import Image
|
| 71 |
+
|
| 72 |
+
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
| 73 |
+
>>> negative_prompt = "low quality, bad quality, sketches"
|
| 74 |
+
|
| 75 |
+
>>> # download an image
|
| 76 |
+
>>> image = load_image(
|
| 77 |
+
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
| 78 |
+
... )
|
| 79 |
+
|
| 80 |
+
>>> # initialize the models and pipeline
|
| 81 |
+
>>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
|
| 82 |
+
>>> controlnet = ControlNetModel.from_pretrained(
|
| 83 |
+
... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
|
| 84 |
+
... )
|
| 85 |
+
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
| 86 |
+
>>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
| 87 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
|
| 88 |
+
... )
|
| 89 |
+
>>> pipe.enable_model_cpu_offload()
|
| 90 |
+
|
| 91 |
+
>>> # get canny image
|
| 92 |
+
>>> image = np.array(image)
|
| 93 |
+
>>> image = cv2.Canny(image, 100, 200)
|
| 94 |
+
>>> image = image[:, :, None]
|
| 95 |
+
>>> image = np.concatenate([image, image, image], axis=2)
|
| 96 |
+
>>> canny_image = Image.fromarray(image)
|
| 97 |
+
|
| 98 |
+
>>> # generate image
|
| 99 |
+
>>> image = pipe(
|
| 100 |
+
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
| 101 |
+
... ).images[0]
|
| 102 |
+
```
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class StableDiffusionXLControlNetUnionPipeline(
|
| 107 |
+
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin,IPAdapterMixin
|
| 108 |
+
):
|
| 109 |
+
r"""
|
| 110 |
+
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
|
| 111 |
+
|
| 112 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 113 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 114 |
+
|
| 115 |
+
The pipeline also inherits the following loading methods:
|
| 116 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 117 |
+
- [`loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 118 |
+
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
vae ([`AutoencoderKL`]):
|
| 122 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 123 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 124 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 125 |
+
text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
|
| 126 |
+
Second frozen text-encoder
|
| 127 |
+
([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
|
| 128 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 129 |
+
A `CLIPTokenizer` to tokenize text.
|
| 130 |
+
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
|
| 131 |
+
A `CLIPTokenizer` to tokenize text.
|
| 132 |
+
unet ([`UNet2DConditionModel`]):
|
| 133 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 134 |
+
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
|
| 135 |
+
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
| 136 |
+
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
| 137 |
+
additional conditioning.
|
| 138 |
+
scheduler ([`SchedulerMixin`]):
|
| 139 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 140 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 141 |
+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
| 142 |
+
Whether the negative prompt embeddings should always be set to 0. Also see the config of
|
| 143 |
+
`stabilityai/stable-diffusion-xl-base-1-0`.
|
| 144 |
+
add_watermarker (`bool`, *optional*):
|
| 145 |
+
Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
|
| 146 |
+
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
|
| 147 |
+
watermarker is used.
|
| 148 |
+
"""
|
| 149 |
+
model_cpu_offload_seq = (
|
| 150 |
+
"text_encoder->text_encoder_2->image_encoder->unet->vae" # leave controlnet out on purpose because it iterates with unet
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
vae: AutoencoderKL,
|
| 156 |
+
text_encoder: CLIPTextModel,
|
| 157 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 158 |
+
tokenizer: CLIPTokenizer,
|
| 159 |
+
tokenizer_2: CLIPTokenizer,
|
| 160 |
+
unet: UNet2DConditionModel,
|
| 161 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
| 162 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 163 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 164 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 165 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 166 |
+
add_watermarker: Optional[bool] = None,
|
| 167 |
+
):
|
| 168 |
+
super().__init__()
|
| 169 |
+
|
| 170 |
+
if isinstance(controlnet, (list, tuple)):
|
| 171 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 172 |
+
|
| 173 |
+
self.register_modules(
|
| 174 |
+
vae=vae,
|
| 175 |
+
text_encoder=text_encoder,
|
| 176 |
+
text_encoder_2=text_encoder_2,
|
| 177 |
+
tokenizer=tokenizer,
|
| 178 |
+
tokenizer_2=tokenizer_2,
|
| 179 |
+
unet=unet,
|
| 180 |
+
controlnet=controlnet,
|
| 181 |
+
scheduler=scheduler,
|
| 182 |
+
feature_extractor=feature_extractor,
|
| 183 |
+
image_encoder=image_encoder,
|
| 184 |
+
)
|
| 185 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 186 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
| 187 |
+
self.control_image_processor = VaeImageProcessor(
|
| 188 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 189 |
+
)
|
| 190 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
| 191 |
+
|
| 192 |
+
if add_watermarker:
|
| 193 |
+
self.watermark = StableDiffusionXLWatermarker()
|
| 194 |
+
else:
|
| 195 |
+
self.watermark = None
|
| 196 |
+
|
| 197 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 198 |
+
|
| 199 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
| 200 |
+
def enable_vae_slicing(self):
|
| 201 |
+
r"""
|
| 202 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 203 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 204 |
+
"""
|
| 205 |
+
self.vae.enable_slicing()
|
| 206 |
+
|
| 207 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
| 208 |
+
def disable_vae_slicing(self):
|
| 209 |
+
r"""
|
| 210 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 211 |
+
computing decoding in one step.
|
| 212 |
+
"""
|
| 213 |
+
self.vae.disable_slicing()
|
| 214 |
+
|
| 215 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
| 216 |
+
def enable_vae_tiling(self):
|
| 217 |
+
r"""
|
| 218 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 219 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 220 |
+
processing larger images.
|
| 221 |
+
"""
|
| 222 |
+
self.vae.enable_tiling()
|
| 223 |
+
|
| 224 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
| 225 |
+
def disable_vae_tiling(self):
|
| 226 |
+
r"""
|
| 227 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 228 |
+
computing decoding in one step.
|
| 229 |
+
"""
|
| 230 |
+
self.vae.disable_tiling()
|
| 231 |
+
|
| 232 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
| 233 |
+
def encode_prompt(
|
| 234 |
+
self,
|
| 235 |
+
prompt: str,
|
| 236 |
+
prompt_2: Optional[str] = None,
|
| 237 |
+
device: Optional[torch.device] = None,
|
| 238 |
+
num_images_per_prompt: int = 1,
|
| 239 |
+
do_classifier_free_guidance: bool = True,
|
| 240 |
+
negative_prompt: Optional[str] = None,
|
| 241 |
+
negative_prompt_2: Optional[str] = None,
|
| 242 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 243 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 244 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 245 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 246 |
+
lora_scale: Optional[float] = None,
|
| 247 |
+
):
|
| 248 |
+
r"""
|
| 249 |
+
Encodes the prompt into text encoder hidden states.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 253 |
+
prompt to be encoded
|
| 254 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 255 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 256 |
+
used in both text-encoders
|
| 257 |
+
device: (`torch.device`):
|
| 258 |
+
torch device
|
| 259 |
+
num_images_per_prompt (`int`):
|
| 260 |
+
number of images that should be generated per prompt
|
| 261 |
+
do_classifier_free_guidance (`bool`):
|
| 262 |
+
whether to use classifier free guidance or not
|
| 263 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 264 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 265 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 266 |
+
less than `1`).
|
| 267 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 268 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 269 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
| 270 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 271 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 272 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 273 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 274 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 275 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 276 |
+
argument.
|
| 277 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 278 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 279 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 280 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 281 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 282 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 283 |
+
input argument.
|
| 284 |
+
lora_scale (`float`, *optional*):
|
| 285 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 286 |
+
"""
|
| 287 |
+
device = device or self._execution_device
|
| 288 |
+
|
| 289 |
+
# set lora scale so that monkey patched LoRA
|
| 290 |
+
# function of text encoder can correctly access it
|
| 291 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin,):
|
| 292 |
+
self._lora_scale = lora_scale
|
| 293 |
+
|
| 294 |
+
# dynamically adjust the LoRA scale
|
| 295 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 296 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
| 297 |
+
|
| 298 |
+
if prompt is not None and isinstance(prompt, str):
|
| 299 |
+
batch_size = 1
|
| 300 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 301 |
+
batch_size = len(prompt)
|
| 302 |
+
else:
|
| 303 |
+
batch_size = prompt_embeds.shape[0]
|
| 304 |
+
|
| 305 |
+
# Define tokenizers and text encoders
|
| 306 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
| 307 |
+
text_encoders = (
|
| 308 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if prompt_embeds is None:
|
| 312 |
+
prompt_2 = prompt_2 or prompt
|
| 313 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
| 314 |
+
prompt_embeds_list = []
|
| 315 |
+
prompts = [prompt, prompt_2]
|
| 316 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
| 317 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 318 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
| 319 |
+
|
| 320 |
+
text_inputs = tokenizer(
|
| 321 |
+
prompt,
|
| 322 |
+
padding="max_length",
|
| 323 |
+
max_length=tokenizer.model_max_length,
|
| 324 |
+
truncation=True,
|
| 325 |
+
return_tensors="pt",
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
text_input_ids = text_inputs.input_ids
|
| 329 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 330 |
+
|
| 331 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 332 |
+
text_input_ids, untruncated_ids
|
| 333 |
+
):
|
| 334 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
| 335 |
+
logger.warning(
|
| 336 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 337 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
prompt_embeds = text_encoder(
|
| 341 |
+
text_input_ids.to(device),
|
| 342 |
+
output_hidden_states=True,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 346 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 347 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 348 |
+
|
| 349 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 350 |
+
|
| 351 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 352 |
+
|
| 353 |
+
# get unconditional embeddings for classifier free guidance
|
| 354 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
| 355 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
| 356 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
| 357 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
| 358 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 359 |
+
negative_prompt = negative_prompt or ""
|
| 360 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
| 361 |
+
|
| 362 |
+
uncond_tokens: List[str]
|
| 363 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 364 |
+
raise TypeError(
|
| 365 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 366 |
+
f" {type(prompt)}."
|
| 367 |
+
)
|
| 368 |
+
elif isinstance(negative_prompt, str):
|
| 369 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
| 370 |
+
elif batch_size != len(negative_prompt):
|
| 371 |
+
raise ValueError(
|
| 372 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 373 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 374 |
+
" the batch size of `prompt`."
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
| 378 |
+
|
| 379 |
+
negative_prompt_embeds_list = []
|
| 380 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
| 381 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 382 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
| 383 |
+
|
| 384 |
+
max_length = prompt_embeds.shape[1]
|
| 385 |
+
uncond_input = tokenizer(
|
| 386 |
+
negative_prompt,
|
| 387 |
+
padding="max_length",
|
| 388 |
+
max_length=max_length,
|
| 389 |
+
truncation=True,
|
| 390 |
+
return_tensors="pt",
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
negative_prompt_embeds = text_encoder(
|
| 394 |
+
uncond_input.input_ids.to(device),
|
| 395 |
+
output_hidden_states=True,
|
| 396 |
+
)
|
| 397 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 398 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
| 399 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
| 400 |
+
|
| 401 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
| 402 |
+
|
| 403 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
| 404 |
+
|
| 405 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
| 406 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 407 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 408 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 409 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 410 |
+
|
| 411 |
+
if do_classifier_free_guidance:
|
| 412 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 413 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 414 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
| 415 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 416 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 417 |
+
|
| 418 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
| 419 |
+
bs_embed * num_images_per_prompt, -1
|
| 420 |
+
)
|
| 421 |
+
if do_classifier_free_guidance:
|
| 422 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
| 423 |
+
bs_embed * num_images_per_prompt, -1
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 427 |
+
|
| 428 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 429 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 430 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 431 |
+
|
| 432 |
+
if not isinstance(image, torch.Tensor):
|
| 433 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 434 |
+
|
| 435 |
+
image = image.to(device=device, dtype=dtype)
|
| 436 |
+
if output_hidden_states:
|
| 437 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 438 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 439 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 440 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 441 |
+
).hidden_states[-2]
|
| 442 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 443 |
+
num_images_per_prompt, dim=0
|
| 444 |
+
)
|
| 445 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 446 |
+
else:
|
| 447 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 448 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 449 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 450 |
+
|
| 451 |
+
return image_embeds, uncond_image_embeds
|
| 452 |
+
|
| 453 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 454 |
+
def prepare_ip_adapter_image_embeds(
|
| 455 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 456 |
+
):
|
| 457 |
+
image_embeds = []
|
| 458 |
+
if do_classifier_free_guidance:
|
| 459 |
+
negative_image_embeds = []
|
| 460 |
+
if ip_adapter_image_embeds is None:
|
| 461 |
+
if not isinstance(ip_adapter_image, list):
|
| 462 |
+
ip_adapter_image = [ip_adapter_image]
|
| 463 |
+
|
| 464 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 470 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 471 |
+
):
|
| 472 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 473 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 474 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 478 |
+
if do_classifier_free_guidance:
|
| 479 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 480 |
+
else:
|
| 481 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 482 |
+
if do_classifier_free_guidance:
|
| 483 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 484 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 485 |
+
image_embeds.append(single_image_embeds)
|
| 486 |
+
|
| 487 |
+
ip_adapter_image_embeds = []
|
| 488 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 489 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 490 |
+
if do_classifier_free_guidance:
|
| 491 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 492 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 493 |
+
|
| 494 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 495 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 496 |
+
|
| 497 |
+
return ip_adapter_image_embeds
|
| 498 |
+
|
| 499 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 500 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 501 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 502 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 503 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 504 |
+
# and should be between [0, 1]
|
| 505 |
+
|
| 506 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 507 |
+
extra_step_kwargs = {}
|
| 508 |
+
if accepts_eta:
|
| 509 |
+
extra_step_kwargs["eta"] = eta
|
| 510 |
+
|
| 511 |
+
# check if the scheduler accepts generator
|
| 512 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 513 |
+
if accepts_generator:
|
| 514 |
+
extra_step_kwargs["generator"] = generator
|
| 515 |
+
return extra_step_kwargs
|
| 516 |
+
|
| 517 |
+
def check_inputs(
|
| 518 |
+
self,
|
| 519 |
+
prompt,
|
| 520 |
+
prompt_2,
|
| 521 |
+
image,
|
| 522 |
+
callback_steps,
|
| 523 |
+
negative_prompt=None,
|
| 524 |
+
negative_prompt_2=None,
|
| 525 |
+
prompt_embeds=None,
|
| 526 |
+
negative_prompt_embeds=None,
|
| 527 |
+
pooled_prompt_embeds=None,
|
| 528 |
+
negative_pooled_prompt_embeds=None,
|
| 529 |
+
controlnet_conditioning_scale=1.0,
|
| 530 |
+
control_guidance_start=0.0,
|
| 531 |
+
control_guidance_end=1.0,
|
| 532 |
+
ip_adapter_image=None,
|
| 533 |
+
ip_adapter_image_embeds=None,
|
| 534 |
+
):
|
| 535 |
+
if (callback_steps is None) or (
|
| 536 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 537 |
+
):
|
| 538 |
+
raise ValueError(
|
| 539 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 540 |
+
f" {type(callback_steps)}."
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
if prompt is not None and prompt_embeds is not None:
|
| 544 |
+
raise ValueError(
|
| 545 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 546 |
+
" only forward one of the two."
|
| 547 |
+
)
|
| 548 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 549 |
+
raise ValueError(
|
| 550 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 551 |
+
" only forward one of the two."
|
| 552 |
+
)
|
| 553 |
+
elif prompt is None and prompt_embeds is None:
|
| 554 |
+
raise ValueError(
|
| 555 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 556 |
+
)
|
| 557 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 558 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 559 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 560 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 561 |
+
|
| 562 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 563 |
+
raise ValueError(
|
| 564 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 565 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 566 |
+
)
|
| 567 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 568 |
+
raise ValueError(
|
| 569 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 570 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 574 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 575 |
+
raise ValueError(
|
| 576 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 577 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 578 |
+
f" {negative_prompt_embeds.shape}."
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 582 |
+
raise ValueError(
|
| 583 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 587 |
+
raise ValueError(
|
| 588 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# `prompt` needs more sophisticated handling when there are multiple
|
| 592 |
+
# conditionings.
|
| 593 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 594 |
+
if isinstance(prompt, list):
|
| 595 |
+
logger.warning(
|
| 596 |
+
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
| 597 |
+
" prompts. The conditionings will be fixed across the prompts."
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# Check `image`
|
| 601 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
| 602 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
| 603 |
+
)
|
| 604 |
+
if (
|
| 605 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 606 |
+
or is_compiled
|
| 607 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 608 |
+
):
|
| 609 |
+
self.check_image(image, prompt, prompt_embeds)
|
| 610 |
+
elif (
|
| 611 |
+
isinstance(self.controlnet, ControlNetModel_Union)
|
| 612 |
+
or is_compiled
|
| 613 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
|
| 614 |
+
):
|
| 615 |
+
self.check_image(image, prompt, prompt_embeds)
|
| 616 |
+
elif (
|
| 617 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 618 |
+
or is_compiled
|
| 619 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 620 |
+
):
|
| 621 |
+
if not isinstance(image, list):
|
| 622 |
+
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
| 623 |
+
|
| 624 |
+
# When `image` is a nested list:
|
| 625 |
+
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
| 626 |
+
elif any(isinstance(i, list) for i in image):
|
| 627 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
| 628 |
+
elif len(image) != len(self.controlnet.nets):
|
| 629 |
+
raise ValueError(
|
| 630 |
+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
for image_ in image:
|
| 634 |
+
self.check_image(image_, prompt, prompt_embeds)
|
| 635 |
+
else:
|
| 636 |
+
assert False
|
| 637 |
+
|
| 638 |
+
# Check `controlnet_conditioning_scale`
|
| 639 |
+
if (
|
| 640 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 641 |
+
or is_compiled
|
| 642 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 643 |
+
):
|
| 644 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
| 645 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
| 646 |
+
|
| 647 |
+
elif (
|
| 648 |
+
isinstance(self.controlnet, ControlNetModel_Union)
|
| 649 |
+
or is_compiled
|
| 650 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
|
| 651 |
+
):
|
| 652 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
| 653 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
| 654 |
+
|
| 655 |
+
elif (
|
| 656 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 657 |
+
or is_compiled
|
| 658 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 659 |
+
):
|
| 660 |
+
if isinstance(controlnet_conditioning_scale, list):
|
| 661 |
+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
| 662 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
| 663 |
+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
| 664 |
+
self.controlnet.nets
|
| 665 |
+
):
|
| 666 |
+
raise ValueError(
|
| 667 |
+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
| 668 |
+
" the same length as the number of controlnets"
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
assert False
|
| 672 |
+
|
| 673 |
+
if not isinstance(control_guidance_start, (tuple, list)):
|
| 674 |
+
control_guidance_start = [control_guidance_start]
|
| 675 |
+
|
| 676 |
+
if not isinstance(control_guidance_end, (tuple, list)):
|
| 677 |
+
control_guidance_end = [control_guidance_end]
|
| 678 |
+
|
| 679 |
+
if len(control_guidance_start) != len(control_guidance_end):
|
| 680 |
+
raise ValueError(
|
| 681 |
+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 685 |
+
if len(control_guidance_start) != len(self.controlnet.nets):
|
| 686 |
+
raise ValueError(
|
| 687 |
+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
for start, end in zip(control_guidance_start, control_guidance_end):
|
| 691 |
+
if start >= end:
|
| 692 |
+
raise ValueError(
|
| 693 |
+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
| 694 |
+
)
|
| 695 |
+
if start < 0.0:
|
| 696 |
+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
| 697 |
+
if end > 1.0:
|
| 698 |
+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
| 699 |
+
|
| 700 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| 701 |
+
raise ValueError(
|
| 702 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
if ip_adapter_image_embeds is not None:
|
| 706 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 707 |
+
raise ValueError(
|
| 708 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| 709 |
+
)
|
| 710 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| 711 |
+
raise ValueError(
|
| 712 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
| 716 |
+
def check_image(self, image, prompt, prompt_embeds):
|
| 717 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
| 718 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
| 719 |
+
image_is_np = isinstance(image, np.ndarray)
|
| 720 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
| 721 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
| 722 |
+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
| 723 |
+
|
| 724 |
+
if (
|
| 725 |
+
not image_is_pil
|
| 726 |
+
and not image_is_tensor
|
| 727 |
+
and not image_is_np
|
| 728 |
+
and not image_is_pil_list
|
| 729 |
+
and not image_is_tensor_list
|
| 730 |
+
and not image_is_np_list
|
| 731 |
+
):
|
| 732 |
+
raise TypeError(
|
| 733 |
+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if image_is_pil:
|
| 737 |
+
image_batch_size = 1
|
| 738 |
+
else:
|
| 739 |
+
image_batch_size = len(image)
|
| 740 |
+
|
| 741 |
+
if prompt is not None and isinstance(prompt, str):
|
| 742 |
+
prompt_batch_size = 1
|
| 743 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 744 |
+
prompt_batch_size = len(prompt)
|
| 745 |
+
elif prompt_embeds is not None:
|
| 746 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
| 747 |
+
|
| 748 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
| 749 |
+
raise ValueError(
|
| 750 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
| 754 |
+
def prepare_image(
|
| 755 |
+
self,
|
| 756 |
+
image,
|
| 757 |
+
width,
|
| 758 |
+
height,
|
| 759 |
+
batch_size,
|
| 760 |
+
num_images_per_prompt,
|
| 761 |
+
device,
|
| 762 |
+
dtype,
|
| 763 |
+
do_classifier_free_guidance=False,
|
| 764 |
+
guess_mode=False,
|
| 765 |
+
):
|
| 766 |
+
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
| 767 |
+
image_batch_size = image.shape[0]
|
| 768 |
+
|
| 769 |
+
if image_batch_size == 1:
|
| 770 |
+
repeat_by = batch_size
|
| 771 |
+
else:
|
| 772 |
+
# image batch size is the same as prompt batch size
|
| 773 |
+
repeat_by = num_images_per_prompt
|
| 774 |
+
|
| 775 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 776 |
+
|
| 777 |
+
image = image.to(device=device, dtype=dtype)
|
| 778 |
+
|
| 779 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 780 |
+
image = torch.cat([image] * 2)
|
| 781 |
+
|
| 782 |
+
return image
|
| 783 |
+
|
| 784 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 785 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 786 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 787 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 788 |
+
raise ValueError(
|
| 789 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 790 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if latents is None:
|
| 794 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 795 |
+
else:
|
| 796 |
+
latents = latents.to(device)
|
| 797 |
+
|
| 798 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 799 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 800 |
+
return latents
|
| 801 |
+
|
| 802 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
| 803 |
+
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
| 804 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 805 |
+
|
| 806 |
+
passed_add_embed_dim = (
|
| 807 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
| 808 |
+
)
|
| 809 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
| 810 |
+
|
| 811 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 812 |
+
raise ValueError(
|
| 813 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
| 817 |
+
return add_time_ids
|
| 818 |
+
|
| 819 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
| 820 |
+
def upcast_vae(self):
|
| 821 |
+
dtype = self.vae.dtype
|
| 822 |
+
self.vae.to(dtype=torch.float32)
|
| 823 |
+
use_torch_2_0_or_xformers = isinstance(
|
| 824 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
| 825 |
+
(
|
| 826 |
+
AttnProcessor2_0,
|
| 827 |
+
XFormersAttnProcessor,
|
| 828 |
+
LoRAXFormersAttnProcessor,
|
| 829 |
+
LoRAAttnProcessor2_0,
|
| 830 |
+
),
|
| 831 |
+
)
|
| 832 |
+
# if xformers or torch_2_0 is used attention block does not need
|
| 833 |
+
# to be in float32 which can save lots of memory
|
| 834 |
+
if use_torch_2_0_or_xformers:
|
| 835 |
+
self.vae.post_quant_conv.to(dtype)
|
| 836 |
+
self.vae.decoder.conv_in.to(dtype)
|
| 837 |
+
self.vae.decoder.mid_block.to(dtype)
|
| 838 |
+
|
| 839 |
+
@torch.no_grad()
|
| 840 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 841 |
+
def __call__(
|
| 842 |
+
self,
|
| 843 |
+
prompt: Union[str, List[str]] = None,
|
| 844 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 845 |
+
image_list: PipelineImageInput = None,
|
| 846 |
+
height: Optional[int] = None,
|
| 847 |
+
width: Optional[int] = None,
|
| 848 |
+
num_inference_steps: int = 50,
|
| 849 |
+
guidance_scale: float = 5.0,
|
| 850 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 851 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 852 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 853 |
+
eta: float = 0.0,
|
| 854 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 855 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 856 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 857 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 858 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 859 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 860 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 861 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 862 |
+
output_type: Optional[str] = "pil",
|
| 863 |
+
return_dict: bool = True,
|
| 864 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 865 |
+
callback_steps: int = 1,
|
| 866 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 867 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 868 |
+
guess_mode: bool = False,
|
| 869 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 870 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 871 |
+
original_size: Tuple[int, int] = None,
|
| 872 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 873 |
+
target_size: Tuple[int, int] = None,
|
| 874 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 875 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 876 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 877 |
+
union_control = False,
|
| 878 |
+
union_control_type = None,
|
| 879 |
+
progress=gr.Progress(),
|
| 880 |
+
|
| 881 |
+
):
|
| 882 |
+
r"""
|
| 883 |
+
The call function to the pipeline for generation.
|
| 884 |
+
|
| 885 |
+
Args:
|
| 886 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 887 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 888 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 889 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 890 |
+
used in both text-encoders.
|
| 891 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
| 892 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
| 893 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
| 894 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
| 895 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
| 896 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
| 897 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
| 898 |
+
input to a single ControlNet.
|
| 899 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 900 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
| 901 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 902 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 903 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 904 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
| 905 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 906 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 907 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 908 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 909 |
+
expense of slower inference.
|
| 910 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 911 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 912 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 913 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 914 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 915 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 916 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 917 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
| 918 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
| 919 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 920 |
+
The number of images to generate per prompt.
|
| 921 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 922 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 923 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 924 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 925 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 926 |
+
generation deterministic.
|
| 927 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 928 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 929 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 930 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 931 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 932 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 933 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 934 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 935 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 936 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 937 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 938 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 939 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
| 940 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 941 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
| 942 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
| 943 |
+
argument.
|
| 944 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 945 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 946 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 947 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 948 |
+
plain tuple.
|
| 949 |
+
callback (`Callable`, *optional*):
|
| 950 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 951 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 952 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 953 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 954 |
+
every step.
|
| 955 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 956 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 957 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 958 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 959 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
| 960 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
| 961 |
+
the corresponding scale as a list.
|
| 962 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
| 963 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
| 964 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
| 965 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
| 966 |
+
The percentage of total steps at which the ControlNet starts applying.
|
| 967 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 968 |
+
The percentage of total steps at which the ControlNet stops applying.
|
| 969 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 970 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 971 |
+
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
| 972 |
+
explained in section 2.2 of
|
| 973 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 974 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 975 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 976 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 977 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 978 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 979 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 980 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
| 981 |
+
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
| 982 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 983 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 984 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
| 985 |
+
micro-conditioning as explained in section 2.2 of
|
| 986 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 987 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 988 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 989 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
| 990 |
+
micro-conditioning as explained in section 2.2 of
|
| 991 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 992 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 993 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 994 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
| 995 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 996 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 997 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 998 |
+
|
| 999 |
+
Examples:
|
| 1000 |
+
|
| 1001 |
+
Returns:
|
| 1002 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1003 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 1004 |
+
otherwise a `tuple` is returned containing the output images.
|
| 1005 |
+
"""
|
| 1006 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 1007 |
+
|
| 1008 |
+
# align format for control guidance
|
| 1009 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 1010 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 1011 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 1012 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 1013 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 1014 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 1015 |
+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
|
| 1016 |
+
control_guidance_end
|
| 1017 |
+
]
|
| 1018 |
+
|
| 1019 |
+
# 1. Check inputs. Raise error if not correct
|
| 1020 |
+
for image in image_list:
|
| 1021 |
+
if image:
|
| 1022 |
+
self.check_inputs(
|
| 1023 |
+
prompt,
|
| 1024 |
+
prompt_2,
|
| 1025 |
+
image,
|
| 1026 |
+
callback_steps,
|
| 1027 |
+
negative_prompt,
|
| 1028 |
+
negative_prompt_2,
|
| 1029 |
+
prompt_embeds,
|
| 1030 |
+
negative_prompt_embeds,
|
| 1031 |
+
pooled_prompt_embeds,
|
| 1032 |
+
negative_pooled_prompt_embeds,
|
| 1033 |
+
controlnet_conditioning_scale,
|
| 1034 |
+
control_guidance_start,
|
| 1035 |
+
control_guidance_end,
|
| 1036 |
+
ip_adapter_image,
|
| 1037 |
+
ip_adapter_image_embeds,
|
| 1038 |
+
)
|
| 1039 |
+
# 2. Define call parameters
|
| 1040 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1041 |
+
batch_size = 1
|
| 1042 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1043 |
+
batch_size = len(prompt)
|
| 1044 |
+
else:
|
| 1045 |
+
batch_size = prompt_embeds.shape[0]
|
| 1046 |
+
|
| 1047 |
+
device = self._execution_device
|
| 1048 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 1049 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 1050 |
+
# corresponds to doing no classifier free guidance.
|
| 1051 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 1052 |
+
|
| 1053 |
+
global_pool_conditions = (
|
| 1054 |
+
controlnet.config.global_pool_conditions
|
| 1055 |
+
)
|
| 1056 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 1057 |
+
|
| 1058 |
+
# 3. Encode input prompt
|
| 1059 |
+
text_encoder_lora_scale = (
|
| 1060 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
| 1061 |
+
)
|
| 1062 |
+
(
|
| 1063 |
+
prompt_embeds,
|
| 1064 |
+
negative_prompt_embeds,
|
| 1065 |
+
pooled_prompt_embeds,
|
| 1066 |
+
negative_pooled_prompt_embeds,
|
| 1067 |
+
) = self.encode_prompt(
|
| 1068 |
+
prompt,
|
| 1069 |
+
prompt_2,
|
| 1070 |
+
device,
|
| 1071 |
+
num_images_per_prompt,
|
| 1072 |
+
do_classifier_free_guidance,
|
| 1073 |
+
negative_prompt,
|
| 1074 |
+
negative_prompt_2,
|
| 1075 |
+
prompt_embeds=prompt_embeds,
|
| 1076 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1077 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1078 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1079 |
+
lora_scale=text_encoder_lora_scale,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
# 3.2 Encode ip_adapter_image
|
| 1083 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1084 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1085 |
+
ip_adapter_image,
|
| 1086 |
+
ip_adapter_image_embeds,
|
| 1087 |
+
device,
|
| 1088 |
+
batch_size * num_images_per_prompt,
|
| 1089 |
+
do_classifier_free_guidance,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
# 4. Prepare image
|
| 1093 |
+
assert isinstance(controlnet, ControlNetModel_Union)
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
for idx in range(len(image_list)):
|
| 1097 |
+
if image_list[idx]:
|
| 1098 |
+
image = self.prepare_image(
|
| 1099 |
+
image=image_list[idx],
|
| 1100 |
+
width=width,
|
| 1101 |
+
height=height,
|
| 1102 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 1103 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1104 |
+
device=device,
|
| 1105 |
+
dtype=controlnet.dtype,
|
| 1106 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 1107 |
+
guess_mode=guess_mode,
|
| 1108 |
+
)
|
| 1109 |
+
height, width = image.shape[-2:]
|
| 1110 |
+
image_list[idx] = image
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
# 5. Prepare timesteps
|
| 1114 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 1115 |
+
timesteps = self.scheduler.timesteps
|
| 1116 |
+
|
| 1117 |
+
# 6. Prepare latent variables
|
| 1118 |
+
num_channels_latents = self.unet.config.in_channels
|
| 1119 |
+
latents = self.prepare_latents(
|
| 1120 |
+
batch_size * num_images_per_prompt,
|
| 1121 |
+
num_channels_latents,
|
| 1122 |
+
height,
|
| 1123 |
+
width,
|
| 1124 |
+
prompt_embeds.dtype,
|
| 1125 |
+
device,
|
| 1126 |
+
generator,
|
| 1127 |
+
latents,
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1131 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1132 |
+
|
| 1133 |
+
# 7.1 Create tensor stating which controlnets to keep
|
| 1134 |
+
controlnet_keep = []
|
| 1135 |
+
for i in range(len(timesteps)):
|
| 1136 |
+
keeps = [
|
| 1137 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 1138 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 1139 |
+
]
|
| 1140 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) or isinstance(controlnet, ControlNetModel_Union) else keeps)
|
| 1141 |
+
|
| 1142 |
+
# 7.2 Prepare added time ids & embeddings
|
| 1143 |
+
for image in image_list:
|
| 1144 |
+
if isinstance(image, torch.Tensor):
|
| 1145 |
+
original_size = original_size or image.shape[-2:]
|
| 1146 |
+
|
| 1147 |
+
target_size = target_size or (height, width)
|
| 1148 |
+
# print(original_size)
|
| 1149 |
+
# print(target_size)
|
| 1150 |
+
add_text_embeds = pooled_prompt_embeds
|
| 1151 |
+
add_time_ids = self._get_add_time_ids(
|
| 1152 |
+
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 1156 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 1157 |
+
negative_original_size,
|
| 1158 |
+
negative_crops_coords_top_left,
|
| 1159 |
+
negative_target_size,
|
| 1160 |
+
dtype=prompt_embeds.dtype,
|
| 1161 |
+
)
|
| 1162 |
+
else:
|
| 1163 |
+
negative_add_time_ids = add_time_ids
|
| 1164 |
+
|
| 1165 |
+
if do_classifier_free_guidance:
|
| 1166 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1167 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 1168 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 1169 |
+
|
| 1170 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 1171 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 1172 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 1173 |
+
|
| 1174 |
+
# 8. Denoising loop
|
| 1175 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1176 |
+
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1177 |
+
# with progress.tqdm(range(num_inference_steps), desc="Diffusing...") as progress_bar:
|
| 1178 |
+
for i, t in progress.tqdm(enumerate(timesteps), desc="Diffusing..."):
|
| 1179 |
+
# expand the latents if we are doing classifier free guidance
|
| 1180 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1181 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1182 |
+
|
| 1183 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids, \
|
| 1184 |
+
"control_type":union_control_type.reshape(1, -1).to(device, dtype=prompt_embeds.dtype).repeat(batch_size * num_images_per_prompt * 2, 1)}
|
| 1185 |
+
|
| 1186 |
+
# controlnet(s) inference
|
| 1187 |
+
if guess_mode and do_classifier_free_guidance:
|
| 1188 |
+
# Infer ControlNet only for the conditional batch.
|
| 1189 |
+
control_model_input = latents
|
| 1190 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 1191 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 1192 |
+
controlnet_added_cond_kwargs = {
|
| 1193 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
| 1194 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
| 1195 |
+
}
|
| 1196 |
+
else:
|
| 1197 |
+
control_model_input = latent_model_input
|
| 1198 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 1199 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
| 1200 |
+
|
| 1201 |
+
if isinstance(controlnet_keep[i], list):
|
| 1202 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 1203 |
+
else:
|
| 1204 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 1205 |
+
if isinstance(controlnet_cond_scale, list):
|
| 1206 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 1207 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
# print(image.shape)
|
| 1211 |
+
if isinstance(controlnet, ControlNetModel_Union):
|
| 1212 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1213 |
+
control_model_input,
|
| 1214 |
+
t,
|
| 1215 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 1216 |
+
controlnet_cond_list=image_list,
|
| 1217 |
+
conditioning_scale=cond_scale,
|
| 1218 |
+
guess_mode=guess_mode,
|
| 1219 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
| 1220 |
+
return_dict=False,
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
if guess_mode and do_classifier_free_guidance:
|
| 1224 |
+
# Infered ControlNet only for the conditional batch.
|
| 1225 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 1226 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 1227 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 1228 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 1229 |
+
|
| 1230 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1231 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
| 1232 |
+
# predict the noise residual
|
| 1233 |
+
noise_pred = self.unet(
|
| 1234 |
+
latent_model_input,
|
| 1235 |
+
t,
|
| 1236 |
+
encoder_hidden_states=prompt_embeds,
|
| 1237 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1238 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 1239 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 1240 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1241 |
+
return_dict=False,
|
| 1242 |
+
)[0]
|
| 1243 |
+
|
| 1244 |
+
# perform guidance
|
| 1245 |
+
if do_classifier_free_guidance:
|
| 1246 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1247 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1248 |
+
|
| 1249 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1250 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1251 |
+
|
| 1252 |
+
# call the callback, if provided
|
| 1253 |
+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1254 |
+
# progress_bar.update()
|
| 1255 |
+
# if callback is not None and i % callback_steps == 0:
|
| 1256 |
+
# callback(i, t, latents)
|
| 1257 |
+
|
| 1258 |
+
# manually for max memory savings
|
| 1259 |
+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
| 1260 |
+
self.upcast_vae()
|
| 1261 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 1262 |
+
|
| 1263 |
+
if not output_type == "latent":
|
| 1264 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 1265 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 1266 |
+
|
| 1267 |
+
if needs_upcasting:
|
| 1268 |
+
self.upcast_vae()
|
| 1269 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 1270 |
+
|
| 1271 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 1272 |
+
|
| 1273 |
+
# cast back to fp16 if needed
|
| 1274 |
+
if needs_upcasting:
|
| 1275 |
+
self.vae.to(dtype=torch.float16)
|
| 1276 |
+
else:
|
| 1277 |
+
image = latents
|
| 1278 |
+
|
| 1279 |
+
if not output_type == "latent":
|
| 1280 |
+
# apply watermark if available
|
| 1281 |
+
if self.watermark is not None:
|
| 1282 |
+
image = self.watermark.apply_watermark(image)
|
| 1283 |
+
|
| 1284 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1285 |
+
|
| 1286 |
+
# Offload all models
|
| 1287 |
+
self.maybe_free_model_hooks()
|
| 1288 |
+
|
| 1289 |
+
if not return_dict:
|
| 1290 |
+
return (image,)
|
| 1291 |
+
|
| 1292 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
| 1293 |
+
|
| 1294 |
+
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
| 1295 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
|
| 1296 |
+
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
| 1297 |
+
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
| 1298 |
+
# it here explicitly to be able to tell that it's coming from an SDXL
|
| 1299 |
+
# pipeline.
|
| 1300 |
+
|
| 1301 |
+
# Remove any existing hooks.
|
| 1302 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
| 1303 |
+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
| 1304 |
+
else:
|
| 1305 |
+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
|
| 1306 |
+
|
| 1307 |
+
is_model_cpu_offload = False
|
| 1308 |
+
is_sequential_cpu_offload = False
|
| 1309 |
+
recursive = False
|
| 1310 |
+
for _, component in self.components.items():
|
| 1311 |
+
if isinstance(component, torch.nn.Module):
|
| 1312 |
+
if hasattr(component, "_hf_hook"):
|
| 1313 |
+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
| 1314 |
+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
| 1315 |
+
logger.info(
|
| 1316 |
+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
| 1317 |
+
)
|
| 1318 |
+
recursive = is_sequential_cpu_offload
|
| 1319 |
+
remove_hook_from_module(component, recurse=recursive)
|
| 1320 |
+
state_dict, network_alphas = self.lora_state_dict(
|
| 1321 |
+
pretrained_model_name_or_path_or_dict,
|
| 1322 |
+
unet_config=self.unet.config,
|
| 1323 |
+
**kwargs,
|
| 1324 |
+
)
|
| 1325 |
+
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
| 1326 |
+
|
| 1327 |
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
| 1328 |
+
if len(text_encoder_state_dict) > 0:
|
| 1329 |
+
self.load_lora_into_text_encoder(
|
| 1330 |
+
text_encoder_state_dict,
|
| 1331 |
+
network_alphas=network_alphas,
|
| 1332 |
+
text_encoder=self.text_encoder,
|
| 1333 |
+
prefix="text_encoder",
|
| 1334 |
+
lora_scale=self.lora_scale,
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
| 1338 |
+
if len(text_encoder_2_state_dict) > 0:
|
| 1339 |
+
self.load_lora_into_text_encoder(
|
| 1340 |
+
text_encoder_2_state_dict,
|
| 1341 |
+
network_alphas=network_alphas,
|
| 1342 |
+
text_encoder=self.text_encoder_2,
|
| 1343 |
+
prefix="text_encoder_2",
|
| 1344 |
+
lora_scale=self.lora_scale,
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
# Offload back.
|
| 1348 |
+
if is_model_cpu_offload:
|
| 1349 |
+
self.enable_model_cpu_offload()
|
| 1350 |
+
elif is_sequential_cpu_offload:
|
| 1351 |
+
self.enable_sequential_cpu_offload()
|
| 1352 |
+
|
| 1353 |
+
@classmethod
|
| 1354 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
|
| 1355 |
+
def save_lora_weights(
|
| 1356 |
+
self,
|
| 1357 |
+
save_directory: Union[str, os.PathLike],
|
| 1358 |
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
| 1359 |
+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
| 1360 |
+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
| 1361 |
+
is_main_process: bool = True,
|
| 1362 |
+
weight_name: str = None,
|
| 1363 |
+
save_function: Callable = None,
|
| 1364 |
+
safe_serialization: bool = True,
|
| 1365 |
+
):
|
| 1366 |
+
state_dict = {}
|
| 1367 |
+
|
| 1368 |
+
def pack_weights(layers, prefix):
|
| 1369 |
+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
| 1370 |
+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
| 1371 |
+
return layers_state_dict
|
| 1372 |
+
|
| 1373 |
+
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
| 1374 |
+
raise ValueError(
|
| 1375 |
+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
| 1376 |
+
)
|
| 1377 |
+
|
| 1378 |
+
if unet_lora_layers:
|
| 1379 |
+
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
| 1380 |
+
|
| 1381 |
+
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
| 1382 |
+
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
| 1383 |
+
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
| 1384 |
+
|
| 1385 |
+
self.write_lora_layers(
|
| 1386 |
+
state_dict=state_dict,
|
| 1387 |
+
save_directory=save_directory,
|
| 1388 |
+
is_main_process=is_main_process,
|
| 1389 |
+
weight_name=weight_name,
|
| 1390 |
+
save_function=save_function,
|
| 1391 |
+
safe_serialization=safe_serialization,
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
|
| 1395 |
+
def _remove_text_encoder_monkey_patch(self):
|
| 1396 |
+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
| 1397 |
+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
utils/pipeline_stable_diffusion_switcher.py
ADDED
|
@@ -0,0 +1,1240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
import numpy as np
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import torch
|
| 20 |
+
from packaging import version
|
| 21 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 22 |
+
import torchvision.transforms.functional as TF
|
| 23 |
+
|
| 24 |
+
from diffusers.configuration_utils import FrozenDict
|
| 25 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 26 |
+
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
| 27 |
+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
| 28 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 29 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 30 |
+
from diffusers.utils import (
|
| 31 |
+
USE_PEFT_BACKEND,
|
| 32 |
+
deprecate,
|
| 33 |
+
logging,
|
| 34 |
+
replace_example_docstring,
|
| 35 |
+
scale_lora_layers,
|
| 36 |
+
unscale_lora_layers,
|
| 37 |
+
)
|
| 38 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 39 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 40 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
| 41 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
EXAMPLE_DOC_STRING = """
|
| 47 |
+
Examples:
|
| 48 |
+
```py
|
| 49 |
+
>>> import torch
|
| 50 |
+
>>> from diffusers import StableDiffusionPipeline
|
| 51 |
+
|
| 52 |
+
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
| 53 |
+
>>> pipe = pipe.to("cuda")
|
| 54 |
+
|
| 55 |
+
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
| 56 |
+
>>> image = pipe(prompt).images[0]
|
| 57 |
+
```
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def scale_latents_rm(latents):
|
| 62 |
+
latents = latents * 0.9702 - 0.5742
|
| 63 |
+
return latents
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def unscale_latents_rm(latents):
|
| 67 |
+
latents = (latents + 0.5742) / 0.9702
|
| 68 |
+
return latents
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def scale_latents_bump(latents):
|
| 72 |
+
latents = latents * 0.9462 + 0.3770
|
| 73 |
+
return latents
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def unscale_latents_bump(latents):
|
| 77 |
+
latents = (latents - 0.3770) / 0.9462
|
| 78 |
+
return latents
|
| 79 |
+
|
| 80 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 81 |
+
"""
|
| 82 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
| 83 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
| 84 |
+
"""
|
| 85 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 86 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 87 |
+
# rescale the results from guidance (fixes overexposure)
|
| 88 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 89 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 90 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 91 |
+
return noise_cfg
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def retrieve_timesteps(
|
| 95 |
+
scheduler,
|
| 96 |
+
num_inference_steps: Optional[int] = None,
|
| 97 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 98 |
+
timesteps: Optional[List[int]] = None,
|
| 99 |
+
**kwargs,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 103 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
scheduler (`SchedulerMixin`):
|
| 107 |
+
The scheduler to get timesteps from.
|
| 108 |
+
num_inference_steps (`int`):
|
| 109 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
| 110 |
+
`timesteps` must be `None`.
|
| 111 |
+
device (`str` or `torch.device`, *optional*):
|
| 112 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 113 |
+
timesteps (`List[int]`, *optional*):
|
| 114 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
| 115 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
| 116 |
+
must be `None`.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 120 |
+
second element is the number of inference steps.
|
| 121 |
+
"""
|
| 122 |
+
if timesteps is not None:
|
| 123 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 124 |
+
if not accepts_timesteps:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 127 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 128 |
+
)
|
| 129 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 130 |
+
timesteps = scheduler.timesteps
|
| 131 |
+
num_inference_steps = len(timesteps)
|
| 132 |
+
else:
|
| 133 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 134 |
+
timesteps = scheduler.timesteps
|
| 135 |
+
return timesteps, num_inference_steps
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class StableDiffusionPipeline(
|
| 139 |
+
DiffusionPipeline,
|
| 140 |
+
StableDiffusionMixin,
|
| 141 |
+
TextualInversionLoaderMixin,
|
| 142 |
+
LoraLoaderMixin,
|
| 143 |
+
IPAdapterMixin,
|
| 144 |
+
FromSingleFileMixin,
|
| 145 |
+
):
|
| 146 |
+
r"""
|
| 147 |
+
Pipeline for text-to-image generation using Stable Diffusion.
|
| 148 |
+
|
| 149 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 150 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 151 |
+
|
| 152 |
+
The pipeline also inherits the following loading methods:
|
| 153 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 154 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 155 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 156 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 157 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
vae ([`AutoencoderKL`]):
|
| 161 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 162 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 163 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 164 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 165 |
+
A `CLIPTokenizer` to tokenize text.
|
| 166 |
+
unet ([`UNet2DConditionModel`]):
|
| 167 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 168 |
+
scheduler ([`SchedulerMixin`]):
|
| 169 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 170 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 171 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 172 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 173 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 174 |
+
about a model's potential harms.
|
| 175 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 176 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 180 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
| 181 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 182 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
vae: AutoencoderKL,
|
| 187 |
+
text_encoder: CLIPTextModel,
|
| 188 |
+
tokenizer: CLIPTokenizer,
|
| 189 |
+
unet: UNet2DConditionModel,
|
| 190 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 191 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 192 |
+
feature_extractor: CLIPImageProcessor,
|
| 193 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 194 |
+
requires_safety_checker: bool = True,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
| 199 |
+
deprecation_message = (
|
| 200 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 201 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 202 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 203 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 204 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 205 |
+
" file"
|
| 206 |
+
)
|
| 207 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
| 208 |
+
new_config = dict(scheduler.config)
|
| 209 |
+
new_config["steps_offset"] = 1
|
| 210 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 211 |
+
|
| 212 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
| 213 |
+
deprecation_message = (
|
| 214 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 215 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 216 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 217 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 218 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 219 |
+
)
|
| 220 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
| 221 |
+
new_config = dict(scheduler.config)
|
| 222 |
+
new_config["clip_sample"] = False
|
| 223 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 224 |
+
|
| 225 |
+
if safety_checker is None and requires_safety_checker:
|
| 226 |
+
logger.warning(
|
| 227 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 228 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 229 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 230 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 231 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 232 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if safety_checker is not None and feature_extractor is None:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 238 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
| 242 |
+
version.parse(unet.config._diffusers_version).base_version
|
| 243 |
+
) < version.parse("0.9.0.dev0")
|
| 244 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
| 245 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
| 246 |
+
deprecation_message = (
|
| 247 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
| 248 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
| 249 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
| 250 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
| 251 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
| 252 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
| 253 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
| 254 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
| 255 |
+
" the `unet/config.json` file"
|
| 256 |
+
)
|
| 257 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
| 258 |
+
new_config = dict(unet.config)
|
| 259 |
+
new_config["sample_size"] = 64
|
| 260 |
+
unet._internal_dict = FrozenDict(new_config)
|
| 261 |
+
|
| 262 |
+
self.register_modules(
|
| 263 |
+
vae=vae,
|
| 264 |
+
text_encoder=text_encoder,
|
| 265 |
+
tokenizer=tokenizer,
|
| 266 |
+
unet=unet,
|
| 267 |
+
scheduler=scheduler,
|
| 268 |
+
safety_checker=safety_checker,
|
| 269 |
+
feature_extractor=feature_extractor,
|
| 270 |
+
image_encoder=image_encoder,
|
| 271 |
+
)
|
| 272 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 273 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 274 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 275 |
+
|
| 276 |
+
def _encode_prompt(
|
| 277 |
+
self,
|
| 278 |
+
prompt,
|
| 279 |
+
device,
|
| 280 |
+
num_images_per_prompt,
|
| 281 |
+
do_classifier_free_guidance,
|
| 282 |
+
negative_prompt=None,
|
| 283 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 284 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 285 |
+
lora_scale: Optional[float] = None,
|
| 286 |
+
**kwargs,
|
| 287 |
+
):
|
| 288 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| 289 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
| 290 |
+
|
| 291 |
+
prompt_embeds_tuple = self.encode_prompt(
|
| 292 |
+
prompt=prompt,
|
| 293 |
+
device=device,
|
| 294 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 295 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 296 |
+
negative_prompt=negative_prompt,
|
| 297 |
+
prompt_embeds=prompt_embeds,
|
| 298 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 299 |
+
lora_scale=lora_scale,
|
| 300 |
+
**kwargs,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# concatenate for backwards comp
|
| 304 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
| 305 |
+
|
| 306 |
+
return prompt_embeds
|
| 307 |
+
|
| 308 |
+
def encode_prompt(
|
| 309 |
+
self,
|
| 310 |
+
prompt,
|
| 311 |
+
device,
|
| 312 |
+
num_images_per_prompt,
|
| 313 |
+
do_classifier_free_guidance,
|
| 314 |
+
negative_prompt=None,
|
| 315 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 316 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 317 |
+
lora_scale: Optional[float] = None,
|
| 318 |
+
clip_skip: Optional[int] = None,
|
| 319 |
+
):
|
| 320 |
+
r"""
|
| 321 |
+
Encodes the prompt into text encoder hidden states.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 325 |
+
prompt to be encoded
|
| 326 |
+
device: (`torch.device`):
|
| 327 |
+
torch device
|
| 328 |
+
num_images_per_prompt (`int`):
|
| 329 |
+
number of images that should be generated per prompt
|
| 330 |
+
do_classifier_free_guidance (`bool`):
|
| 331 |
+
whether to use classifier free guidance or not
|
| 332 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 333 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 334 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 335 |
+
less than `1`).
|
| 336 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 337 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 338 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 339 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 340 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 341 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 342 |
+
argument.
|
| 343 |
+
lora_scale (`float`, *optional*):
|
| 344 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 345 |
+
clip_skip (`int`, *optional*):
|
| 346 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 347 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 348 |
+
"""
|
| 349 |
+
# set lora scale so that monkey patched LoRA
|
| 350 |
+
# function of text encoder can correctly access it
|
| 351 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
| 352 |
+
self._lora_scale = lora_scale
|
| 353 |
+
|
| 354 |
+
# dynamically adjust the LoRA scale
|
| 355 |
+
if not USE_PEFT_BACKEND:
|
| 356 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 357 |
+
else:
|
| 358 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 359 |
+
|
| 360 |
+
if prompt is not None and isinstance(prompt, str):
|
| 361 |
+
batch_size = 1
|
| 362 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 363 |
+
batch_size = len(prompt)
|
| 364 |
+
else:
|
| 365 |
+
batch_size = prompt_embeds.shape[0]
|
| 366 |
+
|
| 367 |
+
if prompt_embeds is None:
|
| 368 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 369 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 370 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 371 |
+
|
| 372 |
+
text_inputs = self.tokenizer(
|
| 373 |
+
prompt,
|
| 374 |
+
padding="max_length",
|
| 375 |
+
max_length=self.tokenizer.model_max_length,
|
| 376 |
+
truncation=True,
|
| 377 |
+
return_tensors="pt",
|
| 378 |
+
)
|
| 379 |
+
text_input_ids = text_inputs.input_ids
|
| 380 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 381 |
+
|
| 382 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 383 |
+
text_input_ids, untruncated_ids
|
| 384 |
+
):
|
| 385 |
+
removed_text = self.tokenizer.batch_decode(
|
| 386 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 387 |
+
)
|
| 388 |
+
logger.warning(
|
| 389 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 390 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 394 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 395 |
+
else:
|
| 396 |
+
attention_mask = None
|
| 397 |
+
|
| 398 |
+
if clip_skip is None:
|
| 399 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 400 |
+
prompt_embeds = prompt_embeds[0]
|
| 401 |
+
else:
|
| 402 |
+
prompt_embeds = self.text_encoder(
|
| 403 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 404 |
+
)
|
| 405 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 406 |
+
# all the hidden states from the encoder layers. Then index into
|
| 407 |
+
# the tuple to access the hidden states from the desired layer.
|
| 408 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 409 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 410 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 411 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 412 |
+
# layer.
|
| 413 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 414 |
+
|
| 415 |
+
if self.text_encoder is not None:
|
| 416 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 417 |
+
elif self.unet is not None:
|
| 418 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 419 |
+
else:
|
| 420 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 421 |
+
|
| 422 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 423 |
+
|
| 424 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 425 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 426 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 427 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 428 |
+
|
| 429 |
+
# get unconditional embeddings for classifier free guidance
|
| 430 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 431 |
+
uncond_tokens: List[str]
|
| 432 |
+
if negative_prompt is None:
|
| 433 |
+
uncond_tokens = [""] * batch_size
|
| 434 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 435 |
+
raise TypeError(
|
| 436 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 437 |
+
f" {type(prompt)}."
|
| 438 |
+
)
|
| 439 |
+
elif isinstance(negative_prompt, str):
|
| 440 |
+
uncond_tokens = [negative_prompt]
|
| 441 |
+
elif batch_size != len(negative_prompt):
|
| 442 |
+
raise ValueError(
|
| 443 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 444 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 445 |
+
" the batch size of `prompt`."
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
uncond_tokens = negative_prompt
|
| 449 |
+
|
| 450 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 451 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 452 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 453 |
+
|
| 454 |
+
max_length = prompt_embeds.shape[1]
|
| 455 |
+
uncond_input = self.tokenizer(
|
| 456 |
+
uncond_tokens,
|
| 457 |
+
padding="max_length",
|
| 458 |
+
max_length=max_length,
|
| 459 |
+
truncation=True,
|
| 460 |
+
return_tensors="pt",
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 464 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 465 |
+
else:
|
| 466 |
+
attention_mask = None
|
| 467 |
+
|
| 468 |
+
negative_prompt_embeds = self.text_encoder(
|
| 469 |
+
uncond_input.input_ids.to(device),
|
| 470 |
+
attention_mask=attention_mask,
|
| 471 |
+
)
|
| 472 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 473 |
+
|
| 474 |
+
if do_classifier_free_guidance:
|
| 475 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 476 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 477 |
+
|
| 478 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 479 |
+
|
| 480 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 481 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 482 |
+
|
| 483 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 484 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 485 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 486 |
+
|
| 487 |
+
return prompt_embeds, negative_prompt_embeds
|
| 488 |
+
|
| 489 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 490 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 491 |
+
|
| 492 |
+
if not isinstance(image, torch.Tensor):
|
| 493 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 494 |
+
|
| 495 |
+
image = image.to(device=device, dtype=dtype)
|
| 496 |
+
if output_hidden_states:
|
| 497 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 498 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 499 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 500 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 501 |
+
).hidden_states[-2]
|
| 502 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 503 |
+
num_images_per_prompt, dim=0
|
| 504 |
+
)
|
| 505 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 506 |
+
else:
|
| 507 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 508 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 509 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 510 |
+
|
| 511 |
+
return image_embeds, uncond_image_embeds
|
| 512 |
+
|
| 513 |
+
def prepare_cond_image_latents(self, image, normal, mask, cond_vae, device, num_images_per_prompt, do_classifier_free_guidance):
|
| 514 |
+
dtype = self.vae.dtype
|
| 515 |
+
|
| 516 |
+
if isinstance(image, list):
|
| 517 |
+
image = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(device=device, dtype=dtype)
|
| 518 |
+
elif isinstance(image, torch.Tensor):
|
| 519 |
+
image = image.to(device=device, dtype=dtype)
|
| 520 |
+
|
| 521 |
+
if isinstance(normal, list):
|
| 522 |
+
normal = torch.stack([TF.to_tensor(img) for img in normal], dim=0).to(device=device, dtype=dtype)
|
| 523 |
+
elif isinstance(normal, torch.Tensor):
|
| 524 |
+
normal = normal.to(device=device, dtype=dtype)
|
| 525 |
+
|
| 526 |
+
if isinstance(mask, list):
|
| 527 |
+
if isinstance(mask[0], np.ndarray):
|
| 528 |
+
mask = [Image.fromarray((img*255).astype(np.uint8), mode='L') for img in mask]
|
| 529 |
+
mask = [img.resize((image.shape[3]//8, image.shape[2]//8), resample=Image.NEAREST) for img in mask]
|
| 530 |
+
elif isinstance(mask[0], Image.Image):
|
| 531 |
+
mask = [img.resize((image.shape[3]//8, image.shape[2]//8), resample=Image.NEAREST) for img in mask]
|
| 532 |
+
mask = torch.stack([TF.to_tensor(img) for img in mask], dim=0).to(device=device, dtype=dtype)
|
| 533 |
+
elif isinstance(mask, torch.Tensor):
|
| 534 |
+
mask = Image.fromarray((mask.cpu().numpy()*255).astype(np.uint8), mode='L')
|
| 535 |
+
mask = mask.resize((image.shape[3]//8, image.shape[2]//8), resample=Image.NEAREST)
|
| 536 |
+
mask = TF.to_tensor(mask).to(device=device, dtype=dtype)
|
| 537 |
+
|
| 538 |
+
if cond_vae is not None:
|
| 539 |
+
image = image * 2.0 - 1.0
|
| 540 |
+
if normal is not None:
|
| 541 |
+
normal = normal * 2.0 - 1.0
|
| 542 |
+
image = torch.cat([image, normal], dim=1)
|
| 543 |
+
latents = cond_vae(image) * self.vae.config.scaling_factor
|
| 544 |
+
else:
|
| 545 |
+
# vae encoder
|
| 546 |
+
image = image * 2.0 - 1.0
|
| 547 |
+
latents = self.vae.encode(image).latent_dist.mode() * self.vae.config.scaling_factor
|
| 548 |
+
latents = latents.repeat(num_images_per_prompt, 1, 1, 1)
|
| 549 |
+
|
| 550 |
+
if normal is not None:
|
| 551 |
+
normal = normal * 2.0 - 1.0
|
| 552 |
+
normal_latents = self.vae.encode(normal).latent_dist.mode() * self.vae.config.scaling_factor
|
| 553 |
+
normal_latents = normal_latents.repeat(num_images_per_prompt, 1, 1, 1)
|
| 554 |
+
latents = torch.cat([latents, normal_latents], dim=1)
|
| 555 |
+
|
| 556 |
+
if mask is not None:
|
| 557 |
+
# mask = torch.ones_like(mask)
|
| 558 |
+
mask = mask * 2.0 - 1.0
|
| 559 |
+
mask_latents = mask.repeat(num_images_per_prompt, 1, 1, 1)
|
| 560 |
+
latents = torch.cat([latents, mask_latents.to(latents)], dim=1)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
if do_classifier_free_guidance:
|
| 564 |
+
# uncond_latens = self.vae.encode(torch.zeros_like(image)).latent_dist.mode() * self.vae.config.scaling_factor
|
| 565 |
+
# uncond_latens.repeat(num_images_per_prompt, 1, 1, 1)
|
| 566 |
+
uncond_latens = torch.zeros_like(latents)
|
| 567 |
+
latents = torch.cat([latents, latents])
|
| 568 |
+
|
| 569 |
+
return latents
|
| 570 |
+
|
| 571 |
+
def prepare_init_latents(self, init_materials, device, num_images_per_prompt, do_classifier_free_guidance):
|
| 572 |
+
dtype = self.vae.dtype
|
| 573 |
+
|
| 574 |
+
image = torch.cat([
|
| 575 |
+
init_materials['albedo'][...,:3].permute(0, 3, 1, 2),
|
| 576 |
+
init_materials['roughness_metallic'][...,:3].permute(0, 3, 1, 2),
|
| 577 |
+
init_materials['bump'][...,:3].permute(0, 3, 1, 2),
|
| 578 |
+
], dim=0).to(device=device, dtype=dtype)
|
| 579 |
+
|
| 580 |
+
from einops import rearrange
|
| 581 |
+
# vae encoder
|
| 582 |
+
image = image * 2.0 - 1.0
|
| 583 |
+
latents = self.vae.encode(image).latent_dist.mode() * self.vae.config.scaling_factor
|
| 584 |
+
latents = rearrange(latents, '(s b) c h w -> b (s c) h w', s=3)
|
| 585 |
+
latents = latents.repeat(num_images_per_prompt, 1, 1, 1)
|
| 586 |
+
|
| 587 |
+
# if do_classifier_free_guidance:
|
| 588 |
+
# # uncond_latens = self.vae.encode(torch.zeros_like(image)).latent_dist.mode() * self.vae.config.scaling_factor
|
| 589 |
+
# # uncond_latens.repeat(num_images_per_prompt, 1, 1, 1)
|
| 590 |
+
# # uncond_latens = torch.zeros_like(latents)
|
| 591 |
+
# latents = torch.cat([latents, latents])
|
| 592 |
+
|
| 593 |
+
return latents
|
| 594 |
+
|
| 595 |
+
def prepare_ip_adapter_image_embeds(
|
| 596 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 597 |
+
):
|
| 598 |
+
if ip_adapter_image_embeds is None:
|
| 599 |
+
if not isinstance(ip_adapter_image, list):
|
| 600 |
+
ip_adapter_image = [ip_adapter_image]
|
| 601 |
+
|
| 602 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 603 |
+
raise ValueError(
|
| 604 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
image_embeds = []
|
| 608 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 609 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 610 |
+
):
|
| 611 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 612 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 613 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 614 |
+
)
|
| 615 |
+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 616 |
+
single_negative_image_embeds = torch.stack(
|
| 617 |
+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
if do_classifier_free_guidance:
|
| 621 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
| 622 |
+
single_image_embeds = single_image_embeds.to(device)
|
| 623 |
+
|
| 624 |
+
image_embeds.append(single_image_embeds)
|
| 625 |
+
else:
|
| 626 |
+
repeat_dims = [1]
|
| 627 |
+
image_embeds = []
|
| 628 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 629 |
+
if do_classifier_free_guidance:
|
| 630 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 631 |
+
single_image_embeds = single_image_embeds.repeat(
|
| 632 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
| 633 |
+
)
|
| 634 |
+
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
| 635 |
+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
| 636 |
+
)
|
| 637 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
| 638 |
+
else:
|
| 639 |
+
single_image_embeds = single_image_embeds.repeat(
|
| 640 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
| 641 |
+
)
|
| 642 |
+
image_embeds.append(single_image_embeds)
|
| 643 |
+
|
| 644 |
+
return image_embeds
|
| 645 |
+
|
| 646 |
+
def run_safety_checker(self, image, device, dtype):
|
| 647 |
+
if self.safety_checker is None:
|
| 648 |
+
has_nsfw_concept = None
|
| 649 |
+
else:
|
| 650 |
+
if torch.is_tensor(image):
|
| 651 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 652 |
+
else:
|
| 653 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 654 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| 655 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 656 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 657 |
+
)
|
| 658 |
+
return image, has_nsfw_concept
|
| 659 |
+
|
| 660 |
+
def decode_latents(self, latents):
|
| 661 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 662 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 663 |
+
|
| 664 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 665 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 666 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 667 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 668 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 669 |
+
return image
|
| 670 |
+
|
| 671 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 672 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 673 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 674 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 675 |
+
# and should be between [0, 1]
|
| 676 |
+
|
| 677 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 678 |
+
extra_step_kwargs = {}
|
| 679 |
+
if accepts_eta:
|
| 680 |
+
extra_step_kwargs["eta"] = eta
|
| 681 |
+
|
| 682 |
+
# check if the scheduler accepts generator
|
| 683 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 684 |
+
if accepts_generator:
|
| 685 |
+
extra_step_kwargs["generator"] = generator
|
| 686 |
+
return extra_step_kwargs
|
| 687 |
+
|
| 688 |
+
def check_inputs(
|
| 689 |
+
self,
|
| 690 |
+
prompt,
|
| 691 |
+
height,
|
| 692 |
+
width,
|
| 693 |
+
callback_steps,
|
| 694 |
+
negative_prompt=None,
|
| 695 |
+
prompt_embeds=None,
|
| 696 |
+
negative_prompt_embeds=None,
|
| 697 |
+
ip_adapter_image=None,
|
| 698 |
+
ip_adapter_image_embeds=None,
|
| 699 |
+
callback_on_step_end_tensor_inputs=None,
|
| 700 |
+
):
|
| 701 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 702 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 703 |
+
|
| 704 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 705 |
+
raise ValueError(
|
| 706 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 707 |
+
f" {type(callback_steps)}."
|
| 708 |
+
)
|
| 709 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 710 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 711 |
+
):
|
| 712 |
+
raise ValueError(
|
| 713 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
if prompt is not None and prompt_embeds is not None:
|
| 717 |
+
raise ValueError(
|
| 718 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 719 |
+
" only forward one of the two."
|
| 720 |
+
)
|
| 721 |
+
elif prompt is None and prompt_embeds is None:
|
| 722 |
+
raise ValueError(
|
| 723 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 724 |
+
)
|
| 725 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 726 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 727 |
+
|
| 728 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 729 |
+
raise ValueError(
|
| 730 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 731 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 735 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 736 |
+
raise ValueError(
|
| 737 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 738 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 739 |
+
f" {negative_prompt_embeds.shape}."
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| 743 |
+
raise ValueError(
|
| 744 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
if ip_adapter_image_embeds is not None:
|
| 748 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 749 |
+
raise ValueError(
|
| 750 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| 751 |
+
)
|
| 752 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| 753 |
+
raise ValueError(
|
| 754 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, copy_noise=False):
|
| 758 |
+
if copy_noise:
|
| 759 |
+
shape = (batch_size, num_channels_latents//3, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 760 |
+
else:
|
| 761 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 762 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 763 |
+
raise ValueError(
|
| 764 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 765 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
if latents is None:
|
| 769 |
+
if copy_noise:
|
| 770 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 771 |
+
latents = torch.cat([latents, latents, latents], dim=1)
|
| 772 |
+
else:
|
| 773 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 774 |
+
else:
|
| 775 |
+
latents = latents.to(device)
|
| 776 |
+
|
| 777 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 778 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 779 |
+
return latents
|
| 780 |
+
|
| 781 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
| 782 |
+
def get_guidance_scale_embedding(
|
| 783 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
| 784 |
+
) -> torch.FloatTensor:
|
| 785 |
+
"""
|
| 786 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 787 |
+
|
| 788 |
+
Args:
|
| 789 |
+
w (`torch.Tensor`):
|
| 790 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
| 791 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 792 |
+
Dimension of the embeddings to generate.
|
| 793 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
| 794 |
+
Data type of the generated embeddings.
|
| 795 |
+
|
| 796 |
+
Returns:
|
| 797 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
| 798 |
+
"""
|
| 799 |
+
assert len(w.shape) == 1
|
| 800 |
+
w = w * 1000.0
|
| 801 |
+
|
| 802 |
+
half_dim = embedding_dim // 2
|
| 803 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 804 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 805 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 806 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 807 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 808 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 809 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 810 |
+
return emb
|
| 811 |
+
|
| 812 |
+
def _get_add_time_ids(
|
| 813 |
+
self, albedo_label, rough_meta_label, bump_label, dtype
|
| 814 |
+
):
|
| 815 |
+
add_time_ids = list(albedo_label + rough_meta_label + bump_label)
|
| 816 |
+
|
| 817 |
+
passed_add_embed_dim = (
|
| 818 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) // 3
|
| 819 |
+
)
|
| 820 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
| 821 |
+
|
| 822 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 823 |
+
raise ValueError(
|
| 824 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
| 828 |
+
return add_time_ids
|
| 829 |
+
|
| 830 |
+
@property
|
| 831 |
+
def guidance_scale(self):
|
| 832 |
+
return self._guidance_scale
|
| 833 |
+
|
| 834 |
+
@property
|
| 835 |
+
def guidance_rescale(self):
|
| 836 |
+
return self._guidance_rescale
|
| 837 |
+
|
| 838 |
+
@property
|
| 839 |
+
def clip_skip(self):
|
| 840 |
+
return self._clip_skip
|
| 841 |
+
|
| 842 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 843 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 844 |
+
# corresponds to doing no classifier free guidance.
|
| 845 |
+
@property
|
| 846 |
+
def do_classifier_free_guidance(self):
|
| 847 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 848 |
+
|
| 849 |
+
@property
|
| 850 |
+
def cross_attention_kwargs(self):
|
| 851 |
+
return self._cross_attention_kwargs
|
| 852 |
+
|
| 853 |
+
@property
|
| 854 |
+
def num_timesteps(self):
|
| 855 |
+
return self._num_timesteps
|
| 856 |
+
|
| 857 |
+
@property
|
| 858 |
+
def interrupt(self):
|
| 859 |
+
return self._interrupt
|
| 860 |
+
|
| 861 |
+
@torch.no_grad()
|
| 862 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 863 |
+
def __call__(
|
| 864 |
+
self,
|
| 865 |
+
prompt: Union[str, List[str]] = None,
|
| 866 |
+
cond_image: Optional[PipelineImageInput] = None,
|
| 867 |
+
normal_image: Optional[PipelineImageInput] = None,
|
| 868 |
+
mask_image: Optional[PipelineImageInput] = None,
|
| 869 |
+
init_materials: Optional[dict] = None,
|
| 870 |
+
masks: Optional[torch.FloatTensor] = None,
|
| 871 |
+
cond_vae = None,
|
| 872 |
+
height: Optional[int] = None,
|
| 873 |
+
width: Optional[int] = None,
|
| 874 |
+
num_inference_steps: int = 50,
|
| 875 |
+
timesteps: List[int] = None,
|
| 876 |
+
guidance_scale: float = 7.5,
|
| 877 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 878 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 879 |
+
eta: float = 0.0,
|
| 880 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 881 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 882 |
+
unscale_latents: bool = False,
|
| 883 |
+
copy_noise: bool = False,
|
| 884 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 885 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 886 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 887 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 888 |
+
output_type: Optional[str] = "pil",
|
| 889 |
+
return_dict: bool = True,
|
| 890 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 891 |
+
guidance_rescale: float = 0.0,
|
| 892 |
+
clip_skip: Optional[int] = None,
|
| 893 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 894 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 895 |
+
**kwargs,
|
| 896 |
+
):
|
| 897 |
+
r"""
|
| 898 |
+
The call function to the pipeline for generation.
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 902 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 903 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 904 |
+
The height in pixels of the generated image.
|
| 905 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 906 |
+
The width in pixels of the generated image.
|
| 907 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 908 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 909 |
+
expense of slower inference.
|
| 910 |
+
timesteps (`List[int]`, *optional*):
|
| 911 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 912 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 913 |
+
passed will be used. Must be in descending order.
|
| 914 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 915 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 916 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 917 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 918 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 919 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 920 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 921 |
+
The number of images to generate per prompt.
|
| 922 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 923 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 924 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 925 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 926 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 927 |
+
generation deterministic.
|
| 928 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 929 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 930 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 931 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 932 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 933 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 934 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 935 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 936 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 937 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 938 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 939 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
| 940 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
|
| 941 |
+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
|
| 942 |
+
if `do_classifier_free_guidance` is set to `True`.
|
| 943 |
+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 944 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 945 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 946 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 947 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 948 |
+
plain tuple.
|
| 949 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 950 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 951 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 952 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 953 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
| 954 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
| 955 |
+
using zero terminal SNR.
|
| 956 |
+
clip_skip (`int`, *optional*):
|
| 957 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 958 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 959 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 960 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 961 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 962 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 963 |
+
`callback_on_step_end_tensor_inputs`.
|
| 964 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 965 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 966 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 967 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 968 |
+
|
| 969 |
+
Examples:
|
| 970 |
+
|
| 971 |
+
Returns:
|
| 972 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 973 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 974 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 975 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 976 |
+
"not-safe-for-work" (nsfw) content.
|
| 977 |
+
"""
|
| 978 |
+
|
| 979 |
+
callback = kwargs.pop("callback", None)
|
| 980 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 981 |
+
|
| 982 |
+
if callback is not None:
|
| 983 |
+
deprecate(
|
| 984 |
+
"callback",
|
| 985 |
+
"1.0.0",
|
| 986 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 987 |
+
)
|
| 988 |
+
if callback_steps is not None:
|
| 989 |
+
deprecate(
|
| 990 |
+
"callback_steps",
|
| 991 |
+
"1.0.0",
|
| 992 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# 0. Default height and width to unet
|
| 996 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 997 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 998 |
+
# to deal with lora scaling and other possible forward hooks
|
| 999 |
+
|
| 1000 |
+
# 1. Check inputs. Raise error if not correct
|
| 1001 |
+
self.check_inputs(
|
| 1002 |
+
prompt,
|
| 1003 |
+
height,
|
| 1004 |
+
width,
|
| 1005 |
+
callback_steps,
|
| 1006 |
+
negative_prompt,
|
| 1007 |
+
prompt_embeds,
|
| 1008 |
+
negative_prompt_embeds,
|
| 1009 |
+
ip_adapter_image,
|
| 1010 |
+
ip_adapter_image_embeds,
|
| 1011 |
+
callback_on_step_end_tensor_inputs,
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
self._guidance_scale = guidance_scale
|
| 1015 |
+
self._guidance_rescale = guidance_rescale
|
| 1016 |
+
self._clip_skip = clip_skip
|
| 1017 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 1018 |
+
self._interrupt = False
|
| 1019 |
+
|
| 1020 |
+
# 2. Define call parameters
|
| 1021 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1022 |
+
batch_size = 1
|
| 1023 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1024 |
+
batch_size = len(prompt)
|
| 1025 |
+
else:
|
| 1026 |
+
batch_size = prompt_embeds.shape[0] // 3
|
| 1027 |
+
|
| 1028 |
+
device = self._execution_device
|
| 1029 |
+
|
| 1030 |
+
# 3. Encode input prompt
|
| 1031 |
+
lora_scale = (
|
| 1032 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 1036 |
+
prompt,
|
| 1037 |
+
device,
|
| 1038 |
+
num_images_per_prompt,
|
| 1039 |
+
self.do_classifier_free_guidance,
|
| 1040 |
+
negative_prompt,
|
| 1041 |
+
prompt_embeds=prompt_embeds,
|
| 1042 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1043 |
+
lora_scale=lora_scale,
|
| 1044 |
+
clip_skip=self.clip_skip,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 1048 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 1049 |
+
# to avoid doing two forward passes
|
| 1050 |
+
if self.do_classifier_free_guidance:
|
| 1051 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 1052 |
+
|
| 1053 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1054 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1055 |
+
ip_adapter_image,
|
| 1056 |
+
ip_adapter_image_embeds,
|
| 1057 |
+
device,
|
| 1058 |
+
batch_size * num_images_per_prompt,
|
| 1059 |
+
self.do_classifier_free_guidance,
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
# 4. Prepare timesteps
|
| 1063 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 1064 |
+
|
| 1065 |
+
# 4.1 Prepare additional class embedding
|
| 1066 |
+
if self.unet.config.addition_time_embed_dim is not None:
|
| 1067 |
+
albedo_label = (1, 0, 0)
|
| 1068 |
+
rough_meta_label = (0, 1, 0)
|
| 1069 |
+
nump_label = (0, 0, 1)
|
| 1070 |
+
add_time_ids = self._get_add_time_ids(
|
| 1071 |
+
albedo_label,
|
| 1072 |
+
rough_meta_label,
|
| 1073 |
+
nump_label,
|
| 1074 |
+
dtype=prompt_embeds.dtype,
|
| 1075 |
+
)
|
| 1076 |
+
negative_add_time_ids = add_time_ids
|
| 1077 |
+
|
| 1078 |
+
if self.do_classifier_free_guidance:
|
| 1079 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 1080 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 1081 |
+
|
| 1082 |
+
# 5. Prepare latent variables
|
| 1083 |
+
num_channels_latents = self.unet.config.in_channels_no_cond
|
| 1084 |
+
latents = self.prepare_latents(
|
| 1085 |
+
batch_size * num_images_per_prompt,
|
| 1086 |
+
num_channels_latents,
|
| 1087 |
+
height,
|
| 1088 |
+
width,
|
| 1089 |
+
prompt_embeds.dtype,
|
| 1090 |
+
device,
|
| 1091 |
+
generator,
|
| 1092 |
+
latents,
|
| 1093 |
+
copy_noise,
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
# 5.1 Prepare conditional image latents
|
| 1097 |
+
cond_latents = None
|
| 1098 |
+
mask_image = [mask.cpu().numpy() for mask in masks]
|
| 1099 |
+
if cond_image is not None:
|
| 1100 |
+
cond_latents = self.prepare_cond_image_latents(
|
| 1101 |
+
cond_image,
|
| 1102 |
+
normal_image,
|
| 1103 |
+
mask_image,
|
| 1104 |
+
cond_vae,
|
| 1105 |
+
device,
|
| 1106 |
+
num_images_per_prompt,
|
| 1107 |
+
self.do_classifier_free_guidance
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
init_latents = None
|
| 1111 |
+
if init_materials is not None:
|
| 1112 |
+
init_latents = self.prepare_init_latents(
|
| 1113 |
+
init_materials,
|
| 1114 |
+
device,
|
| 1115 |
+
num_images_per_prompt,
|
| 1116 |
+
self.do_classifier_free_guidance
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
import cv2
|
| 1120 |
+
import numpy as np
|
| 1121 |
+
from PIL import Image
|
| 1122 |
+
masks = cv2.erode((masks[0].cpu().numpy()*255).astype(np.uint8), kernel=np.ones((5, 5), np.uint8), iterations=4)
|
| 1123 |
+
masks = Image.fromarray(masks.astype(np.uint8)).convert("L")
|
| 1124 |
+
masks = masks.resize((height // 8, width // 8), Image.NEAREST)
|
| 1125 |
+
masks = TF.to_tensor(masks).to(init_latents.device, init_latents.dtype).unsqueeze(1)
|
| 1126 |
+
# masks = torch.zeros_like(masks)
|
| 1127 |
+
|
| 1128 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1129 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1130 |
+
|
| 1131 |
+
# # 6.1 Add image embeds for IP-Adapter
|
| 1132 |
+
# added_cond_kwargs = (
|
| 1133 |
+
# {"image_embeds": image_embeds}
|
| 1134 |
+
# if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
| 1135 |
+
# else None
|
| 1136 |
+
# )
|
| 1137 |
+
|
| 1138 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
| 1139 |
+
timestep_cond = None
|
| 1140 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 1141 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 1142 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 1143 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 1144 |
+
).to(device=device, dtype=latents.dtype)
|
| 1145 |
+
|
| 1146 |
+
# 7. Denoising loop
|
| 1147 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1148 |
+
self._num_timesteps = len(timesteps)
|
| 1149 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1150 |
+
for i, t in enumerate(timesteps):
|
| 1151 |
+
if self.interrupt:
|
| 1152 |
+
continue
|
| 1153 |
+
|
| 1154 |
+
# expand the latents if we are doing classifier free guidance
|
| 1155 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1156 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1157 |
+
|
| 1158 |
+
if cond_latents is not None:
|
| 1159 |
+
latent_model_input = torch.cat([latent_model_input, cond_latents], dim=1)
|
| 1160 |
+
|
| 1161 |
+
# predict the noise residual
|
| 1162 |
+
added_cond_kwargs = {}
|
| 1163 |
+
if self.unet.config.addition_time_embed_dim is not None:
|
| 1164 |
+
added_cond_kwargs["time_ids"] = add_time_ids
|
| 1165 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1166 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
| 1167 |
+
noise_pred = self.unet(
|
| 1168 |
+
latent_model_input,
|
| 1169 |
+
t,
|
| 1170 |
+
encoder_hidden_states=prompt_embeds,
|
| 1171 |
+
timestep_cond=timestep_cond,
|
| 1172 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1173 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1174 |
+
return_dict=False,
|
| 1175 |
+
)[0]
|
| 1176 |
+
|
| 1177 |
+
# perform guidance
|
| 1178 |
+
if self.do_classifier_free_guidance:
|
| 1179 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1180 |
+
# only do cfg for roughness, metallic and bump
|
| 1181 |
+
noise_pred = noise_pred_uncond[:,4:] + self.guidance_scale * (noise_pred_text[:,4:] - noise_pred_uncond[:,4:])
|
| 1182 |
+
noise_pred = torch.cat([noise_pred_text[:, :4], noise_pred], dim=1)
|
| 1183 |
+
|
| 1184 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 1185 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 1186 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
| 1187 |
+
|
| 1188 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1189 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False, init_latents=init_latents, masks=masks)[0]
|
| 1190 |
+
|
| 1191 |
+
if callback_on_step_end is not None:
|
| 1192 |
+
callback_kwargs = {}
|
| 1193 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1194 |
+
callback_kwargs[k] = locals()[k]
|
| 1195 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1196 |
+
|
| 1197 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1198 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1199 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1200 |
+
|
| 1201 |
+
# call the callback, if provided
|
| 1202 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1203 |
+
progress_bar.update()
|
| 1204 |
+
if callback is not None and i % callback_steps == 0:
|
| 1205 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 1206 |
+
callback(step_idx, t, latents)
|
| 1207 |
+
|
| 1208 |
+
if not output_type == "latent":
|
| 1209 |
+
if num_channels_latents == 12:
|
| 1210 |
+
latents = latents / self.vae.config.scaling_factor
|
| 1211 |
+
if unscale_latents:
|
| 1212 |
+
latents[:, 4:8] = unscale_latents_rm(latents[:, 4:8])
|
| 1213 |
+
latents[:, 8:] = unscale_latents_bump(latents[:, 8:])
|
| 1214 |
+
latents = torch.cat([latents[:, :4], latents[:, 4:8], latents[:, 8:]], dim=0)
|
| 1215 |
+
image = self.vae.decode(latents, return_dict=False, generator=generator)[
|
| 1216 |
+
0
|
| 1217 |
+
]
|
| 1218 |
+
else:
|
| 1219 |
+
image = self.vae.decode(latents/ self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 1220 |
+
0
|
| 1221 |
+
]
|
| 1222 |
+
has_nsfw_concept = None
|
| 1223 |
+
else:
|
| 1224 |
+
image = latents
|
| 1225 |
+
has_nsfw_concept = None
|
| 1226 |
+
|
| 1227 |
+
if has_nsfw_concept is None:
|
| 1228 |
+
do_denormalize = [True] * image.shape[0]
|
| 1229 |
+
else:
|
| 1230 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 1231 |
+
|
| 1232 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 1233 |
+
|
| 1234 |
+
# Offload all models
|
| 1235 |
+
self.maybe_free_model_hooks()
|
| 1236 |
+
|
| 1237 |
+
if not return_dict:
|
| 1238 |
+
return (image, has_nsfw_concept)
|
| 1239 |
+
|
| 1240 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
utils/rasterize.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nvdiffrast.torch as dr
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from jaxtyping import Float, Integer
|
| 6 |
+
from typing import Union, Tuple
|
| 7 |
+
|
| 8 |
+
class NVDiffRasterizerContext:
|
| 9 |
+
def __init__(self, context_type: str, device: torch.device) -> None:
|
| 10 |
+
self.device = device
|
| 11 |
+
self.ctx = self.initialize_context(context_type, device)
|
| 12 |
+
|
| 13 |
+
def initialize_context(
|
| 14 |
+
self, context_type: str, device: torch.device
|
| 15 |
+
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
|
| 16 |
+
if context_type == "gl":
|
| 17 |
+
return dr.RasterizeGLContext(device=device)
|
| 18 |
+
elif context_type == "cuda":
|
| 19 |
+
return dr.RasterizeCudaContext(device=device)
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f"Unknown rasterizer context type: {context_type}")
|
| 22 |
+
|
| 23 |
+
def vertex_transform(
|
| 24 |
+
self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
|
| 25 |
+
) -> Float[Tensor, "B Nv 4"]:
|
| 26 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 27 |
+
verts_homo = torch.cat(
|
| 28 |
+
[verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
|
| 29 |
+
)
|
| 30 |
+
verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
|
| 31 |
+
return verts_clip
|
| 32 |
+
|
| 33 |
+
def rasterize(
|
| 34 |
+
self,
|
| 35 |
+
pos: Float[Tensor, "B Nv 4"],
|
| 36 |
+
tri: Integer[Tensor, "Nf 3"],
|
| 37 |
+
resolution: Union[int, Tuple[int, int]],
|
| 38 |
+
):
|
| 39 |
+
# rasterize in instance mode (single topology)
|
| 40 |
+
return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
|
| 41 |
+
|
| 42 |
+
def rasterize_one(
|
| 43 |
+
self,
|
| 44 |
+
pos: Float[Tensor, "Nv 4"],
|
| 45 |
+
tri: Integer[Tensor, "Nf 3"],
|
| 46 |
+
resolution: Union[int, Tuple[int, int]],
|
| 47 |
+
):
|
| 48 |
+
# rasterize one single mesh under a single viewpoint
|
| 49 |
+
rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
|
| 50 |
+
return rast[0], rast_db[0]
|
| 51 |
+
|
| 52 |
+
def antialias(
|
| 53 |
+
self,
|
| 54 |
+
color: Float[Tensor, "B H W C"],
|
| 55 |
+
rast: Float[Tensor, "B H W 4"],
|
| 56 |
+
pos: Float[Tensor, "B Nv 4"],
|
| 57 |
+
tri: Integer[Tensor, "Nf 3"],
|
| 58 |
+
) -> Float[Tensor, "B H W C"]:
|
| 59 |
+
return dr.antialias(color.float(), rast, pos.float(), tri.int())
|
| 60 |
+
|
| 61 |
+
def interpolate(
|
| 62 |
+
self,
|
| 63 |
+
attr: Float[Tensor, "B Nv C"],
|
| 64 |
+
rast: Float[Tensor, "B H W 4"],
|
| 65 |
+
tri: Integer[Tensor, "Nf 3"],
|
| 66 |
+
rast_db=None,
|
| 67 |
+
diff_attrs=None,
|
| 68 |
+
) -> Float[Tensor, "B H W C"]:
|
| 69 |
+
return dr.interpolate(
|
| 70 |
+
attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def interpolate_one(
|
| 74 |
+
self,
|
| 75 |
+
attr: Float[Tensor, "Nv C"],
|
| 76 |
+
rast: Float[Tensor, "B H W 4"],
|
| 77 |
+
tri: Integer[Tensor, "Nf 3"],
|
| 78 |
+
rast_db=None,
|
| 79 |
+
diff_attrs=None,
|
| 80 |
+
) -> Float[Tensor, "B H W C"]:
|
| 81 |
+
return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
|
| 82 |
+
|
| 83 |
+
def texture_map_to_rgb(tex_map, uv_coordinates):
|
| 84 |
+
return dr.texture(tex_map.float(), uv_coordinates)
|
| 85 |
+
|
| 86 |
+
def render_rgb_from_texture_mesh_with_mask(
|
| 87 |
+
ctx,
|
| 88 |
+
mesh,
|
| 89 |
+
tex_map: Float[Tensor, "1 H W C"],
|
| 90 |
+
mvp_matrix: Float[Tensor, "batch 4 4"],
|
| 91 |
+
image_height: int,
|
| 92 |
+
image_width: int,
|
| 93 |
+
background_color: Tensor = torch.tensor([0.0, 0.0, 0.0]),
|
| 94 |
+
):
|
| 95 |
+
batch_size = mvp_matrix.shape[0]
|
| 96 |
+
tex_map = tex_map.contiguous()
|
| 97 |
+
if tex_map.dim() == 3:
|
| 98 |
+
tex_map = tex_map.unsqueeze(0) # Add batch dimension if missing
|
| 99 |
+
|
| 100 |
+
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos, mvp_matrix)
|
| 101 |
+
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx, (image_height, image_width))
|
| 102 |
+
mask = rasterized_output[..., 3:] > 0
|
| 103 |
+
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
|
| 104 |
+
|
| 105 |
+
interpolated_texture_coords, _ = ctx.interpolate_one(mesh._v_tex, rasterized_output, mesh._t_tex_idx)
|
| 106 |
+
rgb_foreground = texture_map_to_rgb(tex_map.float(), interpolated_texture_coords)
|
| 107 |
+
rgb_foreground_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
|
| 108 |
+
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
|
| 109 |
+
rgb_background_batched += background_color.view(1, 1, 1, 3).to(rgb_foreground)
|
| 110 |
+
|
| 111 |
+
selector = mask[..., 0]
|
| 112 |
+
rgb_foreground_batched[selector] = rgb_foreground[selector]
|
| 113 |
+
|
| 114 |
+
# Use the anti-aliased mask for blending
|
| 115 |
+
final_rgb = torch.lerp(rgb_background_batched, rgb_foreground_batched, mask_antialiased)
|
| 116 |
+
final_rgb_aa = ctx.antialias(final_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
|
| 117 |
+
|
| 118 |
+
return final_rgb_aa, selector
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width):
|
| 122 |
+
device = mvp_matrix.device
|
| 123 |
+
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos.to(device), mvp_matrix)
|
| 124 |
+
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx.to(device), (image_height, image_width))
|
| 125 |
+
interpolated_positions, _ = ctx.interpolate_one(mesh.v_pos.to(device), rasterized_output, mesh.t_pos_idx.to(device))
|
| 126 |
+
interpolated_normals, _ = ctx.interpolate_one(mesh.v_normal.to(device).contiguous(), rasterized_output, mesh.t_pos_idx.to(device))
|
| 127 |
+
|
| 128 |
+
mask = rasterized_output[..., 3:] > 0
|
| 129 |
+
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
|
| 130 |
+
|
| 131 |
+
batch_size = mvp_matrix.shape[0]
|
| 132 |
+
rgb_foreground_pos_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
|
| 133 |
+
rgb_foreground_norm_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
|
| 134 |
+
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
|
| 135 |
+
|
| 136 |
+
selector = mask[..., 0]
|
| 137 |
+
rgb_foreground_pos_batched[selector] = interpolated_positions[selector]
|
| 138 |
+
rgb_foreground_norm_batched[selector] = interpolated_normals[selector]
|
| 139 |
+
|
| 140 |
+
final_pos_rgb = torch.lerp(rgb_background_batched, rgb_foreground_pos_batched, mask_antialiased)
|
| 141 |
+
final_norm_rgb = torch.lerp(rgb_background_batched, rgb_foreground_norm_batched, mask_antialiased)
|
| 142 |
+
final_pos_rgb_aa = ctx.antialias(final_pos_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
|
| 143 |
+
final_norm_rgb_aa = ctx.antialias(final_norm_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
|
| 144 |
+
|
| 145 |
+
return final_pos_rgb_aa, final_norm_rgb_aa, mask_antialiased
|
| 146 |
+
|
| 147 |
+
def rasterize_position_and_normal_maps(ctx, mesh, rasterize_height, rasterize_width):
|
| 148 |
+
device = ctx.device
|
| 149 |
+
# Convert mesh data to torch tensors
|
| 150 |
+
mesh_v = mesh.v_pos.to(device)
|
| 151 |
+
mesh_f = mesh.t_pos_idx.to(device)
|
| 152 |
+
uvs_tensor = mesh._v_tex.to(device)
|
| 153 |
+
indices_tensor = mesh._t_tex_idx.to(device)
|
| 154 |
+
normal_v = mesh.v_normal.to(device).contiguous()
|
| 155 |
+
|
| 156 |
+
# Interpolate mesh data
|
| 157 |
+
uv_clip = uvs_tensor[None, ...] * 2.0 - 1.0
|
| 158 |
+
uv_clip_padded = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., :1]), torch.ones_like(uv_clip[..., :1])), dim=-1)
|
| 159 |
+
rasterized_output, _ = ctx.rasterize(uv_clip_padded, indices_tensor.int(), (rasterize_height, rasterize_width))
|
| 160 |
+
|
| 161 |
+
# Interpolate positions.
|
| 162 |
+
position_map, _ = ctx.interpolate_one(mesh_v, rasterized_output, mesh_f.int())
|
| 163 |
+
normal_map, _ = ctx.interpolate_one(normal_v, rasterized_output, mesh_f.int())
|
| 164 |
+
rasterization_mask = rasterized_output[..., 3:4] > 0
|
| 165 |
+
|
| 166 |
+
return position_map, normal_map, rasterization_mask
|
utils/render_utils.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import cache
|
| 3 |
+
from typing import Dict, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from jaxtyping import Float
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torchvision.transforms import ToPILImage
|
| 13 |
+
|
| 14 |
+
from .rasterize import (NVDiffRasterizerContext,
|
| 15 |
+
rasterize_position_and_normal_maps,
|
| 16 |
+
render_geo_from_mesh,
|
| 17 |
+
render_rgb_from_texture_mesh_with_mask)
|
| 18 |
+
|
| 19 |
+
CTX = NVDiffRasterizerContext('cuda', 'cuda')
|
| 20 |
+
|
| 21 |
+
def setup_lights():
|
| 22 |
+
"""
|
| 23 |
+
Set three random point lights in the scene.
|
| 24 |
+
"""
|
| 25 |
+
raise NotImplementedError("setup_lights function is not implemented yet.")
|
| 26 |
+
|
| 27 |
+
def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
|
| 28 |
+
"""
|
| 29 |
+
Render the RGB color images of the mesh. The background will be transparent.
|
| 30 |
+
:param mesh: The mesh to be rendered. Class: Mesh.
|
| 31 |
+
:param texture: The texture of the mesh, a tensor of shape (H, W, 3).
|
| 32 |
+
:param mvp_matrix: The Model-View-Projection matrix for rendering, a tensor of shape (n_v, 4, 4).
|
| 33 |
+
:param lights: The lights in the scene.
|
| 34 |
+
:param img_size: The size of the output image, a tuple (height, width).
|
| 35 |
+
:return: A concatenated PIL Image.
|
| 36 |
+
"""
|
| 37 |
+
if texture.shape[-1] != 3:
|
| 38 |
+
texture = texture.permute(1, 2, 0)
|
| 39 |
+
image_height, image_width = img_size
|
| 40 |
+
rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
|
| 41 |
+
CTX, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device='cuda'))
|
| 42 |
+
|
| 43 |
+
if mvp_matrix.shape[0] == 0:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
pil_images = []
|
| 47 |
+
for i in range(mvp_matrix.shape[0]):
|
| 48 |
+
rgba_img = torch.cat([rgb_cond[i], mask[i].unsqueeze(-1)], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4]
|
| 49 |
+
rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8
|
| 50 |
+
rgba_img = rgba_img.cpu().numpy() # Convert to numpy array
|
| 51 |
+
pil_images.append(Image.fromarray(rgba_img, mode='RGBA'))
|
| 52 |
+
|
| 53 |
+
if not pil_images:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
total_width = sum(img.width for img in pil_images)
|
| 57 |
+
max_height = max(img.height for img in pil_images)
|
| 58 |
+
|
| 59 |
+
concatenated_image = Image.new('RGBA', (total_width, max_height))
|
| 60 |
+
|
| 61 |
+
current_x = 0
|
| 62 |
+
for img in pil_images:
|
| 63 |
+
concatenated_image.paste(img, (current_x, 0))
|
| 64 |
+
current_x += img.width
|
| 65 |
+
|
| 66 |
+
return concatenated_image
|
| 67 |
+
|
| 68 |
+
def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
|
| 69 |
+
"""
|
| 70 |
+
render the geometry information including position and normal from views that mvp matrix implies.
|
| 71 |
+
"""
|
| 72 |
+
image_height, image_width = img_size
|
| 73 |
+
position_images, normal_images, mask_images = render_geo_from_mesh(CTX, mesh, mvp_matrix, image_height, image_width)
|
| 74 |
+
return position_images, normal_images, mask_images
|
| 75 |
+
|
| 76 |
+
def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
|
| 77 |
+
"""
|
| 78 |
+
Render the geometry information including position and normal from UV parameterization.
|
| 79 |
+
"""
|
| 80 |
+
map_height, map_width = map_size
|
| 81 |
+
position_images, normal_images, mask = rasterize_position_and_normal_maps(CTX, mesh, map_height, map_width)
|
| 82 |
+
# out_imgs = []
|
| 83 |
+
# if mask.ndim == 4:
|
| 84 |
+
# mask = mask[0]
|
| 85 |
+
# for img_map in [position_images, normal_images]:
|
| 86 |
+
# if img_map.ndim == 4:
|
| 87 |
+
# img_map = img_map[0]
|
| 88 |
+
# # normalize to [0, 1]
|
| 89 |
+
# img_map = (img_map - img_map.min()) / (img_map.max() - img_map.min() + 1e-6)
|
| 90 |
+
|
| 91 |
+
# rgba_img = torch.cat([img_map, mask], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4]
|
| 92 |
+
# rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8
|
| 93 |
+
# rgba_img = rgba_img.cpu().numpy() # Convert to numpy array
|
| 94 |
+
# out_imgs.append(Image.fromarray(rgba_img, mode='RGBA'))
|
| 95 |
+
return position_images, normal_images
|
| 96 |
+
|
| 97 |
+
@cache
|
| 98 |
+
def get_pure_texture(uv_size, color=(int("0x55", 16), int("0x55", 16), int("0x55", 16))) -> torch.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
get a pure texture image with the specified color.
|
| 101 |
+
:param uv_size: The size of the UV map (height, width).
|
| 102 |
+
:param color: The color of the texture, default is "0x555555" (light gray).
|
| 103 |
+
:return: A texture image tensor of shape (height, width, 3).
|
| 104 |
+
"""
|
| 105 |
+
height, width = uv_size
|
| 106 |
+
|
| 107 |
+
color = torch.tensor(color, dtype=torch.float32).view(1, 1, 3) / 255.0
|
| 108 |
+
texture = color.repeat(height, width, 1)
|
| 109 |
+
|
| 110 |
+
return texture
|
| 111 |
+
|
| 112 |
+
def get_c2w(
|
| 113 |
+
azimuth_deg,
|
| 114 |
+
elevation_deg,
|
| 115 |
+
camera_distances,):
|
| 116 |
+
assert len(azimuth_deg) == len(elevation_deg) == len(camera_distances)
|
| 117 |
+
n_views = len(azimuth_deg)
|
| 118 |
+
#camera_distances = torch.full_like(elevation_deg, dis)
|
| 119 |
+
elevation = elevation_deg * math.pi / 180
|
| 120 |
+
azimuth = azimuth_deg * math.pi / 180
|
| 121 |
+
camera_positions = torch.stack(
|
| 122 |
+
[
|
| 123 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
| 124 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
| 125 |
+
camera_distances * torch.sin(elevation),
|
| 126 |
+
],
|
| 127 |
+
dim=-1,
|
| 128 |
+
)
|
| 129 |
+
center = torch.zeros_like(camera_positions)
|
| 130 |
+
up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
|
| 131 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
| 132 |
+
right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
|
| 133 |
+
up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
|
| 134 |
+
c2w3x4 = torch.cat(
|
| 135 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
| 136 |
+
dim=-1,
|
| 137 |
+
)
|
| 138 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
| 139 |
+
c2w[:, 3, 3] = 1.0
|
| 140 |
+
return c2w
|
| 141 |
+
|
| 142 |
+
def camera_strategy_test_4_90deg(
|
| 143 |
+
mesh: Dict,
|
| 144 |
+
num_views: int = 4,
|
| 145 |
+
**kwargs) -> Dict:
|
| 146 |
+
"""
|
| 147 |
+
For sup views: Random elevation and azimuth, fixed distance and close fov.
|
| 148 |
+
:param num_views: number of supervision views
|
| 149 |
+
:param kwargs: additional arguments
|
| 150 |
+
"""
|
| 151 |
+
# Default camera intrinsics
|
| 152 |
+
default_elevation = 10
|
| 153 |
+
default_camera_lens = 50
|
| 154 |
+
default_camera_sensor_width = 36
|
| 155 |
+
default_fovy = 2 * np.arctan(default_camera_sensor_width / (2 * default_camera_lens))
|
| 156 |
+
|
| 157 |
+
bbox_size = mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0]
|
| 158 |
+
distance = default_camera_lens / default_camera_sensor_width * \
|
| 159 |
+
math.sqrt(bbox_size[0] ** 2 + bbox_size[1] ** 2 + bbox_size[2] ** 2)
|
| 160 |
+
|
| 161 |
+
all_azimuth_deg = torch.linspace(0, 360.0, num_views + 1)[:num_views] - 90
|
| 162 |
+
|
| 163 |
+
all_elevation_deg = torch.full_like(all_azimuth_deg, default_elevation)
|
| 164 |
+
|
| 165 |
+
# Get the corresponding azimuth and elevation
|
| 166 |
+
view_idxs = torch.arange(0, num_views)
|
| 167 |
+
azimuth = all_azimuth_deg[view_idxs]
|
| 168 |
+
elevation = all_elevation_deg[view_idxs]
|
| 169 |
+
camera_distances = torch.full_like(elevation, distance)
|
| 170 |
+
c2w = get_c2w(azimuth, elevation, camera_distances)
|
| 171 |
+
|
| 172 |
+
if c2w.ndim == 2:
|
| 173 |
+
w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
|
| 174 |
+
w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
|
| 175 |
+
w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
|
| 176 |
+
w2c[3, 3] = 1.0
|
| 177 |
+
else:
|
| 178 |
+
w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
|
| 179 |
+
w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
|
| 180 |
+
w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
|
| 181 |
+
w2c[:, 3, 3] = 1.0
|
| 182 |
+
|
| 183 |
+
fovy = torch.full_like(azimuth, default_fovy)
|
| 184 |
+
|
| 185 |
+
return {
|
| 186 |
+
'cond_sup_view_idxs': view_idxs,
|
| 187 |
+
'cond_sup_c2w': c2w,
|
| 188 |
+
'cond_sup_w2c': w2c,
|
| 189 |
+
'cond_sup_fovy': fovy,
|
| 190 |
+
# 'cond_sup_azimuth': azimuth,
|
| 191 |
+
# 'cond_sup_elevation': elevation,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def _get_projection_matrix(
|
| 195 |
+
fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float
|
| 196 |
+
) -> Float[Tensor, "*B 4 4"]:
|
| 197 |
+
if isinstance(fovy, float):
|
| 198 |
+
proj_mtx = torch.zeros(4, 4, dtype=torch.float32)
|
| 199 |
+
proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh)
|
| 200 |
+
proj_mtx[1, 1] = -1.0 / math.tan(
|
| 201 |
+
fovy / 2.0
|
| 202 |
+
) # add a negative sign here as the y axis is flipped in nvdiffrast output
|
| 203 |
+
proj_mtx[2, 2] = -(far + near) / (far - near)
|
| 204 |
+
proj_mtx[2, 3] = -2.0 * far * near / (far - near)
|
| 205 |
+
proj_mtx[3, 2] = -1.0
|
| 206 |
+
else:
|
| 207 |
+
batch_size = fovy.shape[0]
|
| 208 |
+
proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
|
| 209 |
+
proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
|
| 210 |
+
proj_mtx[:, 1, 1] = -1.0 / torch.tan(
|
| 211 |
+
fovy / 2.0
|
| 212 |
+
) # add a negative sign here as the y axis is flipped in nvdiffrast output
|
| 213 |
+
proj_mtx[:, 2, 2] = -(far + near) / (far - near)
|
| 214 |
+
proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
|
| 215 |
+
proj_mtx[:, 3, 2] = -1.0
|
| 216 |
+
return proj_mtx
|
| 217 |
+
|
| 218 |
+
def _get_mvp_matrix(
|
| 219 |
+
c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"]
|
| 220 |
+
) -> Float[Tensor, "*B 4 4"]:
|
| 221 |
+
# calculate w2c from c2w: R' = Rt, t' = -Rt * t
|
| 222 |
+
# mathematically equivalent to (c2w)^-1
|
| 223 |
+
if c2w.ndim == 2:
|
| 224 |
+
assert proj_mtx.ndim == 2
|
| 225 |
+
w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
|
| 226 |
+
w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
|
| 227 |
+
w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
|
| 228 |
+
w2c[3, 3] = 1.0
|
| 229 |
+
else:
|
| 230 |
+
w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
|
| 231 |
+
w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
|
| 232 |
+
w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
|
| 233 |
+
w2c[:, 3, 3] = 1.0
|
| 234 |
+
# calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
|
| 235 |
+
mvp_mtx = proj_mtx @ w2c
|
| 236 |
+
return mvp_mtx
|
| 237 |
+
|
| 238 |
+
def get_mvp_matrix(mesh, num_views=4, width=512, height=512, strategy="strategy_test_4_90deg"):
|
| 239 |
+
"""
|
| 240 |
+
Get Model-View-Projection (MVP) matrix for rendering views.
|
| 241 |
+
:param mesh: The mesh object to determine camera positioning.
|
| 242 |
+
:param num_views: Number of views to generate, default is 4.
|
| 243 |
+
:param width: Image width for projection matrix calculation.
|
| 244 |
+
:param height: Image height for projection matrix calculation.
|
| 245 |
+
:param strategy: Camera positioning strategy, default is "strategy_test_4_90deg".
|
| 246 |
+
:return: MVP matrix and world-to-camera transformation matrix.
|
| 247 |
+
"""
|
| 248 |
+
if strategy == "strategy_test_4_90deg":
|
| 249 |
+
camera_info = camera_strategy_test_4_90deg(
|
| 250 |
+
mesh=mesh, # Dummy mesh for camera strategy
|
| 251 |
+
num_views=num_views,
|
| 252 |
+
)
|
| 253 |
+
cond_sup_fovy = camera_info["cond_sup_fovy"]
|
| 254 |
+
cond_sup_c2w = camera_info["cond_sup_c2w"]
|
| 255 |
+
cond_sup_w2c = camera_info["cond_sup_w2c"]
|
| 256 |
+
# cond_sup_azimuth = camera_info["cond_sup_azimuth"]
|
| 257 |
+
# cond_sup_elevation = camera_info["cond_sup_elevation"]
|
| 258 |
+
else:
|
| 259 |
+
raise ValueError(f"Unsupported camera strategy: {strategy}")
|
| 260 |
+
cond_sup_proj_mtx: Float[Tensor, "B 4 4"] = _get_projection_matrix(
|
| 261 |
+
cond_sup_fovy, width / height, 0.1, 1000.0
|
| 262 |
+
)
|
| 263 |
+
mvp_mtx: Float[Tensor, "B 4 4"] = _get_mvp_matrix(cond_sup_c2w, cond_sup_proj_mtx)
|
| 264 |
+
return mvp_mtx, cond_sup_w2c
|
| 265 |
+
|
| 266 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 267 |
+
def _get_depth_noraml_map_with_mask(xyz_map, normal_map, mask, w2c, device="cuda", background_color=(0, 0, 0)):
|
| 268 |
+
"""
|
| 269 |
+
Get depth and normal map with mask from position and normal images.
|
| 270 |
+
:param xyz_map: Position images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
|
| 271 |
+
:param normal_map: Normal images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
|
| 272 |
+
:param mask: Mask for the images, shape [B, Nv, H, W]. It is the return value of `render_geo_views`.
|
| 273 |
+
:param w2c: World to camera transformation matrix, shape [B, Nv, 4, 4].
|
| 274 |
+
:param device: Device to run the computation on, default is "cuda".
|
| 275 |
+
:param background_color: Background color for the depth and normal maps.
|
| 276 |
+
:return: depth_map, normal_map, mask
|
| 277 |
+
"""
|
| 278 |
+
w2c = w2c.to(device)
|
| 279 |
+
|
| 280 |
+
# Render world coordinate position map and mask
|
| 281 |
+
B, Nv, H, W, C = xyz_map.shape # B: batch size, Nv: number of views, H/W: height/width, C: channels
|
| 282 |
+
assert Nv == 1
|
| 283 |
+
# Rearrange tensors for batch processing
|
| 284 |
+
xyz_map = rearrange(xyz_map, "B Nv H W C -> (B Nv) (H W) C")
|
| 285 |
+
normal_map = rearrange(normal_map, "B Nv H W C -> (B Nv) (H W) C")
|
| 286 |
+
w2c = rearrange(w2c, "B Nv C1 C2 -> (B Nv) C1 C2")
|
| 287 |
+
|
| 288 |
+
# Create homogeneous coordinates and correctly transform to camera coordinate system
|
| 289 |
+
# Points in world coordinate system need to be multiplied by world-to-camera transformation matrix
|
| 290 |
+
B_Nv, N, C = xyz_map.shape
|
| 291 |
+
ones = torch.ones(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
|
| 292 |
+
homogeneous_xyz = torch.cat([xyz_map, ones], dim=2) # [x,y,z,1]
|
| 293 |
+
zeros = torch.zeros(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
|
| 294 |
+
homogeneous_normal = torch.cat([normal_map, zeros], dim=2) # [x,y,z,1]
|
| 295 |
+
|
| 296 |
+
camera_coords = torch.bmm(homogeneous_xyz, w2c.transpose(1, 2))
|
| 297 |
+
camera_normals = torch.bmm(homogeneous_normal, w2c.transpose(1, 2))
|
| 298 |
+
|
| 299 |
+
depth_map = camera_coords[..., 2:3] # Z-axis is the depth direction in camera coordinate system
|
| 300 |
+
depth_map = rearrange(depth_map, "(B Nv) (H W) 1 -> B Nv H W", B=B, Nv=Nv, H=H, W=W)
|
| 301 |
+
normal_map = camera_normals[..., :3] # Keep only x, y, z components
|
| 302 |
+
normal_map = rearrange(normal_map, "(B Nv) (H W) c -> B Nv H W c", B=B, Nv=Nv, H=H, W=W)
|
| 303 |
+
assert depth_map.dtype == torch.float32, f"depth_map must be float32, otherwise there will be artifact in controlnet generated pictures, but got {depth_map.dtype}"
|
| 304 |
+
|
| 305 |
+
# Calculate min and max values
|
| 306 |
+
min_depth = depth_map.amin((1,2,3), keepdim=True)
|
| 307 |
+
max_depth = depth_map.amax((1,2,3), keepdim=True)
|
| 308 |
+
|
| 309 |
+
depth_map = (depth_map - min_depth) / (max_depth - min_depth + 1e-6) # Normalize to [0, 1]
|
| 310 |
+
|
| 311 |
+
depth_map = depth_map.repeat(1, 3, 1, 1) # Repeat 3 times to get RGB depth map
|
| 312 |
+
normal_map = normal_map * 0.5 + 0.5 # Normalize to [0, 1], [B, Nv, H, W, 3]
|
| 313 |
+
normal_map = normal_map[:,0].permute(0, 3, 1, 2) # [B, 3, H, W]
|
| 314 |
+
|
| 315 |
+
rgb_background_batched = torch.tensor(background_color, dtype=torch.float32, device=device).view(1, 3, 1, 1)
|
| 316 |
+
depth_map = torch.lerp(rgb_background_batched, depth_map, mask)
|
| 317 |
+
normal_map = torch.lerp(rgb_background_batched, normal_map, mask)
|
| 318 |
+
|
| 319 |
+
return depth_map, normal_map, mask
|
| 320 |
+
|
| 321 |
+
def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
|
| 322 |
+
"""
|
| 323 |
+
Get the silhouette image based on geometry image.
|
| 324 |
+
:param position_imgs: Position images from different views, shape [Nv, H, W, 3].
|
| 325 |
+
:param normal_imgs: Normal images from different views, shape [Nv, H, W, 3].
|
| 326 |
+
:param mask_imgs: Mask for the images, shape [Nv, H, W]. It is the return value of `render_geo_views`.
|
| 327 |
+
:param w2c: World to camera transformation matrix, shape [Nv, 4, 4].
|
| 328 |
+
:param selected_view: The view selected for generating the image condition.
|
| 329 |
+
:return: silhouettes (including depth and normal, which is in camera coordinate system).
|
| 330 |
+
"""
|
| 331 |
+
view_id_map = {
|
| 332 |
+
"First View": 0,
|
| 333 |
+
"Second View": 1,
|
| 334 |
+
"Third View": 2,
|
| 335 |
+
"Fourth View": 3
|
| 336 |
+
}
|
| 337 |
+
view_id = view_id_map[selected_view]
|
| 338 |
+
position_view = position_imgs[view_id: view_id + 1]
|
| 339 |
+
normal_view = normal_imgs[view_id: view_id + 1]
|
| 340 |
+
mask_view = mask_imgs[view_id: view_id + 1]
|
| 341 |
+
w2c = w2c[view_id: view_id + 1] # Select the corresponding w2c for the view
|
| 342 |
+
|
| 343 |
+
depth_img, normal_img, mask = _get_depth_noraml_map_with_mask(
|
| 344 |
+
position_view.unsqueeze(0), # Add batch dimension
|
| 345 |
+
normal_view.unsqueeze(0),
|
| 346 |
+
mask_view.unsqueeze(0),
|
| 347 |
+
w2c.unsqueeze(0),
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
to_img = ToPILImage()
|
| 351 |
+
return to_img(depth_img.squeeze(0)), to_img(normal_img.squeeze(0)), to_img(mask.squeeze(0))
|
| 352 |
+
|
utils/texture_generation.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import threading
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from urllib.parse import urlparse
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import spaces
|
| 9 |
+
import torch
|
| 10 |
+
from diffusers.models import AutoencoderKLWan
|
| 11 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from jaxtyping import Float
|
| 14 |
+
from peft import LoraConfig
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
|
| 18 |
+
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
|
| 19 |
+
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
|
| 20 |
+
|
| 21 |
+
TEX_PIPE = None
|
| 22 |
+
VAE = None
|
| 23 |
+
LATENTS_MEAN, LATENTS_STD = None, None
|
| 24 |
+
TEX_PIPE_LOCK = threading.Lock()
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class Config:
|
| 28 |
+
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
| 29 |
+
seqtex_path: str = "https://huggingface.co/VAST-AI/SeqTex/resolve/main/.gitattributes/edm2_ema_12176_clean.pth"
|
| 30 |
+
min_noise_level_index: int = 15 # which is same as paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)
|
| 31 |
+
|
| 32 |
+
use_causal_mask: bool = False
|
| 33 |
+
addtional_qk_geometry: bool = False
|
| 34 |
+
use_normal: bool = True
|
| 35 |
+
use_position: bool = True
|
| 36 |
+
randomly_init: bool = True # we load the weights from a corresponding ckpt
|
| 37 |
+
|
| 38 |
+
num_views: int = 4
|
| 39 |
+
uv_num_views: int = 1
|
| 40 |
+
mv_height: int = 512
|
| 41 |
+
mv_width: int = 512
|
| 42 |
+
uv_height: int = 1024
|
| 43 |
+
uv_width: int = 1024
|
| 44 |
+
|
| 45 |
+
flow_shift: float = 5.0
|
| 46 |
+
eval_guidance_scale: float = 1.0
|
| 47 |
+
eval_num_inference_steps: int = 30
|
| 48 |
+
eval_seed: int = 42
|
| 49 |
+
|
| 50 |
+
lora_rank: int = 128
|
| 51 |
+
lora_alpha: int = 64
|
| 52 |
+
|
| 53 |
+
cfg = Config()
|
| 54 |
+
|
| 55 |
+
def load_model_weights(model_path: str, map_location="cpu"):
|
| 56 |
+
"""
|
| 57 |
+
Load model weights from either a URL or local file path.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model_path (str): Path to model weights, can be URL or local file path
|
| 61 |
+
map_location (str): Device to map the model to
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Dict: Loaded state dictionary
|
| 65 |
+
"""
|
| 66 |
+
# Check if the path is a URL
|
| 67 |
+
parsed_url = urlparse(model_path)
|
| 68 |
+
if parsed_url.scheme in ('http', 'https'):
|
| 69 |
+
# Load from URL using torch.hub
|
| 70 |
+
try:
|
| 71 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 72 |
+
model_path,
|
| 73 |
+
map_location=map_location,
|
| 74 |
+
progress=True
|
| 75 |
+
)
|
| 76 |
+
return state_dict
|
| 77 |
+
except Exception as e:
|
| 78 |
+
gr.Warning(f"Failed to load from URL: {e}")
|
| 79 |
+
raise e
|
| 80 |
+
else:
|
| 81 |
+
# Load from local file path
|
| 82 |
+
if not os.path.exists(model_path):
|
| 83 |
+
raise FileNotFoundError(f"Local model file not found: {model_path}")
|
| 84 |
+
return torch.load(model_path, map_location=map_location)
|
| 85 |
+
|
| 86 |
+
def lazy_get_seqtex_pipe():
|
| 87 |
+
"""
|
| 88 |
+
Lazy load the SeqTex pipeline for texture generation.
|
| 89 |
+
"""
|
| 90 |
+
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
|
| 91 |
+
if TEX_PIPE is not None:
|
| 92 |
+
return TEX_PIPE
|
| 93 |
+
gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.")
|
| 94 |
+
with TEX_PIPE_LOCK:
|
| 95 |
+
if TEX_PIPE is not None:
|
| 96 |
+
return TEX_PIPE
|
| 97 |
+
|
| 98 |
+
# Pipeline
|
| 99 |
+
TEX_PIPE = WanT2TexPipeline.from_pretrained(cfg.video_base_name)
|
| 100 |
+
|
| 101 |
+
# Models
|
| 102 |
+
transformer = WanT2TexTransformer3DModel(
|
| 103 |
+
TEX_PIPE.transformer,
|
| 104 |
+
use_causal_mask=cfg.use_causal_mask,
|
| 105 |
+
addtional_qk_geo=cfg.addtional_qk_geometry,
|
| 106 |
+
use_normal=cfg.use_normal,
|
| 107 |
+
use_position=cfg.use_position,
|
| 108 |
+
randomly_init=cfg.randomly_init,
|
| 109 |
+
)
|
| 110 |
+
transformer.add_adapter(
|
| 111 |
+
LoraConfig(
|
| 112 |
+
r=cfg.lora_rank,
|
| 113 |
+
lora_alpha=cfg.lora_alpha,
|
| 114 |
+
init_lora_weights=True,
|
| 115 |
+
target_modules=["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0", "attn1.to_out.2",
|
| 116 |
+
"ffn.net.0.proj", "ffn.net.2"],
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
# load transformer
|
| 120 |
+
state_dict = load_model_weights(cfg.seqtex_path, map_location="cpu")
|
| 121 |
+
transformer.load_state_dict(state_dict, strict=True)
|
| 122 |
+
TEX_PIPE.transformer = transformer
|
| 123 |
+
|
| 124 |
+
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32).to("cuda").requires_grad_(False)
|
| 125 |
+
TEX_PIPE.vae = VAE
|
| 126 |
+
|
| 127 |
+
# Some useful parameters
|
| 128 |
+
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
|
| 129 |
+
1, VAE.config.z_dim, 1, 1, 1
|
| 130 |
+
).to("cuda", dtype=torch.float32)
|
| 131 |
+
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
|
| 132 |
+
1, VAE.config.z_dim, 1, 1, 1
|
| 133 |
+
).to("cuda", dtype=torch.float32)
|
| 134 |
+
|
| 135 |
+
scheduler: FlowMatchEulerDiscreteScheduler = (
|
| 136 |
+
FlowMatchEulerDiscreteScheduler.from_config(
|
| 137 |
+
TEX_PIPE.scheduler.config, shift=cfg.flow_shift
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically
|
| 141 |
+
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
|
| 142 |
+
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
|
| 143 |
+
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
|
| 144 |
+
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
|
| 145 |
+
|
| 146 |
+
TEX_PIPE = TEX_PIPE.to("cuda", dtype=torch.float32) # use float32 for inference
|
| 147 |
+
return TEX_PIPE
|
| 148 |
+
|
| 149 |
+
@torch.amp.autocast('cuda', dtype=torch.float32)
|
| 150 |
+
def encode_images(
|
| 151 |
+
images: Float[Tensor, "B F H W C"], encode_as_first: bool = False
|
| 152 |
+
) -> Float[Tensor, "B C' F H/8 W/8"]:
|
| 153 |
+
"""
|
| 154 |
+
Encode images to latent space using VAE.
|
| 155 |
+
Every frame is seen as a separate image, without any awareness of the temporal dimension.
|
| 156 |
+
:param images: Input images tensor with shape [B, F, H, W, C].
|
| 157 |
+
:param encode_as_first: Whether to encode all frames as the first frame.
|
| 158 |
+
:return: Encoded latents with shape [B, C', F, H/8, W/8].
|
| 159 |
+
"""
|
| 160 |
+
if images.min() < - 0.1:
|
| 161 |
+
# images are in [-1, 1] range
|
| 162 |
+
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
|
| 163 |
+
if encode_as_first:
|
| 164 |
+
# encode all the frame as the first one
|
| 165 |
+
B = images.shape[0]
|
| 166 |
+
images = rearrange(images, "B F H W C -> (B F) C 1 H W")
|
| 167 |
+
latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD
|
| 168 |
+
latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B)
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError("Currently only support encode as first frame.")
|
| 171 |
+
|
| 172 |
+
return latents
|
| 173 |
+
|
| 174 |
+
# @torch.no_grad()
|
| 175 |
+
# @torch.amp.autocast('cuda', dtype=torch.float32)
|
| 176 |
+
# def decode_images(self, latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
|
| 177 |
+
# if decode_as_first:
|
| 178 |
+
# F = latents.shape[2]
|
| 179 |
+
# latents = latents.to(self.vae.dtype)
|
| 180 |
+
# latents = latents / self.latents_std + self.latents_mean
|
| 181 |
+
# latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
|
| 182 |
+
# images = self.vae.decode(latents, return_dict=False)[0]
|
| 183 |
+
# images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
|
| 184 |
+
# else:
|
| 185 |
+
# raise NotImplementedError("Currently only support decode as first frame.")
|
| 186 |
+
# return images
|
| 187 |
+
@torch.amp.autocast('cuda', dtype=torch.float32)
|
| 188 |
+
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
|
| 189 |
+
"""
|
| 190 |
+
Decode latents back to images using VAE.
|
| 191 |
+
:param latents: Input latents with shape [B, C, F, H, W].
|
| 192 |
+
:param decode_as_first: Whether to decode all frames as the first frame.
|
| 193 |
+
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
|
| 194 |
+
"""
|
| 195 |
+
if decode_as_first:
|
| 196 |
+
F = latents.shape[2]
|
| 197 |
+
latents = latents.to(VAE.dtype)
|
| 198 |
+
latents = latents / LATENTS_STD + LATENTS_MEAN
|
| 199 |
+
latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
|
| 200 |
+
images = VAE.decode(latents, return_dict=False)[0]
|
| 201 |
+
images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
|
| 202 |
+
else:
|
| 203 |
+
raise NotImplementedError("Currently only support decode as first frame.")
|
| 204 |
+
return images
|
| 205 |
+
|
| 206 |
+
def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]:
|
| 207 |
+
"""
|
| 208 |
+
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
|
| 209 |
+
:param image: PIL Image to convert. [0, 255]
|
| 210 |
+
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
|
| 211 |
+
"""
|
| 212 |
+
# Convert to RGBA to ensure alpha channel exists
|
| 213 |
+
image = image.convert("RGBA")
|
| 214 |
+
np_img = np.array(image)
|
| 215 |
+
rgb = np_img[..., :3]
|
| 216 |
+
alpha = np_img[..., 3:4] / 255.0 # Normalize alpha to [0, 1]
|
| 217 |
+
# Blend with black background using alpha mask
|
| 218 |
+
rgb = rgb * alpha
|
| 219 |
+
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
| 220 |
+
tensor = torch.from_numpy(rgb).to(device)
|
| 221 |
+
return tensor
|
| 222 |
+
|
| 223 |
+
@spaces.GPU(duration=120)
|
| 224 |
+
@torch.cuda.amp.autocast(dtype=torch.float32)
|
| 225 |
+
@torch.inference_mode
|
| 226 |
+
@torch.no_grad
|
| 227 |
+
def generate_texture(position_map, normal_map, position_images, normal_images, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
|
| 228 |
+
"""
|
| 229 |
+
Use SeqTex to generate texture for the mesh based on the image condition.
|
| 230 |
+
:param position_images: List of position images from different views.
|
| 231 |
+
:param normal_images: List of normal images from different views.
|
| 232 |
+
:param condition_image: Image condition generated from the selected view.
|
| 233 |
+
:param text_prompt: Text prompt for texture generation.
|
| 234 |
+
:param selected_view: The view selected for generating the image condition.
|
| 235 |
+
:return: Generated texture map, and multi-view frames in tensor.
|
| 236 |
+
"""
|
| 237 |
+
progress(0, desc="Loading SeqTex pipeline...")
|
| 238 |
+
tex_pipe = lazy_get_seqtex_pipe()
|
| 239 |
+
progress(0.2, desc="SeqTex pipeline loaded successfully.")
|
| 240 |
+
view_id_map = {
|
| 241 |
+
"First View": 0,
|
| 242 |
+
"Second View": 1,
|
| 243 |
+
"Third View": 2,
|
| 244 |
+
"Fourth View": 3
|
| 245 |
+
}
|
| 246 |
+
view_id = view_id_map[selected_view]
|
| 247 |
+
|
| 248 |
+
progress(0.3, desc="Encoding position and normal images...")
|
| 249 |
+
nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C
|
| 250 |
+
uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0)
|
| 251 |
+
nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W
|
| 252 |
+
uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W'
|
| 253 |
+
nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0)
|
| 254 |
+
uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0)
|
| 255 |
+
nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1)
|
| 256 |
+
uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1)
|
| 257 |
+
cond_model_latents = (nat_geo_latents, uv_geo_latents)
|
| 258 |
+
|
| 259 |
+
num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample))
|
| 260 |
+
uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample))
|
| 261 |
+
|
| 262 |
+
progress(0.4, desc="Encoding condition image...")
|
| 263 |
+
if isinstance(condition_image, Image.Image):
|
| 264 |
+
condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS)
|
| 265 |
+
# Convert PIL Image to tensor
|
| 266 |
+
condition_image = convert_img_to_tensor(condition_image, device=device)
|
| 267 |
+
condition_image = condition_image.unsqueeze(0).unsqueeze(0)
|
| 268 |
+
gt_latents = (encode_images(condition_image, encode_as_first=True), None)
|
| 269 |
+
|
| 270 |
+
progress(0.5, desc="Generating texture with SeqTex...")
|
| 271 |
+
latents = tex_pipe(
|
| 272 |
+
prompt=text_prompt,
|
| 273 |
+
negative_prompt=negative_prompt,
|
| 274 |
+
num_frames=num_frames,
|
| 275 |
+
generator=torch.Generator(device=device).manual_seed(cfg.eval_seed),
|
| 276 |
+
num_inference_steps=cfg.eval_num_inference_steps,
|
| 277 |
+
guidance_scale=cfg.eval_guidance_scale,
|
| 278 |
+
height=cfg.mv_height,
|
| 279 |
+
width=cfg.mv_width,
|
| 280 |
+
output_type="latent",
|
| 281 |
+
|
| 282 |
+
cond_model_latents=cond_model_latents,
|
| 283 |
+
# mask_indices=test_mask_indices,
|
| 284 |
+
uv_height=cfg.uv_height,
|
| 285 |
+
uv_width=cfg.uv_width,
|
| 286 |
+
uv_num_frames=uv_num_frames,
|
| 287 |
+
treat_as_first=True,
|
| 288 |
+
gt_condition=gt_latents,
|
| 289 |
+
inference_img_cond_frame=view_id,
|
| 290 |
+
use_qk_geometry=True,
|
| 291 |
+
task_type="img2tex", # img2tex
|
| 292 |
+
progress=progress,
|
| 293 |
+
).frames
|
| 294 |
+
|
| 295 |
+
mv_latents, uv_latents = latents
|
| 296 |
+
|
| 297 |
+
progress(0.9, desc="Decoding generated latents to images...")
|
| 298 |
+
mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W
|
| 299 |
+
uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W
|
| 300 |
+
|
| 301 |
+
uv_map_pred = uv_frames[:, :, -1, ...]
|
| 302 |
+
uv_map_pred.squeeze_(0)
|
| 303 |
+
mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0]
|
| 304 |
+
|
| 305 |
+
mv_out = torch.clamp(mv_out, 0.0, 1.0)
|
| 306 |
+
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
|
| 307 |
+
|
| 308 |
+
progress(1, desc="Texture generated successfully.")
|
| 309 |
+
return uv_map_pred.float(), mv_out.float(), "Step 3: Texture generated successfully."
|
wan/__init__.py
ADDED
|
File without changes
|
wan/pipeline_wan_t2tex_extra.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
| 3 |
+
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import regex as re
|
| 6 |
+
import torch
|
| 7 |
+
from diffusers.pipelines.wan.pipeline_wan import WanPipeline
|
| 8 |
+
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
| 9 |
+
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
|
| 10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from transformers import AutoTokenizer, UMT5EncoderModel
|
| 13 |
+
from jaxtyping import Float
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
def get_sigmas(scheduler, timesteps, dtype=torch.float32, device="cuda"):
|
| 17 |
+
sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
|
| 18 |
+
schedule_timesteps = scheduler.timesteps.to(device)
|
| 19 |
+
timesteps = timesteps.to(device)
|
| 20 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 21 |
+
|
| 22 |
+
sigma = sigmas[step_indices].flatten()
|
| 23 |
+
return sigma
|
| 24 |
+
|
| 25 |
+
class WanT2TexPipeline(WanPipeline):
|
| 26 |
+
def __init__(self, tokenizer, text_encoder, transformer, vae, scheduler):
|
| 27 |
+
super().__init__(tokenizer, text_encoder, transformer, vae, scheduler)
|
| 28 |
+
self.uv_scheduler = copy.deepcopy(scheduler)
|
| 29 |
+
|
| 30 |
+
def prepare_latents(
|
| 31 |
+
self,
|
| 32 |
+
batch_size: int,
|
| 33 |
+
num_channels_latents: int = 16,
|
| 34 |
+
height: int = 480,
|
| 35 |
+
width: int = 832,
|
| 36 |
+
num_frames: int = 81,
|
| 37 |
+
dtype: Optional[torch.dtype] = None,
|
| 38 |
+
device: Optional[torch.device] = None,
|
| 39 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 40 |
+
latents: Optional[torch.Tensor] = None,
|
| 41 |
+
treat_as_first: Optional[bool] = True,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
+
if latents is not None:
|
| 44 |
+
return latents.to(device=device, dtype=dtype)
|
| 45 |
+
|
| 46 |
+
####################
|
| 47 |
+
if treat_as_first:
|
| 48 |
+
num_latent_frames = num_frames // self.vae_scale_factor_temporal
|
| 49 |
+
else:
|
| 50 |
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 51 |
+
####################
|
| 52 |
+
|
| 53 |
+
shape = (
|
| 54 |
+
batch_size,
|
| 55 |
+
num_channels_latents,
|
| 56 |
+
num_latent_frames,
|
| 57 |
+
int(height) // self.vae_scale_factor_spatial,
|
| 58 |
+
int(width) // self.vae_scale_factor_spatial,
|
| 59 |
+
)
|
| 60 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 63 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 67 |
+
return latents
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def __call__(
|
| 71 |
+
self,
|
| 72 |
+
prompt: Union[str, List[str]] = None,
|
| 73 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 74 |
+
height: int = 480,
|
| 75 |
+
width: int = 832,
|
| 76 |
+
num_frames: int = 81,
|
| 77 |
+
num_inference_steps: int = 50,
|
| 78 |
+
guidance_scale: float = 5.0,
|
| 79 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 80 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 81 |
+
latents: Optional[torch.Tensor] = None,
|
| 82 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 83 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 84 |
+
output_type: Optional[str] = "np",
|
| 85 |
+
return_dict: bool = True,
|
| 86 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 87 |
+
callback_on_step_end: Optional[
|
| 88 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 89 |
+
] = None,
|
| 90 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 91 |
+
max_sequence_length: int = 512,
|
| 92 |
+
cond_model_latents: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 93 |
+
uv_height=None,
|
| 94 |
+
uv_width=None,
|
| 95 |
+
uv_num_frames=None,
|
| 96 |
+
# multi_task_cond=None,
|
| 97 |
+
treat_as_first=True,
|
| 98 |
+
gt_condition:Tuple[Optional[Float[Tensor, "B C F H W"]], Optional[Float[Tensor, "B C F H W"]]]=None,
|
| 99 |
+
inference_img_cond_frame=None,
|
| 100 |
+
use_qk_geometry=False,
|
| 101 |
+
task_type="all",
|
| 102 |
+
progress=gr.Progress()
|
| 103 |
+
):
|
| 104 |
+
r"""
|
| 105 |
+
The call function to the pipeline for generation.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 109 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 110 |
+
instead.
|
| 111 |
+
height (`int`, defaults to `480`):
|
| 112 |
+
The height in pixels of the generated image.
|
| 113 |
+
width (`int`, defaults to `832`):
|
| 114 |
+
The width in pixels of the generated image.
|
| 115 |
+
num_frames (`int`, defaults to `81`):
|
| 116 |
+
The number of frames in the generated video.
|
| 117 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 118 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 119 |
+
expense of slower inference.
|
| 120 |
+
guidance_scale (`float`, defaults to `5.0`):
|
| 121 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 122 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 123 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 124 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 125 |
+
usually at the expense of lower image quality.
|
| 126 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 127 |
+
The number of images to generate per prompt.
|
| 128 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 129 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 130 |
+
generation deterministic.
|
| 131 |
+
latents (`torch.Tensor`, *optional*):
|
| 132 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 133 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 134 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 135 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 136 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 137 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 138 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 139 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 140 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 141 |
+
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
| 142 |
+
attention_kwargs (`dict`, *optional*):
|
| 143 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 144 |
+
`self.processor` in
|
| 145 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 146 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 147 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 148 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 149 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 150 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 151 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 152 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 153 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 154 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 155 |
+
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
| 156 |
+
The dtype to use for the torch.amp.autocast.
|
| 157 |
+
|
| 158 |
+
Examples:
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
[`~WanPipelineOutput`] or `tuple`:
|
| 162 |
+
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
| 163 |
+
the first element is a list with the generated images and the second element is a list of `bool`s
|
| 164 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 168 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 169 |
+
|
| 170 |
+
# 1. Check inputs. Raise error if not correct
|
| 171 |
+
self.check_inputs(
|
| 172 |
+
prompt,
|
| 173 |
+
negative_prompt,
|
| 174 |
+
height,
|
| 175 |
+
width,
|
| 176 |
+
prompt_embeds,
|
| 177 |
+
negative_prompt_embeds,
|
| 178 |
+
callback_on_step_end_tensor_inputs,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# ATTENTION: My inputs are images, so the num_frames is 5, without time dimension compression.
|
| 182 |
+
# if num_frames % self.vae_scale_factor_temporal != 1:
|
| 183 |
+
# raise ValueError(
|
| 184 |
+
# f"num_frames should be divisible by {self.vae_scale_factor_temporal} + 1, but got {num_frames}."
|
| 185 |
+
# )
|
| 186 |
+
# num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
| 187 |
+
# num_frames = max(num_frames, 1)
|
| 188 |
+
|
| 189 |
+
self._guidance_scale = guidance_scale
|
| 190 |
+
self._attention_kwargs = attention_kwargs
|
| 191 |
+
self._current_timestep = None
|
| 192 |
+
self._interrupt = False
|
| 193 |
+
|
| 194 |
+
device = self._execution_device
|
| 195 |
+
|
| 196 |
+
# 2. Define call parameters
|
| 197 |
+
if prompt is not None and isinstance(prompt, str):
|
| 198 |
+
batch_size = 1
|
| 199 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 200 |
+
batch_size = len(prompt)
|
| 201 |
+
else:
|
| 202 |
+
batch_size = prompt_embeds.shape[0]
|
| 203 |
+
|
| 204 |
+
# 3. Encode input prompt
|
| 205 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 206 |
+
prompt=prompt,
|
| 207 |
+
negative_prompt=negative_prompt,
|
| 208 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 209 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 210 |
+
prompt_embeds=prompt_embeds,
|
| 211 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 212 |
+
max_sequence_length=max_sequence_length,
|
| 213 |
+
device=device,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
transformer_dtype = self.transformer.dtype
|
| 217 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 218 |
+
if self.do_classifier_free_guidance:
|
| 219 |
+
if negative_prompt_embeds is not None:
|
| 220 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 221 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 222 |
+
|
| 223 |
+
# 4. Prepare timesteps
|
| 224 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 225 |
+
self.uv_scheduler.set_timesteps(num_inference_steps, device=device)
|
| 226 |
+
timesteps = self.scheduler.timesteps
|
| 227 |
+
|
| 228 |
+
# 5. Prepare latent variables
|
| 229 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 230 |
+
mv_latents = self.prepare_latents(
|
| 231 |
+
batch_size * num_videos_per_prompt,
|
| 232 |
+
num_channels_latents,
|
| 233 |
+
height,
|
| 234 |
+
width,
|
| 235 |
+
num_frames,
|
| 236 |
+
torch.float32,
|
| 237 |
+
device,
|
| 238 |
+
generator,
|
| 239 |
+
treat_as_first=treat_as_first,
|
| 240 |
+
)
|
| 241 |
+
uv_latents = self.prepare_latents(
|
| 242 |
+
batch_size * num_videos_per_prompt,
|
| 243 |
+
num_channels_latents,
|
| 244 |
+
uv_height,
|
| 245 |
+
uv_width,
|
| 246 |
+
uv_num_frames,
|
| 247 |
+
torch.float32,
|
| 248 |
+
device,
|
| 249 |
+
generator,
|
| 250 |
+
treat_as_first=True # UV latents are always different from the others, so treat as the first frame
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# 6. Denoising loop
|
| 254 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 255 |
+
self._num_timesteps = len(timesteps)
|
| 256 |
+
|
| 257 |
+
# with progress.tqdm(total=num_inference_steps, desc="Diffusing...") as progress_bar:
|
| 258 |
+
for i, t in progress.tqdm(enumerate(timesteps), desc="Diffusing..."):
|
| 259 |
+
if self.interrupt:
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
# set conditions
|
| 263 |
+
timestep_df = torch.ones((batch_size, num_frames // self.vae_scale_factor_temporal + 1)).to(device) * t
|
| 264 |
+
sigmas = get_sigmas(self.scheduler, rearrange(timestep_df, "B F -> (B F)"), dtype=transformer_dtype, device=device)
|
| 265 |
+
sigmas = rearrange(sigmas, "(B F) -> B 1 F 1 1", B=batch_size)
|
| 266 |
+
match task_type:
|
| 267 |
+
case "geo+mv2tex":
|
| 268 |
+
timestep_df[:, :num_frames // self.vae_scale_factor_temporal] = self.min_noise_level_timestep
|
| 269 |
+
sigmas[:, :, :num_frames // self.vae_scale_factor_temporal, ...] = self.min_noise_level_sigma
|
| 270 |
+
mv_noise = torch.randn_like(mv_latents) # B C 4 H W
|
| 271 |
+
mv_latents = (1.0 - sigmas[:, :, :-1, ...]) * gt_condition[0] + sigmas[:, :, :-1, ...] * mv_noise
|
| 272 |
+
case "img2tex":
|
| 273 |
+
assert inference_img_cond_frame is not None, "inference_img_cond_frame should be specified for img2tex task"
|
| 274 |
+
# Use specified frame index as condition instead of just first frame
|
| 275 |
+
timestep_df[:, inference_img_cond_frame: inference_img_cond_frame + 1] = self.min_noise_level_timestep
|
| 276 |
+
sigmas[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...] = self.min_noise_level_sigma
|
| 277 |
+
mv_noise = randn_tensor(mv_latents[:, :, inference_img_cond_frame: inference_img_cond_frame + 1].shape, generator=generator, device=device, dtype=self.dtype)
|
| 278 |
+
# mv_noise = torch.randn_like(mv_latents[:, :, inference_img_cond_frame: inference_img_cond_frame + 1], generator=generator) # B C selected_frames H W
|
| 279 |
+
mv_latents[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...] = (1.0 - sigmas[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...]) * gt_condition[0] + sigmas[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...] * mv_noise
|
| 280 |
+
case "soft_render":
|
| 281 |
+
timestep_df[:, -1:] = self.min_noise_level_timestep
|
| 282 |
+
sigmas[:, :, -1:, ...] = self.min_noise_level_sigma
|
| 283 |
+
uv_noise = torch.randn_like(uv_latents) # B C 1 H W
|
| 284 |
+
uv_latents = (1.0 - sigmas[:, :, -1:, ...]) * gt_condition[1] + sigmas[:, :, -1:, ...] * uv_noise
|
| 285 |
+
case "geo2mv":
|
| 286 |
+
timestep_df[:, -1:] = 1000.
|
| 287 |
+
sigmas[:, :, -1:, ...] = 1.
|
| 288 |
+
case _:
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
# add geometry information to channel C
|
| 292 |
+
mv_latents_input = torch.cat([mv_latents, cond_model_latents[0]], dim=1)
|
| 293 |
+
uv_latents_input = torch.cat([uv_latents, cond_model_latents[1]], dim=1)
|
| 294 |
+
if self.do_classifier_free_guidance:
|
| 295 |
+
mv_latents_input = torch.cat([mv_latents_input, mv_latents_input], dim=0)
|
| 296 |
+
uv_latents_input = torch.cat([uv_latents_input, uv_latents_input], dim=0)
|
| 297 |
+
|
| 298 |
+
self._current_timestep = t
|
| 299 |
+
latent_model_input = (mv_latents_input.to(transformer_dtype), uv_latents_input.to(transformer_dtype))
|
| 300 |
+
# timestep = t.expand(mv_latents.shape[0])
|
| 301 |
+
|
| 302 |
+
noise_out = self.transformer(
|
| 303 |
+
hidden_states=latent_model_input,
|
| 304 |
+
timestep=timestep_df,
|
| 305 |
+
encoder_hidden_states=prompt_embeds,
|
| 306 |
+
attention_kwargs=attention_kwargs,
|
| 307 |
+
# task_cond=multi_task_cond,
|
| 308 |
+
return_dict=False,
|
| 309 |
+
use_qk_geometry=use_qk_geometry
|
| 310 |
+
)[0]
|
| 311 |
+
mv_noise_out, uv_noise_out = noise_out
|
| 312 |
+
|
| 313 |
+
if self.do_classifier_free_guidance:
|
| 314 |
+
mv_noise_uncond, mv_noise_pred = mv_noise_out.chunk(2)
|
| 315 |
+
uv_noise_uncond, uv_noise_pred = uv_noise_out.chunk(2)
|
| 316 |
+
mv_noise_pred = mv_noise_uncond + guidance_scale * (mv_noise_pred - mv_noise_uncond)
|
| 317 |
+
uv_noise_pred = uv_noise_uncond + guidance_scale * (uv_noise_pred - uv_noise_uncond)
|
| 318 |
+
else:
|
| 319 |
+
mv_noise_pred = mv_noise_out
|
| 320 |
+
uv_noise_pred = uv_noise_out
|
| 321 |
+
|
| 322 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 323 |
+
# The conditions will be replaced anyway, so perhaps we don't need to step frames seperately
|
| 324 |
+
mv_latents = self.scheduler.step(mv_noise_pred, t, mv_latents, return_dict=False)[0]
|
| 325 |
+
uv_latents = self.uv_scheduler.step(uv_noise_pred, t, uv_latents, return_dict=False)[0]
|
| 326 |
+
|
| 327 |
+
if callback_on_step_end is not None:
|
| 328 |
+
raise NotImplementedError()
|
| 329 |
+
callback_kwargs = {}
|
| 330 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 331 |
+
callback_kwargs[k] = locals()[k]
|
| 332 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 333 |
+
|
| 334 |
+
latents = callback_outputs.pop("latents", latents)
|
| 335 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 336 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 337 |
+
|
| 338 |
+
# # call the callback, if provided
|
| 339 |
+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 340 |
+
# progress_bar.update()
|
| 341 |
+
|
| 342 |
+
self._current_timestep = None
|
| 343 |
+
|
| 344 |
+
if not output_type == "latent":
|
| 345 |
+
latents = latents.to(self.vae.dtype)
|
| 346 |
+
latents_mean = (
|
| 347 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 348 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 349 |
+
.to(latents.device, latents.dtype)
|
| 350 |
+
)
|
| 351 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 352 |
+
latents.device, latents.dtype
|
| 353 |
+
)
|
| 354 |
+
latents = latents / latents_std + latents_mean
|
| 355 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
| 356 |
+
# video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 357 |
+
else:
|
| 358 |
+
video = (mv_latents, uv_latents)
|
| 359 |
+
|
| 360 |
+
# Offload all models
|
| 361 |
+
self.maybe_free_model_hooks()
|
| 362 |
+
|
| 363 |
+
if not return_dict:
|
| 364 |
+
return (video,)
|
| 365 |
+
|
| 366 |
+
return WanPipelineOutput(frames=video)
|
wan/wan_t2tex_transformer_3d_extra.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import copy
|
| 16 |
+
import math
|
| 17 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 18 |
+
from functools import cache
|
| 19 |
+
|
| 20 |
+
from einops import rearrange, repeat
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 26 |
+
from diffusers.models import WanTransformer3DModel
|
| 27 |
+
from diffusers.models.attention import FeedForward
|
| 28 |
+
from diffusers.models.attention_processor import Attention
|
| 29 |
+
from diffusers.models.cache_utils import CacheMixin
|
| 30 |
+
from diffusers.models.embeddings import (PixArtAlphaTextProjection,
|
| 31 |
+
TimestepEmbedding, Timesteps,
|
| 32 |
+
get_1d_rotary_pos_embed)
|
| 33 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 34 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 35 |
+
from diffusers.models.normalization import FP32LayerNorm
|
| 36 |
+
from diffusers.models.transformers.transformer_wan import \
|
| 37 |
+
WanTimeTextImageEmbedding
|
| 38 |
+
from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
|
| 39 |
+
unscale_lora_layers)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class WanT2TexAttnProcessor2_0:
|
| 43 |
+
def __init__(self):
|
| 44 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 45 |
+
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
| 46 |
+
|
| 47 |
+
def __call__(
|
| 48 |
+
self,
|
| 49 |
+
attn: Attention,
|
| 50 |
+
hidden_states: torch.Tensor,
|
| 51 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 52 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 53 |
+
rotary_emb: Optional[torch.Tensor] = None,
|
| 54 |
+
geometry_embedding: Optional[torch.Tensor] = None,
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
encoder_hidden_states_img = None
|
| 57 |
+
if attn.add_k_proj is not None:
|
| 58 |
+
encoder_hidden_states_img = encoder_hidden_states[:, :257]
|
| 59 |
+
encoder_hidden_states = encoder_hidden_states[:, 257:]
|
| 60 |
+
if encoder_hidden_states is None:
|
| 61 |
+
encoder_hidden_states = hidden_states
|
| 62 |
+
|
| 63 |
+
query = attn.to_q(hidden_states)
|
| 64 |
+
key = attn.to_k(encoder_hidden_states)
|
| 65 |
+
value = attn.to_v(encoder_hidden_states)
|
| 66 |
+
|
| 67 |
+
if attn.norm_q is not None:
|
| 68 |
+
query = attn.norm_q(query)
|
| 69 |
+
if attn.norm_k is not None:
|
| 70 |
+
key = attn.norm_k(key)
|
| 71 |
+
|
| 72 |
+
if geometry_embedding is not None:
|
| 73 |
+
# add-type geometry embedding
|
| 74 |
+
if True:
|
| 75 |
+
if isinstance(geometry_embedding, Tuple):
|
| 76 |
+
query = query + geometry_embedding[0]
|
| 77 |
+
key = key + geometry_embedding[1]
|
| 78 |
+
else:
|
| 79 |
+
query = query + geometry_embedding
|
| 80 |
+
key = key + geometry_embedding
|
| 81 |
+
else:
|
| 82 |
+
# mul-type geometry embedding
|
| 83 |
+
if isinstance(geometry_embedding, Tuple):
|
| 84 |
+
query = query * (1 + geometry_embedding[0])
|
| 85 |
+
key = key * (1 + geometry_embedding[1])
|
| 86 |
+
else:
|
| 87 |
+
query = query * (1 + geometry_embedding)
|
| 88 |
+
key = key * (1 + geometry_embedding)
|
| 89 |
+
|
| 90 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) # [B, F*H*W, 2C] -> [B, H, F*H*W, 2C//H]
|
| 91 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 92 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 93 |
+
|
| 94 |
+
if rotary_emb is not None:
|
| 95 |
+
|
| 96 |
+
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
| 97 |
+
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
| 98 |
+
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
| 99 |
+
return x_out.type_as(hidden_states)
|
| 100 |
+
|
| 101 |
+
if isinstance(rotary_emb, Tuple):
|
| 102 |
+
query = apply_rotary_emb(query, rotary_emb[0])
|
| 103 |
+
key = apply_rotary_emb(key, rotary_emb[1])
|
| 104 |
+
else:
|
| 105 |
+
query = apply_rotary_emb(query, rotary_emb)
|
| 106 |
+
key = apply_rotary_emb(key, rotary_emb)
|
| 107 |
+
|
| 108 |
+
# I2V task
|
| 109 |
+
hidden_states_img = None
|
| 110 |
+
if encoder_hidden_states_img is not None:
|
| 111 |
+
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
| 112 |
+
key_img = attn.norm_added_k(key_img)
|
| 113 |
+
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
| 114 |
+
|
| 115 |
+
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 116 |
+
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 117 |
+
|
| 118 |
+
hidden_states_img = F.scaled_dot_product_attention(
|
| 119 |
+
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 120 |
+
)
|
| 121 |
+
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
| 122 |
+
hidden_states_img = hidden_states_img.type_as(query)
|
| 123 |
+
|
| 124 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 125 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 126 |
+
)
|
| 127 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 128 |
+
hidden_states = hidden_states.type_as(query)
|
| 129 |
+
|
| 130 |
+
if hidden_states_img is not None:
|
| 131 |
+
hidden_states = hidden_states + hidden_states_img
|
| 132 |
+
|
| 133 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 134 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 135 |
+
return hidden_states
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class WanTimeTaskTextImageEmbedding(WanTimeTextImageEmbedding):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
original_model,
|
| 142 |
+
dim: int,
|
| 143 |
+
time_freq_dim: int,
|
| 144 |
+
time_proj_dim: int,
|
| 145 |
+
text_embed_dim: int,
|
| 146 |
+
image_embed_dim: Optional[int] = None,
|
| 147 |
+
randomly_init: bool = False,
|
| 148 |
+
):
|
| 149 |
+
super(WanTimeTaskTextImageEmbedding, self).__init__(dim, time_freq_dim, time_proj_dim, text_embed_dim, image_embed_dim)
|
| 150 |
+
if not randomly_init:
|
| 151 |
+
self.load_state_dict(original_model.state_dict(), strict=True)
|
| 152 |
+
# cond_proj = nn.Linear(512, original_model.timesteps_proj.num_channels, bias=False)
|
| 153 |
+
# setattr(self.time_embedder, "cond_proj", cond_proj)
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
timestep: torch.Tensor,
|
| 158 |
+
encoder_hidden_states: torch.Tensor,
|
| 159 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 160 |
+
# time_cond: Optional[torch.Tensor] = None,
|
| 161 |
+
):
|
| 162 |
+
B = timestep.shape[0]
|
| 163 |
+
timestep = rearrange(timestep, "B F -> (B F)")
|
| 164 |
+
timestep = self.timesteps_proj(timestep)
|
| 165 |
+
timestep = rearrange(timestep, "(B F) D -> B F D", B=B)
|
| 166 |
+
|
| 167 |
+
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
| 168 |
+
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
| 169 |
+
timestep = timestep.to(time_embedder_dtype)
|
| 170 |
+
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
| 171 |
+
timestep_proj = self.time_proj(self.act_fn(temb))
|
| 172 |
+
|
| 173 |
+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
| 174 |
+
if encoder_hidden_states_image is not None:
|
| 175 |
+
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
| 176 |
+
|
| 177 |
+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class WanRotaryPosEmbed(nn.Module):
|
| 181 |
+
def __init__(
|
| 182 |
+
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, addtional_qk_geo: bool = False
|
| 183 |
+
):
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
if addtional_qk_geo: # to add PE to geometry embedding
|
| 187 |
+
attention_head_dim = attention_head_dim * 2
|
| 188 |
+
self.attention_head_dim = attention_head_dim
|
| 189 |
+
self.patch_size = patch_size
|
| 190 |
+
self.max_seq_len = max_seq_len
|
| 191 |
+
|
| 192 |
+
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
| 193 |
+
t_dim = attention_head_dim - h_dim - w_dim
|
| 194 |
+
|
| 195 |
+
freqs = []
|
| 196 |
+
for dim in [t_dim, h_dim, w_dim]:
|
| 197 |
+
freq = get_1d_rotary_pos_embed(
|
| 198 |
+
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
| 199 |
+
)
|
| 200 |
+
freqs.append(freq)
|
| 201 |
+
self.freqs = torch.cat(freqs, dim=1)
|
| 202 |
+
|
| 203 |
+
def forward(self, hidden_states: torch.Tensor, uv_hidden_states: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 205 |
+
_, _, uv_num_frames, uv_height, uv_width = uv_hidden_states.shape
|
| 206 |
+
p_t, p_h, p_w = self.patch_size
|
| 207 |
+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
| 208 |
+
uppf, upph, uppw = uv_num_frames // p_t, uv_height // p_h, uv_width // p_w
|
| 209 |
+
|
| 210 |
+
self.freqs = self.freqs.to(hidden_states.device)
|
| 211 |
+
freqs = self.freqs.split_with_sizes(
|
| 212 |
+
[
|
| 213 |
+
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
| 214 |
+
self.attention_head_dim // 6,
|
| 215 |
+
self.attention_head_dim // 6,
|
| 216 |
+
],
|
| 217 |
+
dim=1,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
| 221 |
+
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
| 222 |
+
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
| 223 |
+
|
| 224 |
+
uv_freqs_f = freqs[0][ppf:ppf+uppf].view(uppf, 1, 1, -1).expand(uppf, upph, uppw, -1)
|
| 225 |
+
uv_freqs_h = freqs[1][:upph].view(1, upph, 1, -1).expand(uppf, upph, uppw, -1)
|
| 226 |
+
uv_freqs_w = freqs[2][:uppw].view(1, 1, uppw, -1).expand(uppf, upph, uppw, -1)
|
| 227 |
+
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
| 228 |
+
uv_freqs = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(1, 1, uppf * upph * uppw, -1)
|
| 229 |
+
return torch.cat([freqs, uv_freqs], dim=-2)
|
| 230 |
+
|
| 231 |
+
# def pseudo_code(freqs, mv_tokens_shape, uv_tokens_shape, dimmension):
|
| 232 |
+
# """
|
| 233 |
+
# Input:
|
| 234 |
+
# freqs: [S, D/2], S is the number of tokens, D is the dimension of tokens, 2 indicates Cos and Sin in original RoPE.
|
| 235 |
+
# mv_tokens_shape: (mv_num_frames, mv_height, mv_width)
|
| 236 |
+
# uv_tokens_shape: (uv_num_frames, uv_height, uv_width)
|
| 237 |
+
# dimension: the dimension of tokens
|
| 238 |
+
# Output:
|
| 239 |
+
# """
|
| 240 |
+
# mpf, mph, mpw = mv_tokens_shape # mv_num_frames, mv_height, mv_width
|
| 241 |
+
# upf, uph, upw = uv_tokens_shape # uv_num_frames, uv_height, uv_width
|
| 242 |
+
|
| 243 |
+
# # 1. To evenly split the freqs into 3 parts
|
| 244 |
+
# freqs = freqs.split_with_sizes(
|
| 245 |
+
# [
|
| 246 |
+
# dimmension // 2 - 2 * (dimmension // 6),
|
| 247 |
+
# dimmension // 6,
|
| 248 |
+
# dimmension // 6,
|
| 249 |
+
# ],
|
| 250 |
+
# dim=1,
|
| 251 |
+
# )
|
| 252 |
+
|
| 253 |
+
# # 2. In time dimension, the freqs for UV are subsequent to the freqs for MV
|
| 254 |
+
# freqs_f = freqs[0][:mpf].view(mpf, 1, 1, -1).expand(mpf, mph, mpw, -1)
|
| 255 |
+
# uv_freqs_f = freqs[0][mpf:mpf+upf].view(upf, 1, 1, -1).expand(upf, uph, upw, -1)
|
| 256 |
+
|
| 257 |
+
# # 3. The freqs in height and width dimension are the same for mv and uv
|
| 258 |
+
# freqs_h = freqs[1][:mph].view(1, mph, 1, -1).expand(mpf, mph, mpw, -1)
|
| 259 |
+
# uv_freqs_h = freqs[1][:uph].view(1, uph, 1, -1).expand(upf, uph, upw, -1)
|
| 260 |
+
# freqs_w = freqs[2][:mpw].view(1, 1, mpw, -1).expand(mpf, mph, mpw, -1)
|
| 261 |
+
# uv_freqs_w = freqs[2][:upw].view(1, 1, upw, -1).expand(upf, uph, upw, -1)
|
| 262 |
+
|
| 263 |
+
# # 4. rearrange three 1D RoPEs into 3D RoPE in channel dimension
|
| 264 |
+
# mv_rope = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(mpf * mph * mpw, -1)
|
| 265 |
+
# uv_rope = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(upf * uph * upw, -1)
|
| 266 |
+
# return torch.cat([mv_rope, uv_rope], dim=-2)
|
| 267 |
+
|
| 268 |
+
class WanT2TexTransformerBlock(nn.Module):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
dim: int,
|
| 272 |
+
ffn_dim: int,
|
| 273 |
+
num_heads: int,
|
| 274 |
+
qk_norm: str = "rms_norm_across_heads",
|
| 275 |
+
cross_attn_norm: bool = False,
|
| 276 |
+
eps: float = 1e-6,
|
| 277 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 278 |
+
addtional_qk_geo: bool = False,
|
| 279 |
+
):
|
| 280 |
+
super().__init__()
|
| 281 |
+
|
| 282 |
+
# 1. Self-attention
|
| 283 |
+
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 284 |
+
self.attn1 = Attention(
|
| 285 |
+
query_dim=dim,
|
| 286 |
+
heads=num_heads,
|
| 287 |
+
kv_heads=num_heads,
|
| 288 |
+
dim_head=dim // num_heads,
|
| 289 |
+
qk_norm=qk_norm,
|
| 290 |
+
eps=eps,
|
| 291 |
+
bias=True,
|
| 292 |
+
cross_attention_dim=None,
|
| 293 |
+
out_bias=True,
|
| 294 |
+
processor=WanT2TexAttnProcessor2_0(),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# 2. Cross-attention
|
| 298 |
+
self.attn2 = Attention(
|
| 299 |
+
query_dim=dim,
|
| 300 |
+
heads=num_heads,
|
| 301 |
+
kv_heads=num_heads,
|
| 302 |
+
dim_head=dim // num_heads,
|
| 303 |
+
qk_norm=qk_norm,
|
| 304 |
+
eps=eps,
|
| 305 |
+
bias=True,
|
| 306 |
+
cross_attention_dim=None,
|
| 307 |
+
out_bias=True,
|
| 308 |
+
added_kv_proj_dim=added_kv_proj_dim,
|
| 309 |
+
added_proj_bias=True,
|
| 310 |
+
processor=WanT2TexAttnProcessor2_0(),
|
| 311 |
+
)
|
| 312 |
+
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 313 |
+
|
| 314 |
+
# 3. Feed-forward
|
| 315 |
+
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
| 316 |
+
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 317 |
+
|
| 318 |
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 319 |
+
|
| 320 |
+
self.geometry_caster = nn.Linear(dim, dim)
|
| 321 |
+
nn.init.zeros_(self.geometry_caster.weight.data)
|
| 322 |
+
nn.init.zeros_(self.geometry_caster.bias.data)
|
| 323 |
+
|
| 324 |
+
self.attnuv = Attention(
|
| 325 |
+
query_dim=dim,
|
| 326 |
+
heads=num_heads,
|
| 327 |
+
kv_heads=num_heads,
|
| 328 |
+
dim_head=dim // num_heads,
|
| 329 |
+
qk_norm=qk_norm,
|
| 330 |
+
eps=eps,
|
| 331 |
+
bias=True,
|
| 332 |
+
cross_attention_dim=None,
|
| 333 |
+
out_bias=True,
|
| 334 |
+
processor=WanT2TexAttnProcessor2_0(),
|
| 335 |
+
)
|
| 336 |
+
self.normuv2 = FP32LayerNorm(dim, eps, elementwise_affine=True)
|
| 337 |
+
self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 338 |
+
self.ffnuv = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
| 339 |
+
|
| 340 |
+
def forward(
|
| 341 |
+
self,
|
| 342 |
+
hidden_states: torch.Tensor,
|
| 343 |
+
encoder_hidden_states: torch.Tensor,
|
| 344 |
+
temb: torch.Tensor,
|
| 345 |
+
rotary_emb: torch.Tensor,
|
| 346 |
+
attn_bias: Optional[torch.Tensor] = None,
|
| 347 |
+
geometry_embedding: Optional[torch.Tensor] = None,
|
| 348 |
+
token_shape: Optional[Tuple[int, int, int, int, int, int]] = None,
|
| 349 |
+
) -> torch.Tensor:
|
| 350 |
+
post_patch_num_frames, post_patch_height, post_patch_width, post_uv_num_frames, post_uv_height, post_uv_width = token_shape
|
| 351 |
+
mv_temb, uv_temb = temb[:, :post_patch_num_frames], temb[:, post_patch_num_frames:]
|
| 352 |
+
mv_temb = repeat(mv_temb, "B F N D -> B N (F H W) D", H=post_patch_height, W=post_patch_width)
|
| 353 |
+
uv_temb = repeat(uv_temb, "B F N D -> B N (F H W) D", H=post_uv_height, W=post_uv_width)
|
| 354 |
+
dit_ssg = rearrange(self.scale_shift_table, "1 N D -> 1 N 1 D") + mv_temb.float()
|
| 355 |
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = torch.unbind(dit_ssg, dim=1)
|
| 356 |
+
dit_ssg_uv = rearrange(self.scale_shift_table_uv, "1 N D -> 1 N 1 D") + uv_temb.float()
|
| 357 |
+
shift_msa_uv, scale_msa_uv, gate_msa_uv, c_shift_msa_uv, c_scale_msa_uv, c_gate_msa_uv = torch.unbind(dit_ssg_uv, dim=1)
|
| 358 |
+
|
| 359 |
+
geometry_embedding = self.geometry_caster(geometry_embedding)
|
| 360 |
+
|
| 361 |
+
n_mv, n_uv = post_patch_num_frames * post_patch_height * post_patch_width, post_uv_num_frames * post_uv_height * post_uv_width
|
| 362 |
+
assert hidden_states.shape[1] == n_mv + n_uv, f"hidden_states shape {hidden_states.shape} is not equal to {n_mv + n_uv}"
|
| 363 |
+
mv_hidden_states, uv_hidden_states = hidden_states[:, :n_mv], hidden_states[:, n_mv:]
|
| 364 |
+
|
| 365 |
+
# 1. Self-attention
|
| 366 |
+
mv_norm_hidden_states = (self.norm1(mv_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(mv_hidden_states)
|
| 367 |
+
uv_norm_hidden_states = (self.norm1(uv_hidden_states.float()) * (1 + scale_msa_uv) + shift_msa_uv).type_as(uv_hidden_states)
|
| 368 |
+
|
| 369 |
+
mv_attn_output = self.attn1(hidden_states=mv_norm_hidden_states, rotary_emb=rotary_emb[:, :, :n_mv], attention_mask=attn_bias, geometry_embedding=geometry_embedding[:, :n_mv])
|
| 370 |
+
mv_hidden_states = (mv_hidden_states.float() + mv_attn_output * gate_msa).type_as(mv_hidden_states)
|
| 371 |
+
uv_attn_output = self.attnuv(hidden_states=uv_norm_hidden_states, encoder_hidden_states=torch.cat([mv_hidden_states, uv_norm_hidden_states], dim=1),
|
| 372 |
+
rotary_emb=(rotary_emb[:, :, n_mv:], rotary_emb), geometry_embedding=(geometry_embedding[:, n_mv:], geometry_embedding))
|
| 373 |
+
uv_hidden_states = (uv_hidden_states.float() + uv_attn_output * gate_msa_uv).type_as(uv_hidden_states)
|
| 374 |
+
|
| 375 |
+
# 2. Cross-attention
|
| 376 |
+
mv_norm_hidden_states = self.norm2(mv_hidden_states.float()).type_as(mv_hidden_states)
|
| 377 |
+
uv_norm_hidden_states = self.normuv2(uv_hidden_states.float()).type_as(uv_hidden_states)
|
| 378 |
+
attn_output = self.attn2(hidden_states=torch.cat([mv_norm_hidden_states, uv_norm_hidden_states], dim=1), encoder_hidden_states=encoder_hidden_states)
|
| 379 |
+
mv_attn_output, uv_attn_output = attn_output[:, :n_mv], attn_output[:, n_mv:]
|
| 380 |
+
mv_hidden_states.add_(mv_attn_output)
|
| 381 |
+
uv_hidden_states.add_(uv_attn_output)
|
| 382 |
+
|
| 383 |
+
# 3. Feed-forward
|
| 384 |
+
mv_norm_hidden_states = (self.norm3(mv_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
| 385 |
+
mv_hidden_states
|
| 386 |
+
)
|
| 387 |
+
uv_norm_hidden_states = (self.norm3(uv_hidden_states.float()) * (1 + c_scale_msa_uv) + c_shift_msa_uv).type_as(
|
| 388 |
+
uv_hidden_states
|
| 389 |
+
)
|
| 390 |
+
ff_output = self.ffn(mv_norm_hidden_states)
|
| 391 |
+
mv_hidden_states = (mv_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(mv_hidden_states)
|
| 392 |
+
ff_output_uv = self.ffnuv(uv_norm_hidden_states)
|
| 393 |
+
uv_hidden_states = (uv_hidden_states.float() + ff_output_uv.float() * c_gate_msa_uv).type_as(uv_hidden_states)
|
| 394 |
+
hidden_states = torch.cat([mv_hidden_states, uv_hidden_states], dim=1)
|
| 395 |
+
|
| 396 |
+
return hidden_states
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class WanT2TexTransformer3DModel(WanTransformer3DModel):
|
| 400 |
+
"""
|
| 401 |
+
3D Transformer model for T2Tex.
|
| 402 |
+
"""
|
| 403 |
+
def __init__(self, original_model, use_causal_mask=False, addtional_qk_geo=False, randomly_init=False, **kwargs):
|
| 404 |
+
super(WanT2TexTransformer3DModel, self).__init__(**original_model.config)
|
| 405 |
+
if not randomly_init:
|
| 406 |
+
self.load_state_dict(original_model.state_dict(), strict=True)
|
| 407 |
+
self.addtional_qk_geo = addtional_qk_geo
|
| 408 |
+
if addtional_qk_geo:
|
| 409 |
+
raise ValueError("addtional_qk_geo did not work")
|
| 410 |
+
warn("addtional_qk_geo is set to True, this will drastically increase the memory usage and slow down the training, without significant performance gain.")
|
| 411 |
+
|
| 412 |
+
# 1. Patch & position embedding
|
| 413 |
+
self.rope = WanRotaryPosEmbed(self.rope.attention_head_dim, self.rope.patch_size, self.rope.max_seq_len, addtional_qk_geo=addtional_qk_geo)
|
| 414 |
+
self.use_normal, self.use_position = kwargs.get("use_normal", True), kwargs.get("use_position", True)
|
| 415 |
+
if self.use_normal:
|
| 416 |
+
self.norm_patch_embedding = copy.deepcopy(self.patch_embedding)
|
| 417 |
+
# torch.nn.init.zeros_(self.norm_patch_embedding.weight.data)
|
| 418 |
+
# torch.nn.init.zeros_(self.norm_patch_embedding.bias.data)
|
| 419 |
+
if self.use_position:
|
| 420 |
+
self.pos_patch_embedding = copy.deepcopy(self.patch_embedding)
|
| 421 |
+
# torch.nn.init.zeros_(self.pos_patch_embedding.weight.data)
|
| 422 |
+
# torch.nn.init.zeros_(self.pos_patch_embedding.bias.data)
|
| 423 |
+
|
| 424 |
+
# 2. Condition embeddings
|
| 425 |
+
inner_dim = original_model.config.num_attention_heads * original_model.config.attention_head_dim
|
| 426 |
+
self.condition_embedder = WanTimeTaskTextImageEmbedding(
|
| 427 |
+
original_model=self.condition_embedder,
|
| 428 |
+
dim=inner_dim,
|
| 429 |
+
time_freq_dim=original_model.config.freq_dim,
|
| 430 |
+
time_proj_dim=inner_dim * 6,
|
| 431 |
+
text_embed_dim=original_model.config.text_dim,
|
| 432 |
+
image_embed_dim=original_model.config.image_dim,
|
| 433 |
+
randomly_init=randomly_init,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# 3. Transformer blocks
|
| 437 |
+
self.use_causal_mask = use_causal_mask
|
| 438 |
+
self.num_attention_heads = original_model.config.num_attention_heads
|
| 439 |
+
|
| 440 |
+
block = WanT2TexTransformerBlock(
|
| 441 |
+
inner_dim,
|
| 442 |
+
original_model.config.ffn_dim,
|
| 443 |
+
original_model.config.num_attention_heads,
|
| 444 |
+
original_model.config.qk_norm,
|
| 445 |
+
original_model.config.cross_attn_norm,
|
| 446 |
+
original_model.config.eps,
|
| 447 |
+
original_model.config.added_kv_proj_dim,
|
| 448 |
+
)
|
| 449 |
+
self.blocks = None
|
| 450 |
+
self.blocks = nn.ModuleList(
|
| 451 |
+
[
|
| 452 |
+
copy.deepcopy(block)
|
| 453 |
+
for _ in range(original_model.config.num_layers)
|
| 454 |
+
]
|
| 455 |
+
)
|
| 456 |
+
self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
| 457 |
+
if not randomly_init:
|
| 458 |
+
self.scale_shift_table_uv.data.copy_(self.scale_shift_table.data)
|
| 459 |
+
self.blocks.load_state_dict(original_model.blocks.state_dict(), strict=False)
|
| 460 |
+
for block in self.blocks:
|
| 461 |
+
block.attnuv.load_state_dict(block.attn1.state_dict())
|
| 462 |
+
block.scale_shift_table_uv.data.copy_(block.scale_shift_table.data)
|
| 463 |
+
block.normuv2.load_state_dict(block.norm2.state_dict())
|
| 464 |
+
block.ffnuv.load_state_dict(block.ffn.state_dict())
|
| 465 |
+
|
| 466 |
+
# 4. Output norm & projection
|
| 467 |
+
pass
|
| 468 |
+
|
| 469 |
+
@cache
|
| 470 |
+
def get_attention_bias(self, mv_length, uv_length):
|
| 471 |
+
total_len = mv_length + uv_length
|
| 472 |
+
attention_mask = torch.ones((total_len, total_len), dtype=torch.bool)
|
| 473 |
+
uv_start = mv_length
|
| 474 |
+
attention_mask[:uv_start, uv_start:] = False
|
| 475 |
+
|
| 476 |
+
attention_mask = repeat(attention_mask, "s l -> 1 h s l", h=self.num_attention_heads)
|
| 477 |
+
attention_bias = torch.ones_like(attention_mask)
|
| 478 |
+
attention_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
| 479 |
+
attention_bias = attention_bias.to("cuda").contiguous()
|
| 480 |
+
return attention_bias
|
| 481 |
+
|
| 482 |
+
def forward(
|
| 483 |
+
self,
|
| 484 |
+
hidden_states: Tuple[torch.Tensor, torch.Tensor],
|
| 485 |
+
timestep: torch.LongTensor,
|
| 486 |
+
encoder_hidden_states: torch.Tensor,
|
| 487 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 488 |
+
# task_cond: Optional[torch.Tensor] = None,
|
| 489 |
+
return_dict: bool = True,
|
| 490 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 491 |
+
use_qk_geometry: Optional[bool] = False,
|
| 492 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 493 |
+
if attention_kwargs is not None:
|
| 494 |
+
attention_kwargs = attention_kwargs.copy()
|
| 495 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 496 |
+
else:
|
| 497 |
+
lora_scale = 1.0
|
| 498 |
+
|
| 499 |
+
if USE_PEFT_BACKEND:
|
| 500 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 501 |
+
scale_lora_layers(self, lora_scale)
|
| 502 |
+
else:
|
| 503 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 504 |
+
raise NotImplementedError()
|
| 505 |
+
|
| 506 |
+
assert timestep.ndim == 2, "Use Diffusion Forcing to set seperate timestep for each frame."
|
| 507 |
+
|
| 508 |
+
mv_hidden_states, uv_hidden_states = hidden_states
|
| 509 |
+
|
| 510 |
+
batch_size, num_channels, num_frames, height, width = mv_hidden_states.shape
|
| 511 |
+
_, _, uv_num_frames, uv_height, uv_width = uv_hidden_states.shape
|
| 512 |
+
|
| 513 |
+
p_t, p_h, p_w = self.config.patch_size
|
| 514 |
+
post_patch_num_frames = num_frames // p_t
|
| 515 |
+
post_patch_height = height // p_h
|
| 516 |
+
post_patch_width = width // p_w
|
| 517 |
+
post_uv_num_frames = uv_num_frames // p_t
|
| 518 |
+
post_uv_height = uv_height // p_h
|
| 519 |
+
post_uv_width = uv_width // p_w
|
| 520 |
+
|
| 521 |
+
rotary_emb = self.rope(mv_hidden_states, uv_hidden_states)
|
| 522 |
+
|
| 523 |
+
# Patchify
|
| 524 |
+
if self.use_normal and self.use_position:
|
| 525 |
+
mv_rgb_hidden_states, mv_pos_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 3, dim=1)
|
| 526 |
+
uv_rgb_hidden_states, uv_pos_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 3, dim=1)
|
| 527 |
+
mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states) + self.norm_patch_embedding(mv_norm_hidden_states)
|
| 528 |
+
uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states) + self.norm_patch_embedding(uv_norm_hidden_states)
|
| 529 |
+
elif self.use_normal:
|
| 530 |
+
mv_rgb_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
|
| 531 |
+
uv_rgb_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
|
| 532 |
+
mv_geometry_embedding = self.norm_patch_embedding(mv_norm_hidden_states)
|
| 533 |
+
uv_geometry_embedding = self.norm_patch_embedding(uv_norm_hidden_states)
|
| 534 |
+
elif self.use_position:
|
| 535 |
+
mv_rgb_hidden_states, mv_pos_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
|
| 536 |
+
uv_rgb_hidden_states, uv_pos_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
|
| 537 |
+
mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states)
|
| 538 |
+
uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states)
|
| 539 |
+
else:
|
| 540 |
+
raise ValueError("use_normal and use_position are both False, please set at least one of them to True.")
|
| 541 |
+
|
| 542 |
+
mv_hidden_states = self.patch_embedding(mv_rgb_hidden_states)
|
| 543 |
+
uv_hidden_states = self.patch_embedding(uv_rgb_hidden_states)
|
| 544 |
+
if use_qk_geometry:
|
| 545 |
+
mv_geometry_embedding = mv_geometry_embedding.flatten(2).transpose(1, 2)
|
| 546 |
+
uv_geometry_embedding = uv_geometry_embedding.flatten(2).transpose(1, 2) # [B, F*H*W, C]
|
| 547 |
+
geometry_embedding = torch.cat([mv_geometry_embedding, uv_geometry_embedding], dim=1)
|
| 548 |
+
else:
|
| 549 |
+
raise NotImplementedError("please set use_qk_geometry to True")
|
| 550 |
+
# geometry_embedding = None
|
| 551 |
+
# mv_hidden_states = mv_hidden_states + mv_geometry_embedding
|
| 552 |
+
# uv_hidden_states = uv_hidden_states + uv_geometry_embedding
|
| 553 |
+
|
| 554 |
+
mv_hidden_states = mv_hidden_states.flatten(2).transpose(1, 2)
|
| 555 |
+
uv_hidden_states = uv_hidden_states.flatten(2).transpose(1, 2) # [B, F*H*W, C]
|
| 556 |
+
hidden_states = torch.cat([mv_hidden_states, uv_hidden_states], dim=1) # [B, F*H*W, C]
|
| 557 |
+
|
| 558 |
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
| 559 |
+
timestep, encoder_hidden_states, encoder_hidden_states_image
|
| 560 |
+
)
|
| 561 |
+
# temb [B, F, 6*D], timestep_proj [B, F, 6*D], used to be [B, 6*D]
|
| 562 |
+
timestep_proj = timestep_proj.unflatten(-1, (6, -1)) # [B, F, 6*D] -> [B, F, 6, D]
|
| 563 |
+
|
| 564 |
+
if encoder_hidden_states_image is not None:
|
| 565 |
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
| 566 |
+
|
| 567 |
+
# # Get attention bias
|
| 568 |
+
# if self.use_causal_mask:
|
| 569 |
+
# # This may be gainless, because the patch embedding is not causal, which will leak information to MV
|
| 570 |
+
# attn_bias = self.get_attention_bias(post_patch_num_frames * post_patch_height * post_patch_width,
|
| 571 |
+
# post_uv_num_frames * post_uv_height * post_uv_width)
|
| 572 |
+
# else:
|
| 573 |
+
attn_bias = None
|
| 574 |
+
|
| 575 |
+
# 4. Transformer blocks
|
| 576 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 577 |
+
for block in self.blocks:
|
| 578 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 579 |
+
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb,
|
| 580 |
+
attn_bias, geometry_embedding, (post_patch_num_frames, post_patch_height, post_patch_width, post_uv_num_frames, post_uv_height, post_uv_width)
|
| 581 |
+
)
|
| 582 |
+
else:
|
| 583 |
+
for block in self.blocks:
|
| 584 |
+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb,
|
| 585 |
+
attn_bias=attn_bias, geometry_embedding=geometry_embedding,
|
| 586 |
+
token_shape=(post_patch_num_frames, post_patch_height, post_patch_width, post_uv_num_frames, post_uv_height, post_uv_width))
|
| 587 |
+
|
| 588 |
+
# 5. Output norm, projection & unpatchify
|
| 589 |
+
# [B, 2, D] chunk into [B, 1, D] and [B, 1, D], D is 1536
|
| 590 |
+
inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
| 591 |
+
mv_temb, uv_temb = temb[:, :post_patch_num_frames], temb[:, post_patch_num_frames:]
|
| 592 |
+
mv_temb = repeat(mv_temb, "B F D -> B 1 (F H W) D", H=post_patch_height, W=post_patch_width)
|
| 593 |
+
uv_temb = repeat(uv_temb, "B F D -> B 1 (F H W) D", H=post_uv_height, W=post_uv_width)
|
| 594 |
+
shift, scale = (self.scale_shift_table.view(1, 2, 1, inner_dim) + mv_temb).chunk(2, dim=1)
|
| 595 |
+
shift_uv, scale_uv = (self.scale_shift_table_uv.view(1, 2, 1, inner_dim) + uv_temb).chunk(2, dim=1)
|
| 596 |
+
|
| 597 |
+
# Move the shift and scale tensors to the same device as hidden_states.
|
| 598 |
+
# When using multi-GPU inference via accelerate these will be on the
|
| 599 |
+
# first device rather than the last device, which hidden_states ends up
|
| 600 |
+
# on.
|
| 601 |
+
shift = shift.squeeze(1).to(hidden_states.device)
|
| 602 |
+
scale = scale.squeeze(1).to(hidden_states.device)
|
| 603 |
+
shift_uv = shift_uv.squeeze(1).to(hidden_states.device)
|
| 604 |
+
scale_uv = scale_uv.squeeze(1).to(hidden_states.device)
|
| 605 |
+
|
| 606 |
+
# Unpatchify
|
| 607 |
+
uv_token_length = post_uv_num_frames * post_uv_height * post_uv_width
|
| 608 |
+
mv_token_length = post_patch_num_frames * post_patch_height * post_patch_width
|
| 609 |
+
assert uv_token_length + mv_token_length == hidden_states.shape[1]
|
| 610 |
+
uv_hidden_states = hidden_states[:, mv_token_length:]
|
| 611 |
+
mv_hidden_states = hidden_states[:, :mv_token_length]
|
| 612 |
+
|
| 613 |
+
mv_hidden_states = (self.norm_out(mv_hidden_states.float()) * (1 + scale) + shift).type_as(mv_hidden_states)
|
| 614 |
+
uv_hidden_states = (self.norm_out(uv_hidden_states.float()) * (1 + scale_uv) + shift_uv).type_as(uv_hidden_states)
|
| 615 |
+
mv_hidden_states = self.proj_out(mv_hidden_states)
|
| 616 |
+
uv_hidden_states = self.proj_out(uv_hidden_states)
|
| 617 |
+
|
| 618 |
+
mv_hidden_states = mv_hidden_states.reshape(
|
| 619 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
| 620 |
+
)
|
| 621 |
+
mv_hidden_states = mv_hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
| 622 |
+
mv_output = mv_hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 623 |
+
uv_hidden_states = uv_hidden_states.reshape(
|
| 624 |
+
batch_size, post_uv_num_frames, post_uv_height, post_uv_width, p_t, p_h, p_w, -1
|
| 625 |
+
)
|
| 626 |
+
uv_hidden_states = uv_hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
| 627 |
+
uv_output = uv_hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 628 |
+
|
| 629 |
+
if USE_PEFT_BACKEND:
|
| 630 |
+
# remove `lora_scale` from each PEFT layer
|
| 631 |
+
unscale_lora_layers(self, lora_scale)
|
| 632 |
+
|
| 633 |
+
return ((mv_output, uv_output),)
|
| 634 |
+
|