Spaces:
Runtime error
Runtime error
Commit
Β·
53ff575
1
Parent(s):
fa3042b
Change all predicted results as refined RGBA images. Fix a typo in device specification.
Browse files
app.py
CHANGED
|
@@ -21,7 +21,7 @@ import zipfile
|
|
| 21 |
torch.set_float32_matmul_precision('high')
|
| 22 |
torch.jit.script = lambda f: f
|
| 23 |
|
| 24 |
-
device = "cuda" if torch.cuda.is_available() else "
|
| 25 |
|
| 26 |
### image_proc.py
|
| 27 |
def refine_foreground(image, mask, r=90):
|
|
@@ -125,20 +125,18 @@ def predict(images, resolution, weights_file):
|
|
| 125 |
for idx_image, image_src in enumerate(images):
|
| 126 |
if isinstance(image_src, str):
|
| 127 |
if os.path.isfile(image_src):
|
| 128 |
-
|
| 129 |
else:
|
| 130 |
response = requests.get(image_src)
|
| 131 |
image_data = BytesIO(response.content)
|
| 132 |
-
|
| 133 |
else:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
image_shape = image.shape[:2]
|
| 137 |
-
image_pil = array_to_pil_image(image, tuple(resolution))
|
| 138 |
|
|
|
|
| 139 |
# Preprocess the image
|
| 140 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 141 |
-
image_proc = image_preprocessor.proc(
|
| 142 |
image_proc = image_proc.unsqueeze(0)
|
| 143 |
|
| 144 |
# Prediction
|
|
@@ -148,8 +146,8 @@ def predict(images, resolution, weights_file):
|
|
| 148 |
|
| 149 |
# Show Results
|
| 150 |
pred_pil = transforms.ToPILImage()(pred)
|
| 151 |
-
image_masked = refine_foreground(
|
| 152 |
-
image_masked.putalpha(pred_pil.resize(
|
| 153 |
|
| 154 |
torch.cuda.empty_cache()
|
| 155 |
|
|
@@ -158,12 +156,6 @@ def predict(images, resolution, weights_file):
|
|
| 158 |
image_masked.save(save_file_path)
|
| 159 |
save_paths.append(save_file_path)
|
| 160 |
|
| 161 |
-
# Apply the prediction mask to the original image
|
| 162 |
-
pred = torch.nn.functional.interpolate(preds, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
| 163 |
-
image_pil = image_pil.resize(pred.shape[::-1])
|
| 164 |
-
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
| 165 |
-
image_masked = (pred * np.array(image_pil)).astype(np.uint8)
|
| 166 |
-
|
| 167 |
if tab_is_batch:
|
| 168 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 169 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
|
@@ -171,7 +163,7 @@ def predict(images, resolution, weights_file):
|
|
| 171 |
zipf.write(file, os.path.basename(file))
|
| 172 |
return save_paths, zip_file_path
|
| 173 |
else:
|
| 174 |
-
return
|
| 175 |
|
| 176 |
|
| 177 |
examples = [[_] for _ in glob('examples/*')][:]
|
|
|
|
| 21 |
torch.set_float32_matmul_precision('high')
|
| 22 |
torch.jit.script = lambda f: f
|
| 23 |
|
| 24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
|
| 26 |
### image_proc.py
|
| 27 |
def refine_foreground(image, mask, r=90):
|
|
|
|
| 125 |
for idx_image, image_src in enumerate(images):
|
| 126 |
if isinstance(image_src, str):
|
| 127 |
if os.path.isfile(image_src):
|
| 128 |
+
image_ori = Image.open(image_src)
|
| 129 |
else:
|
| 130 |
response = requests.get(image_src)
|
| 131 |
image_data = BytesIO(response.content)
|
| 132 |
+
image_ori = Image.open(image_data)
|
| 133 |
else:
|
| 134 |
+
image_ori = Image.fromarray(image_src)
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
image = image_ori.convert('RGB')
|
| 137 |
# Preprocess the image
|
| 138 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 139 |
+
image_proc = image_preprocessor.proc(image)
|
| 140 |
image_proc = image_proc.unsqueeze(0)
|
| 141 |
|
| 142 |
# Prediction
|
|
|
|
| 146 |
|
| 147 |
# Show Results
|
| 148 |
pred_pil = transforms.ToPILImage()(pred)
|
| 149 |
+
image_masked = refine_foreground(image, pred_pil)
|
| 150 |
+
image_masked.putalpha(pred_pil.resize(image.size))
|
| 151 |
|
| 152 |
torch.cuda.empty_cache()
|
| 153 |
|
|
|
|
| 156 |
image_masked.save(save_file_path)
|
| 157 |
save_paths.append(save_file_path)
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
if tab_is_batch:
|
| 160 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 161 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
|
|
|
| 163 |
zipf.write(file, os.path.basename(file))
|
| 164 |
return save_paths, zip_file_path
|
| 165 |
else:
|
| 166 |
+
return (image_ori, image_masked)
|
| 167 |
|
| 168 |
|
| 169 |
examples = [[_] for _ in glob('examples/*')][:]
|