Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -157,7 +157,12 @@ def predict(images, resolution, weights_file):
|
|
| 157 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
| 158 |
image_masked.save(save_file_path)
|
| 159 |
save_paths.append(save_file_path)
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
if tab_is_batch:
|
| 163 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
|
|
|
| 157 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
| 158 |
image_masked.save(save_file_path)
|
| 159 |
save_paths.append(save_file_path)
|
| 160 |
+
|
| 161 |
+
# Apply the prediction mask to the original image
|
| 162 |
+
pred = torch.nn.functional.interpolate(pred, size=image_shape, mode='bilinear', align_corners=True).numpy()
|
| 163 |
+
image_pil = image_pil.resize(pred.shape[::-1])
|
| 164 |
+
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
| 165 |
+
image_masked = (pred * np.array(image_pil)).astype(np.uint8)
|
| 166 |
|
| 167 |
if tab_is_batch:
|
| 168 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|