Spaces:
Runtime error
Runtime error
| # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from collections import defaultdict | |
| import time | |
| import paddle | |
| import paddle.nn as nn | |
| import paddle.nn.functional as F | |
| import paddleseg | |
| from paddleseg.models import layers | |
| from paddleseg import utils | |
| from paddleseg.cvlibs import manager | |
| from ppmatting.models.losses import MRSD, GradientLoss | |
| from ppmatting.models.backbone import resnet_vd | |
| class PPMatting(nn.Layer): | |
| """ | |
| The PPMattinh implementation based on PaddlePaddle. | |
| The original article refers to | |
| Guowei Chen, et, al. "PP-Matting: High-Accuracy Natural Image Matting" | |
| (https://arxiv.org/pdf/2204.09433.pdf). | |
| Args: | |
| backbone: backbone model. | |
| pretrained(str, optional): The path of pretrianed model. Defautl: None. | |
| """ | |
| def __init__(self, backbone, pretrained=None): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.pretrained = pretrained | |
| self.loss_func_dict = self.get_loss_func_dict() | |
| self.backbone_channels = backbone.feat_channels | |
| self.scb = SCB(self.backbone_channels[-1]) | |
| self.hrdb = HRDB( | |
| self.backbone_channels[0] + self.backbone_channels[1], | |
| scb_channels=self.scb.out_channels, | |
| gf_index=[0, 2, 4]) | |
| self.init_weight() | |
| def forward(self, inputs): | |
| x = inputs['img'] | |
| input_shape = paddle.shape(x) | |
| fea_list = self.backbone(x) | |
| scb_logits = self.scb(fea_list[-1]) | |
| semantic_map = F.softmax(scb_logits[-1], axis=1) | |
| fea0 = F.interpolate( | |
| fea_list[0], input_shape[2:], mode='bilinear', align_corners=False) | |
| fea1 = F.interpolate( | |
| fea_list[1], input_shape[2:], mode='bilinear', align_corners=False) | |
| hrdb_input = paddle.concat([fea0, fea1], 1) | |
| hrdb_logit = self.hrdb(hrdb_input, scb_logits) | |
| detail_map = F.sigmoid(hrdb_logit) | |
| fusion = self.fusion(semantic_map, detail_map) | |
| if self.training: | |
| logit_dict = { | |
| 'semantic': semantic_map, | |
| 'detail': detail_map, | |
| 'fusion': fusion | |
| } | |
| loss_dict = self.loss(logit_dict, inputs) | |
| return logit_dict, loss_dict | |
| else: | |
| return fusion | |
| def get_loss_func_dict(self): | |
| loss_func_dict = defaultdict(list) | |
| loss_func_dict['semantic'].append(nn.NLLLoss()) | |
| loss_func_dict['detail'].append(MRSD()) | |
| loss_func_dict['detail'].append(GradientLoss()) | |
| loss_func_dict['fusion'].append(MRSD()) | |
| loss_func_dict['fusion'].append(MRSD()) | |
| loss_func_dict['fusion'].append(GradientLoss()) | |
| return loss_func_dict | |
| def loss(self, logit_dict, label_dict): | |
| loss = {} | |
| # semantic loss computation | |
| # get semantic label | |
| semantic_label = label_dict['trimap'] | |
| semantic_label_trans = (semantic_label == 128).astype('int64') | |
| semantic_label_bg = (semantic_label == 0).astype('int64') | |
| semantic_label = semantic_label_trans + semantic_label_bg * 2 | |
| loss_semantic = self.loss_func_dict['semantic'][0]( | |
| paddle.log(logit_dict['semantic'] + 1e-6), | |
| semantic_label.squeeze(1)) | |
| loss['semantic'] = loss_semantic | |
| # detail loss computation | |
| transparent = label_dict['trimap'] == 128 | |
| detail_alpha_loss = self.loss_func_dict['detail'][0]( | |
| logit_dict['detail'], label_dict['alpha'], transparent) | |
| # gradient loss | |
| detail_gradient_loss = self.loss_func_dict['detail'][1]( | |
| logit_dict['detail'], label_dict['alpha'], transparent) | |
| loss_detail = detail_alpha_loss + detail_gradient_loss | |
| loss['detail'] = loss_detail | |
| loss['detail_alpha'] = detail_alpha_loss | |
| loss['detail_gradient'] = detail_gradient_loss | |
| # fusion loss | |
| loss_fusion_func = self.loss_func_dict['fusion'] | |
| # fusion_sigmoid loss | |
| fusion_alpha_loss = loss_fusion_func[0](logit_dict['fusion'], | |
| label_dict['alpha']) | |
| # composion loss | |
| comp_pred = logit_dict['fusion'] * label_dict['fg'] + ( | |
| 1 - logit_dict['fusion']) * label_dict['bg'] | |
| comp_gt = label_dict['alpha'] * label_dict['fg'] + ( | |
| 1 - label_dict['alpha']) * label_dict['bg'] | |
| fusion_composition_loss = loss_fusion_func[1](comp_pred, comp_gt) | |
| # grandient loss | |
| fusion_grad_loss = loss_fusion_func[2](logit_dict['fusion'], | |
| label_dict['alpha']) | |
| # fusion loss | |
| loss_fusion = fusion_alpha_loss + fusion_composition_loss + fusion_grad_loss | |
| loss['fusion'] = loss_fusion | |
| loss['fusion_alpha'] = fusion_alpha_loss | |
| loss['fusion_composition'] = fusion_composition_loss | |
| loss['fusion_gradient'] = fusion_grad_loss | |
| loss[ | |
| 'all'] = 0.25 * loss_semantic + 0.25 * loss_detail + 0.25 * loss_fusion | |
| return loss | |
| def fusion(self, semantic_map, detail_map): | |
| # semantic_map [N, 3, H, W] | |
| # In index, 0 is foreground, 1 is transition, 2 is backbone | |
| # After fusion, the foreground is 1, the background is 0, and the transion is between [0, 1] | |
| index = paddle.argmax(semantic_map, axis=1, keepdim=True) | |
| transition_mask = (index == 1).astype('float32') | |
| fg = (index == 0).astype('float32') | |
| alpha = detail_map * transition_mask + fg | |
| return alpha | |
| def init_weight(self): | |
| if self.pretrained is not None: | |
| utils.load_entire_model(self, self.pretrained) | |
| class SCB(nn.Layer): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = [512 + in_channels, 512, 256, 128, 128, 64] | |
| self.mid_channels = [512, 256, 128, 128, 64, 64] | |
| self.out_channels = [256, 128, 64, 64, 64, 3] | |
| self.psp_module = layers.PPModule( | |
| in_channels, | |
| 512, | |
| bin_sizes=(1, 3, 5), | |
| dim_reduction=False, | |
| align_corners=False) | |
| psp_upsamples = [2, 4, 8, 16] | |
| self.psps = nn.LayerList([ | |
| self.conv_up_psp(512, self.out_channels[i], psp_upsamples[i]) | |
| for i in range(4) | |
| ]) | |
| scb_list = [ | |
| self._make_stage( | |
| self.in_channels[i], | |
| self.mid_channels[i], | |
| self.out_channels[i], | |
| padding=int(i == 0) + 1, | |
| dilation=int(i == 0) + 1) | |
| for i in range(len(self.in_channels) - 1) | |
| ] | |
| scb_list += [ | |
| nn.Sequential( | |
| layers.ConvBNReLU( | |
| self.in_channels[-1], self.mid_channels[-1], 3, padding=1), | |
| layers.ConvBNReLU( | |
| self.mid_channels[-1], self.mid_channels[-1], 3, padding=1), | |
| nn.Conv2D( | |
| self.mid_channels[-1], self.out_channels[-1], 3, padding=1)) | |
| ] | |
| self.scb_stages = nn.LayerList(scb_list) | |
| def forward(self, x): | |
| psp_x = self.psp_module(x) | |
| psps = [psp(psp_x) for psp in self.psps] | |
| scb_logits = [] | |
| for i, scb_stage in enumerate(self.scb_stages): | |
| if i == 0: | |
| x = scb_stage(paddle.concat((psp_x, x), 1)) | |
| elif i <= len(psps): | |
| x = scb_stage(paddle.concat((psps[i - 1], x), 1)) | |
| else: | |
| x = scb_stage(x) | |
| scb_logits.append(x) | |
| return scb_logits | |
| def conv_up_psp(self, in_channels, out_channels, up_sample): | |
| return nn.Sequential( | |
| layers.ConvBNReLU( | |
| in_channels, out_channels, 3, padding=1), | |
| nn.Upsample( | |
| scale_factor=up_sample, mode='bilinear', align_corners=False)) | |
| def _make_stage(self, | |
| in_channels, | |
| mid_channels, | |
| out_channels, | |
| padding=1, | |
| dilation=1): | |
| layer_list = [ | |
| layers.ConvBNReLU( | |
| in_channels, mid_channels, 3, padding=1), layers.ConvBNReLU( | |
| mid_channels, | |
| mid_channels, | |
| 3, | |
| padding=padding, | |
| dilation=dilation), layers.ConvBNReLU( | |
| mid_channels, | |
| out_channels, | |
| 3, | |
| padding=padding, | |
| dilation=dilation), nn.Upsample( | |
| scale_factor=2, | |
| mode='bilinear', | |
| align_corners=False) | |
| ] | |
| return nn.Sequential(*layer_list) | |
| class HRDB(nn.Layer): | |
| """ | |
| The High-Resolution Detail Branch | |
| Args: | |
| in_channels(int): The number of input channels. | |
| scb_channels(list|tuple): The channels of scb logits | |
| gf_index(list|tuple, optional): Which logit is selected as guidance flow from scb logits. Default: (0, 2, 4) | |
| """ | |
| def __init__(self, in_channels, scb_channels, gf_index=(0, 2, 4)): | |
| super().__init__() | |
| self.gf_index = gf_index | |
| self.gf_list = nn.LayerList( | |
| [nn.Conv2D(scb_channels[i], 1, 1) for i in gf_index]) | |
| channels = [64, 32, 16, 8] | |
| self.res_list = [ | |
| resnet_vd.BasicBlock( | |
| in_channels, channels[0], stride=1, shortcut=False) | |
| ] | |
| self.res_list += [ | |
| resnet_vd.BasicBlock( | |
| i, i, stride=1) for i in channels[1:-1] | |
| ] | |
| self.res_list = nn.LayerList(self.res_list) | |
| self.convs = nn.LayerList([ | |
| nn.Conv2D( | |
| channels[i], channels[i + 1], kernel_size=1) | |
| for i in range(len(channels) - 1) | |
| ]) | |
| self.gates = nn.LayerList( | |
| [GatedSpatailConv2d(i, i) for i in channels[1:]]) | |
| self.detail_conv = nn.Conv2D(channels[-1], 1, 1, bias_attr=False) | |
| def forward(self, x, scb_logits): | |
| for i in range(len(self.res_list)): | |
| x = self.res_list[i](x) | |
| x = self.convs[i](x) | |
| gf = self.gf_list[i](scb_logits[self.gf_index[i]]) | |
| gf = F.interpolate( | |
| gf, paddle.shape(x)[-2:], mode='bilinear', align_corners=False) | |
| x = self.gates[i](x, gf) | |
| return self.detail_conv(x) | |
| class GatedSpatailConv2d(nn.Layer): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias_attr=False): | |
| super().__init__() | |
| self._gate_conv = nn.Sequential( | |
| layers.SyncBatchNorm(in_channels + 1), | |
| nn.Conv2D( | |
| in_channels + 1, in_channels + 1, kernel_size=1), | |
| nn.ReLU(), | |
| nn.Conv2D( | |
| in_channels + 1, 1, kernel_size=1), | |
| layers.SyncBatchNorm(1), | |
| nn.Sigmoid()) | |
| self.conv = nn.Conv2D( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias_attr=bias_attr) | |
| def forward(self, input_features, gating_features): | |
| cat = paddle.concat([input_features, gating_features], axis=1) | |
| alphas = self._gate_conv(cat) | |
| x = input_features * (alphas + 1) | |
| x = self.conv(x) | |
| return x | |