Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from collections import defaultdict | |
| import paddle | |
| import paddle.nn as nn | |
| import paddle.nn.functional as F | |
| from paddleseg.models import layers | |
| from paddleseg import utils | |
| from paddleseg.cvlibs import manager | |
| from ppmatting.models.losses import MRSD | |
| class DIM(nn.Layer): | |
| """ | |
| The DIM implementation based on PaddlePaddle. | |
| The original article refers to | |
| Ning Xu, et, al. "Deep Image Matting" | |
| (https://arxiv.org/pdf/1908.07919.pdf). | |
| Args: | |
| backbone: backbone model. | |
| stage (int, optional): The stage of model. Defautl: 3. | |
| decoder_input_channels(int, optional): The channel of decoder input. Default: 512. | |
| pretrained(str, optional): The path of pretrianed model. Defautl: None. | |
| """ | |
| def __init__(self, | |
| backbone, | |
| stage=3, | |
| decoder_input_channels=512, | |
| pretrained=None): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.pretrained = pretrained | |
| self.stage = stage | |
| self.loss_func_dict = None | |
| decoder_output_channels = [64, 128, 256, 512] | |
| self.decoder = Decoder( | |
| input_channels=decoder_input_channels, | |
| output_channels=decoder_output_channels) | |
| if self.stage == 2: | |
| for param in self.backbone.parameters(): | |
| param.stop_gradient = True | |
| for param in self.decoder.parameters(): | |
| param.stop_gradient = True | |
| if self.stage >= 2: | |
| self.refine = Refine() | |
| self.init_weight() | |
| def forward(self, inputs): | |
| input_shape = paddle.shape(inputs['img'])[-2:] | |
| x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1) | |
| fea_list = self.backbone(x) | |
| # decoder stage | |
| up_shape = [] | |
| for i in range(5): | |
| up_shape.append(paddle.shape(fea_list[i])[-2:]) | |
| alpha_raw = self.decoder(fea_list, up_shape) | |
| alpha_raw = F.interpolate( | |
| alpha_raw, input_shape, mode='bilinear', align_corners=False) | |
| logit_dict = {'alpha_raw': alpha_raw} | |
| if self.stage < 2: | |
| return logit_dict | |
| if self.stage >= 2: | |
| # refine stage | |
| refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1) | |
| alpha_refine = self.refine(refine_input) | |
| # finally alpha | |
| alpha_pred = alpha_refine + alpha_raw | |
| alpha_pred = F.interpolate( | |
| alpha_pred, input_shape, mode='bilinear', align_corners=False) | |
| if not self.training: | |
| alpha_pred = paddle.clip(alpha_pred, min=0, max=1) | |
| logit_dict['alpha_pred'] = alpha_pred | |
| if self.training: | |
| loss_dict = self.loss(logit_dict, inputs) | |
| return logit_dict, loss_dict | |
| else: | |
| return alpha_pred | |
| def loss(self, logit_dict, label_dict, loss_func_dict=None): | |
| if loss_func_dict is None: | |
| if self.loss_func_dict is None: | |
| self.loss_func_dict = defaultdict(list) | |
| self.loss_func_dict['alpha_raw'].append(MRSD()) | |
| self.loss_func_dict['comp'].append(MRSD()) | |
| self.loss_func_dict['alpha_pred'].append(MRSD()) | |
| else: | |
| self.loss_func_dict = loss_func_dict | |
| loss = {} | |
| mask = label_dict['trimap'] == 128 | |
| loss['all'] = 0 | |
| if self.stage != 2: | |
| loss['alpha_raw'] = self.loss_func_dict['alpha_raw'][0]( | |
| logit_dict['alpha_raw'], label_dict['alpha'], mask) | |
| loss['alpha_raw'] = 0.5 * loss['alpha_raw'] | |
| loss['all'] = loss['all'] + loss['alpha_raw'] | |
| if self.stage == 1 or self.stage == 3: | |
| comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \ | |
| (1 - logit_dict['alpha_raw']) * label_dict['bg'] | |
| loss['comp'] = self.loss_func_dict['comp'][0]( | |
| comp_pred, label_dict['img'], mask) | |
| loss['comp'] = 0.5 * loss['comp'] | |
| loss['all'] = loss['all'] + loss['comp'] | |
| if self.stage == 2 or self.stage == 3: | |
| loss['alpha_pred'] = self.loss_func_dict['alpha_pred'][0]( | |
| logit_dict['alpha_pred'], label_dict['alpha'], mask) | |
| loss['all'] = loss['all'] + loss['alpha_pred'] | |
| return loss | |
| def init_weight(self): | |
| if self.pretrained is not None: | |
| utils.load_entire_model(self, self.pretrained) | |
| # bilinear interpolate skip connect | |
| class Up(nn.Layer): | |
| def __init__(self, input_channels, output_channels): | |
| super().__init__() | |
| self.conv = layers.ConvBNReLU( | |
| input_channels, | |
| output_channels, | |
| kernel_size=5, | |
| padding=2, | |
| bias_attr=False) | |
| def forward(self, x, skip, output_shape): | |
| x = F.interpolate( | |
| x, size=output_shape, mode='bilinear', align_corners=False) | |
| x = x + skip | |
| x = self.conv(x) | |
| x = F.relu(x) | |
| return x | |
| class Decoder(nn.Layer): | |
| def __init__(self, input_channels, output_channels=(64, 128, 256, 512)): | |
| super().__init__() | |
| self.deconv6 = nn.Conv2D( | |
| input_channels, input_channels, kernel_size=1, bias_attr=False) | |
| self.deconv5 = Up(input_channels, output_channels[-1]) | |
| self.deconv4 = Up(output_channels[-1], output_channels[-2]) | |
| self.deconv3 = Up(output_channels[-2], output_channels[-3]) | |
| self.deconv2 = Up(output_channels[-3], output_channels[-4]) | |
| self.deconv1 = Up(output_channels[-4], 64) | |
| self.alpha_conv = nn.Conv2D( | |
| 64, 1, kernel_size=5, padding=2, bias_attr=False) | |
| def forward(self, fea_list, shape_list): | |
| x = fea_list[-1] | |
| x = self.deconv6(x) | |
| x = self.deconv5(x, fea_list[4], shape_list[4]) | |
| x = self.deconv4(x, fea_list[3], shape_list[3]) | |
| x = self.deconv3(x, fea_list[2], shape_list[2]) | |
| x = self.deconv2(x, fea_list[1], shape_list[1]) | |
| x = self.deconv1(x, fea_list[0], shape_list[0]) | |
| alpha = self.alpha_conv(x) | |
| alpha = F.sigmoid(alpha) | |
| return alpha | |
| class Refine(nn.Layer): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = layers.ConvBNReLU( | |
| 4, 64, kernel_size=3, padding=1, bias_attr=False) | |
| self.conv2 = layers.ConvBNReLU( | |
| 64, 64, kernel_size=3, padding=1, bias_attr=False) | |
| self.conv3 = layers.ConvBNReLU( | |
| 64, 64, kernel_size=3, padding=1, bias_attr=False) | |
| self.alpha_pred = layers.ConvBNReLU( | |
| 64, 1, kernel_size=3, padding=1, bias_attr=False) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = self.conv3(x) | |
| alpha = self.alpha_pred(x) | |
| return alpha | |