wcy1122 commited on
Commit
de27e62
·
1 Parent(s): 3afda02

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ # Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ # poetry.lock
109
+ # poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ # pdm.lock
116
+ # pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ # pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # Redis
135
+ *.rdb
136
+ *.aof
137
+ *.pid
138
+
139
+ # RabbitMQ
140
+ mnesia/
141
+ rabbitmq/
142
+ rabbitmq-data/
143
+
144
+ # ActiveMQ
145
+ activemq-data/
146
+
147
+ # SageMath parsed files
148
+ *.sage.py
149
+
150
+ # Environments
151
+ .env
152
+ .envrc
153
+ .venv
154
+ env/
155
+ venv/
156
+ ENV/
157
+ env.bak/
158
+ venv.bak/
159
+
160
+ # Spyder project settings
161
+ .spyderproject
162
+ .spyproject
163
+
164
+ # Rope project settings
165
+ .ropeproject
166
+
167
+ # mkdocs documentation
168
+ /site
169
+
170
+ # mypy
171
+ .mypy_cache/
172
+ .dmypy.json
173
+ dmypy.json
174
+
175
+ # Pyre type checker
176
+ .pyre/
177
+
178
+ # pytype static type analyzer
179
+ .pytype/
180
+
181
+ # Cython debug symbols
182
+ cython_debug/
183
+
184
+ # PyCharm
185
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
186
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
187
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
188
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
189
+ # .idea/
190
+
191
+ # Abstra
192
+ # Abstra is an AI-powered process automation framework.
193
+ # Ignore directories containing user credentials, local state, and settings.
194
+ # Learn more at https://abstra.io/docs
195
+ .abstra/
196
+
197
+ # Visual Studio Code
198
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
199
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
200
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
201
+ # you could uncomment the following to ignore the entire vscode folder
202
+ # .vscode/
203
+
204
+ # Ruff stuff:
205
+ .ruff_cache/
206
+
207
+ # PyPI configuration file
208
+ .pypirc
209
+
210
+ # Marimo
211
+ marimo/_static/
212
+ marimo/_lsp/
213
+ __marimo__/
214
+
215
+ # Streamlit
216
+ .streamlit/secrets.toml
217
+
218
+ tmp_script.sh
219
+ .gradio/
220
+ gradio_tmp
README.md CHANGED
@@ -1,14 +1,18 @@
1
- ---
2
- title: DreamOmni2
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Multimodal Instruction-based Editing and Generation
12
- ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DreamOmni2
2
+ This project is the official implementation of 'DreamOmni2: Multimodal Instruction-based Editing and Generation''
 
 
 
 
 
 
 
 
 
 
3
 
4
+ ## Web Demo
5
+ ```
6
+ CUDA_VISIBLE_DEVICES=0 python web_edit.py \
7
+ --vlm_path PATH_TO_VLM \
8
+ --edit_lora_path PATH_TO_DEIT_LORA \
9
+ --server_name "0.0.0.0" \
10
+ --server_port 7860
11
+
12
+
13
+ CUDA_VISIBLE_DEVICES=1 python web_generate.py \
14
+ --vlm_path PATH_TO_VLM \
15
+ --gen_lora_path PATH_TO_GENERATION_LORA \
16
+ --server_name "0.0.0.0" \
17
+ --server_port 7861
18
+ ```
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pipeline_flux_kontext import FluxKontextPipeline
3
+ from diffusers.utils import load_image
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+ import os
7
+ import re
8
+ from PIL import Image
9
+ import gradio as gr
10
+ import uuid
11
+ import argparse
12
+
13
+
14
+ def _load_model_processor():
15
+
16
+ local_vlm_dir = snapshot_download(
17
+ repo_id="xiabs/DreamOmni2",
18
+ revision="main",
19
+ allow_patterns=["vlm-model/**"],
20
+ )
21
+ local_lora_dir = snapshot_download(
22
+ repo_id="xiabs/DreamOmni2",
23
+ revision="main",
24
+ allow_patterns=["edit_lora/**"],
25
+ )
26
+
27
+ print(f"Loading models from vlm_path: {local_vlm_dir}, edit_lora_path: {local_lora_dir}")
28
+ pipe = FluxKontextPipeline.from_pretrained(
29
+ "black-forest-labs/FLUX.1-Kontext-dev",
30
+ torch_dtype=torch.bfloat16
31
+ )
32
+ pipe.load_lora_weights(local_lora_dir, adapter_name="edit")
33
+ pipe.set_adapters(["edit"], adapter_weights=[1])
34
+
35
+ vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
36
+ local_vlm_dir,
37
+ torch_dtype="bfloat16",
38
+ device_map="cuda"
39
+ )
40
+ processor = AutoProcessor.from_pretrained(local_vlm_dir)
41
+ return vlm_model, processor, pipe
42
+
43
+
44
+ def _launch_demo(vlm_model, processor, pipe):
45
+
46
+ @spaces.GPU()
47
+ def infer_vlm(input_img_path, input_instruction, prefix):
48
+ if not vlm_model or not processor:
49
+ raise gr.Error("VLM Model not loaded. Cannot process prompt.")
50
+ tp = []
51
+ for path in input_img_path:
52
+ tp.append({"type": "image", "image": path})
53
+ tp.append({"type": "text", "text": input_instruction + prefix})
54
+ messages = [{"role": "user", "content": tp}]
55
+
56
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
57
+ image_inputs, video_inputs = process_vision_info(messages)
58
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
59
+ inputs = inputs.to("cuda")
60
+
61
+ generated_ids = vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096)
62
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
63
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
64
+ return output_text[0]
65
+
66
+ PREFERRED_KONTEXT_RESOLUTIONS = [
67
+ (672, 1568),
68
+ (688, 1504),
69
+ (720, 1456),
70
+ (752, 1392),
71
+ (800, 1328),
72
+ (832, 1248),
73
+ (880, 1184),
74
+ (944, 1104),
75
+ (1024, 1024),
76
+ (1104, 944),
77
+ (1184, 880),
78
+ (1248, 832),
79
+ (1328, 800),
80
+ (1392, 752),
81
+ (1456, 720),
82
+ (1504, 688),
83
+ (1568, 672),
84
+ ]
85
+ def find_closest_resolution(width, height, preferred_resolutions):
86
+ input_ratio = width / height
87
+ closest_resolution = min(
88
+ preferred_resolutions,
89
+ key=lambda res: abs((res[0] / res[1]) - input_ratio)
90
+ )
91
+ return closest_resolution
92
+
93
+ def extract_gen_content(text):
94
+ text = text[6:-7]
95
+ return text
96
+
97
+ @spaces.GPU()
98
+ def perform_edit(input_img_paths, input_instruction, output_path):
99
+ prefix = " It is editing task."
100
+ source_imgs = [load_image(path) for path in input_img_paths]
101
+ resized_imgs = []
102
+ for img in source_imgs:
103
+ target_resolution = find_closest_resolution(img.width, img.height, PREFERRED_KONTEXT_RESOLUTIONS)
104
+ resized_img = img.resize(target_resolution, Image.LANCZOS)
105
+ resized_imgs.append(resized_img)
106
+ prompt = infer_vlm(input_img_paths, input_instruction, prefix)
107
+ prompt = extract_gen_content(prompt)
108
+ print(f"Generated Prompt for VLM: {prompt}")
109
+
110
+ image = pipe(
111
+ images=resized_imgs,
112
+ height=resized_imgs[0].height,
113
+ width=resized_imgs[0].width,
114
+ prompt=prompt,
115
+ num_inference_steps=30,
116
+ guidance_scale=3.5,
117
+ ).images[0]
118
+ image.save(output_path)
119
+ print(f"Edit result saved to {output_path}")
120
+
121
+
122
+ def process_request(image_file_1, image_file_2, instruction):
123
+ # debugpy.listen(5678)
124
+ # print("Waiting for debugger attach...")
125
+ # debugpy.wait_for_client()
126
+ if not image_file_1 or not image_file_2:
127
+ raise gr.Error("Please upload both images.")
128
+ if not instruction:
129
+ raise gr.Error("Please provide an instruction.")
130
+ if not pipe or not vlm_model:
131
+ raise gr.Error("Models not loaded. Check the console for errors.")
132
+
133
+ output_path = f"/tmp/{uuid.uuid4()}.png"
134
+ input_img_paths = [image_file_1, image_file_2] # List of file paths from the two gr.File inputs
135
+
136
+ perform_edit(input_img_paths, instruction, output_path)
137
+ return output_path
138
+
139
+ css = """
140
+ .text-center { text-align: center; }
141
+ .result-img img {
142
+ max-height: 60vh !important;
143
+ min-height: 30vh !important;
144
+ width: auto !important;
145
+ object-fit: contain;
146
+ }
147
+ .input-img img {
148
+ max-height: 30vh !important;
149
+ width: auto !important;
150
+ object-fit: contain;
151
+ }
152
+ """
153
+
154
+
155
+ with gr.Blocks(theme=gr.themes.Soft(), title="DreamOmni2", css=css) as demo:
156
+ gr.HTML(
157
+ """
158
+ <h1 style="text-align:center; font-size:48px; font-weight:bold; margin-bottom:20px;">
159
+ DreamOmni2: Omni-purpose Image Generation and Editing
160
+ </h1>
161
+ """
162
+ )
163
+ gr.Markdown(
164
+ "Select a mode, upload two images, provide an instruction, and click 'Run'.",
165
+ elem_classes="text-center"
166
+ )
167
+ with gr.Row():
168
+ with gr.Column(scale=2):
169
+ gr.Markdown("⬆️ Upload images. Click or drag to upload.")
170
+
171
+ with gr.Row():
172
+ image_uploader_1 = gr.Image(
173
+ label="Img 1",
174
+ type="filepath",
175
+ interactive=True,
176
+ elem_classes="input-img",
177
+ )
178
+ image_uploader_2 = gr.Image(
179
+ label="Img 2",
180
+ type="filepath",
181
+ interactive=True,
182
+ elem_classes="input-img",
183
+ )
184
+
185
+ instruction_text = gr.Textbox(
186
+ label="Instruction",
187
+ lines=2,
188
+ placeholder="Input your instruction for generation or editing here...",
189
+ )
190
+ run_button = gr.Button("Run", variant="primary")
191
+
192
+ with gr.Column(scale=2):
193
+ gr.Markdown(
194
+ "✏️ **Editing Mode**: Modify an existing image using instructions and references.\n\n"
195
+ "Tip: If the result is not what you expect, try clicking **Run** again. "
196
+ )
197
+ output_image = gr.Image(
198
+ label="Result",
199
+ type="filepath",
200
+ elem_classes="result-img",
201
+ )
202
+
203
+ # --- Examples (不变) ---
204
+ gr.Markdown("## Examples")
205
+
206
+ gr.Examples(
207
+ label="Editing Examples",
208
+ examples=[
209
+ ["edit_tests/4/ref_0.jpg", "edit_tests/4/ref_1.jpg", "Replace the first image have the same image style as the second image.","edit_tests/4/res.jpg"],
210
+ ["edit_tests/5/ref_0.jpg", "edit_tests/5/ref_1.jpg", "Make the person in the first image have the same hairstyle as the person in the second image.","edit_tests/5/res.jpg"],
211
+ ["edit_tests/src.jpg", "edit_tests/ref.jpg", "Make the woman from the second image stand on the road in the first image.","edit_tests/edi_res.png"],
212
+ ["edit_tests/1/ref_0.jpg", "edit_tests/1/ref_1.jpg", "Replace the lantern in the first image with the dog in the second image.","edit_tests/1/res.jpg"],
213
+ ["edit_tests/2/ref_0.jpg", "edit_tests/2/ref_1.jpg", "Replace the suit in the first image with the clothes in the second image.","edit_tests/2/res.jpg"],
214
+ ["edit_tests/3/ref_0.jpg", "edit_tests/3/ref_1.jpg", "Make the first image has the same light condition as the second image.","edit_tests/3/res.jpg"],
215
+ ["edit_tests/6/ref_0.jpg", "edit_tests/6/ref_1.jpg", "Make the words in the first image have the same font as the words in the second image.","edit_tests/6/res.jpg"],
216
+ ["edit_tests/7/ref_0.jpg", "edit_tests/7/ref_1.jpg", "Make the car in the first image have the same pattern as the mouse in the second image.","edit_tests/7/res.jpg"],
217
+ ["edit_tests/8/ref_0.jpg", "edit_tests/8/ref_1.jpg", "Make the dress in the first image have the same pattern in the second image.","edit_tests/8/res.jpg"],
218
+ ],
219
+ inputs=[image_uploader_1, image_uploader_2, instruction_text, output_image],
220
+ cache_examples=False,
221
+ )
222
+
223
+ run_button.click(
224
+ fn=process_request,
225
+ inputs=[image_uploader_1, image_uploader_2, instruction_text],
226
+ outputs=output_image
227
+ )
228
+
229
+
230
+ if __name__ == "__main__":
231
+ vlm_model, processor, pipe = _load_model_processor()
232
+ print("Launching Gradio Demo...")
233
+ _launch_demo(vlm_model, processor, pipe)
edit_tests/1/ref_0.jpg ADDED

