Commit 
							
							·
						
						b9754a2
	
1
								Parent(s):
							
							c98de0d
								
Align lib_name as birefnet and add inference endpoint option.
Browse files- README.md +1 -1
 - birefnet.py +28 -24
 - handler.py +138 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,5 +1,5 @@ 
     | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
            -
            library_name:  
     | 
| 3 | 
         
             
            tags:
         
     | 
| 4 | 
         
             
            - background-removal
         
     | 
| 5 | 
         
             
            - mask-generation
         
     | 
| 
         | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
            +
            library_name: birefnet
         
     | 
| 3 | 
         
             
            tags:
         
     | 
| 4 | 
         
             
            - background-removal
         
     | 
| 5 | 
         
             
            - mask-generation
         
     | 
    	
        birefnet.py
    CHANGED
    
    | 
         @@ -615,6 +615,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 
     | 
|
| 615 | 
         | 
| 616 | 
         
             
            # config = Config()
         
     | 
| 617 | 
         | 
| 
         | 
|
| 618 | 
         
             
            class Mlp(nn.Module):
         
     | 
| 619 | 
         
             
                """ Multilayer perceptron."""
         
     | 
| 620 | 
         | 
| 
         @@ -739,7 +740,8 @@ class WindowAttention(nn.Module): 
     | 
|
| 739 | 
         
             
                        attn = (q @ k.transpose(-2, -1))
         
     | 
| 740 | 
         | 
| 741 | 
         
             
                        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
         
     | 
| 742 | 
         
            -
                            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 
     | 
| 
         | 
|
| 743 | 
         
             
                        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         
     | 
| 744 | 
         
             
                        attn = attn + relative_position_bias.unsqueeze(0)
         
     | 
| 745 | 
         | 
| 
         @@ -974,8 +976,9 @@ class BasicLayer(nn.Module): 
     | 
