prasannareddyp commited on
Commit
941ee5b
·
verified ·
1 Parent(s): 910ea5a

Upload 10 files

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Prasanna Reddy Pulakurthi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,25 @@
1
  ---
2
- title: DR Augmentation
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: 'Dual-Region Foreground-Background Augmentation '
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Dual-Region Augmentation (DRA, Local U²-Net)
3
+ emoji: ✂️
4
+ colorFrom: gray
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.22.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: An interactive demo for DRA augmentation.
12
  ---
13
 
14
+ ### Dual-Region Foreground-Background Augmentation (DRA)
15
+
16
+ An interactive demo for DRA augmentation.
17
+
18
+ GitHub repo: https://github.com/PrasannaPulakurthi/Foreground-Background-Augmentation
19
+
20
+ arxiv.org/abs/2504.13077
21
+
22
+ - **Grid control**: choose number of patches (2×2, 4×4, 8×8, 16×16). No overlap.
23
+ - **Background**: jigsaw shuffle on the full image, fused only into background via U²‑Net mask.
24
+ - **Foreground**: Gaussian noise applied to **one random rectangular box** inside the foreground.
25
+ - **Weights**: uses local `saved_models/u2net(.pth)` or `saved_models/u2netp(.pth)`.
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, os, random
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ import torch
9
+ from torchvision import transforms
10
+
11
+ # --- Expect the user's u2net codebase available as a local module folder "u2net"
12
+ try:
13
+ from u2net.model import U2NET, U2NETP
14
+ except Exception as e:
15
+ raise RuntimeError(
16
+ "Could not import 'u2net'. Please place the U^2-Net code folder named 'u2net' "
17
+ "next to app.py (containing model.py, data_loader.py, ...)."
18
+ ) from e
19
+
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
22
+
23
+ # Cache for loaded models
24
+ _MODEL_CACHE = {"u2net": None, "u2netp": None}
25
+
26
+
27
+ def _pil_to_np(img: Image.Image) -> np.ndarray:
28
+ return np.array(img.convert("RGB"), dtype=np.float32)
29
+
30
+
31
+ def _np_to_pil(arr: np.ndarray) -> Image.Image:
32
+ return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8), mode="RGB")
33
+
34
+
35
+ def _minmax_norm(t: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
36
+ t_min = t.amin(dim=(-2, -1), keepdim=True)
37
+ t_max = t.amax(dim=(-2, -1), keepdim=True)
38
+ return (t - t_min) / (t_max - t_min + eps)
39
+
40
+
41
+ def _find_weight_file(model_type: str) -> Path:
42
+ """
43
+ model_type in {"u2net", "u2netp"}.
44
+ Looks under: saved_models/{model_type}/{model_type}.pth (preferred)
45
+ or first *.pth under that subfolder.
46
+ """
47
+ base = Path("saved_models").expanduser().resolve()
48
+ sub = base / model_type
49
+ preferred = sub / f"{model_type}.pth"
50
+ if preferred.exists():
51
+ return preferred
52
+ # fallback: first .pth in subdir
53
+ if sub.exists():
54
+ for p in sorted(sub.glob("*.pth")):
55
+ return p
56
+ raise FileNotFoundError(
57
+ f"Could not find weights for '{model_type}'. Expected at '{preferred}' or any .pth in '{sub}'."
58
+ )
59
+
60
+
61
+ def load_u2net(model_type: str = "u2netp"):
62
+ assert model_type in {"u2net", "u2netp"}
63
+ if _MODEL_CACHE.get(model_type) is not None:
64
+ return _MODEL_CACHE[model_type]
65
+
66
+ weights_path = _find_weight_file(model_type)
67
+ if model_type == "u2net":
68
+ net = U2NET(3, 1)
69
+ else:
70
+ net = U2NETP(3, 1)
71
+
72
+ state = torch.load(weights_path, map_location="cpu")
73
+ net.load_state_dict(state)
74
+ net.to(DEVICE)
75
+ net.eval()
76
+ _MODEL_CACHE[model_type] = net
77
+ return net
78
+
79
+
80
+ def get_u2net_mask_with_local_weights(
81
+ pil_img: Image.Image,
82
+ model_type: str = "u2netp",
83
+ resize_to: int = 320,
84
+ ) -> Image.Image:
85
+ """
86
+ Single-image inference using user's local U^2-Net/U^2-NetP weights.
87
+ Returns 8-bit 'L' mask resized back to original W,H.
88
+ """
89
+ W, H = pil_img.size
90
+ net = load_u2net(model_type)
91
+
92
+ tr = transforms.Compose([
93
+ transforms.Resize((resize_to, resize_to), interpolation=Image.BILINEAR),
94
+ transforms.ToTensor(), # [0,1], CxHxW
95
+ ])
96
+ x = tr(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # 1x3x320x320
97
+
98
+ with torch.no_grad():
99
+ d1, d2, d3, d4, d5, d6, d7 = net(x)
100
+ pred = d1[:, 0, :, :] # 1xHxW
101
+ pred = _minmax_norm(pred) # min-max normalize per-batch
102
+
103
+ pred_np = pred.squeeze(0).detach().cpu().numpy() # HxW, [0..1]
104
+ mask_small = Image.fromarray((pred_np * 255).astype(np.uint8), mode="L")
105
+ mask = mask_small.resize((W, H), resample=Image.BILINEAR)
106
+ return mask
107
+
108
+
109
+ def jigsaw_shuffle_full_image(pil_img: Image.Image, N: int, seed: int) -> Image.Image:
110
+ """
111
+ Create a jigsaw-shuffled version of the input image by splitting into an N×N grid
112
+ with *no overlap*, permuting patches uniformly at random, and reassembling.
113
+ To keep uniform patch sizes, we center-crop the image to (H2,W2) divisible by N,
114
+ then paste back to the original canvas.
115
+ """
116
+ random.seed(seed)
117
+ np.random.seed(seed)
118
+
119
+ W, H = pil_img.size
120
+ # compute crop that is divisible by N
121
+ H2 = (H // N) * N
122
+ W2 = (W // N) * N
123
+ pad_canvas = Image.fromarray(np.array(pil_img), mode="RGB")
124
+ if H2 == 0 or W2 == 0:
125
+ # too small; just return original
126
+ return pil_img
127
+
128
+ # center crop box
129
+ y0 = (H - H2) // 2
130
+ x0 = (W - W2) // 2
131
+ crop = pil_img.crop((x0, y0, x0 + W2, y0 + H2))
132
+
133
+ arr = np.array(crop).copy()
134
+ out = np.empty_like(arr)
135
+
136
+ ph = H2 // N
137
+ pw = W2 // N
138
+
139
+ # build coordinates
140
+ coords = []
141
+ for i in range(N):
142
+ for j in range(N):
143
+ y1 = i * ph
144
+ x1 = j * pw
145
+ coords.append((y1, y1 + ph, x1, x1 + pw))
146
+
147
+ perm = np.random.permutation(len(coords))
148
+ for dst_idx, src_idx in enumerate(perm):
149
+ yd0, yd1, xd0, xd1 = coords[dst_idx]
150
+ ys0, ys1, xs0, xs1 = coords[src_idx]
151
+ out[yd0:yd1, xd0:xd1, :] = arr[ys0:ys1, xs0:xs1, :]
152
+
153
+ # paste back into original canvas
154
+ pad = np.array(pad_canvas)
155
+ pad[y0:y0 + H2, x0:x0 + W2, :] = out
156
+ return Image.fromarray(pad.astype(np.uint8), mode="RGB")
157
+
158
+
159
+ def add_noise_in_random_fg_box(base: np.ndarray, mask_hard: np.ndarray, sigma: float, seed: int) -> np.ndarray:
160
+ """
161
+ Add Gaussian noise only within a randomly selected rectangular region *inside the foreground*.
162
+ If no foreground is found, no noise is added.
163
+ """
164
+ rng = np.random.default_rng(seed)
165
+ H, W = mask_hard.shape
166
+ ys, xs = np.where(mask_hard > 0.5)
167
+ if len(ys) == 0:
168
+ return base
169
+
170
+ y_min, y_max = ys.min(), ys.max()
171
+ x_min, x_max = xs.min(), xs.max()
172
+
173
+ # choose box size as a fraction of the FG bbox (20% ~ 60% of width/height)
174
+ box_h = max(1, int((y_max - y_min + 1) * float(rng.uniform(0.2, 0.6))))
175
+ box_w = max(1, int((x_max - x_min + 1) * float(rng.uniform(0.2, 0.6))))
176
+
177
+ # random top-left so that box fits within FG bbox
178
+ y0 = int(rng.integers(y_min, max(y_min, y_max - box_h + 1) + 1))
179
+ x0 = int(rng.integers(x_min, max(x_min, x_max - box_w + 1) + 1))
180
+
181
+ # slice
182
+ region_mask = mask_hard[y0:y0 + box_h, x0:x0 + box_w]
183
+ if region_mask.size == 0:
184
+ return base
185
+
186
+ noise = rng.normal(0.0, sigma, size=(box_h, box_w, 3))
187
+ base[y0:y0 + box_h, x0:x0 + box_w, :] += noise * region_mask[:, :, None]
188
+ return base
189
+
190
+
191
+ def dual_region_augment_dra(
192
+ img: Image.Image,
193
+ grid_n: int = 8,
194
+ fg_noise_std: float = 20.0,
195
+ seed: int = 0,
196
+ model_type: str = "u2netp",
197
+ ):
198
+ """
199
+ DRA:
200
+ - Background: jigsaw-shuffle full image on an N×N grid (no overlap), then use only on background.
201
+ - Foreground: add Gaussian noise to a single randomly selected rectangular box (inside FG).
202
+ - Fusion: FG from noisy image, BG from jigsaw image, using a hard U^2-Net mask.
203
+ """
204
+ random.seed(seed)
205
+ np.random.seed(seed)
206
+
207
+ base = _pil_to_np(img) # (H, W, 3), float32 [0..255]
208
+ H, W = base.shape[:2]
209
+
210
+ # 1) Mask from local weights (no feather; use hard threshold at 0.5)
211
+ raw_mask_L = get_u2net_mask_with_local_weights(img, model_type=model_type)
212
+ mask = (np.array(raw_mask_L, dtype=np.float32) / 255.0) >= 0.5
213
+ mask_hard = mask.astype(np.float32) # (H,W) in {0,1}
214
+
215
+ # 2) Foreground: noise in a random FG rectangle
216
+ img_fg = base.copy()
217
+ img_fg = add_noise_in_random_fg_box(img_fg, mask_hard, sigma=fg_noise_std, seed=seed)
218
+
219
+ # 3) Background: jigsaw-shuffle full image on N×N grid
220
+ jig = jigsaw_shuffle_full_image(Image.fromarray(base.astype(np.uint8)), N=grid_n, seed=seed)
221
+ img_bg = _pil_to_np(jig)
222
+
223
+ # 4) Fusion: BG where mask==0, FG where mask==1
224
+ m3 = np.repeat(mask_hard[:, :, None], 3, axis=2)
225
+ out = img_bg * (1.0 - m3) + img_fg * m3
226
+
227
+ return _np_to_pil(out), raw_mask_L
228
+
229
+
230
+ # ---- Gradio UI ----
231
+ GRID_CHOICES = ["2x2", "4x4", "8x8", "16x16"]
232
+
233
+ def parse_grid_choice(s: str) -> int:
234
+ try:
235
+ n = int(s.lower().split('x')[0])
236
+ return max(2, min(16, n))
237
+ except Exception:
238
+ return 8
239
+
240
+
241
+ def run_demo(
242
+ image,
243
+ grid_choice,
244
+ fg_noise_std,
245
+ seed,
246
+ model_type,
247
+ ):
248
+ if image is None:
249
+ raise gr.Error("Please upload an image or pick one from the examples.")
250
+ n = parse_grid_choice(grid_choice)
251
+ out_img, mask_L = dual_region_augment_dra(
252
+ image,
253
+ grid_n=n,
254
+ fg_noise_std=fg_noise_std,
255
+ seed=seed,
256
+ model_type=model_type,
257
+ )
258
+ return out_img, mask_L
259
+
260
+
261
+ def list_example_images():
262
+ ex_dir = Path("examples")
263
+ ex_dir.mkdir(exist_ok=True)
264
+ files = [
265
+ str(p) for p in sorted(ex_dir.iterdir())
266
+ if p.suffix.lower() in IMG_EXTS and p.is_file()
267
+ ]
268
+ return files if files else None
269
+
270
+
271
+ examples = list_example_images()
272
+
273
+ with gr.Blocks(title="Dual-Region Augmentation (DRA, Local U²-Net Weights)") as demo:
274
+ gr.Markdown(
275
+ "### Dual-Region Augmentation (DRA)\n"
276
+ "- **Background**: random patch shuffle on an N×N grid (no overlap) in the background region.\n"
277
+ "- **Foreground**: Gaussian noise in the foreground region.\n"
278
+ "- **Mask**: U²-Net / U²-NetP"
279
+ )
280
+ with gr.Row():
281
+ with gr.Column():
282
+ in_img = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"])
283
+ grid_choice = gr.Dropdown(GRID_CHOICES, value="8x8", label="Grid (number of patches)")
284
+ noise_std = gr.Slider(0, 100, value=50, step=1, label="Foreground Noise σ")
285
+ seed = gr.Slider(0, 9999, value=69, step=1, label="Seed")
286
+ model_type = gr.Dropdown(choices=["u2netp", "u2net"], value="u2netp", label="Model Type")
287
+ btn = gr.Button("Augment")
288
+ with gr.Column():
289
+ out_img = gr.Image(type="pil", label="Augmented Output")
290
+ out_mask = gr.Image(type="pil", label="U²-Net Mask (preview)")
291
+
292
+ btn.click(
293
+ fn=run_demo,
294
+ inputs=[in_img, grid_choice, noise_std, seed, model_type],
295
+ outputs=[out_img, out_mask],
296
+ concurrency_limit=3,
297
+ api_name="augment",
298
+ )
299
+
300
+ if examples:
301
+ gr.Examples(
302
+ examples=examples,
303
+ inputs=[in_img],
304
+ examples_per_page=12,
305
+ label="Examples (loaded from ./examples)"
306
+ )
307
+
308
+ if __name__ == "__main__":
309
+ demo.launch()
examples/example_1.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.22,<5
2
+ numpy
3
+ Pillow
4
+ torch
5
+ torchvision
saved_models/u2net/u2net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10025a17f49cd3208afc342b589890e402ee63123d6f2d289a4a0903695cce58
3
+ size 176290937
saved_models/u2netp/u2netp.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7567cde013fb64813973ce6e1ecc25a80c05c3ca7adbc5a54f3c3d90991b854
3
+ size 4683258
u2net/data_loader.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data loader
2
+ from __future__ import print_function, division
3
+ import glob
4
+ import torch
5
+ from skimage import io, transform, color
6
+ import numpy as np
7
+ import random
8
+ import math
9
+ import matplotlib.pyplot as plt
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms, utils
12
+ from PIL import Image
13
+
14
+ #==========================dataset load==========================
15
+ class RescaleT(object):
16
+
17
+ def __init__(self,output_size):
18
+ assert isinstance(output_size,(int,tuple))
19
+ self.output_size = output_size
20
+
21
+ def __call__(self,sample):
22
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
23
+
24
+ h, w = image.shape[:2]
25
+
26
+ if isinstance(self.output_size,int):
27
+ if h > w:
28
+ new_h, new_w = self.output_size*h/w,self.output_size
29
+ else:
30
+ new_h, new_w = self.output_size,self.output_size*w/h
31
+ else:
32
+ new_h, new_w = self.output_size
33
+
34
+ new_h, new_w = int(new_h), int(new_w)
35
+
36
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
37
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
38
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
39
+
40
+ img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
41
+ lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
42
+
43
+ return {'imidx':imidx, 'image':img,'label':lbl}
44
+
45
+ class Rescale(object):
46
+
47
+ def __init__(self,output_size):
48
+ assert isinstance(output_size,(int,tuple))
49
+ self.output_size = output_size
50
+
51
+ def __call__(self,sample):
52
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
53
+
54
+ if random.random() >= 0.5:
55
+ image = image[::-1]
56
+ label = label[::-1]
57
+
58
+ h, w = image.shape[:2]
59
+
60
+ if isinstance(self.output_size,int):
61
+ if h > w:
62
+ new_h, new_w = self.output_size*h/w,self.output_size
63
+ else:
64
+ new_h, new_w = self.output_size,self.output_size*w/h
65
+ else:
66
+ new_h, new_w = self.output_size
67
+
68
+ new_h, new_w = int(new_h), int(new_w)
69
+
70
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
71
+ img = transform.resize(image,(new_h,new_w),mode='constant')
72
+ lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
73
+
74
+ return {'imidx':imidx, 'image':img,'label':lbl}
75
+
76
+ class RandomCrop(object):
77
+
78
+ def __init__(self,output_size):
79
+ assert isinstance(output_size, (int, tuple))
80
+ if isinstance(output_size, int):
81
+ self.output_size = (output_size, output_size)
82
+ else:
83
+ assert len(output_size) == 2
84
+ self.output_size = output_size
85
+ def __call__(self,sample):
86
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
87
+
88
+ if random.random() >= 0.5:
89
+ image = image[::-1]
90
+ label = label[::-1]
91
+
92
+ h, w = image.shape[:2]
93
+ new_h, new_w = self.output_size
94
+
95
+ top = np.random.randint(0, h - new_h)
96
+ left = np.random.randint(0, w - new_w)
97
+
98
+ image = image[top: top + new_h, left: left + new_w]
99
+ label = label[top: top + new_h, left: left + new_w]
100
+
101
+ return {'imidx':imidx,'image':image, 'label':label}
102
+
103
+ class ToTensor(object):
104
+ """Convert ndarrays in sample to Tensors."""
105
+
106
+ def __call__(self, sample):
107
+
108
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
109
+
110
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
111
+ tmpLbl = np.zeros(label.shape)
112
+
113
+ image = image/np.max(image)
114
+ if(np.max(label)<1e-6):
115
+ label = label
116
+ else:
117
+ label = label/np.max(label)
118
+
119
+ if image.shape[2]==1:
120
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
121
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
122
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
123
+ else:
124
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
125
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
126
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
127
+
128
+ tmpLbl[:,:,0] = label[:,:,0]
129
+
130
+
131
+ tmpImg = tmpImg.transpose((2, 0, 1))
132
+ tmpLbl = label.transpose((2, 0, 1))
133
+
134
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
135
+
136
+ class ToTensorLab(object):
137
+ """Convert ndarrays in sample to Tensors."""
138
+ def __init__(self,flag=0):
139
+ self.flag = flag
140
+
141
+ def __call__(self, sample):
142
+
143
+ imidx, image, label =sample['imidx'], sample['image'], sample['label']
144
+
145
+ tmpLbl = np.zeros(label.shape)
146
+
147
+ if(np.max(label)<1e-6):
148
+ label = label
149
+ else:
150
+ label = label/np.max(label)
151
+
152
+ # change the color space
153
+ if self.flag == 2: # with rgb and Lab colors
154
+ tmpImg = np.zeros((image.shape[0],image.shape[1],6))
155
+ tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
156
+ if image.shape[2]==1:
157
+ tmpImgt[:,:,0] = image[:,:,0]
158
+ tmpImgt[:,:,1] = image[:,:,0]
159
+ tmpImgt[:,:,2] = image[:,:,0]
160
+ else:
161
+ tmpImgt = image
162
+ tmpImgtl = color.rgb2lab(tmpImgt)
163
+
164
+ # nomalize image to range [0,1]
165
+ tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
166
+ tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
167
+ tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
168
+ tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
169
+ tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
170
+ tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
171
+
172
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
173
+
174
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
175
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
176
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
177
+ tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
178
+ tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
179
+ tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
180
+
181
+ elif self.flag == 1: #with Lab color
182
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
183
+
184
+ if image.shape[2]==1:
185
+ tmpImg[:,:,0] = image[:,:,0]
186
+ tmpImg[:,:,1] = image[:,:,0]
187
+ tmpImg[:,:,2] = image[:,:,0]
188
+ else:
189
+ tmpImg = image
190
+
191
+ tmpImg = color.rgb2lab(tmpImg)
192
+
193
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
194
+
195
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
196
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
197
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
198
+
199
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
200
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
201
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
202
+
203
+ else: # with rgb color
204
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
205
+ image = image/np.max(image)
206
+ if image.shape[2]==1:
207
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
208
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
209
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
210
+ else:
211
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
212
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
213
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
214
+
215
+ tmpLbl[:,:,0] = label[:,:,0]
216
+
217
+
218
+ tmpImg = tmpImg.transpose((2, 0, 1))
219
+ tmpLbl = label.transpose((2, 0, 1))
220
+
221
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
222
+
223
+ class SalObjDataset(Dataset):
224
+ def __init__(self,img_name_list,lbl_name_list,transform=None):
225
+ # self.root_dir = root_dir
226
+ # self.image_name_list = glob.glob(image_dir+'*.png')
227
+ # self.label_name_list = glob.glob(label_dir+'*.png')
228
+ self.image_name_list = img_name_list
229
+ self.label_name_list = lbl_name_list
230
+ self.transform = transform
231
+
232
+ def __len__(self):
233
+ return len(self.image_name_list)
234
+
235
+ def __getitem__(self,idx):
236
+
237
+ # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
238
+ # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
239
+
240
+ image = io.imread(self.image_name_list[idx])
241
+ imname = self.image_name_list[idx]
242
+ imidx = np.array([idx])
243
+
244
+ if(0==len(self.label_name_list)):
245
+ label_3 = np.zeros(image.shape)
246
+ else:
247
+ label_3 = io.imread(self.label_name_list[idx])
248
+
249
+ label = np.zeros(label_3.shape[0:2])
250
+ if(3==len(label_3.shape)):
251
+ label = label_3[:,:,0]
252
+ elif(2==len(label_3.shape)):
253
+ label = label_3
254
+
255
+ if(3==len(image.shape) and 2==len(label.shape)):
256
+ label = label[:,:,np.newaxis]
257
+ elif(2==len(image.shape) and 2==len(label.shape)):
258
+ image = image[:,:,np.newaxis]
259
+ label = label[:,:,np.newaxis]
260
+
261
+ sample = {'imidx':imidx, 'image':image, 'label':label}
262
+
263
+ if self.transform:
264
+ sample = self.transform(sample)
265
+
266
+ return sample
u2net/model.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
utils.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ from PIL import Image
5
+ import yaml
6
+
7
+ from sklearn.metrics import confusion_matrix
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+ from torch.nn.parallel import DistributedDataParallel
12
+ from torchvision import transforms
13
+
14
+ from moco.loader import GaussianBlur
15
+ import numpy as np
16
+ from augmentations import JigsawPuzzle, JigsawPuzzle_l, JigsawPuzzle_all, RandomErasing, RandomPatchNoise, RandomPatchErase
17
+
18
+
19
+ LOG_FORMAT = "[%(levelname)s] %(asctime)s %(filename)s:%(lineno)s %(message)s"
20
+ LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
21
+
22
+ NUM_CLASSES = {"domainnet-126": 126, "VISDA-C": 12, "PACS": 7}
23
+
24
+ import torch
25
+ import numpy as np
26
+ from PIL import Image
27
+
28
+
29
+ def configure_logger(rank, log_path=None):
30
+ if log_path:
31
+ log_dir = os.path.dirname(log_path)
32
+ os.makedirs(log_dir, exist_ok=True)
33
+
34
+ # only master process will print & write
35
+ level = logging.INFO if rank in {-1, 0} else logging.WARNING
36
+ handlers = [logging.StreamHandler()]
37
+ if rank in {0, -1} and log_path:
38
+ handlers.append(logging.FileHandler(log_path, "w"))
39
+
40
+ logging.basicConfig(
41
+ level=level,
42
+ format=LOG_FORMAT,
43
+ datefmt=LOG_DATEFMT,
44
+ handlers=handlers,
45
+ )
46
+
47
+
48
+ class UnevenBatchLoader:
49
+ """Loader that loads data from multiple datasets with different length."""
50
+
51
+ def __init__(self, data_loaders, is_ddp=False):
52
+ # register N data loaders with epoch counters.
53
+ self.data_loaders = data_loaders
54
+ self.epoch_counters = [0 for _ in range(len(data_loaders))]
55
+
56
+ # set_epoch() needs to be called before creating the iterator
57
+ self.is_ddp = is_ddp
58
+ if is_ddp:
59
+ for data_loader in data_loaders:
60
+ data_loader.sampler.set_epoch(0)
61
+ self.iterators = [iter(data_loader) for data_loader in data_loaders]
62
+
63
+ def next_batch(self):
64
+ """Load the next batch by collecting from N data loaders.
65
+ Args:
66
+ None
67
+ Returns:
68
+ data: a list of N items from N data loaders. each item has the format
69
+ output by a single data loader.
70
+ """
71
+ data = []
72
+ for i, iterator in enumerate(self.iterators):
73
+ try:
74
+ batch_i = next(iterator)
75
+ except StopIteration:
76
+ self.epoch_counters[i] += 1
77
+ # create a new iterator
78
+ if self.is_ddp:
79
+ self.data_loaders[i].sampler.set_epoch(self.epoch_counters[i])
80
+ new_iterator = iter(self.data_loaders[i])
81
+ self.iterators[i] = new_iterator
82
+ batch_i = next(new_iterator)
83
+ data.append(batch_i)
84
+
85
+ return data
86
+
87
+ def update_loader(self, idx, loader, epoch=None):
88
+ if self.is_ddp and isinstance(epoch, int):
89
+ loader.sampler.set_epoch(epoch)
90
+ self.iterators[idx] = iter(loader)
91
+
92
+
93
+ class CustomDistributedDataParallel(DistributedDataParallel):
94
+ """A wrapper class over DDP that relay "module" attribute."""
95
+
96
+ def __init__(self, model, **kwargs):
97
+ super(CustomDistributedDataParallel, self).__init__(model, **kwargs)
98
+
99
+ def __getattr__(self, name):
100
+ try:
101
+ return super(CustomDistributedDataParallel, self).__getattr__(name)
102
+ except AttributeError:
103
+ return getattr(self.module, name)
104
+
105
+
106
+ @torch.no_grad()
107
+ def concat_all_gather(tensor):
108
+ """
109
+ Performs all_gather operation on the provided tensors.
110
+ *** Warning ***: torch.distributed.all_gather has no gradient.
111
+ """
112
+ tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
113
+ dist.all_gather(tensors_gather, tensor, async_op=False)
114
+
115
+ output = torch.cat(tensors_gather, dim=0)
116
+ return output
117
+
118
+
119
+ @torch.no_grad()
120
+ def remove_wrap_arounds(tensor, ranks):
121
+ if ranks == 0:
122
+ return tensor
123
+
124
+ world_size = dist.get_world_size()
125
+ single_length = len(tensor) // world_size
126
+ output = []
127
+ for rank in range(world_size):
128
+ sub_tensor = tensor[rank * single_length : (rank + 1) * single_length]
129
+ if rank >= ranks:
130
+ output.append(sub_tensor[:-1])
131
+ else:
132
+ output.append(sub_tensor)
133
+ output = torch.cat(output)
134
+
135
+ return output
136
+
137
+
138
+ def get_categories(category_file):
139
+ """Return a list of categories ordered by corresponding label.
140
+
141
+ Args:
142
+ category_file: str, path to the category file. can be .yaml or .txt
143
+
144
+ Returns:
145
+ categories: List[str], a list of categories ordered by label.
146
+ """
147
+ if category_file.endswith(".yaml"):
148
+ with open(category_file, "r") as fd:
149
+ cat_mapping = yaml.load(fd, Loader=yaml.SafeLoader)
150
+ categories = list(cat_mapping.keys())
151
+ categories.sort(key=lambda x: cat_mapping[x])
152
+ elif category_file.endswith(".txt"):
153
+ with open(category_file, "r") as fd:
154
+ categories = fd.readlines()
155
+ categories = [cat.strip() for cat in categories if cat]
156
+ else:
157
+ raise NotImplementedError()
158
+
159
+ categories = [cat.replace("_", " ") for cat in categories]
160
+ return categories
161
+
162
+
163
+ def get_augmentation(aug_type, patch_height=28, mix_prob=0.8, normalize=None):
164
+ if not normalize:
165
+ normalize = transforms.Normalize(
166
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
167
+ )
168
+ if aug_type == "moco-v2":
169
+ image_aug = transforms.Compose(
170
+ [
171
+ transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
172
+ transforms.RandomApply(
173
+ [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
174
+ p=0.8, # not strengthened
175
+ ),
176
+ transforms.RandomGrayscale(p=0.2),
177
+ transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
178
+ transforms.RandomHorizontalFlip(),
179
+ transforms.ToTensor(),
180
+ normalize,
181
+ ]
182
+ )
183
+ elif aug_type == "moco-v1":
184
+ image_aug = transforms.Compose(
185
+ [
186
+ transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
187
+ transforms.RandomGrayscale(p=0.2),
188
+ transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
189
+ transforms.RandomHorizontalFlip(),
190
+ transforms.ToTensor(),
191
+ normalize,
192
+ ]
193
+ )
194
+ elif aug_type == "plain":
195
+ image_aug = transforms.Compose(
196
+ [
197
+ transforms.Resize((256, 256)),
198
+ transforms.RandomCrop(224),
199
+ transforms.RandomHorizontalFlip(),
200
+ transforms.ToTensor(),
201
+ normalize,
202
+ ]
203
+ )
204
+ elif aug_type == "clip_inference":
205
+ image_aug = transforms.Compose(
206
+ [
207
+ transforms.Resize(224, interpolation=Image.BICUBIC),
208
+ transforms.CenterCrop(224),
209
+ transforms.ToTensor(),
210
+ normalize,
211
+ ]
212
+ )
213
+ elif aug_type == "test":
214
+ image_aug = transforms.Compose(
215
+ [
216
+ transforms.Resize((256, 256)),
217
+ transforms.CenterCrop(224),
218
+ transforms.ToTensor(),
219
+ normalize,
220
+ ]
221
+ )
222
+ elif aug_type == "jigsaw":
223
+ image_aug = transforms.Compose(
224
+ [
225
+ transforms.Resize((256, 256)),
226
+ transforms.CenterCrop(224),
227
+ # transforms.RandomHorizontalFlip(),
228
+ JigsawPuzzle(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
229
+ transforms.ToTensor(),
230
+ normalize,
231
+ ]
232
+ )
233
+ elif aug_type == "jigsaw_all":
234
+ image_aug = transforms.Compose(
235
+ [
236
+ transforms.Resize((256, 256)),
237
+ transforms.CenterCrop(224),
238
+ # transforms.RandomHorizontalFlip(),
239
+ JigsawPuzzle_all(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
240
+ transforms.ToTensor(),
241
+ normalize,
242
+ ]
243
+ )
244
+ elif aug_type == "jigsaw_l":
245
+ image_aug = transforms.Compose(
246
+ [
247
+ transforms.Resize((256, 256)),
248
+ transforms.CenterCrop(224),
249
+ # transforms.RandomHorizontalFlip(),
250
+ JigsawPuzzle_l(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
251
+ transforms.ToTensor(),
252
+ normalize,
253
+ ]
254
+ )
255
+ elif aug_type == "rpe":
256
+ image_aug = transforms.Compose(
257
+ [
258
+ transforms.Resize((256, 256)),
259
+ transforms.CenterCrop(224),
260
+ # transforms.RandomHorizontalFlip(),
261
+ RandomPatchErase(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
262
+ transforms.ToTensor(),
263
+ normalize,
264
+ ]
265
+ )
266
+ elif aug_type == "rpn":
267
+ image_aug = transforms.Compose(
268
+ [
269
+ transforms.Resize((256, 256)),
270
+ transforms.CenterCrop(224),
271
+ # transforms.RandomHorizontalFlip(),
272
+ RandomPatchNoise(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
273
+ transforms.ToTensor(),
274
+ normalize,
275
+ ]
276
+ )
277
+ elif aug_type in ["ours", "ours_1"]:
278
+ image_aug = transforms.Compose(
279
+ [
280
+ transforms.Resize((256, 256)),
281
+ transforms.CenterCrop(224),
282
+ JigsawPuzzle_all(patch_height=patch_height, patch_width=patch_height, mix_prob=mix_prob),
283
+ transforms.ToTensor(),
284
+ ]
285
+ )
286
+ else:
287
+ image_aug = None
288
+
289
+ return DualTransform(
290
+ aug_type=aug_type,
291
+ image_transform=image_aug,
292
+ patch_height=patch_height,
293
+ patch_width=patch_height,
294
+ mix_prob=mix_prob,
295
+ )
296
+
297
+ def fuse_foreground_background(img1, img2, mask):
298
+ """
299
+ Given a (C,H,W) image tensor and a (possibly 2D) mask,
300
+ multiply img by mask to black out the background.
301
+ Expects 0 as background in the mask.
302
+ """
303
+ mask = (mask > 0.5)
304
+ output = img1 * mask + img2 * (~mask)
305
+
306
+ return output
307
+
308
+ def normalize(tensor):
309
+ T = transforms.Normalize(
310
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
311
+ )
312
+ return T(tensor)
313
+
314
+ class DualTransform:
315
+ """
316
+ A wrapper that can apply image-only transforms or image+mask transforms.
317
+ """
318
+ def __init__(self, aug_type, image_transform=None, patch_height=28, patch_width=28,mix_prob=1.0):
319
+ self.image_transform = image_transform
320
+ self.aug_type = aug_type
321
+ self.base_transform = transforms.Compose(
322
+ [
323
+ transforms.Resize((256, 256)),
324
+ transforms.CenterCrop(224),
325
+ transforms.ToTensor(),
326
+ ]
327
+ )
328
+
329
+ self.moco_transform = transforms.Compose(
330
+ [
331
+ transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
332
+ transforms.RandomApply(
333
+ [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
334
+ p=0.8, # not strengthened
335
+ ),
336
+ transforms.RandomGrayscale(p=0.2),
337
+ transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
338
+ transforms.RandomHorizontalFlip(),
339
+ # RandomErasing(mode='soft_pixel'),
340
+ transforms.ToTensor(),
341
+ normalize,
342
+ ]
343
+ )
344
+ self.fpn = RandomPatchNoise(patch_height=28, patch_width=28, mix_prob=mix_prob)
345
+ self.to_pil = transforms.ToPILImage()
346
+ self.to_tensor = transforms.ToTensor()
347
+ self.jigsaw = JigsawPuzzle(patch_height=28, patch_width=28, mix_prob=mix_prob)
348
+ self.jigsaw_all = JigsawPuzzle_all(mix_prob=mix_prob)
349
+
350
+ def __call__(self, img, mask=None):
351
+ if self.aug_type == "mask":
352
+ mask = self.base_transform(mask)
353
+ return normalize(mask)
354
+ elif self.aug_type == "foreground":
355
+ mask = self.base_transform(mask)
356
+ img = self.base_transform(img)
357
+ return normalize(img * (mask>0.5).float())
358
+ elif self.aug_type == "fpn":
359
+ mask = self.base_transform(mask)
360
+ img = self.base_transform(img)
361
+ img_n = self.to_tensor(self.fpn(self.to_pil(img)))
362
+ return normalize(img_n * (mask>0.5).float())
363
+ elif self.aug_type == "bps":
364
+ mask = self.base_transform(mask)
365
+ img = self.base_transform(img)
366
+ img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
367
+ return normalize(img_jigsaw * (mask<0.5).float())
368
+ elif self.aug_type == "ours_raw":
369
+ mask = self.base_transform(mask)
370
+ img = self.base_transform(img)
371
+ img_n = self.to_tensor(self.fpn(self.to_pil(img)))
372
+ img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
373
+ img_out = fuse_foreground_background(img_n, img_jigsaw, mask)
374
+ return normalize(img_out)
375
+ elif self.aug_type == "ours":
376
+ mask = self.base_transform(mask)
377
+ img = self.base_transform(img)
378
+ img_n = self.to_tensor(self.fpn(self.to_pil(img)))
379
+ img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
380
+ img_out = fuse_foreground_background(img_n, img_jigsaw, mask)
381
+ return self.moco_transform(self.to_pil(img_out))
382
+ elif self.aug_type == "ours_fpn":
383
+ mask = self.base_transform(mask)
384
+ img = self.base_transform(img)
385
+ img_n = self.to_tensor(self.fpn(self.to_pil(img)))
386
+ img_out = fuse_foreground_background(img_n, img, mask)
387
+ return self.moco_transform(self.to_pil(img_out))
388
+ elif self.aug_type == "ours_bps":
389
+ mask = self.base_transform(mask)
390
+ img = self.base_transform(img)
391
+ # img_n = self.to_tensor(self.fpn(self.to_pil(img)))
392
+ img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
393
+ img_out = fuse_foreground_background(img, img_jigsaw, mask)
394
+ return self.moco_transform(self.to_pil(img_out))
395
+ # Always transform the image if we have an image_transform
396
+ else:
397
+ return self.image_transform(img)
398
+ '''
399
+ elif self.aug_type == "ours_old":
400
+ img_t = self.image_transform(img)
401
+ img = self.base_transform(img)
402
+ mask = self.base_transform(mask)
403
+ img_t1 = fuse_foreground_background(img, img_t, mask)
404
+ img_t1_pil = self.to_pil(img_t1)
405
+ output = self.moco_transform(img_t1_pil)
406
+ return output
407
+ elif self.aug_type == "ours_1":
408
+ img_t = self.image_transform(img)
409
+ img = self.base_transform(img)
410
+ mask = self.base_transform(mask)
411
+ img_t1 = fuse_foreground_background(img, img_t, mask)
412
+ return normalize(img_t1)'
413
+ '''
414
+
415
+ class AverageMeter(object):
416
+ """Computes and stores the average and current value"""
417
+
418
+ def __init__(self, name, fmt=":f"):
419
+ self.name = name
420
+ self.fmt = fmt
421
+ self.reset()
422
+
423
+ def reset(self):
424
+ self.val = 0
425
+ self.avg = 0
426
+ self.sum = 0
427
+ self.count = 0
428
+
429
+ def update(self, val, n=1):
430
+ self.val = val
431
+ self.sum += val * n
432
+ self.count += n
433
+ self.avg = self.sum / self.count
434
+
435
+ def __str__(self):
436
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
437
+ return fmtstr.format(**self.__dict__)
438
+
439
+
440
+ class ProgressMeter(object):
441
+ def __init__(self, num_batches, meters, prefix=""):
442
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
443
+ self.meters = meters
444
+ self.prefix = prefix
445
+
446
+ def display(self, batch):
447
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
448
+ entries += [str(meter) for meter in self.meters]
449
+ logging.info("\t".join(entries))
450
+
451
+ def _get_batch_fmtstr(self, num_batches):
452
+ num_digits = len(str(num_batches // 1))
453
+ fmt = "{:" + str(num_digits) + "d}"
454
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
455
+
456
+
457
+ def save_checkpoint(model, optimizer, epoch, save_path="checkpoint.pth.tar"):
458
+ state = {
459
+ "state_dict": model.state_dict(),
460
+ "optimizer": optimizer.state_dict(),
461
+ "epoch": epoch,
462
+ }
463
+ torch.save(state, save_path)
464
+
465
+
466
+ def adjust_learning_rate(optimizer, progress, args):
467
+ """
468
+ Decay the learning rate based on epoch or iteration.
469
+ """
470
+ if args.optim.cos:
471
+ decay = 0.5 * (1.0 + math.cos(math.pi * progress / args.learn.full_progress))
472
+ elif args.optim.exp:
473
+ decay = (1 + 10 * progress / args.learn.full_progress) ** -0.75
474
+ else:
475
+ decay = 1.0
476
+ for milestone in args.optim.schedule:
477
+ decay *= args.optim.gamma if progress >= milestone else 1.0
478
+ for param_group in optimizer.param_groups:
479
+ param_group["lr"] = param_group["lr0"] * decay
480
+
481
+ return decay
482
+
483
+
484
+ def per_class_accuracy(y_true, y_pred):
485
+ matrix = confusion_matrix(y_true, y_pred)
486
+ acc_per_class = (matrix.diagonal() / matrix.sum(axis=1) * 100.0).round(2)
487
+ logging.info(
488
+ f"Accuracy per class: {acc_per_class}, mean: {acc_per_class.mean().round(2)}"
489
+ )
490
+
491
+ return acc_per_class
492
+
493
+
494
+ def get_distances(X, Y, dist_type="euclidean"):
495
+ """
496
+ Args:
497
+ X: (N, D) tensor
498
+ Y: (M, D) tensor
499
+ """
500
+ if dist_type == "euclidean":
501
+ distances = torch.cdist(X, Y)
502
+ elif dist_type == "cosine":
503
+ distances = 1 - torch.matmul(F.normalize(X, dim=1), F.normalize(Y, dim=1).T)
504
+ else:
505
+ raise NotImplementedError(f"{dist_type} distance not implemented.")
506
+
507
+ return distances
508
+
509
+
510
+ def is_master(args):
511
+ return args.rank % args.ngpus_per_node == 0
512
+
513
+
514
+ def use_wandb(args):
515
+ return is_master(args) and args.use_wandb