File size: 10,432 Bytes
7fdf0d6
736a5b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# server/localhosted models implementation (extended applications demo)
import torch
import lpips
import gradio as gr
import numpy as np
from PIL import Image
from dequantor import (
    StableDiffusion3Pipeline,
    GGUFQuantizationConfig,
    SD3Transformer2DModel,
    QwenImageEditPlusPipeline,
    AutoencoderKLQwenImage,
)
from transformers import (
    T5EncoderModel,
    Qwen2_5_VLForConditionalGeneration,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from nunchaku import (
    NunchakuQwenImageTransformer2DModel,
)
from gguf_connector.vrm import get_gpu_vram

def launch_app(model_path1,model_path,dtype):
    # image recognition model
    MODEL_ID = "callgg/fastvlm-0.5b-bf16"
    IMAGE_TOKEN_INDEX = -200
    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
        trust_remote_code=True,
    )
    def describe_image(img: Image.Image, prompt, num_tokens) -> str:
        if img is None:
            return "Please upload an image."
        messages = [{"role": "user", "content": f"<image>\n{prompt}."}]
        rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        pre, post = rendered.split("<image>", 1)
        pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
        post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
        img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
        input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
        attention_mask = torch.ones_like(input_ids, device=model.device)
        px = model.get_vision_tower().image_processor(images=img, return_tensors="pt")["pixel_values"]
        px = px.to(model.device, dtype=model.dtype)
        with torch.no_grad():
            out = model.generate(
                inputs=input_ids,
                attention_mask=attention_mask,
                images=px,
                max_new_tokens=num_tokens
            )
        return tok.decode(out[0], skip_special_tokens=True)
    sample1_prompts = ['describe this image in detail',
                    'describe what you see in few words',
                    'tell me the difference']
    sample1_prompts = [[x] for x in sample1_prompts]
    # image generation model
    transformer1 = SD3Transformer2DModel.from_single_file(
        model_path1,
        quantization_config=GGUFQuantizationConfig(compute_dtype=dtype),
        torch_dtype=dtype,
        config="callgg/sd3-decoder",
        subfolder="transformer_2"
    )
    text_encoder1 = T5EncoderModel.from_pretrained(
        "chatpig/t5-v1_1-xxl-encoder-fp32-gguf",
        gguf_file="t5xxl-encoder-fp32-q2_k.gguf",
        dtype=dtype
    )
    pipeline = StableDiffusion3Pipeline.from_pretrained(
        "callgg/sd3-decoder",
        transformer=transformer1,
        text_encoder_3=text_encoder1,
        torch_dtype=dtype
    )
    pipeline.enable_model_cpu_offload()
    # Inference function
    def generate_image2(prompt, num_steps, guidance):
        result = pipeline(
            prompt,
            height=1024,
            width=1024,
            num_inference_steps=num_steps,
            guidance_scale=guidance,
        ).images[0]
        return result
    sample_prompts2 = ['a cat in a hat',
                    'a pig in a hat',
                    'a raccoon in a hat',
                    'a dog walking with joy']
    sample_prompts2 = [[x] for x in sample_prompts2]
    # image transformation model
    transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
        model_path
    )
    text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "callgg/qi-decoder",
        subfolder="text_encoder",
        dtype=dtype
    )
    vae = AutoencoderKLQwenImage.from_pretrained(
        "callgg/qi-decoder",
        subfolder="vae",
        torch_dtype=dtype
    )
    pipe = QwenImageEditPlusPipeline.from_pretrained(
        "callgg/image-edit-plus",
        transformer=transformer,
        text_encoder=text_encoder,
        vae=vae,
        torch_dtype=dtype
    )
    if get_gpu_vram() > 18:
        pipe.enable_model_cpu_offload()
    else:
        transformer.set_offload(
            True, use_pin_memory=False, num_blocks_on_gpu=1
        )
        pipe._exclude_from_cpu_offload.append("transformer")
        pipe.enable_sequential_cpu_offload()
    def generate_image(prompt, img1, img2, img3, steps, guidance):
        images = []
        for img in [img1, img2, img3]:
            if img is not None:
                if not isinstance(img, Image.Image):
                    img = Image.open(img)
                images.append(img.convert("RGB"))
        if not images:
            return None
        inputs = {
            "image": images,
            "prompt": prompt,
            "true_cfg_scale": guidance,
            "negative_prompt": " ",
            "num_inference_steps": steps,
            "num_images_per_prompt": 1,
        }
        with torch.inference_mode():
            output = pipe(**inputs)
            return output.images[0]
    sample_prompts = ['merge it',
                    'color it',
                    'use image 1 as background of image 2']
    sample_prompts = [[x] for x in sample_prompts]
    # image discrimination model
    def compare_images(img1,img2):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        lpips_model = lpips.LPIPS(net='squeeze').to(device)
        if img1 is None or img2 is None:
            return "Please upload both images."
        img1_np = np.array(img1).astype(np.float32) / 255.0
        img2_np = np.array(img2).astype(np.float32) / 255.0
        # convert to tensor in LPIPS format
        img1_tensor = lpips.im2tensor(img1_np).to(device)
        img2_tensor = lpips.im2tensor(img2_np).to(device)
        # compute LPIPS distance
        with torch.no_grad():
            distance = lpips_model(img1_tensor, img2_tensor)
        score = distance.item()
        similarity = max(0.0, 1.0 - score*100)  # normalize to positive similarity
        result_text = (
            f"LPIPS Distance: {score:.4f}\n"
            f"Estimated Similarity: {similarity*100:.4f}%"
        )
        return result_text
    # UI
    block = gr.Blocks(title="image studio").queue()
    with block:
        gr.Markdown("## Discriminator")
        with gr.Row():
            img1 = gr.Image(type="pil", label="Image 1")
            img2 = gr.Image(type="pil", label="Image 2")
        compare_btn = gr.Button("Discriminate")
        output_box = gr.Textbox(label="Statistics", lines=2)
        compare_btn.click(compare_images, inputs=[img1,img2], outputs=output_box)
        gr.Markdown("## Descriptor")
        with gr.Row():
            with gr.Column():
                img_input = gr.Image(type="pil", label="Input Image")
                prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here (or click Sample Prompt)", value="")
                quick_prompts = gr.Dataset(samples=sample1_prompts, label='Sample Prompt', samples_per_page=1000, components=[prompt])
                quick_prompts.click(lambda x: x[0], inputs=[quick_prompts], outputs=prompt, show_progress=False, queue=False)
                btn = gr.Button("Describe")
                num_tokens = gr.Slider(minimum=64, maximum=1024, value=128, step=1, label="Output Token")
            with gr.Column():
                output = gr.Textbox(label="Description", lines=5)
        btn.click(fn=describe_image, inputs=[img_input,prompt,num_tokens], outputs=output)
        gr.Markdown("## Generator")
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here (or click Sample Prompt)", value="")
                quick_prompts = gr.Dataset(samples=sample_prompts2, label='Sample Prompt', samples_per_page=1000, components=[prompt])
                quick_prompts.click(lambda x: x[0], inputs=[quick_prompts], outputs=prompt, show_progress=False, queue=False)
                submit_btn = gr.Button("Generate")
                num_steps = gr.Slider(minimum=4, maximum=100, value=8, step=1, label="Step")
                guidance = gr.Slider(minimum=1.0, maximum=10.0, value=2.5, step=0.1, label="Scale")
            with gr.Column():
                output_image = gr.Image(type="pil", label="Output Image")
        submit_btn.click(fn=generate_image2, inputs=[prompt, num_steps, guidance], outputs=output_image)
        gr.Markdown("## Transformer")
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    img1 = gr.Image(label="Image 1", type="pil")
                    img2 = gr.Image(label="Image 2", type="pil")
                    img3 = gr.Image(label="Image 3", type="pil")
                prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here (or click Sample Prompt)", value="")
                quick_prompts = gr.Dataset(samples=sample_prompts, label='Sample Prompt', samples_per_page=1000, components=[prompt])
                quick_prompts.click(lambda x: x[0], inputs=[quick_prompts], outputs=prompt, show_progress=False, queue=False)
                generate_btn = gr.Button("Transform")
                steps = gr.Slider(1, 50, value=4, step=1, label="Inference Steps", visible=False)
                guidance = gr.Slider(0.1, 10.0, value=1.0, step=0.1, label="Guidance Scale", visible=False)
            with gr.Column():
                output_image = gr.Image(label="Output", type="pil")
        generate_btn.click(
            fn=generate_image,
            inputs=[prompt, img1, img2, img3, steps, guidance],
            outputs=output_image,
        )
    block.launch()

# detect your device and assign dtype accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32

# load the model from cache; or pull it from huggingface repo if you don't have
model_path1 = "https://huggingface.co/calcuis/sd3.5-lite-gguf/blob/main/sd3.5-8b-lite-mxfp4_moe.gguf"
model_path = "https://huggingface.co/calcuis/sketch/blob/main/sketch-s9-20b-int4.safetensors"

# launch the app; call the app function above
launch_app(model_path1, model_path, dtype)