File size: 9,936 Bytes
941ee5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import io, os, random
from pathlib import Path

import gradio as gr
import numpy as np
from PIL import Image

import torch
from torchvision import transforms

# --- Expect the user's u2net codebase available as a local module folder "u2net"
try:
    from u2net.model import U2NET, U2NETP
except Exception as e:
    raise RuntimeError(
        "Could not import 'u2net'. Please place the U^2-Net code folder named 'u2net' "
        "next to app.py (containing model.py, data_loader.py, ...)."
    ) from e

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}

# Cache for loaded models
_MODEL_CACHE = {"u2net": None, "u2netp": None}


def _pil_to_np(img: Image.Image) -> np.ndarray:
    return np.array(img.convert("RGB"), dtype=np.float32)


def _np_to_pil(arr: np.ndarray) -> Image.Image:
    return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8), mode="RGB")


def _minmax_norm(t: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    t_min = t.amin(dim=(-2, -1), keepdim=True)
    t_max = t.amax(dim=(-2, -1), keepdim=True)
    return (t - t_min) / (t_max - t_min + eps)


def _find_weight_file(model_type: str) -> Path:
    """
    model_type in {"u2net", "u2netp"}.
    Looks under: saved_models/{model_type}/{model_type}.pth (preferred)
    or first *.pth under that subfolder.
    """
    base = Path("saved_models").expanduser().resolve()
    sub = base / model_type
    preferred = sub / f"{model_type}.pth"
    if preferred.exists():
        return preferred
    # fallback: first .pth in subdir
    if sub.exists():
        for p in sorted(sub.glob("*.pth")):
            return p
    raise FileNotFoundError(
        f"Could not find weights for '{model_type}'. Expected at '{preferred}' or any .pth in '{sub}'."
    )


def load_u2net(model_type: str = "u2netp"):
    assert model_type in {"u2net", "u2netp"}
    if _MODEL_CACHE.get(model_type) is not None:
        return _MODEL_CACHE[model_type]

    weights_path = _find_weight_file(model_type)
    if model_type == "u2net":
        net = U2NET(3, 1)
    else:
        net = U2NETP(3, 1)

    state = torch.load(weights_path, map_location="cpu")
    net.load_state_dict(state)
    net.to(DEVICE)
    net.eval()
    _MODEL_CACHE[model_type] = net
    return net


def get_u2net_mask_with_local_weights(
    pil_img: Image.Image,
    model_type: str = "u2netp",
    resize_to: int = 320,
) -> Image.Image:
    """
    Single-image inference using user's local U^2-Net/U^2-NetP weights.
    Returns 8-bit 'L' mask resized back to original W,H.
    """
    W, H = pil_img.size
    net = load_u2net(model_type)

    tr = transforms.Compose([
        transforms.Resize((resize_to, resize_to), interpolation=Image.BILINEAR),
        transforms.ToTensor(),  # [0,1], CxHxW
    ])
    x = tr(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)  # 1x3x320x320

    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = net(x)
        pred = d1[:, 0, :, :]  # 1xHxW
        pred = _minmax_norm(pred)  # min-max normalize per-batch

    pred_np = pred.squeeze(0).detach().cpu().numpy()  # HxW, [0..1]
    mask_small = Image.fromarray((pred_np * 255).astype(np.uint8), mode="L")
    mask = mask_small.resize((W, H), resample=Image.BILINEAR)
    return mask


def jigsaw_shuffle_full_image(pil_img: Image.Image, N: int, seed: int) -> Image.Image:
    """
    Create a jigsaw-shuffled version of the input image by splitting into an N×N grid
    with *no overlap*, permuting patches uniformly at random, and reassembling.
    To keep uniform patch sizes, we center-crop the image to (H2,W2) divisible by N,
    then paste back to the original canvas.
    """
    random.seed(seed)
    np.random.seed(seed)

    W, H = pil_img.size
    # compute crop that is divisible by N
    H2 = (H // N) * N
    W2 = (W // N) * N
    pad_canvas = Image.fromarray(np.array(pil_img), mode="RGB")
    if H2 == 0 or W2 == 0:
        # too small; just return original
        return pil_img

    # center crop box
    y0 = (H - H2) // 2
    x0 = (W - W2) // 2
    crop = pil_img.crop((x0, y0, x0 + W2, y0 + H2))

    arr = np.array(crop).copy()
    out = np.empty_like(arr)

    ph = H2 // N
    pw = W2 // N

    # build coordinates
    coords = []
    for i in range(N):
        for j in range(N):
            y1 = i * ph
            x1 = j * pw
            coords.append((y1, y1 + ph, x1, x1 + pw))

    perm = np.random.permutation(len(coords))
    for dst_idx, src_idx in enumerate(perm):
        yd0, yd1, xd0, xd1 = coords[dst_idx]
        ys0, ys1, xs0, xs1 = coords[src_idx]
        out[yd0:yd1, xd0:xd1, :] = arr[ys0:ys1, xs0:xs1, :]

    # paste back into original canvas
    pad = np.array(pad_canvas)
    pad[y0:y0 + H2, x0:x0 + W2, :] = out
    return Image.fromarray(pad.astype(np.uint8), mode="RGB")


def add_noise_in_random_fg_box(base: np.ndarray, mask_hard: np.ndarray, sigma: float, seed: int) -> np.ndarray:
    """
    Add Gaussian noise only within a randomly selected rectangular region *inside the foreground*.
    If no foreground is found, no noise is added.
    """
    rng = np.random.default_rng(seed)
    H, W = mask_hard.shape
    ys, xs = np.where(mask_hard > 0.5)
    if len(ys) == 0:
        return base

    y_min, y_max = ys.min(), ys.max()
    x_min, x_max = xs.min(), xs.max()

    # choose box size as a fraction of the FG bbox (20% ~ 60% of width/height)
    box_h = max(1, int((y_max - y_min + 1) * float(rng.uniform(0.2, 0.6))))
    box_w = max(1, int((x_max - x_min + 1) * float(rng.uniform(0.2, 0.6))))

    # random top-left so that box fits within FG bbox
    y0 = int(rng.integers(y_min, max(y_min, y_max - box_h + 1) + 1))
    x0 = int(rng.integers(x_min, max(x_min, x_max - box_w + 1) + 1))

    # slice
    region_mask = mask_hard[y0:y0 + box_h, x0:x0 + box_w]
    if region_mask.size == 0:
        return base

    noise = rng.normal(0.0, sigma, size=(box_h, box_w, 3))
    base[y0:y0 + box_h, x0:x0 + box_w, :] += noise * region_mask[:, :, None]
    return base


def dual_region_augment_dra(
    img: Image.Image,
    grid_n: int = 8,
    fg_noise_std: float = 20.0,
    seed: int = 0,
    model_type: str = "u2netp",
):
    """
    DRA:
    - Background: jigsaw-shuffle full image on an N×N grid (no overlap), then use only on background.
    - Foreground: add Gaussian noise to a single randomly selected rectangular box (inside FG).
    - Fusion: FG from noisy image, BG from jigsaw image, using a hard U^2-Net mask.
    """
    random.seed(seed)
    np.random.seed(seed)

    base = _pil_to_np(img)  # (H, W, 3), float32 [0..255]
    H, W = base.shape[:2]

    # 1) Mask from local weights (no feather; use hard threshold at 0.5)
    raw_mask_L = get_u2net_mask_with_local_weights(img, model_type=model_type)
    mask = (np.array(raw_mask_L, dtype=np.float32) / 255.0) >= 0.5
    mask_hard = mask.astype(np.float32)  # (H,W) in {0,1}

    # 2) Foreground: noise in a random FG rectangle
    img_fg = base.copy()
    img_fg = add_noise_in_random_fg_box(img_fg, mask_hard, sigma=fg_noise_std, seed=seed)

    # 3) Background: jigsaw-shuffle full image on N×N grid
    jig = jigsaw_shuffle_full_image(Image.fromarray(base.astype(np.uint8)), N=grid_n, seed=seed)
    img_bg = _pil_to_np(jig)

    # 4) Fusion: BG where mask==0, FG where mask==1
    m3 = np.repeat(mask_hard[:, :, None], 3, axis=2)
    out = img_bg * (1.0 - m3) + img_fg * m3

    return _np_to_pil(out), raw_mask_L


# ---- Gradio UI ----
GRID_CHOICES = ["2x2", "4x4", "8x8", "16x16"]

def parse_grid_choice(s: str) -> int:
    try:
        n = int(s.lower().split('x')[0])
        return max(2, min(16, n))
    except Exception:
        return 8


def run_demo(
    image,
    grid_choice,
    fg_noise_std,
    seed,
    model_type,
):
    if image is None:
        raise gr.Error("Please upload an image or pick one from the examples.")
    n = parse_grid_choice(grid_choice)
    out_img, mask_L = dual_region_augment_dra(
        image,
        grid_n=n,
        fg_noise_std=fg_noise_std,
        seed=seed,
        model_type=model_type,
    )
    return out_img, mask_L


def list_example_images():
    ex_dir = Path("examples")
    ex_dir.mkdir(exist_ok=True)
    files = [
        str(p) for p in sorted(ex_dir.iterdir())
        if p.suffix.lower() in IMG_EXTS and p.is_file()
    ]
    return files if files else None


examples = list_example_images()

with gr.Blocks(title="Dual-Region Augmentation (DRA, Local U²-Net Weights)") as demo:
    gr.Markdown(
        "### Dual-Region Augmentation (DRA)\n"
        "- **Background**: random patch shuffle on an N×N grid (no overlap) in the background region.\n"
        "- **Foreground**: Gaussian noise in the foreground region.\n"
        "- **Mask**: U²-Net / U²-NetP"
    )
    with gr.Row():
        with gr.Column():
            in_img = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"])
            grid_choice = gr.Dropdown(GRID_CHOICES, value="8x8", label="Grid (number of patches)")
            noise_std = gr.Slider(0, 100, value=50, step=1, label="Foreground Noise σ")
            seed = gr.Slider(0, 9999, value=69, step=1, label="Seed")
            model_type = gr.Dropdown(choices=["u2netp", "u2net"], value="u2netp", label="Model Type")
            btn = gr.Button("Augment")
        with gr.Column():
            out_img = gr.Image(type="pil", label="Augmented Output")
            out_mask = gr.Image(type="pil", label="U²-Net Mask (preview)")

    btn.click(
        fn=run_demo,
        inputs=[in_img, grid_choice, noise_std, seed, model_type],
        outputs=[out_img, out_mask],
        concurrency_limit=3,
        api_name="augment",
    )

    if examples:
        gr.Examples(
            examples=examples,
            inputs=[in_img],
            examples_per_page=12,
            label="Examples (loaded from ./examples)"
        )

if __name__ == "__main__":
    demo.launch()