Commit
·
bdff8f5
1
Parent(s):
a2fe5a2
Align lib_name as birefnet and add inference endpoint option.
Browse files- README.md +1 -1
- birefnet.py +2 -1
- handler.py +7 -1
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
library_name:
|
| 3 |
tags:
|
| 4 |
- background-removal
|
| 5 |
- mask-generation
|
|
|
|
| 1 |
---
|
| 2 |
+
library_name: birefnet
|
| 3 |
tags:
|
| 4 |
- background-removal
|
| 5 |
- mask-generation
|
birefnet.py
CHANGED
|
@@ -1995,7 +1995,8 @@ class BiRefNet(
|
|
| 1995 |
):
|
| 1996 |
config_class = BiRefNetConfig
|
| 1997 |
def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
|
| 1998 |
-
super(BiRefNet, self).__init__()
|
|
|
|
| 1999 |
self.config = Config()
|
| 2000 |
self.epoch = 1
|
| 2001 |
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
|
|
|
|
| 1995 |
):
|
| 1996 |
config_class = BiRefNetConfig
|
| 1997 |
def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
|
| 1998 |
+
super(BiRefNet, self).__init__(config)
|
| 1999 |
+
bb_pretrained = config.bb_pretrained
|
| 2000 |
self.config = Config()
|
| 2001 |
self.epoch = 1
|
| 2002 |
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
|
handler.py
CHANGED
|
@@ -62,6 +62,7 @@ class ImagePreprocessor():
|
|
| 62 |
|
| 63 |
usage_to_weights_file = {
|
| 64 |
'General': 'BiRefNet',
|
|
|
|
| 65 |
'General-Lite': 'BiRefNet_lite',
|
| 66 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
| 67 |
'General-reso_512': 'BiRefNet-reso_512',
|
|
@@ -82,9 +83,12 @@ if usage in ['General-Lite-2K']:
|
|
| 82 |
resolution = (2560, 1440)
|
| 83 |
elif usage in ['General-reso_512']:
|
| 84 |
resolution = (512, 512)
|
|
|
|
|
|
|
| 85 |
else:
|
| 86 |
resolution = (1024, 1024)
|
| 87 |
|
|
|
|
| 88 |
|
| 89 |
class EndpointHandler():
|
| 90 |
def __init__(self, path=''):
|
|
@@ -93,6 +97,8 @@ class EndpointHandler():
|
|
| 93 |
)
|
| 94 |
self.birefnet.to(device)
|
| 95 |
self.birefnet.eval()
|
|
|
|
|
|
|
| 96 |
|
| 97 |
def __call__(self, data: Dict[str, Any]):
|
| 98 |
"""
|
|
@@ -122,7 +128,7 @@ class EndpointHandler():
|
|
| 122 |
|
| 123 |
# Prediction
|
| 124 |
with torch.no_grad():
|
| 125 |
-
preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
| 126 |
pred = preds[0].squeeze()
|
| 127 |
|
| 128 |
# Show Results
|
|
|
|
| 62 |
|
| 63 |
usage_to_weights_file = {
|
| 64 |
'General': 'BiRefNet',
|
| 65 |
+
'General-HR': 'BiRefNet_HR',
|
| 66 |
'General-Lite': 'BiRefNet_lite',
|
| 67 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
| 68 |
'General-reso_512': 'BiRefNet-reso_512',
|
|
|
|
| 83 |
resolution = (2560, 1440)
|
| 84 |
elif usage in ['General-reso_512']:
|
| 85 |
resolution = (512, 512)
|
| 86 |
+
elif usage in ['General-HR']:
|
| 87 |
+
resolution = (2048, 2048)
|
| 88 |
else:
|
| 89 |
resolution = (1024, 1024)
|
| 90 |
|
| 91 |
+
half_precision = True
|
| 92 |
|
| 93 |
class EndpointHandler():
|
| 94 |
def __init__(self, path=''):
|
|
|
|
| 97 |
)
|
| 98 |
self.birefnet.to(device)
|
| 99 |
self.birefnet.eval()
|
| 100 |
+
if half_precision:
|
| 101 |
+
self.birefnet.half()
|
| 102 |
|
| 103 |
def __call__(self, data: Dict[str, Any]):
|
| 104 |
"""
|
|
|
|
| 128 |
|
| 129 |
# Prediction
|
| 130 |
with torch.no_grad():
|
| 131 |
+
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
|
| 132 |
pred = preds[0].squeeze()
|
| 133 |
|
| 134 |
# Show Results
|