Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import numpy as np | |
| from skimage import color | |
| import torch | |
| import torch.nn.functional as F | |
| from IPython import embed | |
| def load_img(img_path): | |
| out_np = np.asarray(Image.open(img_path)) | |
| if(out_np.ndim==2): | |
| out_np = np.tile(out_np[:,:,None],3) | |
| return out_np | |
| def resize_img(img, HW=(256,256), resample=3): | |
| return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) | |
| def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): | |
| # return original size L and resized L as torch Tensors | |
| img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) | |
| img_lab_orig = color.rgb2lab(img_rgb_orig) | |
| img_lab_rs = color.rgb2lab(img_rgb_rs) | |
| img_l_orig = img_lab_orig[:,:,0] | |
| img_l_rs = img_lab_rs[:,:,0] | |
| tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:] | |
| tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:] | |
| return (tens_orig_l, tens_rs_l) | |
| def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): | |
| # tens_orig_l 1 x 1 x H_orig x W_orig | |
| # out_ab 1 x 2 x H x W | |
| HW_orig = tens_orig_l.shape[2:] | |
| HW = out_ab.shape[2:] | |
| # call resize function if needed | |
| if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): | |
| out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') | |
| else: | |
| out_ab_orig = out_ab | |
| out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) | |
| return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0))) | |