|
| 974 | 
         
             
                    """
         
     | 
| 975 | 
         | 
| 976 | 
         
             
                    # calculate attention mask for SW-MSA
         
     | 
| 977 | 
         
            -
                     
     | 
| 978 | 
         
            -
                     
     | 
| 
         | 
|
| 979 | 
         
             
                    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
         
     | 
| 980 | 
         
             
                    h_slices = (slice(0, -self.window_size),
         
     | 
| 981 | 
         
             
                                slice(-self.window_size, -self.shift_size),
         
     | 
| 
         @@ -1961,6 +1964,7 @@ import torch.nn as nn 
     | 
|
| 1961 | 
         
             
            import torch.nn.functional as F
         
     | 
| 1962 | 
         
             
            from kornia.filters import laplacian
         
     | 
| 1963 | 
         
             
            from transformers import PreTrainedModel
         
     | 
| 
         | 
|
| 1964 | 
         | 
| 1965 | 
         
             
            # from config import Config
         
     | 
| 1966 | 
         
             
            # from dataset import class_labels_TR_sorted
         
     | 
| 
         @@ -1974,6 +1978,18 @@ from transformers import PreTrainedModel 
     | 
|
| 1974 | 
         
             
            from .BiRefNet_config import BiRefNetConfig
         
     | 
| 1975 | 
         | 
| 1976 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1977 | 
         
             
            class BiRefNet(
         
     | 
| 1978 | 
         
             
                PreTrainedModel
         
     | 
| 1979 | 
         
             
            ):
         
     | 
| 
         @@ -2124,18 +2140,6 @@ class Decoder(nn.Module): 
     | 
|
| 2124 | 
         
             
                            self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         
     | 
| 2125 | 
         
             
                            self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         
     | 
| 2126 | 
         | 
| 2127 | 
         
            -
                def get_patches_batch(self, x, p):
         
     | 
| 2128 | 
         
            -
                    _size_h, _size_w = p.shape[2:]
         
     | 
| 2129 | 
         
            -
                    patches_batch = []
         
     | 
| 2130 | 
         
            -
                    for idx in range(x.shape[0]):
         
     | 
| 2131 | 
         
            -
                        columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
         
     | 
| 2132 | 
         
            -
                        patches_x = []
         
     | 
| 2133 | 
         
            -
                        for column_x in columns_x:
         
     | 
| 2134 | 
         
            -
                            patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
         
     | 
| 2135 | 
         
            -
                        patch_sample = torch.cat(patches_x, dim=1)
         
     | 
| 2136 | 
         
            -
                        patches_batch.append(patch_sample)
         
     | 
| 2137 | 
         
            -
                    return torch.cat(patches_batch, dim=0)
         
     | 
| 2138 | 
         
            -
             
     | 
| 2139 | 
         
             
                def forward(self, features):
         
     | 
| 2140 | 
         
             
                    if self.training and self.config.out_ref:
         
     | 
| 2141 | 
         
             
                        outs_gdt_pred = []
         
     | 
| 
         @@ -2146,10 +2150,10 @@ class Decoder(nn.Module): 
     | 
|
| 2146 | 
         
             
                    outs = []
         
     | 
| 2147 | 
         | 
| 2148 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2149 | 
         
            -
                        patches_batch =  
     | 
| 2150 | 
         
             
                        x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2151 | 
         
             
                    p4 = self.decoder_block4(x4)
         
     | 
| 2152 | 
         
            -
                    m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
         
     | 
| 2153 | 
         
             
                    if self.config.out_ref:
         
     | 
| 2154 | 
         
             
                        p4_gdt = self.gdt_convs_4(p4)
         
     | 
| 2155 | 
         
             
                        if self.training:
         
     | 
| 
         @@ -2167,10 +2171,10 @@ class Decoder(nn.Module): 
     | 
|
| 2167 | 
         
             
                    _p3 = _p4 + self.lateral_block4(x3)
         
     | 
| 2168 | 
         | 
| 2169 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2170 | 
         
            -
                        patches_batch =  
     | 
| 2171 | 
         
             
                        _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2172 | 
         
             
                    p3 = self.decoder_block3(_p3)
         
     | 
| 2173 | 
         
            -
                    m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
         
     | 
| 2174 | 
         
             
                    if self.config.out_ref:
         
     | 
| 2175 | 
         
             
                        p3_gdt = self.gdt_convs_3(p3)
         
     | 
| 2176 | 
         
             
                        if self.training:
         
     | 
| 
         @@ -2193,10 +2197,10 @@ class Decoder(nn.Module): 
     | 
|
| 2193 | 
         
             
                    _p2 = _p3 + self.lateral_block3(x2)
         
     | 
| 2194 | 
         | 
| 2195 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2196 | 
         
            -
                        patches_batch =  
     | 
| 2197 | 
         
             
                        _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2198 | 
         
             
                    p2 = self.decoder_block2(_p2)
         
     | 
| 2199 | 
         
            -
                    m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
         
     | 
| 2200 | 
         
             
                    if self.config.out_ref:
         
     | 
| 2201 | 
         
             
                        p2_gdt = self.gdt_convs_2(p2)
         
     | 
| 2202 | 
         
             
                        if self.training:
         
     | 
| 
         @@ -2214,17 +2218,17 @@ class Decoder(nn.Module): 
     | 
|
| 2214 | 
         
             
                    _p1 = _p2 + self.lateral_block2(x1)
         
     | 
| 2215 | 
         | 
| 2216 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2217 | 
         
            -
                        patches_batch =  
     | 
| 2218 | 
         
             
                        _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2219 | 
         
             
                    _p1 = self.decoder_block1(_p1)
         
     | 
| 2220 | 
         
             
                    _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
         
     | 
| 2221 | 
         | 
| 2222 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2223 | 
         
            -
                        patches_batch =  
     | 
| 2224 | 
         
             
                        _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2225 | 
         
             
                    p1_out = self.conv_out1(_p1)
         
     | 
| 2226 | 
         | 
| 2227 | 
         
            -
                    if self.config.ms_supervision:
         
     | 
| 2228 | 
         
             
                        outs.append(m4)
         
     | 
| 2229 | 
         
             
                        outs.append(m3)
         
     | 
| 2230 | 
         
             
                        outs.append(m2)
         
     | 
| 
         | 
|
| 615 | 
         | 
| 616 | 
         
             
            # config = Config()
         
     | 
| 617 | 
         | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
             
            class Mlp(nn.Module):
         
     | 
| 620 | 
         
             
                """ Multilayer perceptron."""
         
     | 
| 621 | 
         | 
| 
         | 
|
| 740 | 
         
             
                        attn = (q @ k.transpose(-2, -1))
         
     | 
| 741 | 
         | 
| 742 | 
         
             
                        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
         
     | 
| 743 | 
         
            +
                            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
         
     | 
| 744 | 
         
            +
                        )   # Wh*Ww, Wh*Ww, nH
         
     | 
| 745 | 
         
             
                        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         
     | 
| 746 | 
         
             
                        attn = attn + relative_position_bias.unsqueeze(0)
         
     | 
| 747 | 
         | 
| 
         | 
|
| 976 | 
         
             
                    """
         
     | 
| 977 | 
         | 
