Spaces:
Sleeping
Sleeping
File size: 9,936 Bytes
941ee5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import io, os, random
from pathlib import Path
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
# --- Expect the user's u2net codebase available as a local module folder "u2net"
try:
from u2net.model import U2NET, U2NETP
except Exception as e:
raise RuntimeError(
"Could not import 'u2net'. Please place the U^2-Net code folder named 'u2net' "
"next to app.py (containing model.py, data_loader.py, ...)."
) from e
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
# Cache for loaded models
_MODEL_CACHE = {"u2net": None, "u2netp": None}
def _pil_to_np(img: Image.Image) -> np.ndarray:
return np.array(img.convert("RGB"), dtype=np.float32)
def _np_to_pil(arr: np.ndarray) -> Image.Image:
return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8), mode="RGB")
def _minmax_norm(t: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
t_min = t.amin(dim=(-2, -1), keepdim=True)
t_max = t.amax(dim=(-2, -1), keepdim=True)
return (t - t_min) / (t_max - t_min + eps)
def _find_weight_file(model_type: str) -> Path:
"""
model_type in {"u2net", "u2netp"}.
Looks under: saved_models/{model_type}/{model_type}.pth (preferred)
or first *.pth under that subfolder.
"""
base = Path("saved_models").expanduser().resolve()
sub = base / model_type
preferred = sub / f"{model_type}.pth"
if preferred.exists():
return preferred
# fallback: first .pth in subdir
if sub.exists():
for p in sorted(sub.glob("*.pth")):
return p
raise FileNotFoundError(
f"Could not find weights for '{model_type}'. Expected at '{preferred}' or any .pth in '{sub}'."
)
def load_u2net(model_type: str = "u2netp"):
assert model_type in {"u2net", "u2netp"}
if _MODEL_CACHE.get(model_type) is not None:
return _MODEL_CACHE[model_type]
weights_path = _find_weight_file(model_type)
if model_type == "u2net":
net = U2NET(3, 1)
else:
net = U2NETP(3, 1)
state = torch.load(weights_path, map_location="cpu")
net.load_state_dict(state)
net.to(DEVICE)
net.eval()
_MODEL_CACHE[model_type] = net
return net
def get_u2net_mask_with_local_weights(
pil_img: Image.Image,
model_type: str = "u2netp",
resize_to: int = 320,
) -> Image.Image:
"""
Single-image inference using user's local U^2-Net/U^2-NetP weights.
Returns 8-bit 'L' mask resized back to original W,H.
"""
W, H = pil_img.size
net = load_u2net(model_type)
tr = transforms.Compose([
transforms.Resize((resize_to, resize_to), interpolation=Image.BILINEAR),
transforms.ToTensor(), # [0,1], CxHxW
])
x = tr(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # 1x3x320x320
with torch.no_grad():
d1, d2, d3, d4, d5, d6, d7 = net(x)
pred = d1[:, 0, :, :] # 1xHxW
pred = _minmax_norm(pred) # min-max normalize per-batch
pred_np = pred.squeeze(0).detach().cpu().numpy() # HxW, [0..1]
mask_small = Image.fromarray((pred_np * 255).astype(np.uint8), mode="L")
mask = mask_small.resize((W, H), resample=Image.BILINEAR)
return mask
def jigsaw_shuffle_full_image(pil_img: Image.Image, N: int, seed: int) -> Image.Image:
"""
Create a jigsaw-shuffled version of the input image by splitting into an N×N grid
with *no overlap*, permuting patches uniformly at random, and reassembling.
To keep uniform patch sizes, we center-crop the image to (H2,W2) divisible by N,
then paste back to the original canvas.
"""
random.seed(seed)
np.random.seed(seed)
W, H = pil_img.size
# compute crop that is divisible by N
H2 = (H // N) * N
W2 = (W // N) * N
pad_canvas = Image.fromarray(np.array(pil_img), mode="RGB")
if H2 == 0 or W2 == 0:
# too small; just return original
return pil_img
# center crop box
y0 = (H - H2) // 2
x0 = (W - W2) // 2
crop = pil_img.crop((x0, y0, x0 + W2, y0 + H2))
arr = np.array(crop).copy()
out = np.empty_like(arr)
ph = H2 // N
pw = W2 // N
# build coordinates
coords = []
for i in range(N):
for j in range(N):
y1 = i * ph
x1 = j * pw
coords.append((y1, y1 + ph, x1, x1 + pw))
perm = np.random.permutation(len(coords))
for dst_idx, src_idx in enumerate(perm):
yd0, yd1, xd0, xd1 = coords[dst_idx]
ys0, ys1, xs0, xs1 = coords[src_idx]
out[yd0:yd1, xd0:xd1, :] = arr[ys0:ys1, xs0:xs1, :]
# paste back into original canvas
pad = np.array(pad_canvas)
pad[y0:y0 + H2, x0:x0 + W2, :] = out
return Image.fromarray(pad.astype(np.uint8), mode="RGB")
def add_noise_in_random_fg_box(base: np.ndarray, mask_hard: np.ndarray, sigma: float, seed: int) -> np.ndarray:
"""
Add Gaussian noise only within a randomly selected rectangular region *inside the foreground*.
If no foreground is found, no noise is added.
"""
rng = np.random.default_rng(seed)
H, W = mask_hard.shape
ys, xs = np.where(mask_hard > 0.5)
if len(ys) == 0:
return base
y_min, y_max = ys.min(), ys.max()
x_min, x_max = xs.min(), xs.max()
# choose box size as a fraction of the FG bbox (20% ~ 60% of width/height)
box_h = max(1, int((y_max - y_min + 1) * float(rng.uniform(0.2, 0.6))))
box_w = max(1, int((x_max - x_min + 1) * float(rng.uniform(0.2, 0.6))))
# random top-left so that box fits within FG bbox
y0 = int(rng.integers(y_min, max(y_min, y_max - box_h + 1) + 1))
x0 = int(rng.integers(x_min, max(x_min, x_max - box_w + 1) + 1))
# slice
region_mask = mask_hard[y0:y0 + box_h, x0:x0 + box_w]
if region_mask.size == 0:
return base
noise = rng.normal(0.0, sigma, size=(box_h, box_w, 3))
base[y0:y0 + box_h, x0:x0 + box_w, :] += noise * region_mask[:, :, None]
return base
def dual_region_augment_dra(
img: Image.Image,
grid_n: int = 8,
fg_noise_std: float = 20.0,
seed: int = 0,
model_type: str = "u2netp",
):
"""
DRA:
- Background: jigsaw-shuffle full image on an N×N grid (no overlap), then use only on background.
- Foreground: add Gaussian noise to a single randomly selected rectangular box (inside FG).
- Fusion: FG from noisy image, BG from jigsaw image, using a hard U^2-Net mask.
"""
random.seed(seed)
np.random.seed(seed)
base = _pil_to_np(img) # (H, W, 3), float32 [0..255]
H, W = base.shape[:2]
# 1) Mask from local weights (no feather; use hard threshold at 0.5)
raw_mask_L = get_u2net_mask_with_local_weights(img, model_type=model_type)
mask = (np.array(raw_mask_L, dtype=np.float32) / 255.0) >= 0.5
mask_hard = mask.astype(np.float32) # (H,W) in {0,1}
# 2) Foreground: noise in a random FG rectangle
img_fg = base.copy()
img_fg = add_noise_in_random_fg_box(img_fg, mask_hard, sigma=fg_noise_std, seed=seed)
# 3) Background: jigsaw-shuffle full image on N×N grid
jig = jigsaw_shuffle_full_image(Image.fromarray(base.astype(np.uint8)), N=grid_n, seed=seed)
img_bg = _pil_to_np(jig)
# 4) Fusion: BG where mask==0, FG where mask==1
m3 = np.repeat(mask_hard[:, :, None], 3, axis=2)
out = img_bg * (1.0 - m3) + img_fg * m3
return _np_to_pil(out), raw_mask_L
# ---- Gradio UI ----
GRID_CHOICES = ["2x2", "4x4", "8x8", "16x16"]
def parse_grid_choice(s: str) -> int:
try:
n = int(s.lower().split('x')[0])
return max(2, min(16, n))
except Exception:
return 8
def run_demo(
image,
grid_choice,
fg_noise_std,
seed,
model_type,
):
if image is None:
raise gr.Error("Please upload an image or pick one from the examples.")
n = parse_grid_choice(grid_choice)
out_img, mask_L = dual_region_augment_dra(
image,
grid_n=n,
fg_noise_std=fg_noise_std,
seed=seed,
model_type=model_type,
)
return out_img, mask_L
def list_example_images():
ex_dir = Path("examples")
ex_dir.mkdir(exist_ok=True)
files = [
str(p) for p in sorted(ex_dir.iterdir())
if p.suffix.lower() in IMG_EXTS and p.is_file()
]
return files if files else None
examples = list_example_images()
with gr.Blocks(title="Dual-Region Augmentation (DRA, Local U²-Net Weights)") as demo:
gr.Markdown(
"### Dual-Region Augmentation (DRA)\n"
"- **Background**: random patch shuffle on an N×N grid (no overlap) in the background region.\n"
"- **Foreground**: Gaussian noise in the foreground region.\n"
"- **Mask**: U²-Net / U²-NetP"
)
with gr.Row():
with gr.Column():
in_img = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"])
grid_choice = gr.Dropdown(GRID_CHOICES, value="8x8", label="Grid (number of patches)")
noise_std = gr.Slider(0, 100, value=50, step=1, label="Foreground Noise σ")
seed = gr.Slider(0, 9999, value=69, step=1, label="Seed")
model_type = gr.Dropdown(choices=["u2netp", "u2net"], value="u2netp", label="Model Type")
btn = gr.Button("Augment")
with gr.Column():
out_img = gr.Image(type="pil", label="Augmented Output")
out_mask = gr.Image(type="pil", label="U²-Net Mask (preview)")
btn.click(
fn=run_demo,
inputs=[in_img, grid_choice, noise_std, seed, model_type],
outputs=[out_img, out_mask],
concurrency_limit=3,
api_name="augment",
)
if examples:
gr.Examples(
examples=examples,
inputs=[in_img],
examples_per_page=12,
label="Examples (loaded from ./examples)"
)
if __name__ == "__main__":
demo.launch()
|