Spaces:
Runtime error
Runtime error
Duplicate from befozg/stylematte
Browse filesCo-authored-by: Karen Efremyan <befozg@users.noreply.huggingface.co>
- .gitattributes +34 -0
- .gitignore +1 -0
- README.md +13 -0
- app.py +30 -0
- base.yaml +61 -0
- logo.jpeg +0 -0
- models.py +481 -0
- requirements.txt +38 -0
- stylematte.pth +3 -0
- test.py +1002 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pth
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Stylematte
|
| 3 |
+
emoji: 💻
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.29.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
duplicated_from: befozg/stylematte
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from test import inference_img
|
| 3 |
+
from models import *
|
| 4 |
+
|
| 5 |
+
device='cpu'
|
| 6 |
+
model = StyleMatte()
|
| 7 |
+
model = model.to(device)
|
| 8 |
+
checkpoint = f"stylematte.pth"
|
| 9 |
+
state_dict = torch.load(checkpoint, map_location=f'{device}')
|
| 10 |
+
|
| 11 |
+
model.load_state_dict(state_dict)
|
| 12 |
+
model.eval()
|
| 13 |
+
|
| 14 |
+
def predict(inp):
|
| 15 |
+
print("***********Inference****************")
|
| 16 |
+
res = inference_img(model, inp)
|
| 17 |
+
print("***********Inference finish****************")
|
| 18 |
+
|
| 19 |
+
return res
|
| 20 |
+
|
| 21 |
+
print("MODEL LOADED")
|
| 22 |
+
print("************************************")
|
| 23 |
+
|
| 24 |
+
iface = gr.Interface(fn=predict,
|
| 25 |
+
inputs=gr.Image(type="numpy"),
|
| 26 |
+
outputs=gr.Image(type="numpy"),
|
| 27 |
+
examples=["./logo.jpeg"])
|
| 28 |
+
print("****************Interface created******************")
|
| 29 |
+
|
| 30 |
+
iface.launch()
|
base.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
world_size: 1
|
| 2 |
+
experiment_name: "test"
|
| 3 |
+
datasets:
|
| 4 |
+
synthetic_fg: "/home/jovyan/datasets/synthetic_psi/"
|
| 5 |
+
synthetic_animals: "/home/jovyan/datasets/synthetic_psiny/"
|
| 6 |
+
bg: "/home/jovyan/datasets/matting/background/testval/"
|
| 7 |
+
ppm100: "/home/jovyan/kvanchiani/stylegan3/PPM-100/image"
|
| 8 |
+
aim500: "/home/jovyan/datasets/AIM-500"
|
| 9 |
+
am2k:
|
| 10 |
+
train_original: "/home/jovyan/datasets/matting/am-2k/train/original"
|
| 11 |
+
train_mask: "/home/jovyan/datasets/matting/am-2k/train/mask"
|
| 12 |
+
background: "/home/jovyan/datasets/matting/am-2k/background/train"
|
| 13 |
+
validation_original: "/home/jovyan/datasets/matting/am-2k/validation/original/"
|
| 14 |
+
validation_mask: "/home/jovyan/datasets/matting/am-2k/validation/mask/"
|
| 15 |
+
validation_trimap: "/home/jovyan/datasets/matting/am-2k/validation/trimap/"
|
| 16 |
+
tiktok: "/home/jovyan/datasets/tiktokdataset/dataset"
|
| 17 |
+
p3m10k: "/home/jovyan/datasets/matting/P3M-10k"
|
| 18 |
+
p3m10k_test:
|
| 19 |
+
VAL500P:
|
| 20 |
+
ROOT_PATH: "P3M-500-P/"
|
| 21 |
+
ORIGINAL_PATH: "P3M-500-P/blurred_image/"
|
| 22 |
+
MASK_PATH: "P3M-500-P/mask/"
|
| 23 |
+
TRIMAP_PATH: "P3M-500-P/trimap/"
|
| 24 |
+
SAMPLE_NUMBER: 500
|
| 25 |
+
VAL500NP:
|
| 26 |
+
ROOT_PATH: "P3M-500-NP/"
|
| 27 |
+
ORIGINAL_PATH: "P3M-500-NP/original_image/"
|
| 28 |
+
MASK_PATH: "P3M-500-NP/mask/"
|
| 29 |
+
TRIMAP_PATH: "P3M-500-NP/trimap/"
|
| 30 |
+
SAMPLE_NUMBER: 500
|
| 31 |
+
MAX_SIZE_H: 1600
|
| 32 |
+
MAX_SIZE_W: 1600
|
| 33 |
+
image_crop: 800
|
| 34 |
+
max_image_count: 10000
|
| 35 |
+
dataset_to_use: MixedDataset
|
| 36 |
+
pretrained_model: "microsoft/swinv2-tiny-patch4-window8-256" #"nielsr/mask2former-swin-base-youtubevis-2021" #"nvidia/mit-b2"
|
| 37 |
+
batch_size: 4
|
| 38 |
+
num_workers: 4
|
| 39 |
+
log_dir: "log"
|
| 40 |
+
checkpoint_dir: "checkpoints"
|
| 41 |
+
checkpoint: "best-89.pth"
|
| 42 |
+
distributed_addr: "localhost"
|
| 43 |
+
distributed_port: "12357"
|
| 44 |
+
image_size: 800
|
| 45 |
+
lr: 1e-7
|
| 46 |
+
epochs: 200
|
| 47 |
+
disable_validation: False
|
| 48 |
+
warmup_steps: 2
|
| 49 |
+
validate_each_epoch: 5
|
| 50 |
+
max_images_for_validation: 500
|
| 51 |
+
disable_mixed_precision: True
|
| 52 |
+
log_image_interval: 500
|
| 53 |
+
log_image_number: 8
|
| 54 |
+
save_model_interval: 10000
|
| 55 |
+
switch: 3
|
| 56 |
+
lambda_losses:
|
| 57 |
+
default: 1.
|
| 58 |
+
Laplassian: 1.
|
| 59 |
+
Grad: 3.
|
| 60 |
+
L1: 1.
|
| 61 |
+
switch: 1e-6
|
logo.jpeg
ADDED
|
models.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List
|
| 8 |
+
from itertools import chain
|
| 9 |
+
|
| 10 |
+
from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
|
| 11 |
+
device='cpu'
|
| 12 |
+
class EncoderDecoder(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
encoder,
|
| 16 |
+
decoder,
|
| 17 |
+
prefix=nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True),
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.encoder = encoder
|
| 21 |
+
self.decoder = decoder
|
| 22 |
+
self.prefix = prefix
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
if self.prefix is not None:
|
| 26 |
+
x = self.prefix(x)
|
| 27 |
+
x = self.encoder(x)["hidden_states"] #transformers
|
| 28 |
+
return self.decoder(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def conv2d_relu(input_filters,output_filters,kernel_size=3, bias=True):
|
| 32 |
+
return nn.Sequential(
|
| 33 |
+
nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
|
| 34 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 35 |
+
nn.BatchNorm2d(output_filters)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def up_and_add(x, y):
|
| 39 |
+
return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
|
| 40 |
+
|
| 41 |
+
class FPN_fuse(nn.Module):
|
| 42 |
+
def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
|
| 43 |
+
super(FPN_fuse, self).__init__()
|
| 44 |
+
assert feature_channels[0] == fpn_out
|
| 45 |
+
self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
|
| 46 |
+
for ft_size in feature_channels[1:]])
|
| 47 |
+
self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
|
| 48 |
+
* (len(feature_channels)-1))
|
| 49 |
+
self.conv_fusion = nn.Sequential(
|
| 50 |
+
nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3, padding=1, bias=False),
|
| 51 |
+
nn.BatchNorm2d(fpn_out),
|
| 52 |
+
nn.ReLU(inplace=True),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, features):
|
| 56 |
+
|
| 57 |
+
features[:-1] = [conv1x1(feature) for feature, conv1x1 in zip(features[:-1], self.conv1x1)]##
|
| 58 |
+
feature=up_and_add(self.smooth_conv[0](features[0]),features[1])
|
| 59 |
+
feature=up_and_add(self.smooth_conv[1](feature),features[2])
|
| 60 |
+
feature=up_and_add(self.smooth_conv[2](feature),features[3])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
H, W = features[-1].size(2), features[-1].size(3)
|
| 64 |
+
x = [feature,features[-1]]
|
| 65 |
+
x = [F.interpolate(x_el, size=(H, W), mode='bilinear', align_corners=True) for x_el in x]
|
| 66 |
+
|
| 67 |
+
x = self.conv_fusion(torch.cat(x, dim=1))
|
| 68 |
+
#x = F.interpolate(x, size=(H*4, W*4), mode='bilinear', align_corners=True)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
class PSPModule(nn.Module):
|
| 72 |
+
# In the original inmplementation they use precise RoI pooling
|
| 73 |
+
# Instead of using adaptative average pooling
|
| 74 |
+
def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
|
| 75 |
+
super(PSPModule, self).__init__()
|
| 76 |
+
out_channels = in_channels // len(bin_sizes)
|
| 77 |
+
self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
|
| 78 |
+
for b_s in bin_sizes])
|
| 79 |
+
self.bottleneck = nn.Sequential(
|
| 80 |
+
nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
|
| 81 |
+
kernel_size=3, padding=1, bias=False),
|
| 82 |
+
nn.BatchNorm2d(in_channels),
|
| 83 |
+
nn.ReLU(inplace=True),
|
| 84 |
+
nn.Dropout2d(0.1)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def _make_stages(self, in_channels, out_channels, bin_sz):
|
| 88 |
+
prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
|
| 89 |
+
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 90 |
+
bn = nn.BatchNorm2d(out_channels)
|
| 91 |
+
relu = nn.ReLU(inplace=True)
|
| 92 |
+
return nn.Sequential(prior, conv, bn, relu)
|
| 93 |
+
|
| 94 |
+
def forward(self, features):
|
| 95 |
+
h, w = features.size()[2], features.size()[3]
|
| 96 |
+
pyramids = [features]
|
| 97 |
+
pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
|
| 98 |
+
align_corners=True) for stage in self.stages])
|
| 99 |
+
output = self.bottleneck(torch.cat(pyramids, dim=1))
|
| 100 |
+
return output
|
| 101 |
+
class UperNet_swin(nn.Module):
|
| 102 |
+
# Implementing only the object path
|
| 103 |
+
def __init__(self, backbone,pretrained=True):
|
| 104 |
+
super(UperNet_swin, self).__init__()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
self.backbone = backbone
|
| 108 |
+
feature_channels = [192,384,768,768]
|
| 109 |
+
self.PPN = PSPModule(feature_channels[-1])
|
| 110 |
+
self.FPN = FPN_fuse(feature_channels, fpn_out=feature_channels[0])
|
| 111 |
+
self.head = nn.Conv2d(feature_channels[0], 1, kernel_size=3, padding=1)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
input_size = (x.size()[2], x.size()[3])
|
| 117 |
+
features = self.backbone(x)["hidden_states"]
|
| 118 |
+
features[-1] = self.PPN(features[-1])
|
| 119 |
+
x = self.head(self.FPN(features))
|
| 120 |
+
|
| 121 |
+
x = F.interpolate(x, size=input_size, mode='bilinear')
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
def get_backbone_params(self):
|
| 125 |
+
return self.backbone.parameters()
|
| 126 |
+
|
| 127 |
+
def get_decoder_params(self):
|
| 128 |
+
return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters())
|
| 129 |
+
|
| 130 |
+
class UnetDecoder(nn.Module):
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
encoder_channels= (3,192,384,768,768),
|
| 134 |
+
decoder_channels=(512,256,128,64),
|
| 135 |
+
n_blocks=4,
|
| 136 |
+
use_batchnorm=True,
|
| 137 |
+
attention_type=None,
|
| 138 |
+
center=False,
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
if n_blocks != len(decoder_channels):
|
| 143 |
+
raise ValueError(
|
| 144 |
+
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
|
| 145 |
+
n_blocks, len(decoder_channels)
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# remove first skip with same spatial resolution
|
| 150 |
+
encoder_channels = encoder_channels[1:]
|
| 151 |
+
# reverse channels to start from head of encoder
|
| 152 |
+
encoder_channels = encoder_channels[::-1]
|
| 153 |
+
|
| 154 |
+
# computing blocks input and output channels
|
| 155 |
+
head_channels = encoder_channels[0]
|
| 156 |
+
in_channels = [head_channels] + list(decoder_channels[:-1])
|
| 157 |
+
skip_channels = list(encoder_channels[1:]) + [0]
|
| 158 |
+
|
| 159 |
+
out_channels = decoder_channels
|
| 160 |
+
|
| 161 |
+
if center:
|
| 162 |
+
self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
|
| 163 |
+
else:
|
| 164 |
+
self.center = nn.Identity()
|
| 165 |
+
|
| 166 |
+
# combine decoder keyword arguments
|
| 167 |
+
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
|
| 168 |
+
blocks = [
|
| 169 |
+
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
|
| 170 |
+
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
|
| 171 |
+
]
|
| 172 |
+
self.blocks = nn.ModuleList(blocks)
|
| 173 |
+
upscale_factor=4
|
| 174 |
+
self.matting_head = nn.Sequential(
|
| 175 |
+
nn.Conv2d(64,1, kernel_size=3, padding=1),
|
| 176 |
+
nn.ReLU(),
|
| 177 |
+
nn.UpsamplingBilinear2d(scale_factor=upscale_factor),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def preprocess_features(self,x):
|
| 181 |
+
features=[]
|
| 182 |
+
for out_tensor in x:
|
| 183 |
+
bs,n,f=out_tensor.size()
|
| 184 |
+
h = int(n**0.5)
|
| 185 |
+
feature = out_tensor.view(-1,h,h,f).permute(0, 3, 1, 2).contiguous()
|
| 186 |
+
features.append(feature)
|
| 187 |
+
return features
|
| 188 |
+
|
| 189 |
+
def forward(self, features):
|
| 190 |
+
features = features[1:] # remove first skip with same spatial resolution
|
| 191 |
+
features = features[::-1] # reverse channels to start from head of encoder
|
| 192 |
+
|
| 193 |
+
features = self.preprocess_features(features)
|
| 194 |
+
|
| 195 |
+
head = features[0]
|
| 196 |
+
skips = features[1:]
|
| 197 |
+
|
| 198 |
+
x = self.center(head)
|
| 199 |
+
for i, decoder_block in enumerate(self.blocks):
|
| 200 |
+
skip = skips[i] if i < len(skips) else None
|
| 201 |
+
x = decoder_block(x, skip)
|
| 202 |
+
#y_i = self.upsample1(y_i)
|
| 203 |
+
#hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)
|
| 204 |
+
x = self.matting_head(x)
|
| 205 |
+
x=1-nn.ReLU()(1-x)
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class SegmentationHead(nn.Sequential):
|
| 210 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
|
| 211 |
+
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
| 212 |
+
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
| 213 |
+
super().__init__(conv2d, upsampling)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class DecoderBlock(nn.Module):
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
in_channels,
|
| 220 |
+
skip_channels,
|
| 221 |
+
out_channels,
|
| 222 |
+
use_batchnorm=True,
|
| 223 |
+
attention_type=None,
|
| 224 |
+
):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.conv1 = conv2d_relu(
|
| 227 |
+
in_channels + skip_channels,
|
| 228 |
+
out_channels,
|
| 229 |
+
kernel_size=3
|
| 230 |
+
)
|
| 231 |
+
self.conv2 = conv2d_relu(
|
| 232 |
+
out_channels,
|
| 233 |
+
out_channels,
|
| 234 |
+
kernel_size=3,
|
| 235 |
+
)
|
| 236 |
+
self.in_channels=in_channels
|
| 237 |
+
self.out_channels = out_channels
|
| 238 |
+
self.skip_channels = skip_channels
|
| 239 |
+
def forward(self, x, skip=None):
|
| 240 |
+
if skip is None:
|
| 241 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 242 |
+
else:
|
| 243 |
+
if x.shape[-1]!=skip.shape[-1]:
|
| 244 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 245 |
+
if skip is not None:
|
| 246 |
+
#print(x.shape,skip.shape)
|
| 247 |
+
x = torch.cat([x, skip], dim=1)
|
| 248 |
+
x = self.conv1(x)
|
| 249 |
+
x = self.conv2(x)
|
| 250 |
+
return x
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CenterBlock(nn.Sequential):
|
| 254 |
+
def __init__(self, in_channels, out_channels):
|
| 255 |
+
conv1 = conv2d_relu(
|
| 256 |
+
in_channels,
|
| 257 |
+
out_channels,
|
| 258 |
+
kernel_size=3,
|
| 259 |
+
)
|
| 260 |
+
conv2 = conv2d_relu(
|
| 261 |
+
out_channels,
|
| 262 |
+
out_channels,
|
| 263 |
+
kernel_size=3,
|
| 264 |
+
)
|
| 265 |
+
super().__init__(conv1, conv2)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class SegForm(nn.Module):
|
| 270 |
+
def __init__(self):
|
| 271 |
+
super(SegForm, self).__init__()
|
| 272 |
+
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
| 273 |
+
# configuration.num_labels = 1 ## set output as 1
|
| 274 |
+
# self.model = SegformerForSemanticSegmentation(config=configuration)
|
| 275 |
+
|
| 276 |
+
self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
|
| 277 |
+
)
|
| 278 |
+
def forward(self, image):
|
| 279 |
+
img_segs = self.model(image)
|
| 280 |
+
upsampled_logits = nn.functional.interpolate(img_segs.logits,
|
| 281 |
+
scale_factor=4,
|
| 282 |
+
mode='nearest',
|
| 283 |
+
)
|
| 284 |
+
return upsampled_logits
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class StyleMatte(nn.Module):
|
| 288 |
+
def __init__(self):
|
| 289 |
+
super(StyleMatte, self).__init__()
|
| 290 |
+
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
| 291 |
+
# configuration.num_labels = 1 ## set output as 1
|
| 292 |
+
self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
|
| 293 |
+
self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
|
| 294 |
+
self.fgf = FastGuidedFilter()
|
| 295 |
+
self.conv = nn.Conv2d(256,1,kernel_size=3,padding=1)
|
| 296 |
+
# self.mean = torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
|
| 297 |
+
# self.register_buffer('image_net_mean', self.mean)
|
| 298 |
+
# self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
|
| 299 |
+
# self.register_buffer('image_net_std', self.std)
|
| 300 |
+
def forward(self, image, normalize=False):
|
| 301 |
+
# if normalize:
|
| 302 |
+
# image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
|
| 303 |
+
|
| 304 |
+
decoder_out = self.pixel_decoder(image)
|
| 305 |
+
decoder_states=list(decoder_out.decoder_hidden_states)
|
| 306 |
+
decoder_states.append(decoder_out.decoder_last_hidden_state)
|
| 307 |
+
out_pure=self.fpn(decoder_states)
|
| 308 |
+
|
| 309 |
+
image_lr=nn.functional.interpolate(image.mean(1, keepdim=True),
|
| 310 |
+
scale_factor=0.25,
|
| 311 |
+
mode='bicubic',
|
| 312 |
+
align_corners=True
|
| 313 |
+
)
|
| 314 |
+
out = self.conv(out_pure)
|
| 315 |
+
out = self.fgf(image_lr,out,image.mean(1, keepdim=True))#.clip(0,1)
|
| 316 |
+
# out = nn.Sigmoid()(out)
|
| 317 |
+
# out = nn.functional.interpolate(out,
|
| 318 |
+
# scale_factor=4,
|
| 319 |
+
# mode='bicubic',
|
| 320 |
+
# align_corners=True
|
| 321 |
+
# )
|
| 322 |
+
|
| 323 |
+
return torch.sigmoid(out)
|
| 324 |
+
|
| 325 |
+
def get_training_params(self):
|
| 326 |
+
return list(self.fpn.parameters())+list(self.conv.parameters())#+list(self.fgf.parameters())
|
| 327 |
+
|
| 328 |
+
class GuidedFilter(nn.Module):
|
| 329 |
+
def __init__(self, r, eps=1e-8):
|
| 330 |
+
super(GuidedFilter, self).__init__()
|
| 331 |
+
|
| 332 |
+
self.r = r
|
| 333 |
+
self.eps = eps
|
| 334 |
+
self.boxfilter = BoxFilter(r)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def forward(self, x, y):
|
| 338 |
+
n_x, c_x, h_x, w_x = x.size()
|
| 339 |
+
n_y, c_y, h_y, w_y = y.size()
|
| 340 |
+
|
| 341 |
+
assert n_x == n_y
|
| 342 |
+
assert c_x == 1 or c_x == c_y
|
| 343 |
+
assert h_x == h_y and w_x == w_y
|
| 344 |
+
assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
|
| 345 |
+
|
| 346 |
+
# N
|
| 347 |
+
N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
|
| 348 |
+
|
| 349 |
+
# mean_x
|
| 350 |
+
mean_x = self.boxfilter(x) / N
|
| 351 |
+
# mean_y
|
| 352 |
+
mean_y = self.boxfilter(y) / N
|
| 353 |
+
# cov_xy
|
| 354 |
+
cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
|
| 355 |
+
# var_x
|
| 356 |
+
var_x = self.boxfilter(x * x) / N - mean_x * mean_x
|
| 357 |
+
|
| 358 |
+
# A
|
| 359 |
+
A = cov_xy / (var_x + self.eps)
|
| 360 |
+
# b
|
| 361 |
+
b = mean_y - A * mean_x
|
| 362 |
+
|
| 363 |
+
# mean_A; mean_b
|
| 364 |
+
mean_A = self.boxfilter(A) / N
|
| 365 |
+
mean_b = self.boxfilter(b) / N
|
| 366 |
+
|
| 367 |
+
return mean_A * x + mean_b
|
| 368 |
+
class FastGuidedFilter(nn.Module):
|
| 369 |
+
def __init__(self, r=1, eps=1e-8):
|
| 370 |
+
super(FastGuidedFilter, self).__init__()
|
| 371 |
+
|
| 372 |
+
self.r = r
|
| 373 |
+
self.eps = eps
|
| 374 |
+
self.boxfilter = BoxFilter(r)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def forward(self, lr_x, lr_y, hr_x):
|
| 378 |
+
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
|
| 379 |
+
n_lry, c_lry, h_lry, w_lry = lr_y.size()
|
| 380 |
+
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
|
| 381 |
+
|
| 382 |
+
assert n_lrx == n_lry and n_lry == n_hrx
|
| 383 |
+
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
|
| 384 |
+
assert h_lrx == h_lry and w_lrx == w_lry
|
| 385 |
+
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
|
| 386 |
+
|
| 387 |
+
## N
|
| 388 |
+
N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
|
| 389 |
+
|
| 390 |
+
## mean_x
|
| 391 |
+
mean_x = self.boxfilter(lr_x) / N
|
| 392 |
+
## mean_y
|
| 393 |
+
mean_y = self.boxfilter(lr_y) / N
|
| 394 |
+
## cov_xy
|
| 395 |
+
cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
|
| 396 |
+
## var_x
|
| 397 |
+
var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
|
| 398 |
+
|
| 399 |
+
## A
|
| 400 |
+
A = cov_xy / (var_x + self.eps)
|
| 401 |
+
## b
|
| 402 |
+
b = mean_y - A * mean_x
|
| 403 |
+
|
| 404 |
+
## mean_A; mean_b
|
| 405 |
+
mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
| 406 |
+
mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
| 407 |
+
|
| 408 |
+
return mean_A*hr_x+mean_b
|
| 409 |
+
class DeepGuidedFilterRefiner(nn.Module):
|
| 410 |
+
def __init__(self, hid_channels=16):
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
|
| 413 |
+
self.box_filter.weight.data[...] = 1 / 9
|
| 414 |
+
self.conv = nn.Sequential(
|
| 415 |
+
nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
|
| 416 |
+
nn.BatchNorm2d(hid_channels),
|
| 417 |
+
nn.ReLU(True),
|
| 418 |
+
nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
|
| 419 |
+
nn.BatchNorm2d(hid_channels),
|
| 420 |
+
nn.ReLU(True),
|
| 421 |
+
nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
|
| 425 |
+
fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
|
| 426 |
+
base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
|
| 427 |
+
base_y = torch.cat([base_fgr, base_pha], dim=1)
|
| 428 |
+
|
| 429 |
+
mean_x = self.box_filter(base_x)
|
| 430 |
+
mean_y = self.box_filter(base_y)
|
| 431 |
+
cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
|
| 432 |
+
var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
|
| 433 |
+
|
| 434 |
+
A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
|
| 435 |
+
b = mean_y - A * mean_x
|
| 436 |
+
|
| 437 |
+
H, W = fine_src.shape[2:]
|
| 438 |
+
A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
|
| 439 |
+
b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
|
| 440 |
+
|
| 441 |
+
out = A * fine_x + b
|
| 442 |
+
fgr, pha = out.split([3, 1], dim=1)
|
| 443 |
+
return fgr, pha
|
| 444 |
+
|
| 445 |
+
def diff_x(input, r):
|
| 446 |
+
assert input.dim() == 4
|
| 447 |
+
|
| 448 |
+
left = input[:, :, r:2 * r + 1]
|
| 449 |
+
middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
|
| 450 |
+
right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
|
| 451 |
+
|
| 452 |
+
output = torch.cat([left, middle, right], dim=2)
|
| 453 |
+
|
| 454 |
+
return output
|
| 455 |
+
|
| 456 |
+
def diff_y(input, r):
|
| 457 |
+
assert input.dim() == 4
|
| 458 |
+
|
| 459 |
+
left = input[:, :, :, r:2 * r + 1]
|
| 460 |
+
middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
|
| 461 |
+
right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
|
| 462 |
+
|
| 463 |
+
output = torch.cat([left, middle, right], dim=3)
|
| 464 |
+
|
| 465 |
+
return output
|
| 466 |
+
|
| 467 |
+
class BoxFilter(nn.Module):
|
| 468 |
+
def __init__(self, r):
|
| 469 |
+
super(BoxFilter, self).__init__()
|
| 470 |
+
|
| 471 |
+
self.r = r
|
| 472 |
+
|
| 473 |
+
def forward(self, x):
|
| 474 |
+
assert x.dim() == 4
|
| 475 |
+
|
| 476 |
+
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
|
| 477 |
+
|
| 478 |
+
if __name__ == '__main__':
|
| 479 |
+
model = StyleMatte().to(device)
|
| 480 |
+
out=model(torch.randn(1,3,640,480).to(devuce))
|
| 481 |
+
print(out.shape)
|
requirements.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==3.30.0
|
| 2 |
+
gradio_client==0.2.4
|
| 3 |
+
huggingface-hub==0.14.1
|
| 4 |
+
imageio==2.25.1
|
| 5 |
+
imgcat==0.5.0
|
| 6 |
+
ipykernel==6.16.0
|
| 7 |
+
ipython==8.5.0
|
| 8 |
+
ipywidgets==8.0.2
|
| 9 |
+
kiwisolver==1.4.2
|
| 10 |
+
kornia==0.6.9
|
| 11 |
+
legacy==0.1.6
|
| 12 |
+
numpy==1.21.6
|
| 13 |
+
omegaconf==2.2.3
|
| 14 |
+
opencv-python==4.5.5.62
|
| 15 |
+
opencv-python-headless==4.7.0.68
|
| 16 |
+
packaging==21.3
|
| 17 |
+
pandas==1.4.2
|
| 18 |
+
parso==0.8.3
|
| 19 |
+
Pillow==9.4.0
|
| 20 |
+
protobuf==3.20.1
|
| 21 |
+
Pygments==2.13.0
|
| 22 |
+
PyMatting==1.1.8
|
| 23 |
+
pyparsing==3.0.9
|
| 24 |
+
pyrsistent==0.19.3
|
| 25 |
+
scikit-image==0.19.3
|
| 26 |
+
scikit-learn==1.1.1
|
| 27 |
+
scipy==1.10.0
|
| 28 |
+
seaborn==0.12.2
|
| 29 |
+
sklearn==0.0
|
| 30 |
+
sniffio==1.3.0
|
| 31 |
+
soupsieve==2.4
|
| 32 |
+
timm==0.6.12
|
| 33 |
+
torch==1.11.0
|
| 34 |
+
torchaudio==0.11.0
|
| 35 |
+
torchvision==0.12.0
|
| 36 |
+
tornado==6.2
|
| 37 |
+
tqdm==4.64.1
|
| 38 |
+
transformers==4.28.1
|
stylematte.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ce985571e909b6677d7d25e560216fa3f620e5cd337a8382ee0799c6d9af16c
|
| 3 |
+
size 140040541
|
test.py
ADDED
|
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#modified from Github repo: https://github.com/JizhiziLi/P3M
|
| 2 |
+
#added inference code for other networks
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
import argparse
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from skimage.transform import resize
|
| 12 |
+
from torchvision import transforms,models
|
| 13 |
+
import os
|
| 14 |
+
from models import *
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import math
|
| 19 |
+
from torch.autograd import Variable
|
| 20 |
+
import torch.nn.functional as fnn
|
| 21 |
+
import glob
|
| 22 |
+
import tqdm
|
| 23 |
+
from torch.autograd import Variable
|
| 24 |
+
from typing import Type, Any, Callable, Union, List, Optional
|
| 25 |
+
import logging
|
| 26 |
+
import time
|
| 27 |
+
from omegaconf import OmegaConf
|
| 28 |
+
config = OmegaConf.load("base.yaml")
|
| 29 |
+
device = "cpu"
|
| 30 |
+
|
| 31 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 32 |
+
"3x3 convolution with padding"
|
| 33 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 34 |
+
padding=1, bias=False)
|
| 35 |
+
class TFI(nn.Module):
|
| 36 |
+
expansion = 1
|
| 37 |
+
def __init__(self, planes,stride=1):
|
| 38 |
+
super(TFI, self).__init__()
|
| 39 |
+
middle_planes = int(planes/2)
|
| 40 |
+
self.transform = conv1x1(planes, middle_planes)
|
| 41 |
+
self.conv1 = conv3x3(middle_planes*3, planes, stride)
|
| 42 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 43 |
+
self.relu = nn.ReLU(inplace=True)
|
| 44 |
+
self.stride = stride
|
| 45 |
+
def forward(self, input_s_guidance, input_m_decoder, input_m_encoder):
|
| 46 |
+
input_s_guidance_transform = self.transform(input_s_guidance)
|
| 47 |
+
input_m_decoder_transform = self.transform(input_m_decoder)
|
| 48 |
+
input_m_encoder_transform = self.transform(input_m_encoder)
|
| 49 |
+
x = torch.cat((input_s_guidance_transform,input_m_decoder_transform,input_m_encoder_transform),1)
|
| 50 |
+
out = self.conv1(x)
|
| 51 |
+
out = self.bn1(out)
|
| 52 |
+
out = self.relu(out)
|
| 53 |
+
return out
|
| 54 |
+
class SBFI(nn.Module):
|
| 55 |
+
def __init__(self, planes,stride=1):
|
| 56 |
+
super(SBFI, self).__init__()
|
| 57 |
+
self.stride = stride
|
| 58 |
+
self.transform1 = conv1x1(planes, int(planes/2))
|
| 59 |
+
self.transform2 = conv1x1(64, int(planes/2))
|
| 60 |
+
self.maxpool = nn.MaxPool2d(2, stride=stride)
|
| 61 |
+
self.conv1 = conv3x3(planes, planes, 1)
|
| 62 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 63 |
+
self.relu = nn.ReLU(inplace=True)
|
| 64 |
+
def forward(self, input_m_decoder,e0):
|
| 65 |
+
input_m_decoder_transform = self.transform1(input_m_decoder)
|
| 66 |
+
e0_maxpool = self.maxpool(e0)
|
| 67 |
+
e0_transform = self.transform2(e0_maxpool)
|
| 68 |
+
x = torch.cat((input_m_decoder_transform,e0_transform),1)
|
| 69 |
+
out = self.conv1(x)
|
| 70 |
+
out = self.bn1(out)
|
| 71 |
+
out = self.relu(out)
|
| 72 |
+
out = out+input_m_decoder
|
| 73 |
+
return out
|
| 74 |
+
class DBFI(nn.Module):
|
| 75 |
+
def __init__(self, planes,stride=1):
|
| 76 |
+
super(DBFI, self).__init__()
|
| 77 |
+
self.stride = stride
|
| 78 |
+
self.transform1 = conv1x1(planes, int(planes/2))
|
| 79 |
+
self.transform2 = conv1x1(512, int(planes/2))
|
| 80 |
+
self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
|
| 81 |
+
self.conv1 = conv3x3(planes, planes, 1)
|
| 82 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 83 |
+
self.relu = nn.ReLU(inplace=True)
|
| 84 |
+
self.conv2 = conv3x3(planes, 3, 1)
|
| 85 |
+
self.upsample2 = nn.Upsample(scale_factor=int(32/stride), mode='bilinear')
|
| 86 |
+
def forward(self, input_s_decoder,e4):
|
| 87 |
+
input_s_decoder_transform = self.transform1(input_s_decoder)
|
| 88 |
+
e4_transform = self.transform2(e4)
|
| 89 |
+
e4_upsample = self.upsample(e4_transform)
|
| 90 |
+
x = torch.cat((input_s_decoder_transform,e4_upsample),1)
|
| 91 |
+
out = self.conv1(x)
|
| 92 |
+
out = self.bn1(out)
|
| 93 |
+
out = self.relu(out)
|
| 94 |
+
out = out+input_s_decoder
|
| 95 |
+
out_side = self.conv2(out)
|
| 96 |
+
out_side = self.upsample2(out_side)
|
| 97 |
+
return out, out_side
|
| 98 |
+
class P3mNet(nn.Module):
|
| 99 |
+
def __init__(self):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.resnet = resnet34_mp()
|
| 102 |
+
############################
|
| 103 |
+
### Encoder part - RESNETMP
|
| 104 |
+
############################
|
| 105 |
+
self.encoder0 = nn.Sequential(
|
| 106 |
+
self.resnet.conv1,
|
| 107 |
+
self.resnet.bn1,
|
| 108 |
+
self.resnet.relu,
|
| 109 |
+
)
|
| 110 |
+
self.mp0 = self.resnet.maxpool1
|
| 111 |
+
self.encoder1 = nn.Sequential(
|
| 112 |
+
self.resnet.layer1)
|
| 113 |
+
self.mp1 = self.resnet.maxpool2
|
| 114 |
+
self.encoder2 = self.resnet.layer2
|
| 115 |
+
self.mp2 = self.resnet.maxpool3
|
| 116 |
+
self.encoder3 = self.resnet.layer3
|
| 117 |
+
self.mp3 = self.resnet.maxpool4
|
| 118 |
+
self.encoder4 = self.resnet.layer4
|
| 119 |
+
self.mp4 = self.resnet.maxpool5
|
| 120 |
+
|
| 121 |
+
self.tfi_3 = TFI(256)
|
| 122 |
+
self.tfi_2 = TFI(128)
|
| 123 |
+
self.tfi_1 = TFI(64)
|
| 124 |
+
self.tfi_0 = TFI(64)
|
| 125 |
+
|
| 126 |
+
self.sbfi_2 = SBFI(128, 8)
|
| 127 |
+
self.sbfi_1 = SBFI(64, 4)
|
| 128 |
+
self.sbfi_0 = SBFI(64, 2)
|
| 129 |
+
|
| 130 |
+
self.dbfi_2 = DBFI(128, 4)
|
| 131 |
+
self.dbfi_1 = DBFI(64, 8)
|
| 132 |
+
self.dbfi_0 = DBFI(64, 16)
|
| 133 |
+
|
| 134 |
+
##########################
|
| 135 |
+
### Decoder part - GLOBAL
|
| 136 |
+
##########################
|
| 137 |
+
self.decoder4_g = nn.Sequential(
|
| 138 |
+
nn.Conv2d(512,512,3,padding=1),
|
| 139 |
+
nn.BatchNorm2d(512),
|
| 140 |
+
nn.ReLU(inplace=True),
|
| 141 |
+
nn.Conv2d(512,512,3,padding=1),
|
| 142 |
+
nn.BatchNorm2d(512),
|
| 143 |
+
nn.ReLU(inplace=True),
|
| 144 |
+
nn.Conv2d(512,256,3,padding=1),
|
| 145 |
+
nn.BatchNorm2d(256),
|
| 146 |
+
nn.ReLU(inplace=True),
|
| 147 |
+
nn.Upsample(scale_factor=2, mode='bilinear') )
|
| 148 |
+
self.decoder3_g = nn.Sequential(
|
| 149 |
+
nn.Conv2d(256,256,3,padding=1),
|
| 150 |
+
nn.BatchNorm2d(256),
|
| 151 |
+
nn.ReLU(inplace=True),
|
| 152 |
+
nn.Conv2d(256,256,3,padding=1),
|
| 153 |
+
nn.BatchNorm2d(256),
|
| 154 |
+
nn.ReLU(inplace=True),
|
| 155 |
+
nn.Conv2d(256,128,3,padding=1),
|
| 156 |
+
nn.BatchNorm2d(128),
|
| 157 |
+
nn.ReLU(inplace=True),
|
| 158 |
+
nn.Upsample(scale_factor=2, mode='bilinear') )
|
| 159 |
+
self.decoder2_g = nn.Sequential(
|
| 160 |
+
nn.Conv2d(128,128,3,padding=1),
|
| 161 |
+
nn.BatchNorm2d(128),
|
| 162 |
+
nn.ReLU(inplace=True),
|
| 163 |
+
nn.Conv2d(128,128,3,padding=1),
|
| 164 |
+
nn.BatchNorm2d(128),
|
| 165 |
+
nn.ReLU(inplace=True),
|
| 166 |
+
nn.Conv2d(128,64,3,padding=1),
|
| 167 |
+
nn.BatchNorm2d(64),
|
| 168 |
+
nn.ReLU(inplace=True),
|
| 169 |
+
nn.Upsample(scale_factor=2, mode='bilinear'))
|
| 170 |
+
self.decoder1_g = nn.Sequential(
|
| 171 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 172 |
+
nn.BatchNorm2d(64),
|
| 173 |
+
nn.ReLU(inplace=True),
|
| 174 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 175 |
+
nn.BatchNorm2d(64),
|
| 176 |
+
nn.ReLU(inplace=True),
|
| 177 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 178 |
+
nn.BatchNorm2d(64),
|
| 179 |
+
nn.ReLU(inplace=True),
|
| 180 |
+
nn.Upsample(scale_factor=2, mode='bilinear'))
|
| 181 |
+
self.decoder0_g = nn.Sequential(
|
| 182 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 183 |
+
nn.BatchNorm2d(64),
|
| 184 |
+
nn.ReLU(inplace=True),
|
| 185 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 186 |
+
nn.BatchNorm2d(64),
|
| 187 |
+
nn.ReLU(inplace=True),
|
| 188 |
+
nn.Conv2d(64,3,3,padding=1),
|
| 189 |
+
nn.Upsample(scale_factor=2, mode='bilinear'))
|
| 190 |
+
|
| 191 |
+
##########################
|
| 192 |
+
### Decoder part - LOCAL
|
| 193 |
+
##########################
|
| 194 |
+
self.decoder4_l = nn.Sequential(
|
| 195 |
+
nn.Conv2d(512,512,3,padding=1),
|
| 196 |
+
nn.BatchNorm2d(512),
|
| 197 |
+
nn.ReLU(inplace=True),
|
| 198 |
+
nn.Conv2d(512,512,3,padding=1),
|
| 199 |
+
nn.BatchNorm2d(512),
|
| 200 |
+
nn.ReLU(inplace=True),
|
| 201 |
+
nn.Conv2d(512,256,3,padding=1),
|
| 202 |
+
nn.BatchNorm2d(256),
|
| 203 |
+
nn.ReLU(inplace=True))
|
| 204 |
+
self.decoder3_l = nn.Sequential(
|
| 205 |
+
nn.Conv2d(256,256,3,padding=1),
|
| 206 |
+
nn.BatchNorm2d(256),
|
| 207 |
+
nn.ReLU(inplace=True),
|
| 208 |
+
nn.Conv2d(256,256,3,padding=1),
|
| 209 |
+
nn.BatchNorm2d(256),
|
| 210 |
+
nn.ReLU(inplace=True),
|
| 211 |
+
nn.Conv2d(256,128,3,padding=1),
|
| 212 |
+
nn.BatchNorm2d(128),
|
| 213 |
+
nn.ReLU(inplace=True))
|
| 214 |
+
self.decoder2_l = nn.Sequential(
|
| 215 |
+
nn.Conv2d(128,128,3,padding=1),
|
| 216 |
+
nn.BatchNorm2d(128),
|
| 217 |
+
nn.ReLU(inplace=True),
|
| 218 |
+
nn.Conv2d(128,128,3,padding=1),
|
| 219 |
+
nn.BatchNorm2d(128),
|
| 220 |
+
nn.ReLU(inplace=True),
|
| 221 |
+
nn.Conv2d(128,64,3,padding=1),
|
| 222 |
+
nn.BatchNorm2d(64),
|
| 223 |
+
nn.ReLU(inplace=True))
|
| 224 |
+
self.decoder1_l = nn.Sequential(
|
| 225 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 226 |
+
nn.BatchNorm2d(64),
|
| 227 |
+
nn.ReLU(inplace=True),
|
| 228 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 229 |
+
nn.BatchNorm2d(64),
|
| 230 |
+
nn.ReLU(inplace=True),
|
| 231 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 232 |
+
nn.BatchNorm2d(64),
|
| 233 |
+
nn.ReLU(inplace=True))
|
| 234 |
+
self.decoder0_l = nn.Sequential(
|
| 235 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 236 |
+
nn.BatchNorm2d(64),
|
| 237 |
+
nn.ReLU(inplace=True),
|
| 238 |
+
nn.Conv2d(64,64,3,padding=1),
|
| 239 |
+
nn.BatchNorm2d(64),
|
| 240 |
+
nn.ReLU(inplace=True))
|
| 241 |
+
self.decoder_final_l = nn.Conv2d(64,1,3,padding=1)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def forward(self, input):
|
| 245 |
+
##########################
|
| 246 |
+
### Encoder part - RESNET
|
| 247 |
+
##########################
|
| 248 |
+
e0 = self.encoder0(input)
|
| 249 |
+
e0p, id0 = self.mp0(e0)
|
| 250 |
+
e1p, id1 = self.mp1(e0p)
|
| 251 |
+
e1 = self.encoder1(e1p)
|
| 252 |
+
e2p, id2 = self.mp2(e1)
|
| 253 |
+
e2 = self.encoder2(e2p)
|
| 254 |
+
e3p, id3 = self.mp3(e2)
|
| 255 |
+
e3 = self.encoder3(e3p)
|
| 256 |
+
e4p, id4 = self.mp4(e3)
|
| 257 |
+
e4 = self.encoder4(e4p)
|
| 258 |
+
###########################
|
| 259 |
+
### Decoder part - Global
|
| 260 |
+
###########################
|
| 261 |
+
d4_g = self.decoder4_g(e4)
|
| 262 |
+
d3_g = self.decoder3_g(d4_g)
|
| 263 |
+
d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, e4)
|
| 264 |
+
d2_g = self.decoder2_g(d2_g)
|
| 265 |
+
d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, e4)
|
| 266 |
+
d1_g = self.decoder1_g(d1_g)
|
| 267 |
+
d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, e4)
|
| 268 |
+
d0_g = self.decoder0_g(d0_g)
|
| 269 |
+
global_sigmoid = d0_g
|
| 270 |
+
###########################
|
| 271 |
+
### Decoder part - Local
|
| 272 |
+
###########################
|
| 273 |
+
d4_l = self.decoder4_l(e4)
|
| 274 |
+
d4_l = F.max_unpool2d(d4_l, id4, kernel_size=2, stride=2)
|
| 275 |
+
d3_l = self.tfi_3(d4_g, d4_l, e3)
|
| 276 |
+
d3_l = self.decoder3_l(d3_l)
|
| 277 |
+
d3_l = F.max_unpool2d(d3_l, id3, kernel_size=2, stride=2)
|
| 278 |
+
d2_l = self.tfi_2(d3_g, d3_l, e2)
|
| 279 |
+
d2_l = self.sbfi_2(d2_l, e0)
|
| 280 |
+
d2_l = self.decoder2_l(d2_l)
|
| 281 |
+
d2_l = F.max_unpool2d(d2_l, id2, kernel_size=2, stride=2)
|
| 282 |
+
d1_l = self.tfi_1(d2_g, d2_l, e1)
|
| 283 |
+
d1_l = self.sbfi_1(d1_l, e0)
|
| 284 |
+
d1_l = self.decoder1_l(d1_l)
|
| 285 |
+
d1_l = F.max_unpool2d(d1_l, id1, kernel_size=2, stride=2)
|
| 286 |
+
d0_l = self.tfi_0(d1_g, d1_l, e0p)
|
| 287 |
+
d0_l = self.sbfi_0(d0_l, e0)
|
| 288 |
+
d0_l = self.decoder0_l(d0_l)
|
| 289 |
+
d0_l = F.max_unpool2d(d0_l, id0, kernel_size=2, stride=2)
|
| 290 |
+
d0_l = self.decoder_final_l(d0_l)
|
| 291 |
+
local_sigmoid = F.sigmoid(d0_l)
|
| 292 |
+
##########################
|
| 293 |
+
### Fusion net - G/L
|
| 294 |
+
##########################
|
| 295 |
+
fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid)
|
| 296 |
+
return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 301 |
+
"""3x3 convolution with padding"""
|
| 302 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 303 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 307 |
+
"""1x1 convolution"""
|
| 308 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class BasicBlock(nn.Module):
|
| 312 |
+
expansion: int = 1
|
| 313 |
+
|
| 314 |
+
def __init__(
|
| 315 |
+
self,
|
| 316 |
+
inplanes: int,
|
| 317 |
+
planes: int,
|
| 318 |
+
stride: int = 1,
|
| 319 |
+
downsample: Optional[nn.Module] = None,
|
| 320 |
+
groups: int = 1,
|
| 321 |
+
base_width: int = 64,
|
| 322 |
+
dilation: int = 1,
|
| 323 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
| 324 |
+
) -> None:
|
| 325 |
+
super(BasicBlock, self).__init__()
|
| 326 |
+
if norm_layer is None:
|
| 327 |
+
norm_layer = nn.BatchNorm2d
|
| 328 |
+
if groups != 1 or base_width != 64:
|
| 329 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 330 |
+
if dilation > 1:
|
| 331 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 332 |
+
|
| 333 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 334 |
+
self.bn1 = norm_layer(planes)
|
| 335 |
+
self.relu = nn.ReLU(inplace=True)
|
| 336 |
+
self.conv2 = conv3x3(planes, planes)
|
| 337 |
+
self.bn2 = norm_layer(planes)
|
| 338 |
+
self.downsample = downsample
|
| 339 |
+
self.stride = stride
|
| 340 |
+
|
| 341 |
+
def forward(self, x):
|
| 342 |
+
identity = x
|
| 343 |
+
|
| 344 |
+
out = self.conv1(x)
|
| 345 |
+
out = self.bn1(out)
|
| 346 |
+
out = self.relu(out)
|
| 347 |
+
|
| 348 |
+
out = self.conv2(out)
|
| 349 |
+
out = self.bn2(out)
|
| 350 |
+
|
| 351 |
+
if self.downsample is not None:
|
| 352 |
+
identity = self.downsample(x)
|
| 353 |
+
out += identity
|
| 354 |
+
out = self.relu(out)
|
| 355 |
+
|
| 356 |
+
return out
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class Bottleneck(nn.Module):
|
| 360 |
+
expansion = 4
|
| 361 |
+
__constants__ = ['downsample']
|
| 362 |
+
|
| 363 |
+
def __init__(self, inplanes, planes,stride=1, downsample=None, groups=1,
|
| 364 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 365 |
+
super(Bottleneck, self).__init__()
|
| 366 |
+
if norm_layer is None:
|
| 367 |
+
norm_layer = nn.BatchNorm2d
|
| 368 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 369 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 370 |
+
self.bn1 = norm_layer(width)
|
| 371 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 372 |
+
self.bn2 = norm_layer(width)
|
| 373 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 374 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 375 |
+
self.relu = nn.ReLU(inplace=True)
|
| 376 |
+
self.downsample = downsample
|
| 377 |
+
self.stride = stride
|
| 378 |
+
|
| 379 |
+
def forward(self, x):
|
| 380 |
+
identity = x
|
| 381 |
+
|
| 382 |
+
out = self.conv1(x)
|
| 383 |
+
out = self.bn1(out)
|
| 384 |
+
out = self.relu(out)
|
| 385 |
+
|
| 386 |
+
out = self.conv2(out)
|
| 387 |
+
out = self.bn2(out)
|
| 388 |
+
out = self.relu(out)
|
| 389 |
+
out = self.attention(out)
|
| 390 |
+
|
| 391 |
+
out = self.conv3(out)
|
| 392 |
+
out = self.bn3(out)
|
| 393 |
+
if self.downsample is not None:
|
| 394 |
+
identity = self.downsample(x)
|
| 395 |
+
out += identity
|
| 396 |
+
out = self.relu(out)
|
| 397 |
+
|
| 398 |
+
return out
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class ResNet(nn.Module):
|
| 402 |
+
|
| 403 |
+
def __init__(self, block, layers, zero_init_residual=False,
|
| 404 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 405 |
+
norm_layer=None):
|
| 406 |
+
super(ResNet, self).__init__()
|
| 407 |
+
if norm_layer is None:
|
| 408 |
+
norm_layer = nn.BatchNorm2d
|
| 409 |
+
self._norm_layer = norm_layer
|
| 410 |
+
self.inplanes = 64
|
| 411 |
+
self.dilation = 1
|
| 412 |
+
if replace_stride_with_dilation is None:
|
| 413 |
+
replace_stride_with_dilation = [False, False, False]
|
| 414 |
+
if len(replace_stride_with_dilation) != 3:
|
| 415 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 416 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 417 |
+
self.groups = groups
|
| 418 |
+
self.base_width = width_per_group
|
| 419 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3,
|
| 420 |
+
bias=False)
|
| 421 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 422 |
+
self.relu = nn.ReLU(inplace=True)
|
| 423 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
| 424 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
| 425 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
| 426 |
+
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
| 427 |
+
self.maxpool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
| 428 |
+
#pdb.set_trace()
|
| 429 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 430 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=1,
|
| 431 |
+
dilate=replace_stride_with_dilation[0])
|
| 432 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
| 433 |
+
dilate=replace_stride_with_dilation[1])
|
| 434 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
| 435 |
+
dilate=replace_stride_with_dilation[2])
|
| 436 |
+
|
| 437 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 438 |
+
self.fc = nn.Linear(512 * block.expansion, 1000)
|
| 439 |
+
|
| 440 |
+
for m in self.modules():
|
| 441 |
+
if isinstance(m, nn.Conv2d):
|
| 442 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 443 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 444 |
+
nn.init.constant_(m.weight, 1)
|
| 445 |
+
nn.init.constant_(m.bias, 0)
|
| 446 |
+
if zero_init_residual:
|
| 447 |
+
for m in self.modules():
|
| 448 |
+
if isinstance(m, Bottleneck):
|
| 449 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 450 |
+
elif isinstance(m, BasicBlock):
|
| 451 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 452 |
+
|
| 453 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 454 |
+
norm_layer = self._norm_layer
|
| 455 |
+
downsample = None
|
| 456 |
+
previous_dilation = self.dilation
|
| 457 |
+
if dilate:
|
| 458 |
+
self.dilation *= stride
|
| 459 |
+
stride = 1
|
| 460 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 461 |
+
downsample = nn.Sequential(
|
| 462 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 463 |
+
norm_layer(planes * block.expansion),
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
layers = []
|
| 467 |
+
layers.append(block(self.inplanes, planes,stride, downsample, self.groups,
|
| 468 |
+
self.base_width, previous_dilation, norm_layer))
|
| 469 |
+
self.inplanes = planes * block.expansion
|
| 470 |
+
for _ in range(1, blocks):
|
| 471 |
+
layers.append(block(self.inplanes, planes,groups=self.groups,
|
| 472 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 473 |
+
norm_layer=norm_layer))
|
| 474 |
+
|
| 475 |
+
return nn.Sequential(*layers)
|
| 476 |
+
|
| 477 |
+
def _forward_impl(self, x):
|
| 478 |
+
x1 = self.conv1(x)
|
| 479 |
+
x1 = self.bn1(x1)
|
| 480 |
+
x1 = self.relu(x1)
|
| 481 |
+
x1, idx1 = self.maxpool1(x1)
|
| 482 |
+
|
| 483 |
+
x2, idx2 = self.maxpool2(x1)
|
| 484 |
+
x2 = self.layer1(x2)
|
| 485 |
+
|
| 486 |
+
x3, idx3 = self.maxpool3(x2)
|
| 487 |
+
x3 = self.layer2(x3)
|
| 488 |
+
|
| 489 |
+
x4, idx4 = self.maxpool4(x3)
|
| 490 |
+
x4 = self.layer3(x4)
|
| 491 |
+
|
| 492 |
+
x5, idx5 = self.maxpool5(x4)
|
| 493 |
+
x5 = self.layer4(x5)
|
| 494 |
+
|
| 495 |
+
x_cls = self.avgpool(x5)
|
| 496 |
+
x_cls = torch.flatten(x_cls, 1)
|
| 497 |
+
x_cls = self.fc(x_cls)
|
| 498 |
+
|
| 499 |
+
return x_cls
|
| 500 |
+
|
| 501 |
+
def forward(self, x):
|
| 502 |
+
return self._forward_impl(x)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def resnet34_mp(**kwargs):
|
| 506 |
+
r"""ResNet-34 model from
|
| 507 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`
|
| 508 |
+
"""
|
| 509 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
| 510 |
+
checkpoint = torch.load("checkpoints/r34mp_pretrained_imagenet.pth.tar")
|
| 511 |
+
model.load_state_dict(checkpoint)
|
| 512 |
+
return model
|
| 513 |
+
|
| 514 |
+
##############################
|
| 515 |
+
### Training loses for P3M-NET
|
| 516 |
+
##############################
|
| 517 |
+
def get_crossentropy_loss(gt,pre):
|
| 518 |
+
gt_copy = gt.clone()
|
| 519 |
+
gt_copy[gt_copy==0] = 0
|
| 520 |
+
gt_copy[gt_copy==255] = 2
|
| 521 |
+
gt_copy[gt_copy>2] = 1
|
| 522 |
+
gt_copy = gt_copy.long()
|
| 523 |
+
gt_copy = gt_copy[:,0,:,:]
|
| 524 |
+
criterion = nn.CrossEntropyLoss()
|
| 525 |
+
entropy_loss = criterion(pre, gt_copy)
|
| 526 |
+
return entropy_loss
|
| 527 |
+
|
| 528 |
+
def get_alpha_loss(predict, alpha, trimap):
|
| 529 |
+
weighted = torch.zeros(trimap.shape).to(device)
|
| 530 |
+
weighted[trimap == 128] = 1.
|
| 531 |
+
alpha_f = alpha / 255.
|
| 532 |
+
alpha_f = alpha_f.to(device)
|
| 533 |
+
diff = predict - alpha_f
|
| 534 |
+
diff = diff * weighted
|
| 535 |
+
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
| 536 |
+
alpha_loss_weighted = alpha_loss.sum() / (weighted.sum() + 1.)
|
| 537 |
+
return alpha_loss_weighted
|
| 538 |
+
|
| 539 |
+
def get_alpha_loss_whole_img(predict, alpha):
|
| 540 |
+
weighted = torch.ones(alpha.shape).to(device)
|
| 541 |
+
alpha_f = alpha / 255.
|
| 542 |
+
alpha_f = alpha_f.to(device)
|
| 543 |
+
diff = predict - alpha_f
|
| 544 |
+
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
| 545 |
+
alpha_loss = alpha_loss.sum()/(weighted.sum())
|
| 546 |
+
return alpha_loss
|
| 547 |
+
|
| 548 |
+
## Laplacian loss is refer to
|
| 549 |
+
## https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270
|
| 550 |
+
def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
|
| 551 |
+
if size % 2 != 1:
|
| 552 |
+
raise ValueError("kernel size must be uneven")
|
| 553 |
+
grid = np.float32(np.mgrid[0:size,0:size].T)
|
| 554 |
+
gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2
|
| 555 |
+
kernel = np.sum(gaussian(grid), axis=2)
|
| 556 |
+
kernel /= np.sum(kernel)
|
| 557 |
+
kernel = np.tile(kernel, (n_channels, 1, 1))
|
| 558 |
+
kernel = torch.FloatTensor(kernel[:, None, :, :]).to(device)
|
| 559 |
+
return Variable(kernel, requires_grad=False)
|
| 560 |
+
|
| 561 |
+
def conv_gauss(img, kernel):
|
| 562 |
+
""" convolve img with a gaussian kernel that has been built with build_gauss_kernel """
|
| 563 |
+
n_channels, _, kw, kh = kernel.shape
|
| 564 |
+
img = fnn.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
|
| 565 |
+
return fnn.conv2d(img, kernel, groups=n_channels)
|
| 566 |
+
|
| 567 |
+
def laplacian_pyramid(img, kernel, max_levels=5):
|
| 568 |
+
current = img
|
| 569 |
+
pyr = []
|
| 570 |
+
for level in range(max_levels):
|
| 571 |
+
filtered = conv_gauss(current, kernel)
|
| 572 |
+
diff = current - filtered
|
| 573 |
+
pyr.append(diff)
|
| 574 |
+
current = fnn.avg_pool2d(filtered, 2)
|
| 575 |
+
pyr.append(current)
|
| 576 |
+
return pyr
|
| 577 |
+
|
| 578 |
+
def get_laplacian_loss(predict, alpha, trimap):
|
| 579 |
+
weighted = torch.zeros(trimap.shape).to(device)
|
| 580 |
+
weighted[trimap == 128] = 1.
|
| 581 |
+
alpha_f = alpha / 255.
|
| 582 |
+
alpha_f = alpha_f.to(device)
|
| 583 |
+
alpha_f = alpha_f.clone()*weighted
|
| 584 |
+
predict = predict.clone()*weighted
|
| 585 |
+
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
| 586 |
+
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
|
| 587 |
+
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
|
| 588 |
+
laplacian_loss_weighted = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
|
| 589 |
+
return laplacian_loss_weighted
|
| 590 |
+
|
| 591 |
+
def get_laplacian_loss_whole_img(predict, alpha):
|
| 592 |
+
alpha_f = alpha / 255.
|
| 593 |
+
alpha_f = alpha_f.to(device)
|
| 594 |
+
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
| 595 |
+
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
|
| 596 |
+
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
|
| 597 |
+
laplacian_loss = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
|
| 598 |
+
return laplacian_loss
|
| 599 |
+
|
| 600 |
+
def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
|
| 601 |
+
weighted = torch.ones(alpha.shape).to(device)
|
| 602 |
+
predict_3 = torch.cat((predict, predict, predict), 1)
|
| 603 |
+
comp = predict_3 * fg + (1. - predict_3) * bg
|
| 604 |
+
comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
|
| 605 |
+
comp_loss = comp_loss.sum()/(weighted.sum())
|
| 606 |
+
return comp_loss
|
| 607 |
+
|
| 608 |
+
##############################
|
| 609 |
+
### Test loss for matting
|
| 610 |
+
##############################
|
| 611 |
+
def calculate_sad_mse_mad(predict_old,alpha,trimap):
|
| 612 |
+
predict = np.copy(predict_old)
|
| 613 |
+
pixel = float((trimap == 128).sum())
|
| 614 |
+
predict[trimap == 255] = 1.
|
| 615 |
+
predict[trimap == 0 ] = 0.
|
| 616 |
+
sad_diff = np.sum(np.abs(predict - alpha))/1000
|
| 617 |
+
if pixel==0:
|
| 618 |
+
pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum())
|
| 619 |
+
mse_diff = np.sum((predict - alpha) ** 2)/pixel
|
| 620 |
+
mad_diff = np.sum(np.abs(predict - alpha))/pixel
|
| 621 |
+
return sad_diff, mse_diff, mad_diff
|
| 622 |
+
|
| 623 |
+
def calculate_sad_mse_mad_whole_img(predict, alpha):
|
| 624 |
+
pixel = predict.shape[0]*predict.shape[1]
|
| 625 |
+
sad_diff = np.sum(np.abs(predict - alpha))/1000
|
| 626 |
+
mse_diff = np.sum((predict - alpha) ** 2)/pixel
|
| 627 |
+
mad_diff = np.sum(np.abs(predict - alpha))/pixel
|
| 628 |
+
return sad_diff, mse_diff, mad_diff
|
| 629 |
+
|
| 630 |
+
def calculate_sad_fgbg(predict, alpha, trimap):
|
| 631 |
+
sad_diff = np.abs(predict-alpha)
|
| 632 |
+
weight_fg = np.zeros(predict.shape)
|
| 633 |
+
weight_bg = np.zeros(predict.shape)
|
| 634 |
+
weight_trimap = np.zeros(predict.shape)
|
| 635 |
+
weight_fg[trimap==255] = 1.
|
| 636 |
+
weight_bg[trimap==0 ] = 1.
|
| 637 |
+
weight_trimap[trimap==128 ] = 1.
|
| 638 |
+
sad_fg = np.sum(sad_diff*weight_fg)/1000
|
| 639 |
+
sad_bg = np.sum(sad_diff*weight_bg)/1000
|
| 640 |
+
sad_trimap = np.sum(sad_diff*weight_trimap)/1000
|
| 641 |
+
return sad_fg, sad_bg
|
| 642 |
+
|
| 643 |
+
def compute_gradient_whole_image(pd, gt):
|
| 644 |
+
from scipy.ndimage import gaussian_filter
|
| 645 |
+
pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32)
|
| 646 |
+
pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32)
|
| 647 |
+
gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32)
|
| 648 |
+
gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32)
|
| 649 |
+
pd_mag = np.sqrt(pd_x**2 + pd_y**2)
|
| 650 |
+
gt_mag = np.sqrt(gt_x**2 + gt_y**2)
|
| 651 |
+
|
| 652 |
+
error_map = np.square(pd_mag - gt_mag)
|
| 653 |
+
loss = np.sum(error_map) / 10
|
| 654 |
+
return loss
|
| 655 |
+
|
| 656 |
+
def compute_connectivity_loss_whole_image(pd, gt, step=0.1):
|
| 657 |
+
|
| 658 |
+
from scipy.ndimage import morphology
|
| 659 |
+
from skimage.measure import label, regionprops
|
| 660 |
+
h, w = pd.shape
|
| 661 |
+
thresh_steps = np.arange(0, 1.1, step)
|
| 662 |
+
l_map = -1 * np.ones((h, w), dtype=np.float32)
|
| 663 |
+
lambda_map = np.ones((h, w), dtype=np.float32)
|
| 664 |
+
for i in range(1, thresh_steps.size):
|
| 665 |
+
pd_th = pd >= thresh_steps[i]
|
| 666 |
+
gt_th = gt >= thresh_steps[i]
|
| 667 |
+
label_image = label(pd_th & gt_th, connectivity=1)
|
| 668 |
+
cc = regionprops(label_image)
|
| 669 |
+
size_vec = np.array([c.area for c in cc])
|
| 670 |
+
if len(size_vec) == 0:
|
| 671 |
+
continue
|
| 672 |
+
max_id = np.argmax(size_vec)
|
| 673 |
+
coords = cc[max_id].coords
|
| 674 |
+
omega = np.zeros((h, w), dtype=np.float32)
|
| 675 |
+
omega[coords[:, 0], coords[:, 1]] = 1
|
| 676 |
+
flag = (l_map == -1) & (omega == 0)
|
| 677 |
+
l_map[flag == 1] = thresh_steps[i-1]
|
| 678 |
+
dist_maps = morphology.distance_transform_edt(omega==0)
|
| 679 |
+
dist_maps = dist_maps / dist_maps.max()
|
| 680 |
+
l_map[l_map == -1] = 1
|
| 681 |
+
d_pd = pd - l_map
|
| 682 |
+
d_gt = gt - l_map
|
| 683 |
+
phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32)
|
| 684 |
+
phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32)
|
| 685 |
+
loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000
|
| 686 |
+
return loss
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def gen_trimap_from_segmap_e2e(segmap):
|
| 691 |
+
trimap = np.argmax(segmap, axis=1)[0]
|
| 692 |
+
trimap = trimap.astype(np.int64)
|
| 693 |
+
trimap[trimap==1]=128
|
| 694 |
+
trimap[trimap==2]=255
|
| 695 |
+
return trimap.astype(np.uint8)
|
| 696 |
+
|
| 697 |
+
def get_masked_local_from_global(global_sigmoid, local_sigmoid):
|
| 698 |
+
values, index = torch.max(global_sigmoid,1)
|
| 699 |
+
index = index[:,None,:,:].float()
|
| 700 |
+
### index <===> [0, 1, 2]
|
| 701 |
+
### bg_mask <===> [1, 0, 0]
|
| 702 |
+
bg_mask = index.clone()
|
| 703 |
+
bg_mask[bg_mask==2]=1
|
| 704 |
+
bg_mask = 1- bg_mask
|
| 705 |
+
### trimap_mask <===> [0, 1, 0]
|
| 706 |
+
trimap_mask = index.clone()
|
| 707 |
+
trimap_mask[trimap_mask==2]=0
|
| 708 |
+
### fg_mask <===> [0, 0, 1]
|
| 709 |
+
fg_mask = index.clone()
|
| 710 |
+
fg_mask[fg_mask==1]=0
|
| 711 |
+
fg_mask[fg_mask==2]=1
|
| 712 |
+
fusion_sigmoid = local_sigmoid*trimap_mask+fg_mask
|
| 713 |
+
return fusion_sigmoid
|
| 714 |
+
|
| 715 |
+
def get_masked_local_from_global_test(global_result, local_result):
|
| 716 |
+
weighted_global = np.ones(global_result.shape)
|
| 717 |
+
weighted_global[global_result==255] = 0
|
| 718 |
+
weighted_global[global_result==0] = 0
|
| 719 |
+
fusion_result = global_result*(1.-weighted_global)/255+local_result*weighted_global
|
| 720 |
+
return fusion_result
|
| 721 |
+
def inference_once( model, scale_img, scale_trimap=None):
|
| 722 |
+
pred_list = []
|
| 723 |
+
tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).to(device)
|
| 724 |
+
input_t = tensor_img
|
| 725 |
+
input_t = input_t/255.0
|
| 726 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 727 |
+
std=[0.229, 0.224, 0.225])
|
| 728 |
+
input_t = normalize(input_t)
|
| 729 |
+
input_t = input_t.unsqueeze(0).float()
|
| 730 |
+
# pred_global, pred_local, pred_fusion = model(input_t)[:3]
|
| 731 |
+
pred_fusion = model(input_t)[:3]
|
| 732 |
+
pred_global = pred_fusion
|
| 733 |
+
pred_local = pred_fusion
|
| 734 |
+
|
| 735 |
+
pred_global = pred_global.data.cpu().numpy()
|
| 736 |
+
pred_global = gen_trimap_from_segmap_e2e(pred_global)
|
| 737 |
+
pred_local = pred_local.data.cpu().numpy()[0,0,:,:]
|
| 738 |
+
pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:]
|
| 739 |
+
return pred_global, pred_local, pred_fusion
|
| 740 |
+
|
| 741 |
+
# def inference_img( test_choice,model, img):
|
| 742 |
+
# h, w, c = img.shape
|
| 743 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, h - (h % 32))
|
| 744 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, w - (w % 32))
|
| 745 |
+
# if test_choice=='HYBRID':
|
| 746 |
+
# global_ratio = 1/2
|
| 747 |
+
# local_ratio = 1
|
| 748 |
+
# resize_h = int(h*global_ratio)
|
| 749 |
+
# resize_w = int(w*global_ratio)
|
| 750 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
|
| 751 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
|
| 752 |
+
# scale_img = resize(img,(new_h,new_w))*255.0
|
| 753 |
+
# pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once( model, scale_img)
|
| 754 |
+
# pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0
|
| 755 |
+
# resize_h = int(h*local_ratio)
|
| 756 |
+
# resize_w = int(w*local_ratio)
|
| 757 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
|
| 758 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
|
| 759 |
+
# scale_img = resize(img,(new_h,new_w))*255.0
|
| 760 |
+
# pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once( model, scale_img)
|
| 761 |
+
# pred_retouching_2 = resize(pred_retouching_2,(h,w))
|
| 762 |
+
# pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2)
|
| 763 |
+
# return pred_fusion
|
| 764 |
+
# else:
|
| 765 |
+
# resize_h = int(h/2)
|
| 766 |
+
# resize_w = int(w/2)
|
| 767 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
|
| 768 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
|
| 769 |
+
# scale_img = resize(img,(new_h,new_w))*255.0
|
| 770 |
+
# pred_global, pred_local, pred_fusion = inference_once( model, scale_img)
|
| 771 |
+
# pred_local = resize(pred_local,(h,w))
|
| 772 |
+
# pred_global = resize(pred_global,(h,w))*255.0
|
| 773 |
+
# pred_fusion = resize(pred_fusion,(h,w))
|
| 774 |
+
# return pred_fusion
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def inference_img(model, img):
|
| 778 |
+
h,w,_ = img.shape
|
| 779 |
+
# print(img.shape)
|
| 780 |
+
if h%8!=0 or w%8!=0:
|
| 781 |
+
img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
|
| 782 |
+
# print(img.shape)
|
| 783 |
+
|
| 784 |
+
tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
|
| 785 |
+
input_t = tensor_img
|
| 786 |
+
input_t = input_t/255.0
|
| 787 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 788 |
+
std=[0.229, 0.224, 0.225])
|
| 789 |
+
input_t = normalize(input_t)
|
| 790 |
+
input_t = input_t.unsqueeze(0).float()
|
| 791 |
+
with torch.no_grad():
|
| 792 |
+
out=model(input_t)
|
| 793 |
+
# print("out",out.shape)
|
| 794 |
+
result = out[0][:,-h:,-w:].cpu().numpy()
|
| 795 |
+
# print(result.shape)
|
| 796 |
+
|
| 797 |
+
return result[0]
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def test_am2k(model):
|
| 802 |
+
############################
|
| 803 |
+
# Some initial setting for paths
|
| 804 |
+
############################
|
| 805 |
+
ORIGINAL_PATH = config['datasets']['am2k']['validation_original']
|
| 806 |
+
MASK_PATH = config['datasets']['am2k']['validation_mask']
|
| 807 |
+
TRIMAP_PATH = config['datasets']['am2k']['validation_trimap']
|
| 808 |
+
img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
|
| 809 |
+
|
| 810 |
+
############################
|
| 811 |
+
# Start testing
|
| 812 |
+
############################
|
| 813 |
+
sad_diffs = 0.
|
| 814 |
+
mse_diffs = 0.
|
| 815 |
+
mad_diffs = 0.
|
| 816 |
+
grad_diffs = 0.
|
| 817 |
+
conn_diffs = 0.
|
| 818 |
+
sad_trimap_diffs = 0.
|
| 819 |
+
mse_trimap_diffs = 0.
|
| 820 |
+
mad_trimap_diffs = 0.
|
| 821 |
+
sad_fg_diffs = 0.
|
| 822 |
+
sad_bg_diffs = 0.
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
total_number = len(img_paths)
|
| 826 |
+
log("===============================")
|
| 827 |
+
log(f'====> Start Testing\n\t--Dataset: AM2k\n\t-\n\t--Number: {total_number}')
|
| 828 |
+
|
| 829 |
+
for img_path in tqdm.tqdm(img_paths):
|
| 830 |
+
img_name=(img_path.split("/")[-1])[:-4]
|
| 831 |
+
alpha_path = MASK_PATH+img_name+'.png'
|
| 832 |
+
trimap_path = TRIMAP_PATH+img_name+'.png'
|
| 833 |
+
pil_img = Image.open(img_path)
|
| 834 |
+
img = np.array(pil_img)
|
| 835 |
+
trimap = np.array(Image.open(trimap_path))
|
| 836 |
+
alpha = np.array(Image.open(alpha_path))/255.
|
| 837 |
+
img = img[:,:,:3] if img.ndim>2 else img
|
| 838 |
+
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
|
| 839 |
+
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
| 840 |
+
|
| 841 |
+
with torch.no_grad():
|
| 842 |
+
# torch.cuda.empty_cache()
|
| 843 |
+
predict = inference_img( model, img)
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
|
| 847 |
+
sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
|
| 848 |
+
sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
|
| 849 |
+
conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
|
| 850 |
+
grad_diff = compute_gradient_whole_image(predict, alpha)
|
| 851 |
+
|
| 852 |
+
log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nsad_trimap:{sad_trimap_diff}\nmse_trimap:{mse_trimap_diff}\nmad_trimap:{mad_trimap_diff}\nsad_fg:{sad_fg_diff}\nsad_bg:{sad_bg_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
|
| 853 |
+
|
| 854 |
+
sad_diffs += sad_diff
|
| 855 |
+
mse_diffs += mse_diff
|
| 856 |
+
mad_diffs += mad_diff
|
| 857 |
+
mse_trimap_diffs += mse_trimap_diff
|
| 858 |
+
sad_trimap_diffs += sad_trimap_diff
|
| 859 |
+
mad_trimap_diffs += mad_trimap_diff
|
| 860 |
+
sad_fg_diffs += sad_fg_diff
|
| 861 |
+
sad_bg_diffs += sad_bg_diff
|
| 862 |
+
conn_diffs += conn_diff
|
| 863 |
+
grad_diffs += grad_diff
|
| 864 |
+
Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
log("===============================")
|
| 868 |
+
log(f"Testing numbers: {total_number}")
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
log("SAD: {}".format(sad_diffs / total_number))
|
| 872 |
+
log("MSE: {}".format(mse_diffs / total_number))
|
| 873 |
+
log("MAD: {}".format(mad_diffs / total_number))
|
| 874 |
+
log("GRAD: {}".format(grad_diffs / total_number))
|
| 875 |
+
log("CONN: {}".format(conn_diffs / total_number))
|
| 876 |
+
log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
|
| 877 |
+
log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
|
| 878 |
+
log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
|
| 879 |
+
log("SAD FG: {}".format(sad_fg_diffs / total_number))
|
| 880 |
+
log("SAD BG: {}".format(sad_bg_diffs / total_number))
|
| 881 |
+
return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def test_p3m10k(model,dataset_choice, max_image=-1):
|
| 885 |
+
############################
|
| 886 |
+
# Some initial setting for paths
|
| 887 |
+
############################
|
| 888 |
+
if dataset_choice == 'P3M_500_P':
|
| 889 |
+
val_option = 'VAL500P'
|
| 890 |
+
else:
|
| 891 |
+
val_option = 'VAL500NP'
|
| 892 |
+
ORIGINAL_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['ORIGINAL_PATH']
|
| 893 |
+
MASK_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['MASK_PATH']
|
| 894 |
+
TRIMAP_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['TRIMAP_PATH']
|
| 895 |
+
############################
|
| 896 |
+
# Start testing
|
| 897 |
+
############################
|
| 898 |
+
sad_diffs = 0.
|
| 899 |
+
mse_diffs = 0.
|
| 900 |
+
mad_diffs = 0.
|
| 901 |
+
sad_trimap_diffs = 0.
|
| 902 |
+
mse_trimap_diffs = 0.
|
| 903 |
+
mad_trimap_diffs = 0.
|
| 904 |
+
sad_fg_diffs = 0.
|
| 905 |
+
sad_bg_diffs = 0.
|
| 906 |
+
conn_diffs = 0.
|
| 907 |
+
grad_diffs = 0.
|
| 908 |
+
model.eval()
|
| 909 |
+
img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
|
| 910 |
+
if (max_image>1):
|
| 911 |
+
img_paths = img_paths[:max_image]
|
| 912 |
+
total_number = len(img_paths)
|
| 913 |
+
log("===============================")
|
| 914 |
+
log(f'====> Start Testing\n\t----Test: {dataset_choice}\n\t--Number: {total_number}')
|
| 915 |
+
|
| 916 |
+
for img_path in tqdm.tqdm(img_paths):
|
| 917 |
+
img_name=(img_path.split("/")[-1])[:-4]
|
| 918 |
+
alpha_path = MASK_PATH+img_name+'.png'
|
| 919 |
+
trimap_path = TRIMAP_PATH+img_name+'.png'
|
| 920 |
+
pil_img = Image.open(img_path)
|
| 921 |
+
img = np.array(pil_img)
|
| 922 |
+
|
| 923 |
+
trimap = np.array(Image.open(trimap_path))
|
| 924 |
+
alpha = np.array(Image.open(alpha_path))/255.
|
| 925 |
+
img = img[:,:,:3] if img.ndim>2 else img
|
| 926 |
+
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
|
| 927 |
+
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
| 928 |
+
with torch.no_grad():
|
| 929 |
+
# torch.cuda.empty_cache()
|
| 930 |
+
start = time.time()
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
predict = inference_img( model, img) #HYBRID show less accuracy
|
| 934 |
+
|
| 935 |
+
# tensorimg=transforms.ToTensor()(pil_img)
|
| 936 |
+
# input_img=transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensorimg)
|
| 937 |
+
|
| 938 |
+
# predict = model(input_img.unsqueeze(0).to(device))[0][0].detach().cpu().numpy()
|
| 939 |
+
# if predict.shape!=(pil_img.height,pil_img.width):
|
| 940 |
+
# print("resize for ",img_path)
|
| 941 |
+
# predict = resize(predict,(pil_img.height,pil_img.width))
|
| 942 |
+
sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
|
| 943 |
+
sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
|
| 944 |
+
|
| 945 |
+
sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
|
| 946 |
+
conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
|
| 947 |
+
grad_diff = compute_gradient_whole_image(predict, alpha)
|
| 948 |
+
log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
|
| 949 |
+
sad_diffs += sad_diff
|
| 950 |
+
mse_diffs += mse_diff
|
| 951 |
+
mad_diffs += mad_diff
|
| 952 |
+
mse_trimap_diffs += mse_trimap_diff
|
| 953 |
+
sad_trimap_diffs += sad_trimap_diff
|
| 954 |
+
mad_trimap_diffs += mad_trimap_diff
|
| 955 |
+
sad_fg_diffs += sad_fg_diff
|
| 956 |
+
sad_bg_diffs += sad_bg_diff
|
| 957 |
+
conn_diffs += conn_diff
|
| 958 |
+
grad_diffs += grad_diff
|
| 959 |
+
|
| 960 |
+
Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
|
| 961 |
+
|
| 962 |
+
log("===============================")
|
| 963 |
+
log(f"Testing numbers: {total_number}")
|
| 964 |
+
log("SAD: {}".format(sad_diffs / total_number))
|
| 965 |
+
log("MSE: {}".format(mse_diffs / total_number))
|
| 966 |
+
log("MAD: {}".format(mad_diffs / total_number))
|
| 967 |
+
log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
|
| 968 |
+
log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
|
| 969 |
+
log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
|
| 970 |
+
log("SAD FG: {}".format(sad_fg_diffs / total_number))
|
| 971 |
+
log("SAD BG: {}".format(sad_bg_diffs / total_number))
|
| 972 |
+
log("CONN: {}".format(conn_diffs / total_number))
|
| 973 |
+
log("GRAD: {}".format(grad_diffs / total_number))
|
| 974 |
+
|
| 975 |
+
return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
|
| 976 |
+
|
| 977 |
+
def log(str):
|
| 978 |
+
print(str)
|
| 979 |
+
logging.info(str)
|
| 980 |
+
|
| 981 |
+
if __name__ == '__main__':
|
| 982 |
+
print('*********************************')
|
| 983 |
+
config = OmegaConf.load("base.yaml")
|
| 984 |
+
config=OmegaConf.merge(config,OmegaConf.from_cli())
|
| 985 |
+
print(config)
|
| 986 |
+
model = MaskForm()
|
| 987 |
+
model = model.to(device)
|
| 988 |
+
checkpoint = f"{config.checkpoint_dir}/{config.checkpoint}"
|
| 989 |
+
state_dict = torch.load(checkpoint, map_location=f'{device}')
|
| 990 |
+
print("loaded",checkpoint)
|
| 991 |
+
model.load_state_dict(state_dict)
|
| 992 |
+
model.eval()
|
| 993 |
+
logging.basicConfig(filename=f'report/{config.checkpoint.replace("/","--")}.report', encoding='utf-8',filemode='w', level=logging.INFO)
|
| 994 |
+
# ckpt = torch.load("checkpoints/p3mnet_pretrained_on_p3m10k.pth")
|
| 995 |
+
# model.load_state_dict(ckpt['state_dict'], strict=True)
|
| 996 |
+
# model = model.cuda()
|
| 997 |
+
if config.dataset_to_use =="AM2K":
|
| 998 |
+
test_am2k(model)
|
| 999 |
+
else:
|
| 1000 |
+
for dataset_choice in ['P3M_500_P','P3M_500_NP']:
|
| 1001 |
+
test_p3m10k(model,dataset_choice)
|
| 1002 |
+
|