Git LFS Details

  • SHA256: 22698b9eee36955254029d0d84a8c3c3e13f8cb12d19cd361087a294f34fc18e
  • Pointer size: 130 Bytes
  • Size of remote file: 99.9 kB
edit_tests/1/ref_1.jpg ADDED

Git LFS Details

  • SHA256: fde235d7896c5175d335831b3c4124161b8d909ebbe1bf708922005db04c345a
  • Pointer size: 130 Bytes
  • Size of remote file: 27 kB
edit_tests/1/res.jpg ADDED

Git LFS Details

  • SHA256: 8bf7f1fb69bd052478d63643b0a71cc38ca0fe032508568c20dfcc0beca20a39
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
edit_tests/2/ref_0.jpg ADDED

Git LFS Details

  • SHA256: d18bd9e3d7a15d9ed65660b491a6960ce1794fbcd20ebb9ec1c8a2277820f8f9
  • Pointer size: 130 Bytes
  • Size of remote file: 54.1 kB
edit_tests/2/ref_1.jpg ADDED

Git LFS Details

  • SHA256: 1c60313358e445081aab723047ca9e36f7e60c1a7e40e58db9ae6c3b46b87120
  • Pointer size: 130 Bytes
  • Size of remote file: 28.3 kB
edit_tests/2/res.jpg ADDED

Git LFS Details

  • SHA256: 0d40aceef14fca011243e808305f2c23693a6250cca1d6d9060d26b83cce2f2a
  • Pointer size: 130 Bytes
  • Size of remote file: 61.9 kB
edit_tests/3/ref_0.jpg ADDED

Git LFS Details

  • SHA256: 77d1b1d0f8d9177bcc44151384b41ee6fb3e4a6408049181d56bf261fa5892e9
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
edit_tests/3/ref_1.jpg ADDED

Git LFS Details

  • SHA256: 3de52fbd95ff89eb97781eb88f85a897a39d516d41218bb8fc2a86cf2f005e1b
  • Pointer size: 130 Bytes
  • Size of remote file: 60.3 kB
edit_tests/3/res.jpg ADDED

Git LFS Details

  • SHA256: 61c6cad1b69a0d7665390bfe4b48dc6639d214bf49084e4f2df6795f8f091f35
  • Pointer size: 130 Bytes
  • Size of remote file: 47.8 kB
edit_tests/4/ref_0.jpg ADDED

Git LFS Details

  • SHA256: 17df467e2a56748929f7bf5cdbd3b3f41c3fe3e504e07eef183ac3c7af8f64d7
  • Pointer size: 130 Bytes
  • Size of remote file: 87.6 kB
edit_tests/4/ref_1.jpg ADDED

Git LFS Details

  • SHA256: 210c93974b8a9216320abe8ff34c03c74004cb8fafc8c691bb80291660559215
  • Pointer size: 131 Bytes
  • Size of remote file: 174 kB
edit_tests/4/res.jpg ADDED

Git LFS Details

  • SHA256: f0db24e49a7bb341eec00388110a4af08aefbb5ab560dfbcbd30e08c210bfd55
  • Pointer size: 130 Bytes
  • Size of remote file: 64.8 kB
edit_tests/5/ref_0.jpg ADDED

Git LFS Details

  • SHA256: 3421c8b427b5a8376bf16672ebcfe64f942a5a91fe6f5e579eb8e61bcfa10367
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
edit_tests/5/ref_1.jpg ADDED

Git LFS Details

  • SHA256: 5bb7ab3bcdc17452fef64c311e0b4a56dcfe89ff37a7d64cfc63e766ce6baaf5
  • Pointer size: 130 Bytes
  • Size of remote file: 84 kB
edit_tests/5/res.jpg ADDED

Git LFS Details

  • SHA256: 2fcd5e9bda10daab54e6802774b810ba1956c8f0dc8dd4f0ee3b4d42c228005e
  • Pointer size: 130 Bytes
  • Size of remote file: 56.7 kB
edit_tests/6/ref_0.jpg ADDED

Git LFS Details

  • SHA256: b7dd6574ec93c156a31b152f557921002a0d4a33add61f905cb86d03f06b7f5e
  • Pointer size: 130 Bytes
  • Size of remote file: 57.1 kB
edit_tests/6/ref_1.jpg ADDED

Git LFS Details

  • SHA256: fd1d25c842ea30975b197e864338862efb6b81439da6929995fb2142f6adc49c
  • Pointer size: 130 Bytes
  • Size of remote file: 68.5 kB
edit_tests/6/res.jpg ADDED

Git LFS Details

  • SHA256: 04b49a1f3eab77008c4f31064599a8b96e8bcd9b8acd2e4687058e09a0962413
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
edit_tests/7/ref_0.jpg ADDED

Git LFS Details

  • SHA256: 937bab58713aa2610839b6fe57c1b642aad23938dd769362739c6178a734bc5e
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
edit_tests/7/ref_1.jpg ADDED

Git LFS Details

  • SHA256: 9a74d1289a992867620747bcb48c59c6c598100b319d8fe910a03519f3e7e370
  • Pointer size: 130 Bytes
  • Size of remote file: 86.6 kB
edit_tests/7/res.jpg ADDED

