import io, os, random from pathlib import Path import gradio as gr import numpy as np from PIL import Image import torch from torchvision import transforms # --- Expect the user's u2net codebase available as a local module folder "u2net" try: from u2net.model import U2NET, U2NETP except Exception as e: raise RuntimeError( "Could not import 'u2net'. Please place the U^2-Net code folder named 'u2net' " "next to app.py (containing model.py, data_loader.py, ...)." ) from e DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} # Cache for loaded models _MODEL_CACHE = {"u2net": None, "u2netp": None} def _pil_to_np(img: Image.Image) -> np.ndarray: return np.array(img.convert("RGB"), dtype=np.float32) def _np_to_pil(arr: np.ndarray) -> Image.Image: return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8), mode="RGB") def _minmax_norm(t: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: t_min = t.amin(dim=(-2, -1), keepdim=True) t_max = t.amax(dim=(-2, -1), keepdim=True) return (t - t_min) / (t_max - t_min + eps) def _find_weight_file(model_type: str) -> Path: """ model_type in {"u2net", "u2netp"}. Looks under: saved_models/{model_type}/{model_type}.pth (preferred) or first *.pth under that subfolder. """ base = Path("saved_models").expanduser().resolve() sub = base / model_type preferred = sub / f"{model_type}.pth" if preferred.exists(): return preferred # fallback: first .pth in subdir if sub.exists(): for p in sorted(sub.glob("*.pth")): return p raise FileNotFoundError( f"Could not find weights for '{model_type}'. Expected at '{preferred}' or any .pth in '{sub}'." ) def load_u2net(model_type: str = "u2netp"): assert model_type in {"u2net", "u2netp"} if _MODEL_CACHE.get(model_type) is not None: return _MODEL_CACHE[model_type] weights_path = _find_weight_file(model_type) if model_type == "u2net": net = U2NET(3, 1) else: net = U2NETP(3, 1) state = torch.load(weights_path, map_location="cpu") net.load_state_dict(state) net.to(DEVICE) net.eval() _MODEL_CACHE[model_type] = net return net def get_u2net_mask_with_local_weights( pil_img: Image.Image, model_type: str = "u2netp", resize_to: int = 320, ) -> Image.Image: """ Single-image inference using user's local U^2-Net/U^2-NetP weights. Returns 8-bit 'L' mask resized back to original W,H. """ W, H = pil_img.size net = load_u2net(model_type) tr = transforms.Compose([ transforms.Resize((resize_to, resize_to), interpolation=Image.BILINEAR), transforms.ToTensor(), # [0,1], CxHxW ]) x = tr(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # 1x3x320x320 with torch.no_grad(): d1, d2, d3, d4, d5, d6, d7 = net(x) pred = d1[:, 0, :, :] # 1xHxW pred = _minmax_norm(pred) # min-max normalize per-batch pred_np = pred.squeeze(0).detach().cpu().numpy() # HxW, [0..1] mask_small = Image.fromarray((pred_np * 255).astype(np.uint8), mode="L") mask = mask_small.resize((W, H), resample=Image.BILINEAR) return mask def jigsaw_shuffle_full_image(pil_img: Image.Image, N: int, seed: int) -> Image.Image: """ Create a jigsaw-shuffled version of the input image by splitting into an N×N grid with *no overlap*, permuting patches uniformly at random, and reassembling. To keep uniform patch sizes, we center-crop the image to (H2,W2) divisible by N, then paste back to the original canvas. """ random.seed(seed) np.random.seed(seed) W, H = pil_img.size # compute crop that is divisible by N H2 = (H // N) * N W2 = (W // N) * N pad_canvas = Image.fromarray(np.array(pil_img), mode="RGB") if H2 == 0 or W2 == 0: # too small; just return original return pil_img # center crop box y0 = (H - H2) // 2 x0 = (W - W2) // 2 crop = pil_img.crop((x0, y0, x0 + W2, y0 + H2)) arr = np.array(crop).copy() out = np.empty_like(arr) ph = H2 // N pw = W2 // N # build coordinates coords = [] for i in range(N): for j in range(N): y1 = i * ph x1 = j * pw coords.append((y1, y1 + ph, x1, x1 + pw)) perm = np.random.permutation(len(coords)) for dst_idx, src_idx in enumerate(perm): yd0, yd1, xd0, xd1 = coords[dst_idx] ys0, ys1, xs0, xs1 = coords[src_idx] out[yd0:yd1, xd0:xd1, :] = arr[ys0:ys1, xs0:xs1, :] # paste back into original canvas pad = np.array(pad_canvas) pad[y0:y0 + H2, x0:x0 + W2, :] = out return Image.fromarray(pad.astype(np.uint8), mode="RGB") def add_noise_in_random_fg_box(base: np.ndarray, mask_hard: np.ndarray, sigma: float, seed: int) -> np.ndarray: """ Add Gaussian noise only within a randomly selected rectangular region *inside the foreground*. If no foreground is found, no noise is added. """ rng = np.random.default_rng(seed) H, W = mask_hard.shape ys, xs = np.where(mask_hard > 0.5) if len(ys) == 0: return base y_min, y_max = ys.min(), ys.max() x_min, x_max = xs.min(), xs.max() # choose box size as a fraction of the FG bbox (20% ~ 60% of width/height) box_h = max(1, int((y_max - y_min + 1) * float(rng.uniform(0.2, 0.6)))) box_w = max(1, int((x_max - x_min + 1) * float(rng.uniform(0.2, 0.6)))) # random top-left so that box fits within FG bbox y0 = int(rng.integers(y_min, max(y_min, y_max - box_h + 1) + 1)) x0 = int(rng.integers(x_min, max(x_min, x_max - box_w + 1) + 1)) # slice region_mask = mask_hard[y0:y0 + box_h, x0:x0 + box_w] if region_mask.size == 0: return base noise = rng.normal(0.0, sigma, size=(box_h, box_w, 3)) base[y0:y0 + box_h, x0:x0 + box_w, :] += noise * region_mask[:, :, None] return base def dual_region_augment_dra( img: Image.Image, grid_n: int = 8, fg_noise_std: float = 20.0, seed: int = 0, model_type: str = "u2netp", ): """ DRA: - Background: jigsaw-shuffle full image on an N×N grid (no overlap), then use only on background. - Foreground: add Gaussian noise to a single randomly selected rectangular box (inside FG). - Fusion: FG from noisy image, BG from jigsaw image, using a hard U^2-Net mask. """ random.seed(seed) np.random.seed(seed) base = _pil_to_np(img) # (H, W, 3), float32 [0..255] H, W = base.shape[:2] # 1) Mask from local weights (no feather; use hard threshold at 0.5) raw_mask_L = get_u2net_mask_with_local_weights(img, model_type=model_type) mask = (np.array(raw_mask_L, dtype=np.float32) / 255.0) >= 0.5 mask_hard = mask.astype(np.float32) # (H,W) in {0,1} # 2) Foreground: noise in a random FG rectangle img_fg = base.copy() img_fg = add_noise_in_random_fg_box(img_fg, mask_hard, sigma=fg_noise_std, seed=seed) # 3) Background: jigsaw-shuffle full image on N×N grid jig = jigsaw_shuffle_full_image(Image.fromarray(base.astype(np.uint8)), N=grid_n, seed=seed) img_bg = _pil_to_np(jig) # 4) Fusion: BG where mask==0, FG where mask==1 m3 = np.repeat(mask_hard[:, :, None], 3, axis=2) out = img_bg * (1.0 - m3) + img_fg * m3 return _np_to_pil(out), raw_mask_L # ---- Gradio UI ---- GRID_CHOICES = ["2x2", "4x4", "8x8", "16x16"] def parse_grid_choice(s: str) -> int: try: n = int(s.lower().split('x')[0]) return max(2, min(16, n)) except Exception: return 8 def run_demo( image, grid_choice, fg_noise_std, seed, model_type, ): if image is None: raise gr.Error("Please upload an image or pick one from the examples.") n = parse_grid_choice(grid_choice) out_img, mask_L = dual_region_augment_dra( image, grid_n=n, fg_noise_std=fg_noise_std, seed=seed, model_type=model_type, ) return out_img, mask_L def list_example_images(): ex_dir = Path("examples") ex_dir.mkdir(exist_ok=True) files = [ str(p) for p in sorted(ex_dir.iterdir()) if p.suffix.lower() in IMG_EXTS and p.is_file() ] return files if files else None examples = list_example_images() with gr.Blocks(title="Dual-Region Augmentation (DRA, Local U²-Net Weights)") as demo: gr.Markdown( "### Dual-Region Augmentation (DRA)\n" "- **Background**: random patch shuffle on an N×N grid (no overlap) in the background region.\n" "- **Foreground**: Gaussian noise in the foreground region.\n" "- **Mask**: U²-Net / U²-NetP" ) with gr.Row(): with gr.Column(): in_img = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]) grid_choice = gr.Dropdown(GRID_CHOICES, value="8x8", label="Grid (number of patches)") noise_std = gr.Slider(0, 100, value=50, step=1, label="Foreground Noise σ") seed = gr.Slider(0, 9999, value=69, step=1, label="Seed") model_type = gr.Dropdown(choices=["u2netp", "u2net"], value="u2netp", label="Model Type") btn = gr.Button("Augment") with gr.Column(): out_img = gr.Image(type="pil", label="Augmented Output") out_mask = gr.Image(type="pil", label="U²-Net Mask (preview)") btn.click( fn=run_demo, inputs=[in_img, grid_choice, noise_std, seed, model_type], outputs=[out_img, out_mask], concurrency_limit=3, api_name="augment", ) if examples: gr.Examples( examples=examples, inputs=[in_img], examples_per_page=12, label="Examples (loaded from ./examples)" ) if __name__ == "__main__": demo.launch()