Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import numpy as np | |
| import paddle | |
| import paddle.nn.functional as F | |
| def reverse_transform(alpha, trans_info): | |
| """recover pred to origin shape""" | |
| for item in trans_info[::-1]: | |
| if item[0] == "resize": | |
| h, w = item[1][0], item[1][1] | |
| alpha = F.interpolate(alpha, [h, w], mode="bilinear") | |
| elif item[0] == "padding": | |
| h, w = item[1][0], item[1][1] | |
| alpha = alpha[:, :, 0:h, 0:w] | |
| else: | |
| raise Exception(f"Unexpected info '{item[0]}' in im_info") | |
| return alpha | |
| def preprocess(img, transforms, trimap=None): | |
| data = {} | |
| data["img"] = img | |
| if trimap is not None: | |
| data["trimap"] = trimap | |
| data["gt_fields"] = ["trimap"] | |
| data["trans_info"] = [] | |
| data = transforms(data) | |
| data["img"] = paddle.to_tensor(data["img"]) | |
| data["img"] = data["img"].unsqueeze(0) | |
| if trimap is not None: | |
| data["trimap"] = paddle.to_tensor(data["trimap"]) | |
| data["trimap"] = data["trimap"].unsqueeze((0, 1)) | |
| return data | |
| def predict( | |
| model, | |
| transforms, | |
| image: np.ndarray, | |
| trimap: Optional[np.ndarray] = None, | |
| ): | |
| with paddle.no_grad(): | |
| data = preprocess(img=image, transforms=transforms, trimap=None) | |
| alpha = model(data) | |
| alpha = reverse_transform(alpha, data["trans_info"]) | |
| alpha = alpha.numpy().squeeze() | |
| if trimap is not None: | |
| alpha[trimap == 0] = 0 | |
| alpha[trimap == 255] = 1. | |
| return alpha | |