Git LFS Details

  • SHA256: 2ae24e25ef7606da785e2158abbbddae9d8077a1cbecf6c9dbf787675235ef7a
  • Pointer size: 131 Bytes
  • Size of remote file: 196 kB
edit_tests/8/ref_0.jpg ADDED

Git LFS Details

  • SHA256: 5d6d515d3eb11732c4de3d7989b6315a93b2f754be937b9b7631e2c22043bf61
  • Pointer size: 130 Bytes
  • Size of remote file: 78.3 kB
edit_tests/8/ref_1.jpg ADDED

Git LFS Details

  • SHA256: 810b2e27db7c7965d90bd76d83fc6f088862f9034fdb407972bd49c577845b57
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
edit_tests/8/res.jpg ADDED

Git LFS Details

  • SHA256: 5adc6f7c8f1c5061909d367f5248be99117ae0f81164516c927a4f2a843ee691
  • Pointer size: 130 Bytes
  • Size of remote file: 47.1 kB
edit_tests/edi_res.png ADDED

Git LFS Details

  • SHA256: 3e5352a19b82623523b73e66a083689e3eb0b8fa738445892ea8af0ab733ed11
  • Pointer size: 132 Bytes
  • Size of remote file: 1.66 MB
edit_tests/ref.jpg ADDED

Git LFS Details

  • SHA256: 7e12827dd85e3b2bf39d11ad27121a09655deb60a500de469569335a3a60566b
  • Pointer size: 130 Bytes
  • Size of remote file: 70.2 kB
edit_tests/src.jpg ADDED

Git LFS Details

  • SHA256: 8b7231bdbf219b6313e95effbce355d11e2789c3a9fd9251d8723cd9a9900624
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
gen_tests/gen_res.png ADDED

Git LFS Details

  • SHA256: 649f4c45658120fffdaac58478d0b88f211a365f4aac9fc678aa7ba82d4da371
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
gen_tests/img1.jpg ADDED

Git LFS Details

  • SHA256: 2bfd72e3aa607a5cf05e3b52c30743921ff8862cb7dda26d14cef8c8180b4242
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
gen_tests/img2.jpg ADDED

Git LFS Details

  • SHA256: 82b414fafd19c3ec7763281e359c8863b72e5bf36561920c94a516539a013a14
  • Pointer size: 130 Bytes
  • Size of remote file: 92.3 kB
