Spaces:
Sleeping
Sleeping
| import io, os, random | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| # --- Expect the user's u2net codebase available as a local module folder "u2net" | |
| try: | |
| from u2net.model import U2NET, U2NETP | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Could not import 'u2net'. Please place the U^2-Net code folder named 'u2net' " | |
| "next to app.py (containing model.py, data_loader.py, ...)." | |
| ) from e | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} | |
| # Cache for loaded models | |
| _MODEL_CACHE = {"u2net": None, "u2netp": None} | |
| def _pil_to_np(img: Image.Image) -> np.ndarray: | |
| return np.array(img.convert("RGB"), dtype=np.float32) | |
| def _np_to_pil(arr: np.ndarray) -> Image.Image: | |
| return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8), mode="RGB") | |
| def _minmax_norm(t: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: | |
| t_min = t.amin(dim=(-2, -1), keepdim=True) | |
| t_max = t.amax(dim=(-2, -1), keepdim=True) | |
| return (t - t_min) / (t_max - t_min + eps) | |
| def _find_weight_file(model_type: str) -> Path: | |
| """ | |
| model_type in {"u2net", "u2netp"}. | |
| Looks under: saved_models/{model_type}/{model_type}.pth (preferred) | |
| or first *.pth under that subfolder. | |
| """ | |
| base = Path("saved_models").expanduser().resolve() | |
| sub = base / model_type | |
| preferred = sub / f"{model_type}.pth" | |
| if preferred.exists(): | |
| return preferred | |
| # fallback: first .pth in subdir | |
| if sub.exists(): | |
| for p in sorted(sub.glob("*.pth")): | |
| return p | |
| raise FileNotFoundError( | |
| f"Could not find weights for '{model_type}'. Expected at '{preferred}' or any .pth in '{sub}'." | |
| ) | |
| def load_u2net(model_type: str = "u2netp"): | |
| assert model_type in {"u2net", "u2netp"} | |
| if _MODEL_CACHE.get(model_type) is not None: | |
| return _MODEL_CACHE[model_type] | |
| weights_path = _find_weight_file(model_type) | |
| if model_type == "u2net": | |
| net = U2NET(3, 1) | |
| else: | |
| net = U2NETP(3, 1) | |
| state = torch.load(weights_path, map_location="cpu") | |
| net.load_state_dict(state) | |
| net.to(DEVICE) | |
| net.eval() | |
| _MODEL_CACHE[model_type] = net | |
| return net | |
| def get_u2net_mask_with_local_weights( | |
| pil_img: Image.Image, | |
| model_type: str = "u2netp", | |
| resize_to: int = 320, | |
| ) -> Image.Image: | |
| """ | |
| Single-image inference using user's local U^2-Net/U^2-NetP weights. | |
| Returns 8-bit 'L' mask resized back to original W,H. | |
| """ | |
| W, H = pil_img.size | |
| net = load_u2net(model_type) | |
| tr = transforms.Compose([ | |
| transforms.Resize((resize_to, resize_to), interpolation=Image.BILINEAR), | |
| transforms.ToTensor(), # [0,1], CxHxW | |
| ]) | |
| x = tr(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # 1x3x320x320 | |
| with torch.no_grad(): | |
| d1, d2, d3, d4, d5, d6, d7 = net(x) | |
| pred = d1[:, 0, :, :] # 1xHxW | |
| pred = _minmax_norm(pred) # min-max normalize per-batch | |
| pred_np = pred.squeeze(0).detach().cpu().numpy() # HxW, [0..1] | |
| mask_small = Image.fromarray((pred_np * 255).astype(np.uint8), mode="L") | |
| mask = mask_small.resize((W, H), resample=Image.BILINEAR) | |
| return mask | |
| def jigsaw_shuffle_full_image(pil_img: Image.Image, N: int, seed: int) -> Image.Image: | |
| """ | |
| Create a jigsaw-shuffled version of the input image by splitting into an N×N grid | |
| with *no overlap*, permuting patches uniformly at random, and reassembling. | |
| To keep uniform patch sizes, we center-crop the image to (H2,W2) divisible by N, | |
| then paste back to the original canvas. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| W, H = pil_img.size | |
| # compute crop that is divisible by N | |
| H2 = (H // N) * N | |
| W2 = (W // N) * N | |
| pad_canvas = Image.fromarray(np.array(pil_img), mode="RGB") | |
| if H2 == 0 or W2 == 0: | |
| # too small; just return original | |
| return pil_img | |
| # center crop box | |
| y0 = (H - H2) // 2 | |
| x0 = (W - W2) // 2 | |
| crop = pil_img.crop((x0, y0, x0 + W2, y0 + H2)) | |
| arr = np.array(crop).copy() | |
| out = np.empty_like(arr) | |
| ph = H2 // N | |
| pw = W2 // N | |
| # build coordinates | |
| coords = [] | |
| for i in range(N): | |
| for j in range(N): | |
| y1 = i * ph | |
| x1 = j * pw | |
| coords.append((y1, y1 + ph, x1, x1 + pw)) | |
| perm = np.random.permutation(len(coords)) | |
| for dst_idx, src_idx in enumerate(perm): | |
| yd0, yd1, xd0, xd1 = coords[dst_idx] | |
| ys0, ys1, xs0, xs1 = coords[src_idx] | |
| out[yd0:yd1, xd0:xd1, :] = arr[ys0:ys1, xs0:xs1, :] | |
| # paste back into original canvas | |
| pad = np.array(pad_canvas) | |
| pad[y0:y0 + H2, x0:x0 + W2, :] = out | |
| return Image.fromarray(pad.astype(np.uint8), mode="RGB") | |
| def add_noise_in_random_fg_box(base: np.ndarray, mask_hard: np.ndarray, sigma: float, seed: int) -> np.ndarray: | |
| """ | |
| Add Gaussian noise only within a randomly selected rectangular region *inside the foreground*. | |
| If no foreground is found, no noise is added. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| H, W = mask_hard.shape | |
| ys, xs = np.where(mask_hard > 0.5) | |
| if len(ys) == 0: | |
| return base | |
| y_min, y_max = ys.min(), ys.max() | |
| x_min, x_max = xs.min(), xs.max() | |
| # choose box size as a fraction of the FG bbox (20% ~ 60% of width/height) | |
| box_h = max(1, int((y_max - y_min + 1) * float(rng.uniform(0.2, 0.6)))) | |
| box_w = max(1, int((x_max - x_min + 1) * float(rng.uniform(0.2, 0.6)))) | |
| # random top-left so that box fits within FG bbox | |
| y0 = int(rng.integers(y_min, max(y_min, y_max - box_h + 1) + 1)) | |
| x0 = int(rng.integers(x_min, max(x_min, x_max - box_w + 1) + 1)) | |
| # slice | |
| region_mask = mask_hard[y0:y0 + box_h, x0:x0 + box_w] | |
| if region_mask.size == 0: | |
| return base | |
| noise = rng.normal(0.0, sigma, size=(box_h, box_w, 3)) | |
| base[y0:y0 + box_h, x0:x0 + box_w, :] += noise * region_mask[:, :, None] | |
| return base | |
| def dual_region_augment_dra( | |
| img: Image.Image, | |
| grid_n: int = 8, | |
| fg_noise_std: float = 20.0, | |
| seed: int = 0, | |
| model_type: str = "u2netp", | |
| ): | |
| """ | |
| DRA: | |
| - Background: jigsaw-shuffle full image on an N×N grid (no overlap), then use only on background. | |
| - Foreground: add Gaussian noise to a single randomly selected rectangular box (inside FG). | |
| - Fusion: FG from noisy image, BG from jigsaw image, using a hard U^2-Net mask. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| base = _pil_to_np(img) # (H, W, 3), float32 [0..255] | |
| H, W = base.shape[:2] | |
| # 1) Mask from local weights (no feather; use hard threshold at 0.5) | |
| raw_mask_L = get_u2net_mask_with_local_weights(img, model_type=model_type) | |
| mask = (np.array(raw_mask_L, dtype=np.float32) / 255.0) >= 0.5 | |
| mask_hard = mask.astype(np.float32) # (H,W) in {0,1} | |
| # 2) Foreground: noise in a random FG rectangle | |
| img_fg = base.copy() | |
| img_fg = add_noise_in_random_fg_box(img_fg, mask_hard, sigma=fg_noise_std, seed=seed) | |
| # 3) Background: jigsaw-shuffle full image on N×N grid | |
| jig = jigsaw_shuffle_full_image(Image.fromarray(base.astype(np.uint8)), N=grid_n, seed=seed) | |
| img_bg = _pil_to_np(jig) | |
| # 4) Fusion: BG where mask==0, FG where mask==1 | |
| m3 = np.repeat(mask_hard[:, :, None], 3, axis=2) | |
| out = img_bg * (1.0 - m3) + img_fg * m3 | |
| return _np_to_pil(out), raw_mask_L | |
| # ---- Gradio UI ---- | |
| GRID_CHOICES = ["2x2", "4x4", "8x8", "16x16"] | |
| def parse_grid_choice(s: str) -> int: | |
| try: | |
| n = int(s.lower().split('x')[0]) | |
| return max(2, min(16, n)) | |
| except Exception: | |
| return 8 | |
| def run_demo( | |
| image, | |
| grid_choice, | |
| fg_noise_std, | |
| seed, | |
| model_type, | |
| ): | |
| if image is None: | |
| raise gr.Error("Please upload an image or pick one from the examples.") | |
| n = parse_grid_choice(grid_choice) | |
| out_img, mask_L = dual_region_augment_dra( | |
| image, | |
| grid_n=n, | |
| fg_noise_std=fg_noise_std, | |
| seed=seed, | |
| model_type=model_type, | |
| ) | |
| return out_img, mask_L | |
| def list_example_images(): | |
| ex_dir = Path("examples") | |
| ex_dir.mkdir(exist_ok=True) | |
| files = [ | |
| str(p) for p in sorted(ex_dir.iterdir()) | |
| if p.suffix.lower() in IMG_EXTS and p.is_file() | |
| ] | |
| return files if files else None | |
| examples = list_example_images() | |
| with gr.Blocks(title="Dual-Region Augmentation (DRA, Local U²-Net Weights)") as demo: | |
| gr.Markdown( | |
| "### Dual-Region Augmentation (DRA)\n" | |
| "- **Background**: random patch shuffle on an N×N grid (no overlap) in the background region.\n" | |
| "- **Foreground**: Gaussian noise in the foreground region.\n" | |
| "- **Mask**: U²-Net / U²-NetP" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_img = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]) | |
| grid_choice = gr.Dropdown(GRID_CHOICES, value="8x8", label="Grid (number of patches)") | |
| noise_std = gr.Slider(0, 100, value=50, step=1, label="Foreground Noise σ") | |
| seed = gr.Slider(0, 9999, value=69, step=1, label="Seed") | |
| model_type = gr.Dropdown(choices=["u2netp", "u2net"], value="u2netp", label="Model Type") | |
| btn = gr.Button("Augment") | |
| with gr.Column(): | |
| out_img = gr.Image(type="pil", label="Augmented Output") | |
| out_mask = gr.Image(type="pil", label="U²-Net Mask (preview)") | |
| btn.click( | |
| fn=run_demo, | |
| inputs=[in_img, grid_choice, noise_std, seed, model_type], | |
| outputs=[out_img, out_mask], | |
| concurrency_limit=3, | |
| api_name="augment", | |
| ) | |
| if examples: | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[in_img], | |
| examples_per_page=12, | |
| label="Examples (loaded from ./examples)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |