leonelhs commited on
Commit
4b2e55f
·
verified ·
1 Parent(s): d8aa408

Update app.py

Browse files

code minimized

Files changed (1) hide show
  1. app.py +120 -186
app.py CHANGED
@@ -1,196 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import gradio as gr
3
- from pathlib import Path
4
- from PIL import Image
5
  import numpy as np
6
  import torch
7
- from torch.autograd import Variable
8
- from torchvision import transforms
9
  import torch.nn.functional as F
10
- import matplotlib.pyplot as plt
11
- import warnings
12
- from zipfile import ZipFile
13
-
14
- warnings.filterwarnings("ignore")
15
-
16
 
17
  # project imports
18
- from data_loader_cache import normalize, im_reader, im_preprocess
19
- from models import *
 
20
 
21
- #Helpers
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
 
24
-
25
- class GOSNormalize(object):
26
- '''
27
- Normalize the Image using torch.transforms
28
- '''
29
- def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
30
- self.mean = mean
31
- self.std = std
32
-
33
- def __call__(self,image):
34
- image = normalize(image,self.mean,self.std)
35
- return image
36
-
37
-
38
- transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
39
-
40
- def load_image(im_path, hypar):
41
- im = im_reader(im_path)
42
- im, im_shp = im_preprocess(im, hypar["cache_size"])
43
- im = torch.divide(im,255.0)
44
- shape = torch.from_numpy(np.array(im_shp))
45
- return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
46
-
47
-
48
- def build_model(hypar,device):
49
- net = hypar["model"]#GOSNETINC(3,1)
50
-
51
- # convert to half precision
52
- if(hypar["model_digit"]=="half"):
53
- net.half()
54
- for layer in net.modules():
55
- if isinstance(layer, nn.BatchNorm2d):
56
- layer.float()
57
-
58
- net.to(device)
59
-
60
- if(hypar["restore_model"]!=""):
61
- net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
62
- net.to(device)
63
- net.eval()
64
- return net
65
-
66
-
67
- def predict(net, inputs_val, shapes_val, hypar, device):
68
- '''
69
- Given an Image, predict the mask
70
- '''
71
- net.eval()
72
-
73
- if(hypar["model_digit"]=="full"):
74
- inputs_val = inputs_val.type(torch.FloatTensor)
75
  else:
76
- inputs_val = inputs_val.type(torch.HalfTensor)
77
-
78
-
79
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
80
-
81
- ds_val = net(inputs_val_v)[0] # list of 6 results
82
-
83
- pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
84
-
85
- ## recover the prediction spatial size to the orignal image size
86
- pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
87
-
88
- ma = torch.max(pred_val)
89
- mi = torch.min(pred_val)
90
- pred_val = (pred_val-mi)/(ma-mi) # max = 1
91
-
92
- if device == 'cuda': torch.cuda.empty_cache()
93
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
94
-
95
- # Set Parameters
96
- hypar = {} # paramters for inferencing
97
-
98
-
99
- hypar["model_path"] ="./saved_models" ## load trained weights from this path
100
- hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
101
- hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
102
-
103
- ## choose floating point accuracy --
104
- hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
105
- hypar["seed"] = 0
106
-
107
- hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
108
-
109
- ## data augmentation parameters ---
110
- hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
111
- hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
112
-
113
- hypar["model"] = ISNetDIS()
114
-
115
- # Build Model
116
- net = build_model(hypar, device)
117
-
118
-
119
- def inference(image_path):
120
-
121
- image_tensor, orig_size = load_image(image_path, hypar)
122
- mask = predict(net, image_tensor, orig_size, hypar, device)
123
-
124
- pil_mask = Image.fromarray(mask).convert('L')
125
- im_rgb = Image.open(image_path).convert("RGB")
126
-
127
- im_rgba = im_rgb.copy()
128
- im_rgba.putalpha(pil_mask)
129
- file_name = Path(image_path).stem+"_nobg.png"
130
- file_path = Path(Path(image_path).parent,file_name)
131
- im_rgba.save(file_path)
132
- return str(file_path.resolve())
133
-
134
- def bw(image_files):
135
- print(image_files)
136
- output = []
137
- for idx, file in enumerate(image_files):
138
- print(file.name)
139
- img = Image.open(file.name)
140
- img = img.convert("L")
141
- output.append(img)
142
- print(output)
143
- return output
144
-
145
- def bw_single(image_file):
146
- img = Image.open(image_file)
147
- img = img.convert("L")
148
- return img
149
-
150
- def batch(image_files):
151
- output = []
152
- for idx, file in enumerate(image_files):
153
- file = inference(file.name)
154
- output.append(file)
155
-
156
- with ZipFile("tmp.zip", "w") as zipObj:
157
- for idx, file in enumerate(output):
158
- zipObj.write(file, file.split("/")[-1])
159
- return output,"tmp.zip"
160
-
161
- with gr.Blocks() as iface:
162
- gr.Markdown("# Remove Background")
163
- gr.HTML("Uses <a href='https://github.com/xuebinqin/DIS'>DIS</a> to remove background")
164
- with gr.Tab("Single Image"):
165
- with gr.Row():
166
- with gr.Column():
167
- image = gr.Image(type='filepath')
168
- with gr.Column():
169
- image_output = gr.Image(interactive=False)
170
- with gr.Row():
171
- with gr.Column():
172
- single_removebg = gr.Button("Remove Bg")
173
- with gr.Column():
174
- single_clear = gr.Button("Clear")
175
-
176
-
177
- with gr.Tab("Batch"):
178
- with gr.Row():
179
- with gr.Column():
180
- images = gr.File(file_count="multiple", file_types=["image"])
181
- with gr.Column():
182
- gallery = gr.Gallery()
183
- file_list = gr.Files(interactive=False)
184
-
185
- with gr.Row():
186
- with gr.Column():
187
- batch_removebg = gr.Button("Batch Process")
188
- with gr.Column():
189
- batch_clear = gr.Button("Clear")
190
- #Events
191
- single_removebg.click(inference, inputs=image, outputs=image_output)
192
- batch_removebg.click(batch, inputs=images, outputs=[gallery,file_list])
193
- single_clear.click(lambda: None, None, image, queue=False)
194
- batch_clear.click(lambda: None, None, images, queue=False)
195
-
196
- iface.launch()
 
1
+ #######################################################################################
2
+ #
3
+ # MIT License
4
+ #
5
+ # Copyright (c) [2025] [leonelhs@gmail.com]
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+ #
25
+ #######################################################################################
26
+
27
+ # This file implements an API endpoint for DIS background image removal system.
28
+ #
29
+ # Source code is based on or inspired by several projects.
30
+ # For more details and proper attribution, please refer to the following resources:
31
+ #
32
+ # - [DIS] - [https://github.com/xuebinqin/DIS]
33
+
34
  import gradio as gr
 
 
 
35
  import numpy as np
36
  import torch
 
 
37
  import torch.nn.functional as F
38
+ from PIL import Image
39
+ from huggingface_hub import hf_hub_download
40
+ from torch.autograd import Variable
41
+ from torchvision.transforms.functional import normalize
 
 
42
 
43
  # project imports
44
+ from models.isnet import ISNetDIS
45
+
46
+ REPO_ID = "leonelhs/removators"
47
 
 
48
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
49
 
50
+ net = ISNetDIS()
51
+
52
+ model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.pth')
53
+ net.load_state_dict(torch.load(model_path, map_location=device))
54
+ net.to(device)
55
+ net.eval()
56
+
57
+ def im_preprocess(im,size):
58
+ if len(im.shape) < 3:
59
+ im = im[:, :, np.newaxis]
60
+ if im.shape[2] == 1:
61
+ im = np.repeat(im, 3, axis=2)
62
+ im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
63
+ im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
64
+ if len(size)<2:
65
+ return im_tensor, im.shape[0:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  else:
67
+ im_tensor = torch.unsqueeze(im_tensor,0)
68
+ im_tensor = F.interpolate(im_tensor, size, mode="bilinear")
69
+ im_tensor = torch.squeeze(im_tensor,0)
70
+
71
+ return im_tensor.type(torch.uint8), im.shape[0:2]
72
+
73
+
74
+ def predict(image):
75
+ """
76
+ Remove the background from an image.
77
+ The function extracts the foreground and generates both a background-removed
78
+ image and a binary mask.
79
+
80
+ Parameters:
81
+ image (string): File path to the input image.
82
+ Returns:
83
+ paths (tuple): paths for background-removed image and cutting mask.
84
+ """
85
+
86
+ im_tensor, shapes = im_preprocess(image, [1024, 1024])
87
+ shapes = torch.from_numpy(np.array(shapes)).unsqueeze(0)
88
+
89
+ im_tensor = torch.divide(im_tensor, 255.0)
90
+ im_tensor = normalize(im_tensor, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]).unsqueeze(0)
91
+ im_tensor_v = Variable(im_tensor, requires_grad=False) # wrap inputs in Variable
92
+ ds_val = net(im_tensor_v)[0] # list of 6 results
93
+ prediction = ds_val[0][0, :, :, :] # B x 1 x H x W # we want the first one which is the most accurate prediction
94
+ ## recover the prediction spatial size to the original image size
95
+ size = (shapes[0][0], shapes[0][1])
96
+ prediction = F.interpolate(torch.unsqueeze(prediction, 0), size, mode='bilinear')
97
+ prediction = torch.squeeze(prediction)
98
+
99
+ ma = torch.max(prediction)
100
+ mi = torch.min(prediction)
101
+ prediction = (prediction - mi) / (ma - mi) # max = 1
102
+
103
+ torch.cuda.empty_cache()
104
+ mask = (prediction.detach().cpu().numpy() * 255).astype(np.uint8) # it is the mask we need
105
+
106
+ mask = Image.fromarray(mask).convert('L')
107
+ image_rgb = Image.fromarray(image).convert("RGB")
108
+ image_rgb.putalpha(mask)
109
+ return image_rgb, mask
110
+
111
+ article = "<div><center>Unofficial demo from:<a href='https://github.com/xuebinqin/DIS'>DIS</<></center></div>"
112
+
113
+ with gr.Blocks(title="DIS") as app:
114
+ gr.Markdown("## Dichotomous Image Segmentation")
115
+ with gr.Row():
116
+ with gr.Column(scale=1):
117
+ inp = gr.Image(type="numpy", label="Upload Image")
118
+ btn_predict = gr.Button("Remove background")
119
+ with gr.Column(scale=2):
120
+ with gr.Row():
121
+ with gr.Column(scale=1):
122
+ out = gr.Image(type="filepath", label="Output image")
123
+ with gr.Accordion("See intermediates", open=False):
124
+ out_mask = gr.Image(type="filepath", label="Mask")
125
+
126
+ btn_predict.click(predict, inputs=inp, outputs=[out, out_mask])
127
+ gr.HTML(article)
128
+
129
+ app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
130
+ app.queue()