pipeline_flux_kontext.py ADDED
@@ -0,0 +1,1151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import FluxKontextPipeline
61
+ >>> from diffusers.utils import load_image
62
+
63
+ >>> pipe = FluxKontextPipeline.from_pretrained(
64
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
65
+ ... )
66
+ >>> pipe.to("cuda")
67
+
68
+ >>> image = load_image(
69
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
70
+ ... ).convert("RGB")
71
+ >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
72
+ >>> image = pipe(
73
+ ... image=image,
74
+ ... prompt=prompt,
75
+ ... guidance_scale=2.5,
76
+ ... generator=torch.Generator().manual_seed(42),
77
+ ... ).images[0]
78
+ >>> image.save("output.png")
79
+ ```
80
+ """
81
+
82
+ PREFERRED_KONTEXT_RESOLUTIONS = [
83
+ (672, 1568),
84
+ (688, 1504),
85
+ (720, 1456),
86
+ (752, 1392),
87
+ (800, 1328),
88
+ (832, 1248),
89
+ (880, 1184),
90
+ (944, 1104),
91
+ (1024, 1024),
92
+ (1104, 944),
93
+ (1184, 880),
94
+ (1248, 832),
95
+ (1328, 800),
96
+ (1392, 752),
97
+ (1456, 720),
98
+ (1504, 688),
99
+ (1568, 672),
100
+ ]
101
+
102
+
103
+ def calculate_shift(
104
+ image_seq_len,
105
+ base_seq_len: int = 256,
106
+ max_seq_len: int = 4096,
107
+ base_shift: float = 0.5,
108
+ max_shift: float = 1.15,
109
+ ):
110
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
111
+ b = base_shift - m * base_seq_len
112
+ mu = image_seq_len * m + b
113
+ return mu
114
+
115
+
116
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
117
+ def retrieve_timesteps(
118
+ scheduler,
119
+ num_inference_steps: Optional[int] = None,
120
+ device: Optional[Union[str, torch.device]] = None,
121
+ timesteps: Optional[List[int]] = None,
122
+ sigmas: Optional[List[float]] = None,
123
+ **kwargs,
124
+ ):
125
+ r"""
126
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
127
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
128
+
129
+ Args:
130
+ scheduler (`SchedulerMixin`):
131
+ The scheduler to get timesteps from.
132
+ num_inference_steps (`int`):
133
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
134
+ must be `None`.
135
+ device (`str` or `torch.device`, *optional*):
136
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137
+ timesteps (`List[int]`, *optional*):
138
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
139
+ `num_inference_steps` and `sigmas` must be `None`.
140
+ sigmas (`List[float]`, *optional*):
141
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
142
+ `num_inference_steps` and `timesteps` must be `None`.
143
+
144
+ Returns:
145
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
146
+ second element is the number of inference steps.
147
+ """
148
+ if timesteps is not None and sigmas is not None:
149
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
150
+ if timesteps is not None:
151
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
152
+ if not accepts_timesteps:
153
+ raise ValueError(
154
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
155
+ f" timestep schedules. Please check whether you are using the correct scheduler."
156
+ )
157
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
158
+ timesteps = scheduler.timesteps
159
+ num_inference_steps = len(timesteps)
160
+ elif sigmas is not None:
161
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
162
+ if not accept_sigmas:
163
+ raise ValueError(
164
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
165
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
166
+ )
167
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
168
+ timesteps = scheduler.timesteps
169
+ num_inference_steps = len(timesteps)
170
+ else:
171
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
172
+ timesteps = scheduler.timesteps
173
+ return timesteps, num_inference_steps
174
+
175
+
176
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
177
+ def retrieve_latents(
178
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
179
+ ):
180
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
181
+ return encoder_output.latent_dist.sample(generator)
182
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
183
+ return encoder_output.latent_dist.mode()
184
+ elif hasattr(encoder_output, "latents"):
185
+ return encoder_output.latents
186
+ else:
187
+ raise AttributeError("Could not access latents of provided encoder_output")
188
+
189
+
190
+ class FluxKontextPipeline(
191
+ DiffusionPipeline,
192
+ FluxLoraLoaderMixin,
193
+ FromSingleFileMixin,
194
+ TextualInversionLoaderMixin,
195
+ FluxIPAdapterMixin,
196
+ ):
197
+ r"""
198
+ The Flux Kontext pipeline for image-to-image and text-to-image generation.
199
+
200
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
201
+
202
+ Args:
203
+ transformer ([`FluxTransformer2DModel`]):
204
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
205
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
206
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
207
+ vae ([`AutoencoderKL`]):
208
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
209
+ text_encoder ([`CLIPTextModel`]):
210
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
211
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
212
+ text_encoder_2 ([`T5EncoderModel`]):
213
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
214
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
215
+ tokenizer (`CLIPTokenizer`):
216
+ Tokenizer of class
217
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
218
+ tokenizer_2 (`T5TokenizerFast`):
219
+ Second Tokenizer of class
220
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
221
+ """
222
+
223
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
224
+ _optional_components = ["image_encoder", "feature_extractor"]
225
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
226
+
227
+ def __init__(
228
+ self,
229
+ scheduler: FlowMatchEulerDiscreteScheduler,
230
+ vae: AutoencoderKL,
231
+ text_encoder: CLIPTextModel,
232
+ tokenizer: CLIPTokenizer,
233
+ text_encoder_2: T5EncoderModel,
234
+ tokenizer_2: T5TokenizerFast,
235
+ transformer: FluxTransformer2DModel,
236
+ image_encoder: CLIPVisionModelWithProjection = None,
237
+ feature_extractor: CLIPImageProcessor = None,
238
+ ):
239
+ super().__init__()
240
+
241
+ self.register_modules(
242
+ vae=vae,
243
+ text_encoder=text_encoder,
244
+ text_encoder_2=text_encoder_2,
245
+ tokenizer=tokenizer,
246
+ tokenizer_2=tokenizer_2,
247
+ transformer=transformer,
248
+ scheduler=scheduler,
249
+ image_encoder=image_encoder,
250
+ feature_extractor=feature_extractor,
251
+ )
252
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
253
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
254
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
255
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
256
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
257
+ self.tokenizer_max_length = (
258
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
259
+ )
260
+ self.default_sample_size = 128
261
+
262
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
263
+ def _get_t5_prompt_embeds(
264
+ self,
265
+ prompt: Union[str, List[str]] = None,
266
+ num_images_per_prompt: int = 1,
267
+ max_sequence_length: int = 512,
268
+ device: Optional[torch.device] = None,
269
+ dtype: Optional[torch.dtype] = None,
270
+ ):
271
+ device = device or self._execution_device
272
+ dtype = dtype or self.text_encoder.dtype
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ batch_size = len(prompt)
276
+
277
+ if isinstance(self, TextualInversionLoaderMixin):
278
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
279
+
280
+ text_inputs = self.tokenizer_2(
281
+ prompt,
282
+ padding="max_length",
283
+ max_length=max_sequence_length,
284
+ truncation=True,
285
+ return_length=False,
286
+ return_overflowing_tokens=False,
287
+ return_tensors="pt",
288
+ )
289
+ text_input_ids = text_inputs.input_ids
290
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
291
+
292
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
293
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
294
+ logger.warning(
295
+ "The following part of your input was truncated because `max_sequence_length` is set to "
296
+ f" {max_sequence_length} tokens: {removed_text}"
297
+ )
298
+
299
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
300
+
301
+ dtype = self.text_encoder_2.dtype
302
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
303
+
304
+ _, seq_len, _ = prompt_embeds.shape
305
+
306
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
307
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
308
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
309
+
310
+ return prompt_embeds
311
+
312
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
313
+ def _get_clip_prompt_embeds(
314
+ self,
315
+ prompt: Union[str, List[str]],
316
+ num_images_per_prompt: int = 1,
317
+ device: Optional[torch.device] = None,
318
+ ):
319
+ device = device or self._execution_device
320
+
321
+ prompt = [prompt] if isinstance(prompt, str) else prompt
322
+ batch_size = len(prompt)
323
+
324
+ if isinstance(self, TextualInversionLoaderMixin):
325
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
326
+
327
+ text_inputs = self.tokenizer(
328
+ prompt,
329
+ padding="max_length",
330
+ max_length=self.tokenizer_max_length,
331
+ truncation=True,
332
+ return_overflowing_tokens=False,
333
+ return_length=False,
334
+ return_tensors="pt",
335
+ )
336
+
337
+ text_input_ids = text_inputs.input_ids
338
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
339
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
340
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
341
+ logger.warning(
342
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
343
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
344
+ )
345
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
346
+
347
+ # Use pooled output of CLIPTextModel
348
+ prompt_embeds = prompt_embeds.pooler_output
349
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
350
+
351
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
352
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
353
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
354
+
355
+ return prompt_embeds
356
+
357
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
358
+ def encode_prompt(
359
+ self,
360
+ prompt: Union[str, List[str]],
361
+ prompt_2: Union[str, List[str]],
362
+ device: Optional[torch.device] = None,
363
+ num_images_per_prompt: int = 1,
364
+ prompt_embeds: Optional[torch.FloatTensor] = None,
365
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
366
+ max_sequence_length: int = 512,
367
+ lora_scale: Optional[float] = None,
368
+ ):
369
+ r"""
370
+
371
+ Args:
372
+ prompt (`str` or `List[str]`, *optional*):
373
+ prompt to be encoded
374
+ prompt_2 (`str` or `List[str]`, *optional*):
375
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
376
+ used in all text-encoders
377
+ device: (`torch.device`):
378
+ torch device
379
+ num_images_per_prompt (`int`):
380
+ number of images that should be generated per prompt
381
+ prompt_embeds (`torch.FloatTensor`, *optional*):
382
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
383
+ provided, text embeddings will be generated from `prompt` input argument.
384
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
385
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
386
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
387
+ lora_scale (`float`, *optional*):
388
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
389
+ """
390
+ device = device or self._execution_device
391
+
392
+ # set lora scale so that monkey patched LoRA
393
+ # function of text encoder can correctly access it
394
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
395
+ self._lora_scale = lora_scale
396
+
397
+ # dynamically adjust the LoRA scale
398
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
399
+ scale_lora_layers(self.text_encoder, lora_scale)
400
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
401
+ scale_lora_layers(self.text_encoder_2, lora_scale)
402
+
403
+ prompt = [prompt] if isinstance(prompt, str) else prompt
404
+
405
+ if prompt_embeds is None:
406
+ prompt_2 = prompt_2 or prompt
407
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
408
+
409
+ # We only use the pooled prompt output from the CLIPTextModel
410
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
411
+ prompt=prompt,
412
+ device=device,
413
+ num_images_per_prompt=num_images_per_prompt,
414
+ )
415
+ prompt_embeds = self._get_t5_prompt_embeds(
416
+ prompt=prompt_2,
417
+ num_images_per_prompt=num_images_per_prompt,
418
+ max_sequence_length=max_sequence_length,
419
+ device=device,
420
+ )
421
+
422
+ if self.text_encoder is not None:
423
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
424
+ # Retrieve the original scale by scaling back the LoRA layers
425
+ unscale_lora_layers(self.text_encoder, lora_scale)
426
+
427
+ if self.text_encoder_2 is not None:
428
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
429
+ # Retrieve the original scale by scaling back the LoRA layers
430
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
431
+
432
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
433
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
434
+
435
+ return prompt_embeds, pooled_prompt_embeds, text_ids
436
+
437
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
438
+ def encode_image(self, image, device, num_images_per_prompt):
439
+ dtype = next(self.image_encoder.parameters()).dtype
440
+
441
+ if not isinstance(image, torch.Tensor):
442
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
443
+
444
+ image = image.to(device=device, dtype=dtype)
445
+ image_embeds = self.image_encoder(image).image_embeds
446
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
447
+ return image_embeds
448
+
449
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
450
+ def prepare_ip_adapter_image_embeds(
451
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
452
+ ):
453
+ image_embeds = []
454
+ if ip_adapter_image_embeds is None:
455
+ if not isinstance(ip_adapter_image, list):
456
+ ip_adapter_image = [ip_adapter_image]
457
+
458
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
459
+ raise ValueError(
460
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
461
+ )
462
+
463
+ for single_ip_adapter_image in ip_adapter_image:
464
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
465
+ image_embeds.append(single_image_embeds[None, :])
466
+ else:
467
+ if not isinstance(ip_adapter_image_embeds, list):
468
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
469
+
470
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
471
+ raise ValueError(
472
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
473
+ )
474
+
475
+ for single_image_embeds in ip_adapter_image_embeds:
476
+ image_embeds.append(single_image_embeds)
477
+
478
+ ip_adapter_image_embeds = []
479
+ for single_image_embeds in image_embeds:
480
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
481
+ single_image_embeds = single_image_embeds.to(device=device)
482
+ ip_adapter_image_embeds.append(single_image_embeds)
483
+
484
+ return ip_adapter_image_embeds
485
+
486
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
487
+ def check_inputs(
488
+ self,
489
+ prompt,
490
+ prompt_2,
491
+ height,
492
+ width,
493
+ negative_prompt=None,
494
+ negative_prompt_2=None,
495
+ prompt_embeds=None,
496
+ negative_prompt_embeds=None,
497
+ pooled_prompt_embeds=None,
498
+ negative_pooled_prompt_embeds=None,
499
+ callback_on_step_end_tensor_inputs=None,
500
+ max_sequence_length=None,
501
+ ):
502
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
503
+ logger.warning(
504
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
505
+ )
506
+
507
+ if callback_on_step_end_tensor_inputs is not None and not all(
508
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
509
+ ):
510
+ raise ValueError(
511
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
512
+ )
513
+
514
+ if prompt is not None and prompt_embeds is not None:
515
+ raise ValueError(
516
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
517
+ " only forward one of the two."
518
+ )
519
+ elif prompt_2 is not None and prompt_embeds is not None:
520
+ raise ValueError(
521
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
522
+ " only forward one of the two."
523
+ )
524
+ elif prompt is None and prompt_embeds is None:
525
+ raise ValueError(
526
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
527
+ )
528
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
529
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
530
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
531
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
532
+
533
+ if negative_prompt is not None and negative_prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
536
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
537
+ )
538
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
539
+ raise ValueError(
540
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
541
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
542
+ )
543
+
544
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
545
+ raise ValueError(
546
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
547
+ )
548
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
549
+ raise ValueError(
550
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
551
+ )
552
+
553
+ if max_sequence_length is not None and max_sequence_length > 512:
554
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
555
+
556
+ @staticmethod
557
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
558
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
559
+ latent_image_ids = torch.zeros(height, width, 3)
560
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
561
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
562
+
563
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
564
+
565
+ latent_image_ids = latent_image_ids.reshape(
566
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
567
+ )
568
+
569
+ return latent_image_ids.to(device=device, dtype=dtype)
570
+
571
+ @staticmethod
572
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
573
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
574
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
575
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
576
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
577
+
578
+ return latents
579
+
580
+ @staticmethod
581
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
582
+ def _unpack_latents(latents, height, width, vae_scale_factor):
583
+ batch_size, num_patches, channels = latents.shape
584
+
585
+ # VAE applies 8x compression on images but we must also account for packing which requires
586
+ # latent height and width to be divisible by 2.
587
+ height = 2 * (int(height) // (vae_scale_factor * 2))
588
+ width = 2 * (int(width) // (vae_scale_factor * 2))
589
+
590
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
591
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
592
+
593
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
594
+
595
+ return latents
596
+
597
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
598
+ if isinstance(generator, list):
599
+ image_latents = [
600
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
601
+ for i in range(image.shape[0])
602
+ ]
603
+ image_latents = torch.cat(image_latents, dim=0)
604
+ else:
605
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
606
+
607
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
608
+
609
+ return image_latents
610
+
611
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
612
+ def enable_vae_slicing(self):
613
+ r"""
614
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
615
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
616
+ """
617
+ self.vae.enable_slicing()
618
+
619
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
620
+ def disable_vae_slicing(self):
621
+ r"""
622
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
623
+ computing decoding in one step.
624
+ """
625
+ self.vae.disable_slicing()
626
+
627
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
628
+ def enable_vae_tiling(self):
629
+ r"""
630
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
631
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
632
+ processing larger images.
633
+ """
634
+ self.vae.enable_tiling()
635
+
636
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
637
+ def disable_vae_tiling(self):
638
+ r"""
639
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
640
+ computing decoding in one step.
641
+ """
642
+ self.vae.disable_tiling()
643
+
644
+ def prepare_latents(
645
+ self,
646
+ images: Optional[torch.Tensor],
647
+ batch_size: int,
648
+ num_channels_latents: int,
649
+ height: int,
650
+ width: int,
651
+ dtype: torch.dtype,
652
+ device: torch.device,
653
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
654
+ latents: Optional[torch.Tensor] = None,
655
+ ):
656
+ if isinstance(generator, list) and len(generator) != batch_size:
657
+ raise ValueError(
658
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
659
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
660
+ )
661
+
662
+ # VAE applies 8x compression on images but we must also account for packing which requires
663
+ # latent height and width to be divisible by 2.
664
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
665
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
666
+ shape = (batch_size, num_channels_latents, height, width)
667
+ h_offset = 0
668
+ w_offset = 0
669
+ image_latents = image_ids = None
670
+ if images is not None:
671
+ tp_image_latents = []
672
+ tp_image_ids = []
673
+ for i, image in enumerate(images):
674
+ image = image.to(device=device, dtype=dtype)
675
+ if image.shape[1] != self.latent_channels:
676
+ image_latents = self._encode_vae_image(image=image, generator=generator)
677
+ else:
678
+ image_latents = image
679
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
680
+ # expand init_latents for batch_size
681
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
682
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
683
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
684
+ raise ValueError(
685
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
686
+ )
687
+ else:
688
+ image_latents = torch.cat([image_latents], dim=0)
689
+
690
+ image_latent_height, image_latent_width = image_latents.shape[2:]
691
+ image_latents = self._pack_latents(
692
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
693
+ )
694
+ image_ids = self._prepare_latent_image_ids(
695
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
696
+ )
697
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
698
+ # image_ids[..., 0] = 0.9+i*0.1
699
+ image_ids[..., 0] = i+1
700
+ # image_ids[..., 1] += h_offset
701
+ image_ids[..., 2] += w_offset
702
+ tp_image_latents.append(image_latents)
703
+ tp_image_ids.append(image_ids)
704
+ h_offset += image_latent_height //2
705
+ w_offset += image_latent_width //2
706
+ image_latents = torch.cat(tp_image_latents, dim=1)
707
+ image_ids = torch.cat(tp_image_ids, dim=0)
708
+
709
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
710
+
711
+ if latents is None:
712
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
713
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
714
+ else:
715
+ latents = latents.to(device=device, dtype=dtype)
716
+
717
+ return latents, image_latents, latent_ids, image_ids
718
+
719
+ @property
720
+ def guidance_scale(self):
721
+ return self._guidance_scale
722
+
723
+ @property
724
+ def joint_attention_kwargs(self):
725
+ return self._joint_attention_kwargs
726
+
727
+ @property
728
+ def num_timesteps(self):
729
+ return self._num_timesteps
730
+
731
+ @property
732
+ def current_timestep(self):
733
+ return self._current_timestep
734
+
735
+ @property
736
+ def interrupt(self):
737
+ return self._interrupt
738
+
739
+ @torch.no_grad()
740
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
741
+ def __call__(
742
+ self,
743
+ images: Optional[List[PipelineImageInput]] = None,
744
+ prompt: Union[str, List[str]] = None,
745
+ prompt_2: Optional[Union[str, List[str]]] = None,
746
+ negative_prompt: Union[str, List[str]] = None,
747
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
748
+ true_cfg_scale: float = 1.0,
749
+ height: Optional[int] = None,
750
+ width: Optional[int] = None,
751
+ num_inference_steps: int = 28,
752
+ sigmas: Optional[List[float]] = None,
753
+ guidance_scale: float = 3.5,
754
+ num_images_per_prompt: Optional[int] = 1,
755
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
756
+ latents: Optional[torch.FloatTensor] = None,
757
+ prompt_embeds: Optional[torch.FloatTensor] = None,
758
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
759
+ ip_adapter_image: Optional[PipelineImageInput] = None,
760
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
761
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
762
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
763
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
764
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
765
+ output_type: Optional[str] = "pil",
766
+ return_dict: bool = True,
767
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
768
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
769
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
770
+ max_sequence_length: int = 512,
771
+ max_area: int = 1024**2,
772
+ _auto_resize: bool = True,
773
+ ):
774
+ r"""
775
+ Function invoked when calling the pipeline for generation.
776
+
777
+ Args:
778
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
779
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
780
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
781
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
782
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
783
+ latents as `image`, but if passing latents directly it is not encoded again.
784
+ prompt (`str` or `List[str]`, *optional*):
785
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
786
+ instead.
787
+ prompt_2 (`str` or `List[str]`, *optional*):
788
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
789
+ will be used instead.
790
+ negative_prompt (`str` or `List[str]`, *optional*):
791
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
792
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
793
+ not greater than `1`).
794
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
795
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
796
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
797
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
798
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
799
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
800
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
801
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
802
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
803
+ num_inference_steps (`int`, *optional*, defaults to 50):
804
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
805
+ expense of slower inference.
806
+ sigmas (`List[float]`, *optional*):
807
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
808
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
809
+ will be used.
810
+ guidance_scale (`float`, *optional*, defaults to 3.5):
811
+ Guidance scale as defined in [Classifier-Free Diffusion
812
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
813
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
814
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
815
+ the text `prompt`, usually at the expense of lower image quality.
816
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
817
+ The number of images to generate per prompt.
818
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
819
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
820
+ to make generation deterministic.
821
+ latents (`torch.FloatTensor`, *optional*):
822
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
823
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
824
+ tensor will ge generated by sampling using the supplied random `generator`.
825
+ prompt_embeds (`torch.FloatTensor`, *optional*):
826
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
827
+ provided, text embeddings will be generated from `prompt` input argument.
828
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
829
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
830
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
831
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
832
+ Optional image input to work with IP Adapters.
833
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
834
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
835
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
836
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
837
+ negative_ip_adapter_image:
838
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
839
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
840
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
841
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
842
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
843
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
844
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
845
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
846
+ argument.
847
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
848
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
849
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
850
+ input argument.
851
+ output_type (`str`, *optional*, defaults to `"pil"`):
852
+ The output format of the generate image. Choose between
853
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
854
+ return_dict (`bool`, *optional*, defaults to `True`):
855
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
856
+ joint_attention_kwargs (`dict`, *optional*):
857
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
858
+ `self.processor` in
859
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
860
+ callback_on_step_end (`Callable`, *optional*):
861
+ A function that calls at the end of each denoising steps during the inference. The function is called
862
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
863
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
864
+ `callback_on_step_end_tensor_inputs`.
865
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
866
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
867
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
868
+ `._callback_tensor_inputs` attribute of your pipeline class.
869
+ max_sequence_length (`int` defaults to 512):
870
+ Maximum sequence length to use with the `prompt`.
871
+ max_area (`int`, defaults to `1024 ** 2`):
872
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
873
+ area while maintaining the aspect ratio.
874
+
875
+ Examples:
876
+
877
+ Returns:
878
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
879
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
880
+ images.
881
+ """
882
+
883
+ height = height or self.default_sample_size * self.vae_scale_factor
884
+ width = width or self.default_sample_size * self.vae_scale_factor
885
+
886
+ original_height, original_width = height, width
887
+ aspect_ratio = width / height
888
+ width = round((max_area * aspect_ratio) ** 0.5)
889
+ height = round((max_area / aspect_ratio) ** 0.5)
890
+
891
+ multiple_of = self.vae_scale_factor * 2
892
+ width = width // multiple_of * multiple_of
893
+ height = height // multiple_of * multiple_of
894
+
895
+ if height != original_height or width != original_width:
896
+ logger.warning(
897
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
898
+ )
899
+
900
+ # 1. Check inputs. Raise error if not correct
901
+ self.check_inputs(
902
+ prompt,
903
+ prompt_2,
904
+ height,
905
+ width,
906
+ negative_prompt=negative_prompt,
907
+ negative_prompt_2=negative_prompt_2,
908
+ prompt_embeds=prompt_embeds,
909
+ negative_prompt_embeds=negative_prompt_embeds,
910
+ pooled_prompt_embeds=pooled_prompt_embeds,
911
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
912
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
913
+ max_sequence_length=max_sequence_length,
914
+ )
915
+
916
+ self._guidance_scale = guidance_scale
917
+ self._joint_attention_kwargs = joint_attention_kwargs
918
+ self._current_timestep = None
919
+ self._interrupt = False
920
+
921
+ # 2. Define call parameters
922
+ if prompt is not None and isinstance(prompt, str):
923
+ batch_size = 1
924
+ elif prompt is not None and isinstance(prompt, list):
925
+ batch_size = len(prompt)
926
+ else:
927
+ batch_size = prompt_embeds.shape[0]
928
+
929
+ device = self._execution_device
930
+
931
+ lora_scale = (
932
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
933
+ )
934
+ has_neg_prompt = negative_prompt is not None or (
935
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
936
+ )
937
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
938
+ (
939
+ prompt_embeds,
940
+ pooled_prompt_embeds,
941
+ text_ids,
942
+ ) = self.encode_prompt(
943
+ prompt=prompt,
944
+ prompt_2=prompt_2,
945
+ prompt_embeds=prompt_embeds,
946
+ pooled_prompt_embeds=pooled_prompt_embeds,
947
+ device=device,
948
+ num_images_per_prompt=num_images_per_prompt,
949
+ max_sequence_length=max_sequence_length,
950
+ lora_scale=lora_scale,
951
+ )
952
+ if do_true_cfg:
953
+ (
954
+ negative_prompt_embeds,
955
+ negative_pooled_prompt_embeds,
956
+ negative_text_ids,
957
+ ) = self.encode_prompt(
958
+ prompt=negative_prompt,
959
+ prompt_2=negative_prompt_2,
960
+ prompt_embeds=negative_prompt_embeds,
961
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
962
+ device=device,
963
+ num_images_per_prompt=num_images_per_prompt,
964
+ max_sequence_length=max_sequence_length,
965
+ lora_scale=lora_scale,
966
+ )
967
+
968
+ # 3. Preprocess image
969
+ if images is not None and not (isinstance(images[0], torch.Tensor) and images[0].size(1) == self.latent_channels):
970
+ tp_images=[]
971
+ for img in images:
972
+ image = img
973
+ image_height, image_width = self.image_processor.get_default_height_width(img)
974
+ aspect_ratio = image_width / image_height
975
+ if _auto_resize:
976
+ # Kontext is trained on specific resolutions, using one of them is recommended
977
+ _, image_width, image_height = min(
978
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
979
+ )
980
+ image_width = image_width // multiple_of * multiple_of
981
+ image_height = image_height // multiple_of * multiple_of
982
+ image = self.image_processor.resize(image, image_height, image_width)
983
+ image = self.image_processor.preprocess(image, image_height, image_width)
984
+ tp_images.append(image)
985
+ images = tp_images
986
+
987
+ # 4. Prepare latent variables
988
+ num_channels_latents = self.transformer.config.in_channels // 4
989
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
990
+ images,
991
+ batch_size * num_images_per_prompt,
992
+ num_channels_latents,
993
+ height,
994
+ width,
995
+ prompt_embeds.dtype,
996
+ device,
997
+ generator,
998
+ latents,
999
+ )
1000
+ if image_ids is not None:
1001
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
1002
+
1003
+ # 5. Prepare timesteps
1004
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1005
+ image_seq_len = latents.shape[1]
1006
+ mu = calculate_shift(
1007
+ image_seq_len,
1008
+ self.scheduler.config.get("base_image_seq_len", 256),
1009
+ self.scheduler.config.get("max_image_seq_len", 4096),
1010
+ self.scheduler.config.get("base_shift", 0.5),
1011
+ self.scheduler.config.get("max_shift", 1.15),
1012
+ )
1013
+ timesteps, num_inference_steps = retrieve_timesteps(
1014
+ self.scheduler,
1015
+ num_inference_steps,
1016
+ device,
1017
+ sigmas=sigmas,
1018
+ mu=mu,
1019
+ )
1020
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1021
+ self._num_timesteps = len(timesteps)
1022
+
1023
+ # handle guidance
1024
+ if self.transformer.config.guidance_embeds:
1025
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1026
+ guidance = guidance.expand(latents.shape[0])
1027
+ else:
1028
+ guidance = None
1029
+
1030
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1031
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1032
+ ):
1033
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1034
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1035
+
1036
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1037
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1038
+ ):
1039
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1040
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1041
+
1042
+ if self.joint_attention_kwargs is None:
1043
+ self._joint_attention_kwargs = {}
1044
+
1045
+ image_embeds = None
1046
+ negative_image_embeds = None
1047
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1048
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1049
+ ip_adapter_image,
1050
+ ip_adapter_image_embeds,
1051
+ device,
1052
+ batch_size * num_images_per_prompt,
1053
+ )
1054
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1055
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1056
+ negative_ip_adapter_image,
1057
+ negative_ip_adapter_image_embeds,
1058
+ device,
1059
+ batch_size * num_images_per_prompt,
1060
+ )
1061
+
1062
+ # 6. Denoising loop
1063
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
1064
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
1065
+ self.scheduler.set_begin_index(0)
1066
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1067
+ for i, t in enumerate(timesteps):
1068
+ if self.interrupt:
1069
+ continue
1070
+
1071
+ self._current_timestep = t
1072
+ if image_embeds is not None:
1073
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1074
+
1075
+ latent_model_input = latents
1076
+ if image_latents is not None:
1077
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
1078
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1079
+
1080
+ noise_pred = self.transformer(
1081
+ hidden_states=latent_model_input,
1082
+ timestep=timestep / 1000,
1083
+ guidance=guidance,
1084
+ pooled_projections=pooled_prompt_embeds,
1085
+ encoder_hidden_states=prompt_embeds,
1086
+ txt_ids=text_ids,
1087
+ img_ids=latent_ids,
1088
+ joint_attention_kwargs=self.joint_attention_kwargs,
1089
+ return_dict=False,
1090
+ )[0]
1091
+ noise_pred = noise_pred[:, : latents.size(1)]
1092
+
1093
+ if do_true_cfg:
1094
+ if negative_image_embeds is not None:
1095
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1096
+ neg_noise_pred = self.transformer(
1097
+ hidden_states=latent_model_input,
1098
+ timestep=timestep / 1000,
1099
+ guidance=guidance,
1100
+ pooled_projections=negative_pooled_prompt_embeds,
1101
+ encoder_hidden_states=negative_prompt_embeds,
1102
+ txt_ids=negative_text_ids,
1103
+ img_ids=latent_ids,
1104
+ joint_attention_kwargs=self.joint_attention_kwargs,
1105
+ return_dict=False,
1106
+ )[0]
1107
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
1108
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1109
+
1110
+ # compute the previous noisy sample x_t -> x_t-1
1111
+ latents_dtype = latents.dtype
1112
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1113
+
1114
+ if latents.dtype != latents_dtype:
1115
+ if torch.backends.mps.is_available():
1116
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1117
+ latents = latents.to(latents_dtype)
1118
+
1119
+ if callback_on_step_end is not None:
1120
+ callback_kwargs = {}
1121
+ for k in callback_on_step_end_tensor_inputs:
1122
+ callback_kwargs[k] = locals()[k]
1123
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1124
+
1125
+ latents = callback_outputs.pop("latents", latents)
1126
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1127
+
1128
+ # call the callback, if provided
1129
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1130
+ progress_bar.update()
1131
+
1132
+ if XLA_AVAILABLE:
1133
+ xm.mark_step()
1134
+
1135
+ self._current_timestep = None
1136
+
1137
+ if output_type == "latent":
1138
+ image = latents
1139
+ else:
1140
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1141
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1142
+ image = self.vae.decode(latents, return_dict=False)[0]
1143
+ image = self.image_processor.postprocess(image, output_type=output_type)
1144
+
1145
+ # Offload all models
1146
+ self.maybe_free_model_hooks()
1147
+
1148
+ if not return_dict:
1149
+ return (image,)
1150
+
1151
+ return FluxPipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ timm
2
+ ujson
3
+ peft
4
+ datasets
5
+ transformers
6
+ opencv-python
7
+ qwen-vl-utils
8
+ lmdb
9
+ diffusers
10
+ numpy
11
+ gradio
script.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=3 GRADIO_TEMP_DIR=gradio_tmp python web_edit.py \
2
+ --vlm_path /gpfs/bhpeng/generation/do2/vlm-model \
3
+ --edit_lora_path /gpfs/bhpeng/generation/do2/edit_lora \
4
+ --server_name "0.0.0.0" \
5
+ --server_port 7869
6
+
7
+
8
+ CUDA_VISIBLE_DEVICES=1 python web_generate.py \
9
+ --vlm_path /gpfs/bhpeng/generation/do2/vlm-model \
10
+ --gen_lora_path /gpfs/bhpeng/generation/do2/gen_lora \
11
+ --server_name "0.0.0.0" \
12
+ --server_port 7861
web_edit.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pipeline_flux_kontext import FluxKontextPipeline
3
+ from diffusers.utils import load_image
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+ import os
7
+ import re
8
+ from PIL import Image
9
+ import gradio as gr
10
+ import uuid
11
+ import argparse
12
+
13
+ def parse_args():
14
+ """Parses command-line arguments for model paths and server configuration."""
15
+ parser = argparse.ArgumentParser(description="Launch DreamOmni2 Editing Gradio Demo.")
16
+ parser.add_argument(
17
+ "--vlm_path",
18
+ type=str,
19
+ default="vlm-model",
20
+ help="Path to the Qwen2_5_VL VLM model directory."
21
+ )
22
+ parser.add_argument(
23
+ "--edit_lora_path",
24
+ type=str,
25
+ default="edit_lora",
26
+ help="Path to the FLUX.1-Kontext editing LoRA weights directory."
27
+ )
28
+ parser.add_argument(
29
+ "--server_name",
30
+ type=str,
31
+ default="0.0.0.0",
32
+ help="The server name (IP address) to host the Gradio demo."
33
+ )
34
+ parser.add_argument(
35
+ "--server_port",
36
+ type=int,
37
+ default=7860,
38
+ help="The port number to host the Gradio demo."
39
+ )
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+ ARGS = parse_args()
44
+ vlm_path = ARGS.vlm_path
45
+ edit_lora_path = ARGS.edit_lora_path
46
+ server_name = ARGS.server_name
47
+ server_port = ARGS.server_port
48
+ device = "cuda"
49
+
50
+ def extract_gen_content(text):
51
+ text = text[6:-7]
52
+ return text
53
+
54
+ print(f"Loading models from vlm_path: {vlm_path}, edit_lora_path: {edit_lora_path}")
55
+
56
+ pipe = FluxKontextPipeline.from_pretrained(
57
+ "black-forest-labs/FLUX.1-Kontext-dev",
58
+ torch_dtype=torch.bfloat16
59
+ )
60
+ pipe.to(device)
61
+ pipe.load_lora_weights(edit_lora_path, adapter_name="edit")
62
+ pipe.set_adapters(["edit"], adapter_weights=[1])
63
+
64
+ vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
+ vlm_path,
66
+ torch_dtype="bfloat16",
67
+ device_map="cuda"
68
+ )
69
+ processor = AutoProcessor.from_pretrained(vlm_path)
70
+
71
+
72
+ def infer_vlm(input_img_path, input_instruction, prefix):
73
+ if not vlm_model or not processor:
74
+ raise gr.Error("VLM Model not loaded. Cannot process prompt.")
75
+ tp = []
76
+ for path in input_img_path:
77
+ tp.append({"type": "image", "image": path})
78
+ tp.append({"type": "text", "text": input_instruction + prefix})
79
+ messages = [{"role": "user", "content": tp}]
80
+
81
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
+ image_inputs, video_inputs = process_vision_info(messages)
83
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
84
+ inputs = inputs.to("cuda")
85
+
86
+ generated_ids = vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096)
87
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
88
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
89
+ return output_text[0]
90
+
91
+ PREFERRED_KONTEXT_RESOLUTIONS = [
92
+ (672, 1568),
93
+ (688, 1504),
94
+ (720, 1456),
95
+ (752, 1392),
96
+ (800, 1328),
97
+ (832, 1248),
98
+ (880, 1184),
99
+ (944, 1104),
100
+ (1024, 1024),
101
+ (1104, 944),
102
+ (1184, 880),
103
+ (1248, 832),
104
+ (1328, 800),
105
+ (1392, 752),
106
+ (1456, 720),
107
+ (1504, 688),
108
+ (1568, 672),
109
+ ]
110
+ def find_closest_resolution(width, height, preferred_resolutions):
111
+ input_ratio = width / height
112
+ closest_resolution = min(
113
+ preferred_resolutions,
114
+ key=lambda res: abs((res[0] / res[1]) - input_ratio)
115
+ )
116
+ return closest_resolution
117
+
118
+ def perform_edit(input_img_paths, input_instruction, output_path):
119
+ prefix = " It is editing task."
120
+ source_imgs = [load_image(path) for path in input_img_paths]
121
+ resized_imgs = []
122
+ for img in source_imgs:
123
+ target_resolution = find_closest_resolution(img.width, img.height, PREFERRED_KONTEXT_RESOLUTIONS)
124
+ resized_img = img.resize(target_resolution, Image.LANCZOS)
125
+ resized_imgs.append(resized_img)
126
+ prompt = infer_vlm(input_img_paths, input_instruction, prefix)
127
+ prompt = extract_gen_content(prompt)
128
+ print(f"Generated Prompt for VLM: {prompt}")
129
+
130
+ image = pipe(
131
+ images=resized_imgs,
132
+ height=resized_imgs[0].height,
133
+ width=resized_imgs[0].width,
134
+ prompt=prompt,
135
+ num_inference_steps=30,
136
+ guidance_scale=3.5,
137
+ ).images[0]
138
+ image.save(output_path)
139
+ print(f"Edit result saved to {output_path}")
140
+
141
+
142
+ def process_request(image_file_1, image_file_2, instruction):
143
+ # debugpy.listen(5678)
144
+ # print("Waiting for debugger attach...")
145
+ # debugpy.wait_for_client()
146
+ if not image_file_1 or not image_file_2:
147
+ raise gr.Error("Please upload both images.")
148
+ if not instruction:
149
+ raise gr.Error("Please provide an instruction.")
150
+ if not pipe or not vlm_model:
151
+ raise gr.Error("Models not loaded. Check the console for errors.")
152
+
153
+ output_path = f"/tmp/{uuid.uuid4()}.png"
154
+ input_img_paths = [image_file_1, image_file_2] # List of file paths from the two gr.File inputs
155
+
156
+ perform_edit(input_img_paths, instruction, output_path)
157
+ return output_path
158
+
159
+
160
+ css = """
161
+ .text-center { text-align: center; }
162
+ .result-img img {
163
+ max-height: 60vh !important;
164
+ min-height: 30vh !important;
165
+ width: auto !important;
166
+ object-fit: contain;
167
+ }
168
+ .input-img img {
169
+ max-height: 30vh !important;
170
+ width: auto !important;
171
+ object-fit: contain;
172
+ }
173
+ """
174
+
175
+
176
+ with gr.Blocks(theme=gr.themes.Soft(), title="DreamOmni2", css=css) as demo:
177
+ gr.HTML(
178
+ """
179
+ <h1 style="text-align:center; font-size:48px; font-weight:bold; margin-bottom:20px;">
180
+ DreamOmni2: Omni-purpose Image Generation and Editing
181
+ </h1>
182
+ """
183
+ )
184
+ gr.Markdown(
185
+ "Select a mode, upload two images, provide an instruction, and click 'Run'.",
186
+ elem_classes="text-center"
187
+ )
188
+ with gr.Row():
189
+ with gr.Column(scale=2):
190
+ gr.Markdown("⬆️ Upload images. Click or drag to upload.")
191
+
192
+ with gr.Row():
193
+ image_uploader_1 = gr.Image(
194
+ label="Img 1",
195
+ type="filepath",
196
+ interactive=True,
197
+ elem_classes="input-img",
198
+ )
199
+ image_uploader_2 = gr.Image(
200
+ label="Img 2",
201
+ type="filepath",
202
+ interactive=True,
203
+ elem_classes="input-img",
204
+ )
205
+
206
+ instruction_text = gr.Textbox(
207
+ label="Instruction",
208
+ lines=2,
209
+ placeholder="Input your instruction for generation or editing here...",
210
+ )
211
+ run_button = gr.Button("Run", variant="primary")
212
+
213
+ with gr.Column(scale=2):
214
+ gr.Markdown(
215
+ "✏️ **Editing Mode**: Modify an existing image using instructions and references.\n\n"
216
+ "Tip: If the result is not what you expect, try clicking **Run** again. "
217
+ )
218
+ output_image = gr.Image(
219
+ label="Result",
220
+ type="filepath",
221
+ elem_classes="result-img",
222
+ )
223
+
224
+ # --- Examples (不变) ---
225
+ gr.Markdown("## Examples")
226
+
227
+ gr.Examples(
228
+ label="Editing Examples",
229
+ examples=[
230
+ ["edit_tests/4/ref_0.jpg", "edit_tests/4/ref_1.jpg", "Replace the first image have the same image style as the second image.","edit_tests/4/res.jpg"],
231
+ ["edit_tests/5/ref_0.jpg", "edit_tests/5/ref_1.jpg", "Make the person in the first image have the same hairstyle as the person in the second image.","edit_tests/5/res.jpg"],
232
+ ["edit_tests/src.jpg", "edit_tests/ref.jpg", "Make the woman from the second image stand on the road in the first image.","edit_tests/edi_res.png"],
233
+ ["edit_tests/1/ref_0.jpg", "edit_tests/1/ref_1.jpg", "Replace the lantern in the first image with the dog in the second image.","edit_tests/1/res.jpg"],
234
+ ["edit_tests/2/ref_0.jpg", "edit_tests/2/ref_1.jpg", "Replace the suit in the first image with the clothes in the second image.","edit_tests/2/res.jpg"],
235
+ ["edit_tests/3/ref_0.jpg", "edit_tests/3/ref_1.jpg", "Make the first image has the same light condition as the second image.","edit_tests/3/res.jpg"],
236
+ ["edit_tests/6/ref_0.jpg", "edit_tests/6/ref_1.jpg", "Make the words in the first image have the same font as the words in the second image.","edit_tests/6/res.jpg"],
237
+ ["edit_tests/7/ref_0.jpg", "edit_tests/7/ref_1.jpg", "Make the car in the first image have the same pattern as the mouse in the second image.","edit_tests/7/res.jpg"],
238
+ ["edit_tests/8/ref_0.jpg", "edit_tests/8/ref_1.jpg", "Make the dress in the first image have the same pattern in the second image.","edit_tests/8/res.jpg"],
239
+ ],
240
+ inputs=[image_uploader_1, image_uploader_2, instruction_text, output_image],
241
+ cache_examples=False,
242
+ )
243
+
244
+ run_button.click(
245
+ fn=process_request,
246
+ inputs=[image_uploader_1, image_uploader_2, instruction_text],
247
+ outputs=output_image
248
+ )
249
+
250
+ if __name__ == "__main__":
251
+ print("Launching Gradio Demo...")
252
+ demo.launch(server_name=server_name, server_port=server_port)
web_generate.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pipeline_flux_kontext import FluxKontextPipeline
3
+ from diffusers.utils import load_image
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+ import os
7
+ import re
8
+ from PIL import Image
9
+ import gradio as gr
10
+ import uuid
11
+ import argparse
12
+
13
+ def parse_args():
14
+ """Parses command-line arguments for model paths and server configuration."""
15
+ parser = argparse.ArgumentParser(description="Launch DreamOmni2 Editing Gradio Demo.")
16
+ parser.add_argument(
17
+ "--vlm_path",
18
+ type=str,
19
+ default="vlm-model",
20
+ help="Path to the Qwen2_5_VL VLM model directory."
21
+ )
22
+ parser.add_argument(
23
+ "--gen_lora_path",
24
+ type=str,
25
+ default="gen_lora",
26
+ help="Path to the FLUX.1-Kontext generation LoRA weights directory."
27
+ )
28
+ parser.add_argument(
29
+ "--server_name",
30
+ type=str,
31
+ default="0.0.0.0",
32
+ help="The server name (IP address) to host the Gradio demo."
33
+ )
34
+ parser.add_argument(
35
+ "--server_port",
36
+ type=int,
37
+ default=7860,
38
+ help="The port number to host the Gradio demo."
39
+ )
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+ ARGS = parse_args()
44
+ vlm_path = ARGS.vlm_path
45
+ gen_lora_path = ARGS.gen_lora_path
46
+ server_name = ARGS.server_name
47
+ server_port = ARGS.server_port
48
+ device = "cuda"
49
+
50
+ def extract_gen_content(text):
51
+ text = text[6:-7]
52
+ return text
53
+
54
+ print(f"Loading models from vlm_path: {vlm_path}, gen_lora_path: {gen_lora_path}")
55
+
56
+ pipe = FluxKontextPipeline.from_pretrained(
57
+ "black-forest-labs/FLUX.1-Kontext-dev",
58
+ torch_dtype=torch.bfloat16
59
+ )
60
+ pipe.to(device)
61
+ pipe.load_lora_weights(gen_lora_path, adapter_name="generation")
62
+ pipe.set_adapters(["generation"], adapter_weights=[1])
63
+
64
+ vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
+ vlm_path,
66
+ torch_dtype="bfloat16",
67
+ device_map="cuda"
68
+ )
69
+ processor = AutoProcessor.from_pretrained(vlm_path)
70
+
71
+
72
+ def infer_vlm(input_img_path, input_instruction, prefix):
73
+ if not vlm_model or not processor:
74
+ raise gr.Error("VLM Model not loaded. Cannot process prompt.")
75
+ tp = []
76
+ for path in input_img_path:
77
+ tp.append({"type": "image", "image": path})
78
+ tp.append({"type": "text", "text": input_instruction + prefix})
79
+ messages = [{"role": "user", "content": tp}]
80
+
81
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
+ image_inputs, video_inputs = process_vision_info(messages)
83
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
84
+ inputs = inputs.to("cuda")
85
+
86
+ generated_ids = vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096)
87
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
88
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
89
+ return output_text[0]
90
+
91
+
92
+ PREFERRED_KONTEXT_RESOLUTIONS = [
93
+ (672, 1568),
94
+ (688, 1504),
95
+ (720, 1456),
96
+ (752, 1392),
97
+ (800, 1328),
98
+ (832, 1248),
99
+ (880, 1184),
100
+ (944, 1104),
101
+ (1024, 1024),
102
+ (1104, 944),
103
+ (1184, 880),
104
+ (1248, 832),
105
+ (1328, 800),
106
+ (1392, 752),
107
+ (1456, 720),
108
+ (1504, 688),
109
+ (1568, 672),
110
+ ]
111
+ def find_closest_resolution(width, height, preferred_resolutions):
112
+ input_ratio = width / height
113
+ closest_resolution = min(
114
+ preferred_resolutions,
115
+ key=lambda res: abs((res[0] / res[1]) - input_ratio)
116
+ )
117
+ return closest_resolution
118
+
119
+ def perform_generation(input_img_paths, input_instruction, output_path, height=1024, width=1024):
120
+ prefix = " It is generation task."
121
+ source_imgs = [load_image(path) for path in input_img_paths]
122
+ resized_imgs = []
123
+ for img in source_imgs:
124
+ target_resolution = find_closest_resolution(img.width, img.height, PREFERRED_KONTEXT_RESOLUTIONS)
125
+ resized_img = img.resize(target_resolution, Image.LANCZOS)
126
+ resized_imgs.append(resized_img)
127
+ prompt = infer_vlm(input_img_paths, input_instruction, prefix)
128
+ prompt = extract_gen_content(prompt)
129
+ print(f"Generated Prompt for VLM: {prompt}")
130
+
131
+ image = pipe(
132
+ images=resized_imgs,
133
+ height=height,
134
+ width=width,
135
+ prompt=prompt,
136
+ num_inference_steps=30,
137
+ guidance_scale=3.5,
138
+ ).images[0]
139
+
140
+ image.save(output_path)
141
+ print(f"Generation result saved to {output_path}")
142
+
143
+ # --- Gradio Interface Logic ---
144
+
145
+ def process_request(image_file_1, image_file_2, instruction):
146
+ # debugpy.listen(5678)
147
+ # print("Waiting for debugger attach...")
148
+ # debugpy.wait_for_client()
149
+ if not image_file_1 or not image_file_2:
150
+ raise gr.Error("Please upload both images.")
151
+ if not instruction:
152
+ raise gr.Error("Please provide an instruction.")
153
+ if not pipe or not vlm_model:
154
+ raise gr.Error("Models not loaded. Check the console for errors.")
155
+
156
+ output_path = f"/tmp/{uuid.uuid4()}.png"
157
+ input_img_paths = [image_file_1, image_file_2] # List of file paths from the two gr.File inputs
158
+
159
+ perform_generation(input_img_paths, instruction, output_path)
160
+
161
+ return output_path
162
+
163
+
164
+ css = """
165
+ .text-center { text-align: center; }
166
+ .result-img img {
167
+ max-height: 60vh !important;
168
+ min-height: 30vh !important;
169
+ width: auto !important;
170
+ object-fit: contain;
171
+ }
172
+ .input-img img {
173
+ max-height: 30vh !important;
174
+ width: auto !important;
175
+ object-fit: contain;
176
+ }
177
+ """
178
+
179
+
180
+ with gr.Blocks(theme=gr.themes.Soft(), title="DreamOmni2", css=css) as demo:
181
+ gr.HTML(
182
+ """
183
+ <h1 style="text-align:center; font-size:48px; font-weight:bold; margin-bottom:20px;">
184
+ DreamOmni2: Omni-purpose Image Generation and Editing
185
+ </h1>
186
+ """
187
+ )
188
+ gr.Markdown(
189
+ "Select a mode, upload two images, provide an instruction, and click 'Run'.",
190
+ elem_classes="text-center"
191
+ )
192
+ with gr.Row():
193
+ with gr.Column(scale=2):
194
+ gr.Markdown("⬆️ Upload images. Click or drag to upload.")
195
+
196
+ with gr.Row():
197
+ image_uploader_1 = gr.Image(
198
+ label="Img 1",
199
+ type="filepath",
200
+ interactive=True,
201
+ elem_classes="input-img",
202
+ )
203
+ image_uploader_2 = gr.Image(
204
+ label="Img 2",
205
+ type="filepath",
206
+ interactive=True,
207
+ elem_classes="input-img",
208
+ )
209
+
210
+ instruction_text = gr.Textbox(
211
+ label="Instruction",
212
+ lines=2,
213
+ placeholder="Input your instruction for generation or editing here...",
214
+ )
215
+ run_button = gr.Button("Run", variant="primary")
216
+
217
+ with gr.Column(scale=2):
218
+ gr.Markdown("🖼️ **Generation Mode**: Create new scenes from reference images."
219
+ "Tip: If the result is not what you expect, try clicking **Run** again. ")
220
+ output_image = gr.Image(
221
+ label="Result",
222
+ type="filepath",
223
+ elem_classes="result-img",
224
+ )
225
+
226
+ # --- Examples ---
227
+ gr.Markdown("## Examples")
228
+ gr.Examples(
229
+ label="Generation Examples",
230
+ examples=[
231
+ [
232
+ "gen_tests/img1.jpg",
233
+ "gen_tests/img2.jpg",
234
+ "In the scene, the character from the first image stands on the left, and the character from the second image stands on the right. They are shaking hands against the backdrop of a spaceship interior.",
235
+ "gen_tests/gen_res.png"
236
+ ]
237
+ ],
238
+ inputs=[image_uploader_1, image_uploader_2, instruction_text, output_image],
239
+ cache_examples=False,
240
+ )
241
+
242
+ run_button.click(
243
+ fn=process_request,
244
+ inputs=[image_uploader_1, image_uploader_2, instruction_text],
245
+ outputs=output_image
246
+ )
247
+
248
+ if __name__ == "__main__":
249
+
250
+ print("Launching Gradio Demo...")
251
+ demo.launch(server_name="0.0.0.0", server_port=7861, )