Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| #mls affine inv | |
| def mls_affine_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0): | |
| ''' Affine inverse deformation | |
| ### Params: | |
| * image - ndarray: original image | |
| * p - ndarray: an array with size [n, 2], original control points | |
| * q - ndarray: an array with size [n, 2], final control points | |
| * alpha - float: parameter used by weights | |
| * density - float: density of the grids | |
| ### Return: | |
| A deformed image. | |
| ''' | |
| # height = image.shape[0] | |
| # width = image.shape[1] | |
| # Change (x, y) to (row, col) | |
| q = q[:, [1, 0]] | |
| p = p[:, [1, 0]] | |
| # Make grids on the original image | |
| gridX = np.linspace(0, width, num=int(width*density), endpoint=False) | |
| gridY = np.linspace(0, height, num=int(height*density), endpoint=False) | |
| vy, vx = np.meshgrid(gridX, gridY) | |
| grow = vx.shape[0] # grid rows | |
| gcol = vx.shape[1] # grid cols | |
| ctrls = p.shape[0] # control points | |
| # Compute | |
| reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] | |
| reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] | |
| reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] | |
| w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] | |
| w[w == np.inf] = 2**31 - 1 | |
| pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] | |
| qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] | |
| reshaped_phat = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] | |
| reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 2, 1, grow, gcol] | |
| reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] | |
| pTwq = np.sum(reshaped_phat * reshaped_w * reshaped_qhat, axis=0) # [2, 2, grow, gcol] | |
| try: | |
| inv_pTwq = np.linalg.inv(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2] | |
| flag = False | |
| except np.linalg.linalg.LinAlgError: | |
| flag = True | |
| det = np.linalg.det(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol] | |
| det[det < 1e-8] = np.inf | |
| reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol] | |
| adjoint = pTwq[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol] | |
| adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol] | |
| inv_pTwq = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] | |
| mul_left = reshaped_v - qstar # [2, grow, gcol] | |
| reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2] | |
| mul_right = np.sum(reshaped_phat * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol] | |
| reshaped_mul_right =mul_right.transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] | |
| temp = np.matmul(np.matmul(reshaped_mul_left, inv_pTwq), reshaped_mul_right) # [grow, gcol, 1, 2] | |
| reshaped_temp = temp.reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] | |
| # Get final image transfomer -- 3-D array | |
| transformers = reshaped_temp + pstar # [2, grow, gcol] | |
| # Correct the points where pTwp is singular | |
| if flag: | |
| blidx = det == np.inf # bool index | |
| transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] | |
| transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] | |
| # Removed the points outside the border | |
| transformers[transformers < 0] = 0 | |
| transformers[0][transformers[0] > height - 1] = 0 | |
| transformers[1][transformers[1] > width - 1] = 0 | |
| # Mapping original image | |
| transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] | |
| # Rescale image | |
| # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') | |
| return transformers.astype(np.float), transformed_image | |
| def mls_affine_deformation_inv_final(height, width, channel, p, q, alpha=1.0, density=1.0): | |
| ''' Affine inverse deformation | |
| ### Params: | |
| * image - ndarray: original image | |
| * p - ndarray: an array with size [n, 2], original control points | |
| * q - ndarray: an array with size [n, 2], final control points | |
| * alpha - float: parameter used by weights | |
| * density - float: density of the grids | |
| ### Return: | |
| A deformed image. | |
| ''' | |
| # height = image.shape[0] | |
| # width = image.shape[1] | |
| # Change (x, y) to (row, col) | |
| q = q[:, [1, 0]] | |
| p = p[:, [1, 0]] | |
| # Make grids on the original image | |
| gridX = np.linspace(0, width, num=int(width*density), endpoint=False) | |
| gridY = np.linspace(0, height, num=int(height*density), endpoint=False) | |
| vy, vx = np.meshgrid(gridX, gridY) | |
| grow = vx.shape[0] # grid rows | |
| gcol = vx.shape[1] # grid cols | |
| ctrls = p.shape[0] # control points | |
| # Compute | |
| reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] | |
| reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] | |
| reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] | |
| w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] | |
| w[w == np.inf] = 2**31 - 1 | |
| pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] | |
| qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] | |
| reshaped_phat = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] | |
| reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 2, 1, grow, gcol] | |
| reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] | |
| pTwq = np.sum(reshaped_phat * reshaped_w * reshaped_qhat, axis=0) # [2, 2, grow, gcol] | |
| try: | |
| inv_pTwq = np.linalg.inv(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2] | |
| flag = False | |
| except np.linalg.linalg.LinAlgError: | |
| flag = True | |
| det = np.linalg.det(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol] | |
| det[det < 1e-8] = np.inf | |
| reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol] | |
| adjoint = pTwq[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol] | |
| adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol] | |
| inv_pTwq = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] | |
| mul_left = reshaped_v - qstar # [2, grow, gcol] | |
| reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2] | |
| mul_right = np.sum(reshaped_phat * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol] | |
| reshaped_mul_right =mul_right.transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] | |
| temp = np.matmul(np.matmul(reshaped_mul_left, inv_pTwq), reshaped_mul_right) # [grow, gcol, 1, 2] | |
| reshaped_temp = temp.reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] | |
| # Get final image transfomer -- 3-D array | |
| transformers = reshaped_temp + pstar # [2, grow, gcol] | |
| # Correct the points where pTwp is singular | |
| if flag: | |
| blidx = det == np.inf # bool index | |
| transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] | |
| transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] | |
| # Removed the points outside the border | |
| transformers[transformers < 0] = 0 | |
| transformers[0][transformers[0] > height - 1] = 0 | |
| transformers[1][transformers[1] > width - 1] = 0 | |
| return transformers | |
| def mls_similarity_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0): | |
| ''' Similarity inverse deformation | |
| ### Params: | |
| * image - ndarray: original image | |
| * p - ndarray: an array with size [n, 2], original control points | |
| * q - ndarray: an array with size [n, 2], final control points | |
| * alpha - float: parameter used by weights | |
| * density - float: density of the grids | |
| ### Return: | |
| A deformed image. | |
| ''' | |
| height = image.shape[0] | |
| width = image.shape[1] | |
| # Change (x, y) to (row, col) | |
| q = q[:, [1, 0]] | |
| p = p[:, [1, 0]] | |
| # Make grids on the original image | |
| gridX = np.linspace(0, width, num=int(width*density), endpoint=False) | |
| gridY = np.linspace(0, height, num=int(height*density), endpoint=False) | |
| vy, vx = np.meshgrid(gridX, gridY) | |
| grow = vx.shape[0] # grid rows | |
| gcol = vx.shape[1] # grid cols | |
| ctrls = p.shape[0] # control points | |
| # Compute | |
| reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] | |
| reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] | |
| reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] | |
| w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] | |
| w[w == np.inf] = 2**31 - 1 | |
| pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] | |
| qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] | |
| reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] | |
| reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] | |
| mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) * | |
| reshaped_phat1.transpose(0, 3, 4, 1, 2), | |
| reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1] | |
| reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol] | |
| neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] | |
| neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] | |
| reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] | |
| mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] | |
| Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2), | |
| mul_right.transpose(0, 3, 4, 1, 2)), | |
| axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1] | |
| Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1] | |
| Delta_verti[...,0,:] = -Delta_verti[...,0,:] | |
| B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2] | |
| try: | |
| inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2] | |
| flag = False | |
| except np.linalg.linalg.LinAlgError: | |
| flag = True | |
| det = np.linalg.det(B) # [grow, gcol] | |
| det[det < 1e-8] = np.inf | |
| reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] | |
| adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] | |
| adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2] | |
| inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol] | |
| v_minus_qstar_mul_mu = (reshaped_v - qstar) * reshaped_mu # [2, grow, gcol] | |
| # Get final image transfomer -- 3-D array | |
| reshaped_v_minus_qstar_mul_mu = v_minus_qstar_mul_mu.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] | |
| transformers = np.matmul(reshaped_v_minus_qstar_mul_mu.transpose(2, 3, 0, 1), | |
| inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) + pstar # [2, grow, gcol] | |
| # Correct the points where pTwp is singular | |
| if flag: | |
| blidx = det == np.inf # bool index | |
| transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] | |
| transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] | |
| # Removed the points outside the border | |
| transformers[transformers < 0] = 0 | |
| transformers[0][transformers[0] > height - 1] = 0 | |
| transformers[1][transformers[1] > width - 1] = 0 | |
| # Mapping original image | |
| transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] | |
| # Rescale image | |
| # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') | |
| return transformers, transformed_image | |
| def mls_rigid_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0): | |
| ''' Rigid inverse deformation | |
| ### Params: | |
| * image - ndarray: original image | |
| * p - ndarray: an array with size [n, 2], original control points | |
| * q - ndarray: an array with size [n, 2], final control points | |
| * alpha - float: parameter used by weights | |
| * density - float: density of the grids | |
| ### Return: | |
| A deformed image. | |
| ''' | |
| height = image.shape[0] | |
| width = image.shape[1] | |
| # Change (x, y) to (row, col) | |
| q = q[:, [1, 0]] | |
| p = p[:, [1, 0]] | |
| # Make grids on the original image | |
| gridX = np.linspace(0, width, num=int(width*density), endpoint=False) | |
| gridY = np.linspace(0, height, num=int(height*density), endpoint=False) | |
| vy, vx = np.meshgrid(gridX, gridY) | |
| grow = vx.shape[0] # grid rows | |
| gcol = vx.shape[1] # grid cols | |
| ctrls = p.shape[0] # control points | |
| # Compute | |
| reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] | |
| reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] | |
| reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] | |
| w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] | |
| w[w == np.inf] = 2**31 - 1 | |
| pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] | |
| qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] | |
| qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] | |
| reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] | |
| reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] | |
| mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) * | |
| reshaped_phat1.transpose(0, 3, 4, 1, 2), | |
| reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1] | |
| reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol] | |
| neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] | |
| neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] | |
| reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] | |
| mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] | |
| Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2), | |
| mul_right.transpose(0, 3, 4, 1, 2)), | |
| axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1] | |
| Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1] | |
| Delta_verti[...,0,:] = -Delta_verti[...,0,:] | |
| B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2] | |
| try: | |
| inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2] | |
| flag = False | |
| except np.linalg.linalg.LinAlgError: | |
| flag = True | |
| det = np.linalg.det(B) # [grow, gcol] | |
| det[det < 1e-8] = np.inf | |
| reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] | |
| adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] | |
| adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2] | |
| inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol] | |
| vqstar = reshaped_v - qstar # [2, grow, gcol] | |
| reshaped_vqstar = vqstar.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] | |
| # Get final image transfomer -- 3-D array | |
| temp = np.matmul(reshaped_vqstar.transpose(2, 3, 0, 1), | |
| inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] | |
| norm_temp = np.linalg.norm(temp, axis=0, keepdims=True) # [1, grow, gcol] | |
| norm_vqstar = np.linalg.norm(vqstar, axis=0, keepdims=True) # [1, grow, gcol] | |
| transformers = temp / norm_temp * norm_vqstar + pstar # [2, grow, gcol] | |
| # Correct the points where pTwp is singular | |
| if flag: | |
| blidx = det == np.inf # bool index | |
| transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] | |
| transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] | |
| # Removed the points outside the border | |
| transformers[transformers < 0] = 0 | |
| transformers[0][transformers[0] > height - 1] = 0 | |
| transformers[1][transformers[1] > width - 1] = 0 | |
| # Mapping original image | |
| transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] | |
| # Rescale image | |
| # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') | |
| return transformers, transformed_image | |
| # mls rigid algorithm | |
| def mls_rigid_deformation_inv_wy(image, height, width, channel, p, q, alpha=1.0, density=1.0): | |
| ''' | |
| Rigid inverse deformation | |
| ### Params: | |
| * image - ndarray: original image | |
| * p - ndarray: an array with size [n, 2], original control points | |
| * q - ndarray: an array with size [n, 2], final control points | |
| * alpha - float: parameter used by weights | |
| * density - float: density of the grids | |
| ### Return: | |
| A deformed image. | |
| ''' | |
| # Change (x, y) to (row, col) | |
| q = q[:, [1, 0]] | |
| p = p[:, [1, 0]] | |
| # Make grids on the original image | |
| gridX = torch.linspace(0, width, steps=int(width * density)) | |
| gridY = torch.linspace(0, height, steps=int(height * density)) | |
| vx, vy = torch.meshgrid(gridY, gridX) | |
| grow = vx.shape[0] # grid rows | |
| gcol = vx.shape[1] # grid cols | |
| ctrls = p.shape[0] # control points | |
| # Compute | |
| reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] | |
| reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] | |
| reshaped_v = torch.stack((vx, vy), dim=0) # [2, grow, gcol] | |
| w = 1.0 / torch.sum((reshaped_p - reshaped_v) ** 2, dim=1) ** alpha # [ctrls, grow, gcol] | |
| w[w == torch.tensor(float("inf"))] = 2 ** 31 - 1 | |
| pstar = torch.sum(w * reshaped_p.permute(1, 0, 2, 3), dim=1) / torch.sum(w, dim=0) # [2, grow, gcol] | |
| phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] | |
| qstar = torch.sum(w * reshaped_q.permute(1, 0, 2, 3), dim=1) / torch.sum(w, dim=0) # [2, grow, gcol] | |
| qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] | |
| reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] | |
| neg_phat_verti = phat[:, [1, 0], ...] # [ctrls, 2, grow, gcol] | |
| neg_phat_verti[:, 1, ...] = -neg_phat_verti[:, 1, ...] | |
| reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] | |
| mul_right = torch.cat((reshaped_phat1, reshaped_neg_phat_verti), dim=1) # [ctrls, 2, 2, grow, gcol] | |
| mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] | |
| Delta = torch.sum(torch.matmul(mul_left.permute(0, 3, 4, 1, 2), | |
| mul_right.permute(0, 3, 4, 1, 2)), | |
| dim=0).permute(0, 1, 3, 2) # [grow, gcol, 2, 1] | |
| Delta_verti = Delta[..., [1, 0], :] # [grow, gcol, 2, 1] | |
| Delta_verti[..., 0, :] = -Delta_verti[..., 0, :] | |
| B = torch.cat((Delta, Delta_verti), dim=3) # [grow, gcol, 2, 2] | |
| try: | |
| inv_B = torch.inverse(B) # [grow, gcol, 2, 2] | |
| flag = False | |
| except: | |
| flag = True | |
| det = np.linalg.det(B.numpy()) | |
| det = torch.from_numpy(det) # [grow, gcol] | |
| det[det < 1e-8] = torch.tensor(float("inf")) | |
| reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] | |
| adjoint = B[:, :, [[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] | |
| adjoint[:, :, [0, 1], [1, 0]] = -adjoint[:, :, [0, 1], [1, 0]] # [grow, gcol, 2, 2] | |
| inv_B = (adjoint / reshaped_det).permute(2, 3, 0, 1) # [2, 2, grow, gcol] | |
| vqstar = reshaped_v - qstar # [2, grow, gcol] | |
| reshaped_vqstar = vqstar.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] | |
| # Get final image transfomer -- 3-D array | |
| temp = torch.matmul(reshaped_vqstar.permute(2, 3, 0, 1), | |
| inv_B).reshape(grow, gcol, 2).permute(2, 0, 1) # [2, grow, gcol] | |
| norm_temp = torch.norm(temp, dim=0, keepdim=True) # [1, grow, gcol] | |
| norm_vqstar = torch.norm(vqstar, dim=0, keepdim=True) # [1, grow, gcol] | |
| transformers = temp / (norm_temp + 1e-10) * norm_vqstar + pstar # [2, grow, gcol] | |
| # Correct the points where pTwp is singular | |
| if flag: | |
| blidx = det == torch.tensor(float("inf")) # bool index | |
| transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] | |
| transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] | |
| # Removed the points outside the border | |
| transformers[transformers < 0] = 0 | |
| transformers[0][transformers[0] > height - 1] = 0 | |
| transformers[1][transformers[1] > width - 1] = 0 | |
| # Mapping original image | |
| transformed_image = image[tuple(transformers.numpy().astype(np.int16))] # [grow, gcol] | |
| # # Rescale image | |
| # img_h, img_w, _ = transformed_image.shape | |
| # transformed_image = transformed_image.resize_(int(img_h/density), int(img_w/density), channel) | |
| return transformers.numpy(), transformed_image | |
| # mls for whole feature, instead of roi align | |
| def roi_mls_whole(feature, d_point, g_point, step=1): | |
| ''' | |
| :param feature: itorchut guidance feature [C, H, W] | |
| :param d_point: landmark for degraded feature [N, 2] | |
| :param g_point: landmark for guidance feature [N, 2] | |
| :param step: step of landmark choose, number of control points: landmark_number/step | |
| :return: transformed feature [C, H, W] | |
| ''' | |
| # feature 3 * 256 * 256 | |
| channel = feature.size(0) | |
| height = feature.size(1) # 256 * 256 | |
| width = feature.size(2) | |
| # ignore the boarder point of face | |
| g_land = g_point[0::step, :] | |
| d_land = d_point[0::step, :] | |
| # mls | |
| featureTmp = feature.permute(1,2,0) | |
| # grid, timg = mls_rigid_deformation_inv_wy(featureTmp.cpu(), height, width, channel, g_land.cpu(), d_land.cpu(), density=1.) # 2 * 256 * 256 # wenyu | |
| grid, timg = mls_affine_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #affine prefered | |
| # grid, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) # similarity | |
| # grid, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #rigid | |
| grid = (grid - 127.5) / 127.5 | |
| gridNew = torch.from_numpy(grid[[1,0],:,:]).float().permute(1,2,0).unsqueeze(0) | |
| if torch.cuda.is_available(): | |
| gridNew = gridNew.cuda() | |
| featureNew = feature.unsqueeze(0) | |
| # warp_feature = F.grid_sample(featureNew,gridNew.cuda(),mode='nearest') | |
| warp_feature = F.grid_sample(featureNew,gridNew.cuda()) | |
| # HWC -> CHW | |
| # _, timg_affine = mls_affine_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) | |
| # _, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) | |
| # _, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) | |
| # return warp_feature.squeeze(), timg.permute(2,0,1) #, timg_sim.permute(2,0,1), timg_rigid.permute(2,0,1) | |
| return warp_feature.squeeze(), gridNew.cuda() | |
| def roi_mls_whole_final(feature, d_point, g_point, step=1): | |
| ''' | |
| :param feature: itorchut guidance feature [C, H, W] | |
| :param d_point: landmark for degraded feature [N, 2] | |
| :param g_point: landmark for guidance feature [N, 2] | |
| :param step: step of landmark choose, number of control points: landmark_number/step | |
| :return: transformed feature [C, H, W] | |
| ''' | |
| # feature 3 * 256 * 256 | |
| channel = feature.size(0) | |
| height = feature.size(1) # 256 * 256 | |
| width = feature.size(2) | |
| # ignore the boarder point of face | |
| g_land = g_point[0::step, :] | |
| d_land = d_point[0::step, :] | |
| # mls | |
| # featureTmp = feature.permute(1,2,0) | |
| # grid, timg = mls_rigid_deformation_inv_wy(featureTmp.cpu(), height, width, channel, g_land.cpu(), d_land.cpu(), density=1.) # 2 * 256 * 256 # wenyu | |
| grid = mls_affine_deformation_inv_final(height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #affine prefered | |
| # grid, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) # similarity | |
| # grid, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #rigid | |
| grid = (grid - height/2) / (height/2) | |
| gridNew = torch.from_numpy(grid[[1,0],:,:]).float().permute(1,2,0).unsqueeze(0) | |
| if torch.cuda.is_available(): | |
| gridNew = gridNew.cuda() | |
| # HWC -> CHW | |
| return gridNew |