DR-Augmentation / app.py
prasannareddyp's picture
Upload 10 files
941ee5b verified
import io, os, random
from pathlib import Path
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
# --- Expect the user's u2net codebase available as a local module folder "u2net"
try:
from u2net.model import U2NET, U2NETP
except Exception as e:
raise RuntimeError(
"Could not import 'u2net'. Please place the U^2-Net code folder named 'u2net' "
"next to app.py (containing model.py, data_loader.py, ...)."
) from e
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
# Cache for loaded models
_MODEL_CACHE = {"u2net": None, "u2netp": None}
def _pil_to_np(img: Image.Image) -> np.ndarray:
return np.array(img.convert("RGB"), dtype=np.float32)
def _np_to_pil(arr: np.ndarray) -> Image.Image:
return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8), mode="RGB")
def _minmax_norm(t: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
t_min = t.amin(dim=(-2, -1), keepdim=True)
t_max = t.amax(dim=(-2, -1), keepdim=True)
return (t - t_min) / (t_max - t_min + eps)
def _find_weight_file(model_type: str) -> Path:
"""
model_type in {"u2net", "u2netp"}.
Looks under: saved_models/{model_type}/{model_type}.pth (preferred)
or first *.pth under that subfolder.
"""
base = Path("saved_models").expanduser().resolve()
sub = base / model_type
preferred = sub / f"{model_type}.pth"
if preferred.exists():
return preferred
# fallback: first .pth in subdir
if sub.exists():
for p in sorted(sub.glob("*.pth")):
return p
raise FileNotFoundError(
f"Could not find weights for '{model_type}'. Expected at '{preferred}' or any .pth in '{sub}'."
)
def load_u2net(model_type: str = "u2netp"):
assert model_type in {"u2net", "u2netp"}
if _MODEL_CACHE.get(model_type) is not None:
return _MODEL_CACHE[model_type]
weights_path = _find_weight_file(model_type)
if model_type == "u2net":
net = U2NET(3, 1)
else:
net = U2NETP(3, 1)
state = torch.load(weights_path, map_location="cpu")
net.load_state_dict(state)
net.to(DEVICE)
net.eval()
_MODEL_CACHE[model_type] = net
return net
def get_u2net_mask_with_local_weights(
pil_img: Image.Image,
model_type: str = "u2netp",
resize_to: int = 320,
) -> Image.Image:
"""
Single-image inference using user's local U^2-Net/U^2-NetP weights.
Returns 8-bit 'L' mask resized back to original W,H.
"""
W, H = pil_img.size
net = load_u2net(model_type)
tr = transforms.Compose([
transforms.Resize((resize_to, resize_to), interpolation=Image.BILINEAR),
transforms.ToTensor(), # [0,1], CxHxW
])
x = tr(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # 1x3x320x320
with torch.no_grad():
d1, d2, d3, d4, d5, d6, d7 = net(x)
pred = d1[:, 0, :, :] # 1xHxW
pred = _minmax_norm(pred) # min-max normalize per-batch
pred_np = pred.squeeze(0).detach().cpu().numpy() # HxW, [0..1]
mask_small = Image.fromarray((pred_np * 255).astype(np.uint8), mode="L")
mask = mask_small.resize((W, H), resample=Image.BILINEAR)
return mask
def jigsaw_shuffle_full_image(pil_img: Image.Image, N: int, seed: int) -> Image.Image:
"""
Create a jigsaw-shuffled version of the input image by splitting into an N×N grid
with *no overlap*, permuting patches uniformly at random, and reassembling.
To keep uniform patch sizes, we center-crop the image to (H2,W2) divisible by N,
then paste back to the original canvas.
"""
random.seed(seed)
np.random.seed(seed)
W, H = pil_img.size
# compute crop that is divisible by N
H2 = (H // N) * N
W2 = (W // N) * N
pad_canvas = Image.fromarray(np.array(pil_img), mode="RGB")
if H2 == 0 or W2 == 0:
# too small; just return original
return pil_img
# center crop box
y0 = (H - H2) // 2
x0 = (W - W2) // 2
crop = pil_img.crop((x0, y0, x0 + W2, y0 + H2))
arr = np.array(crop).copy()
out = np.empty_like(arr)
ph = H2 // N
pw = W2 // N
# build coordinates
coords = []
for i in range(N):
for j in range(N):
y1 = i * ph
x1 = j * pw
coords.append((y1, y1 + ph, x1, x1 + pw))
perm = np.random.permutation(len(coords))
for dst_idx, src_idx in enumerate(perm):
yd0, yd1, xd0, xd1 = coords[dst_idx]
ys0, ys1, xs0, xs1 = coords[src_idx]
out[yd0:yd1, xd0:xd1, :] = arr[ys0:ys1, xs0:xs1, :]
# paste back into original canvas
pad = np.array(pad_canvas)
pad[y0:y0 + H2, x0:x0 + W2, :] = out
return Image.fromarray(pad.astype(np.uint8), mode="RGB")
def add_noise_in_random_fg_box(base: np.ndarray, mask_hard: np.ndarray, sigma: float, seed: int) -> np.ndarray:
"""
Add Gaussian noise only within a randomly selected rectangular region *inside the foreground*.
If no foreground is found, no noise is added.
"""
rng = np.random.default_rng(seed)
H, W = mask_hard.shape
ys, xs = np.where(mask_hard > 0.5)
if len(ys) == 0:
return base
y_min, y_max = ys.min(), ys.max()
x_min, x_max = xs.min(), xs.max()
# choose box size as a fraction of the FG bbox (20% ~ 60% of width/height)
box_h = max(1, int((y_max - y_min + 1) * float(rng.uniform(0.2, 0.6))))
box_w = max(1, int((x_max - x_min + 1) * float(rng.uniform(0.2, 0.6))))
# random top-left so that box fits within FG bbox
y0 = int(rng.integers(y_min, max(y_min, y_max - box_h + 1) + 1))
x0 = int(rng.integers(x_min, max(x_min, x_max - box_w + 1) + 1))
# slice
region_mask = mask_hard[y0:y0 + box_h, x0:x0 + box_w]
if region_mask.size == 0:
return base
noise = rng.normal(0.0, sigma, size=(box_h, box_w, 3))
base[y0:y0 + box_h, x0:x0 + box_w, :] += noise * region_mask[:, :, None]
return base
def dual_region_augment_dra(
img: Image.Image,
grid_n: int = 8,
fg_noise_std: float = 20.0,
seed: int = 0,
model_type: str = "u2netp",
):
"""
DRA:
- Background: jigsaw-shuffle full image on an N×N grid (no overlap), then use only on background.
- Foreground: add Gaussian noise to a single randomly selected rectangular box (inside FG).
- Fusion: FG from noisy image, BG from jigsaw image, using a hard U^2-Net mask.
"""
random.seed(seed)
np.random.seed(seed)
base = _pil_to_np(img) # (H, W, 3), float32 [0..255]
H, W = base.shape[:2]
# 1) Mask from local weights (no feather; use hard threshold at 0.5)
raw_mask_L = get_u2net_mask_with_local_weights(img, model_type=model_type)
mask = (np.array(raw_mask_L, dtype=np.float32) / 255.0) >= 0.5
mask_hard = mask.astype(np.float32) # (H,W) in {0,1}
# 2) Foreground: noise in a random FG rectangle
img_fg = base.copy()
img_fg = add_noise_in_random_fg_box(img_fg, mask_hard, sigma=fg_noise_std, seed=seed)
# 3) Background: jigsaw-shuffle full image on N×N grid
jig = jigsaw_shuffle_full_image(Image.fromarray(base.astype(np.uint8)), N=grid_n, seed=seed)
img_bg = _pil_to_np(jig)
# 4) Fusion: BG where mask==0, FG where mask==1
m3 = np.repeat(mask_hard[:, :, None], 3, axis=2)
out = img_bg * (1.0 - m3) + img_fg * m3
return _np_to_pil(out), raw_mask_L
# ---- Gradio UI ----
GRID_CHOICES = ["2x2", "4x4", "8x8", "16x16"]
def parse_grid_choice(s: str) -> int:
try:
n = int(s.lower().split('x')[0])
return max(2, min(16, n))
except Exception:
return 8
def run_demo(
image,
grid_choice,
fg_noise_std,
seed,
model_type,
):
if image is None:
raise gr.Error("Please upload an image or pick one from the examples.")
n = parse_grid_choice(grid_choice)
out_img, mask_L = dual_region_augment_dra(
image,
grid_n=n,
fg_noise_std=fg_noise_std,
seed=seed,
model_type=model_type,
)
return out_img, mask_L
def list_example_images():
ex_dir = Path("examples")
ex_dir.mkdir(exist_ok=True)
files = [
str(p) for p in sorted(ex_dir.iterdir())
if p.suffix.lower() in IMG_EXTS and p.is_file()
]
return files if files else None
examples = list_example_images()
with gr.Blocks(title="Dual-Region Augmentation (DRA, Local U²-Net Weights)") as demo:
gr.Markdown(
"### Dual-Region Augmentation (DRA)\n"
"- **Background**: random patch shuffle on an N×N grid (no overlap) in the background region.\n"
"- **Foreground**: Gaussian noise in the foreground region.\n"
"- **Mask**: U²-Net / U²-NetP"
)
with gr.Row():
with gr.Column():
in_img = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"])
grid_choice = gr.Dropdown(GRID_CHOICES, value="8x8", label="Grid (number of patches)")
noise_std = gr.Slider(0, 100, value=50, step=1, label="Foreground Noise σ")
seed = gr.Slider(0, 9999, value=69, step=1, label="Seed")
model_type = gr.Dropdown(choices=["u2netp", "u2net"], value="u2netp", label="Model Type")
btn = gr.Button("Augment")
with gr.Column():
out_img = gr.Image(type="pil", label="Augmented Output")
out_mask = gr.Image(type="pil", label="U²-Net Mask (preview)")
btn.click(
fn=run_demo,
inputs=[in_img, grid_choice, noise_std, seed, model_type],
outputs=[out_img, out_mask],
concurrency_limit=3,
api_name="augment",
)
if examples:
gr.Examples(
examples=examples,
inputs=[in_img],
examples_per_page=12,
label="Examples (loaded from ./examples)"
)
if __name__ == "__main__":
demo.launch()