| 978 | 
         
             
                    # calculate attention mask for SW-MSA
         
     | 
| 979 | 
         
            +
                    # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5.
         
     | 
| 980 | 
         
            +
                    Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size
         
     | 
| 981 | 
         
            +
                    Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size
         
     | 
| 982 | 
         
             
                    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
         
     | 
| 983 | 
         
             
                    h_slices = (slice(0, -self.window_size),
         
     | 
| 984 | 
         
             
                                slice(-self.window_size, -self.shift_size),
         
     | 
| 
         | 
|
| 1964 | 
         
             
            import torch.nn.functional as F
         
     | 
| 1965 | 
         
             
            from kornia.filters import laplacian
         
     | 
| 1966 | 
         
             
            from transformers import PreTrainedModel
         
     | 
| 1967 | 
         
            +
            from einops import rearrange
         
     | 
| 1968 | 
         | 
| 1969 | 
         
             
            # from config import Config
         
     | 
| 1970 | 
         
             
            # from dataset import class_labels_TR_sorted
         
     | 
| 
         | 
|
| 1978 | 
         
             
            from .BiRefNet_config import BiRefNetConfig
         
     | 
| 1979 | 
         | 
| 1980 | 
         | 
| 1981 | 
         
            +
            def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
         
     | 
| 1982 | 
         
            +
                if patch_ref is not None:
         
     | 
| 1983 | 
         
            +
                    grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
         
     | 
| 1984 | 
         
            +
                patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
         
     | 
| 1985 | 
         
            +
                return patches
         
     | 
| 1986 | 
         
            +
             
     | 
| 1987 | 
         
            +
            def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
         
     | 
| 1988 | 
         
            +
                if patch_ref is not None:
         
     | 
| 1989 | 
         
            +
                    grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
         
     | 
| 1990 | 
         
            +
                image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
         
     | 
| 1991 | 
         
            +
                return image
         
     | 
| 1992 | 
         
            +
             
     | 
| 1993 | 
         
             
            class BiRefNet(
         
     | 
| 1994 | 
         
             
                PreTrainedModel
         
     | 
| 1995 | 
         
             
            ):
         
     | 
| 
         | 
|
| 2140 | 
         
             
                            self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         
     | 
| 2141 | 
         
             
                            self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         
     | 
| 2142 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 2143 | 
         
             
                def forward(self, features):
         
     | 
| 2144 | 
         
             
                    if self.training and self.config.out_ref:
         
     | 
| 2145 | 
         
             
                        outs_gdt_pred = []
         
     | 
| 
         | 
|
| 2150 | 
         
             
                    outs = []
         
     | 
| 2151 | 
         | 
| 2152 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2153 | 
         
            +
                        patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         
     | 
| 2154 | 
         
             
                        x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2155 | 
         
             
                    p4 = self.decoder_block4(x4)
         
     | 
| 2156 | 
         
            +
                    m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
         
     | 
| 2157 | 
         
             
                    if self.config.out_ref:
         
     | 
| 2158 | 
         
             
                        p4_gdt = self.gdt_convs_4(p4)
         
     | 
| 2159 | 
         
             
                        if self.training:
         
     | 
| 
         | 
|
| 2171 | 
         
             
                    _p3 = _p4 + self.lateral_block4(x3)
         
     | 
| 2172 | 
         | 
| 2173 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2174 | 
         
            +
                        patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         
     | 
| 2175 | 
         
             
                        _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2176 | 
         
             
                    p3 = self.decoder_block3(_p3)
         
     | 
| 2177 | 
         
            +
                    m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
         
     | 
| 2178 | 
         
             
                    if self.config.out_ref:
         
     | 
| 2179 | 
         
             
                        p3_gdt = self.gdt_convs_3(p3)
         
     | 
| 2180 | 
         
             
                        if self.training:
         
     | 
| 
         | 
|
| 2197 | 
         
             
                    _p2 = _p3 + self.lateral_block3(x2)
         
     | 
| 2198 | 
         | 
| 2199 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2200 | 
         
            +
                        patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         
     | 
| 2201 | 
         
             
                        _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2202 | 
         
             
                    p2 = self.decoder_block2(_p2)
         
     | 
| 2203 | 
         
            +
                    m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
         
     | 
| 2204 | 
         
             
                    if self.config.out_ref:
         
     | 
| 2205 | 
         
             
                        p2_gdt = self.gdt_convs_2(p2)
         
     | 
| 2206 | 
         
             
                        if self.training:
         
     | 
| 
         | 
|
| 2218 | 
         
             
                    _p1 = _p2 + self.lateral_block2(x1)
         
     | 
| 2219 | 
         | 
| 2220 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2221 | 
         
            +
                        patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         
     | 
| 2222 | 
         
             
                        _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2223 | 
         
             
                    _p1 = self.decoder_block1(_p1)
         
     | 
| 2224 | 
         
             
                    _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
         
     | 
| 2225 | 
         | 
| 2226 | 
         
             
                    if self.config.dec_ipt:
         
     | 
| 2227 | 
         
            +
                        patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         
     | 
| 2228 | 
         
             
                        _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
         
     | 
| 2229 | 
         
             
                    p1_out = self.conv_out1(_p1)
         
     | 
| 2230 | 
         | 
| 2231 | 
         
            +
                    if self.config.ms_supervision and self.training:
         
     | 
| 2232 | 
         
             
                        outs.append(m4)
         
     | 
| 2233 | 
         
             
                        outs.append(m3)
         
     | 
| 2234 | 
         
             
                        outs.append(m2)
         
     | 
    	
        handler.py
    ADDED
    
    | 
         @@ -0,0 +1,138 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
         
     | 
| 2 | 
         
            +
            from typing import Dict, List, Any, Tuple
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import requests
         
     | 
| 5 | 
         
            +
            from io import BytesIO
         
     | 
| 6 | 
         
            +
            import cv2
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            from torchvision import transforms
         
     | 
| 11 | 
         
            +
            from transformers import AutoModelForImageSegmentation
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            torch.set_float32_matmul_precision(["high", "highest"][0])
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            ### image_proc.py
         
     | 
| 18 | 
         
            +
            def refine_foreground(image, mask, r=90):
         
     | 
| 19 | 
         
            +
                if mask.size != image.size:
         
     | 
| 20 | 
         
            +
                    mask = mask.resize(image.size)
         
     | 
| 21 | 
         
            +
                image = np.array(image) / 255.0
         
     | 
| 22 | 
         
            +
                mask = np.array(mask) / 255.0
         
     | 
| 23 | 
         
            +
                estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
         
     | 
| 24 | 
         
            +
                image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
         
     | 
| 25 | 
         
            +
                return image_masked
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
         
     | 
| 29 | 
         
            +
                # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
         
     | 
| 30 | 
         
            +
                alpha = alpha[:, :, None]
         
     | 
| 31 | 
         
            +
                F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
         
     | 
| 32 | 
         
            +
                return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
         
     | 
| 36 | 
         
            +
                if isinstance(image, Image.Image):
         
     | 
| 37 | 
         
            +
                    image = np.array(image) / 255.0
         
     | 
| 38 | 
         
            +
                blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                blurred_FA = cv2.blur(F * alpha, (r, r))
         
     | 
| 41 | 
         
            +
                blurred_F = blurred_FA / (blurred_alpha + 1e-5)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
         
     | 
| 44 | 
         
            +
                blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
         
     | 
| 45 | 
         
            +
                F = blurred_F + alpha * \
         
     | 
| 46 | 
         
            +
                    (image - alpha * blurred_F - (1 - alpha) * blurred_B)
         
     | 
| 47 | 
         
            +
                F = np.clip(F, 0, 1)
         
     | 
| 48 | 
         
            +
                return F, blurred_B
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            class ImagePreprocessor():
         
     | 
| 52 | 
         
            +
                def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
         
     | 
| 53 | 
         
            +
                    self.transform_image = transforms.Compose([
         
     | 
| 54 | 
         
            +
                        transforms.Resize(resolution),
         
     | 
| 55 | 
         
            +
                        transforms.ToTensor(),
         
     | 
| 56 | 
         
            +
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
         
     | 
| 57 | 
         
            +
                    ])
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def proc(self, image: Image.Image) -> torch.Tensor:
         
     | 
| 60 | 
         
            +
                    image = self.transform_image(image)
         
     | 
| 61 | 
         
            +
                    return image
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            usage_to_weights_file = {
         
     | 
| 64 | 
         
            +
                'General': 'BiRefNet',
         
     | 
| 65 | 
         
            +
                'General-HR': 'BiRefNet_HR',
         
     | 
| 66 | 
         
            +
                'General-Lite': 'BiRefNet_lite',
         
     | 
| 67 | 
         
            +
                'General-Lite-2K': 'BiRefNet_lite-2K',
         
     | 
| 68 | 
         
            +
                'General-reso_512': 'BiRefNet-reso_512',
         
     | 
| 69 | 
         
            +
                'Matting': 'BiRefNet-matting',
         
     | 
| 70 | 
         
            +
                'Portrait': 'BiRefNet-portrait',
         
     | 
| 71 | 
         
            +
                'DIS': 'BiRefNet-DIS5K',
         
     | 
| 72 | 
         
            +
                'HRSOD': 'BiRefNet-HRSOD',
         
     | 
| 73 | 
         
            +
                'COD': 'BiRefNet-COD',
         
     | 
| 74 | 
         
            +
                'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
         
     | 
| 75 | 
         
            +
                'General-legacy': 'BiRefNet-legacy'
         
     | 
| 76 | 
         
            +
            }
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            # Choose the version of BiRefNet here.
         
     | 
| 79 | 
         
            +
            usage = 'Portrait'
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            # Set resolution
         
     | 
| 82 | 
         
            +
            if usage in ['General-Lite-2K']:
         
     | 
| 83 | 
         
            +
                resolution = (2560, 1440)
         
     | 
| 84 | 
         
            +
            elif usage in ['General-reso_512']:
         
     | 
| 85 | 
         
            +
                resolution = (512, 512)
         
     | 
| 86 | 
         
            +
            elif usage in ['General-HR']:
         
     | 
| 87 | 
         
            +
                resolution = (2048, 2048)
         
     | 
| 88 | 
         
            +
            else:
         
     | 
| 89 | 
         
            +
                resolution = (1024, 1024) 
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            half_precision = True
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            class EndpointHandler():
         
     | 
| 94 | 
         
            +
                def __init__(self, path=''):
         
     | 
| 95 | 
         
            +
                    self.birefnet = AutoModelForImageSegmentation.from_pretrained(
         
     | 
| 96 | 
         
            +
                        '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
         
     | 
| 97 | 
         
            +
                    )
         
     | 
| 98 | 
         
            +
                    self.birefnet.to(device)
         
     | 
| 99 | 
         
            +
                    self.birefnet.eval()
         
     | 
| 100 | 
         
            +
                    if half_precision:
         
     | 
| 101 | 
         
            +
                        self.birefnet.half()
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                def __call__(self, data: Dict[str, Any]):
         
     | 
| 104 | 
         
            +
                    """
         
     | 
| 105 | 
         
            +
                    data args:
         
     | 
| 106 | 
         
            +
                        inputs (:obj: `str`)
         
     | 
| 107 | 
         
            +
                        date (:obj: `str`)
         
     | 
| 108 | 
         
            +
                    Return:
         
     | 
| 109 | 
         
            +
                        A :obj:`list` | `dict`: will be serialized and returned
         
     | 
| 110 | 
         
            +
                    """
         
     | 
| 111 | 
         
            +
                    print('data["inputs"] = ', data["inputs"])
         
     | 
| 112 | 
         
            +
                    image_src = data["inputs"]
         
     | 
| 113 | 
         
            +
                    if isinstance(image_src, str):
         
     | 
| 114 | 
         
            +
                        if os.path.isfile(image_src):
         
     | 
| 115 | 
         
            +
                            image_ori = Image.open(image_src)
         
     | 
| 116 | 
         
            +
                        else:
         
     | 
| 117 | 
         
            +
                            response = requests.get(image_src)
         
     | 
| 118 | 
         
            +
                            image_data = BytesIO(response.content)
         
     | 
| 119 | 
         
            +
                            image_ori = Image.open(image_data)
         
     | 
| 120 | 
         
            +
                    else:
         
     | 
| 121 | 
         
            +
                        image_ori = Image.fromarray(image_src)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    image = image_ori.convert('RGB')
         
     | 
| 124 | 
         
            +
                    # Preprocess the image
         
     | 
| 125 | 
         
            +
                    image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
         
     | 
| 126 | 
         
            +
                    image_proc = image_preprocessor.proc(image)
         
     | 
| 127 | 
         
            +
                    image_proc = image_proc.unsqueeze(0)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    # Prediction
         
     | 
| 130 | 
         
            +
                    with torch.no_grad():
         
     | 
| 131 | 
         
            +
                        preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
         
     | 
| 132 | 
         
            +
                    pred = preds[0].squeeze()
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    # Show Results
         
     | 
| 135 | 
         
            +
                    pred_pil = transforms.ToPILImage()(pred)
         
     | 
| 136 | 
         
            +
                    image_masked = refine_foreground(image, pred_pil)
         
     | 
| 137 | 
         
            +
                    image_masked.putalpha(pred_pil.resize(image.size))
         
     | 
| 138 | 
         
            +
                    return image_masked
         
     |