Spaces:
Running
on
T4
Running
on
T4
| import torch | |
| import argparse | |
| import os | |
| import numpy as np | |
| from lightning_fabric import seed_everything | |
| from tqdm import tqdm | |
| import random | |
| import warnings | |
| from scipy.stats import entropy | |
| from sklearn.neighbors import NearestNeighbors | |
| from plyfile import PlyData | |
| from pathlib import Path | |
| import multiprocessing | |
| from chamfer_distance import ChamferDistance | |
| from eval.eval_pc_set import * | |
| N_POINTS = 2000 | |
| def find_files(folder, extension): | |
| return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)]) | |
| def read_ply(path): | |
| with open(path, 'rb') as f: | |
| plydata = PlyData.read(f) | |
| x = np.array(plydata['vertex']['x']) | |
| y = np.array(plydata['vertex']['y']) | |
| z = np.array(plydata['vertex']['z']) | |
| vertex = np.stack([x, y, z], axis=1) | |
| return vertex | |
| def distChamfer(a, b): | |
| x, y = a, b | |
| bs, num_points, points_dim = x.size() | |
| xx = torch.bmm(x, x.transpose(2, 1)) | |
| yy = torch.bmm(y, y.transpose(2, 1)) | |
| zz = torch.bmm(x, y.transpose(2, 1)) | |
| diag_ind = torch.arange(0, num_points).to(a).long() | |
| rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) | |
| ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) | |
| P = (rx.transpose(2, 1) + ry - 2 * zz) | |
| return P.min(1)[0], P.min(2)[0] | |
| def _pairwise_CD(sample_pcs, ref_pcs, batch_size): | |
| N_sample = sample_pcs.shape[0] | |
| N_ref = ref_pcs.shape[0] | |
| all_cd = [] | |
| all_emd = [] | |
| iterator = range(N_sample) | |
| matched_gt = [] | |
| pbar = tqdm(iterator) | |
| chamfer_dist = ChamferDistance() | |
| for sample_b_start in pbar: | |
| sample_batch = sample_pcs[sample_b_start] | |
| cd_lst = [] | |
| emd_lst = [] | |
| for ref_b_start in range(0, N_ref, batch_size): | |
| ref_b_end = min(N_ref, ref_b_start + batch_size) | |
| ref_batch = ref_pcs[ref_b_start:ref_b_end] | |
| batch_size_ref = ref_batch.size(0) | |
| sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1) | |
| sample_batch_exp = sample_batch_exp.contiguous() | |
| dl, dr, idx1, idx2 = chamfer_dist(sample_batch_exp, ref_batch) | |
| cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) | |
| cd_lst = torch.cat(cd_lst, dim=1) | |
| all_cd.append(cd_lst) | |
| hit = np.argmin(cd_lst.detach().cpu().numpy()[0]) | |
| matched_gt.append(hit) | |
| pbar.set_postfix({"cov": len(np.unique(matched_gt)) * 1.0 / N_ref}) | |
| all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref | |
| return all_cd | |
| def compute_cov_mmd(sample_pcs, ref_pcs, batch_size): | |
| all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size) | |
| N_sample, N_ref = all_dist.size(0), all_dist.size(1) | |
| min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) | |
| min_val, _ = torch.min(all_dist, dim=0) | |
| mmd = min_val.mean() | |
| cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) | |
| cov = torch.tensor(cov).to(all_dist) | |
| return { | |
| 'MMD-CD': mmd.item(), | |
| 'COV-CD': cov.item(), | |
| }, min_idx.cpu().numpy() | |
| def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28): | |
| '''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models | |
| For 3D Point Clouds```. | |
| Args: | |
| sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. | |
| ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. | |
| resolution: (int) grid-resolution. Affects granularity of measurements. | |
| ''' | |
| sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] | |
| ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] | |
| return jensen_shannon_divergence(sample_grid_var, ref_grid_var) | |
| def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): | |
| '''Given a collection of point-clouds, estimate the entropy of the random variables | |
| corresponding to occupancy-grid activation patterns. | |
| Inputs: | |
| pclouds: (numpy array) #point-clouds x points per point-cloud x 3 | |
| grid_resolution (int) size of occupancy grid that will be used. | |
| ''' | |
| epsilon = 10e-4 | |
| bound = 1 + epsilon | |
| if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: | |
| print(abs(np.max(pclouds)), abs(np.min(pclouds))) | |
| warnings.warn('Point-clouds are not in unit cube.') | |
| if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: | |
| warnings.warn('Point-clouds are not in unit sphere.') | |
| grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) | |
| grid_coordinates = grid_coordinates.reshape(-1, 3) | |
| grid_counters = np.zeros(len(grid_coordinates)) | |
| grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) | |
| nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) | |
| for pc in pclouds: | |
| _, indices = nn.kneighbors(pc) | |
| indices = np.squeeze(indices) | |
| for i in indices: | |
| grid_counters[i] += 1 | |
| indices = np.unique(indices) | |
| for i in indices: | |
| grid_bernoulli_rvars[i] += 1 | |
| acc_entropy = 0.0 | |
| n = float(len(pclouds)) | |
| for g in grid_bernoulli_rvars: | |
| p = 0.0 | |
| if g > 0: | |
| p = float(g) / n | |
| acc_entropy += entropy([p, 1.0 - p]) | |
| return acc_entropy / len(grid_counters), grid_counters | |
| def unit_cube_grid_point_cloud(resolution, clip_sphere=False): | |
| '''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, | |
| that is placed in the unit-cube. | |
| If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. | |
| ''' | |
| grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) | |
| spacing = 1.0 / float(resolution - 1) * 2 | |
| for i in range(resolution): | |
| for j in range(resolution): | |
| for k in range(resolution): | |
| grid[i, j, k, 0] = i * spacing - 0.5 * 2 | |
| grid[i, j, k, 1] = j * spacing - 0.5 * 2 | |
| grid[i, j, k, 2] = k * spacing - 0.5 * 2 | |
| if clip_sphere: | |
| grid = grid.reshape(-1, 3) | |
| grid = grid[np.linalg.norm(grid, axis=1) <= 0.5] | |
| return grid, spacing | |
| def jensen_shannon_divergence(P, Q): | |
| if np.any(P < 0) or np.any(Q < 0): | |
| raise ValueError('Negative values.') | |
| if len(P) != len(Q): | |
| raise ValueError('Non equal size.') | |
| P_ = P / np.sum(P) # Ensure probabilities. | |
| Q_ = Q / np.sum(Q) | |
| e1 = entropy(P_, base=2) | |
| e2 = entropy(Q_, base=2) | |
| e_sum = entropy((P_ + Q_) / 2.0, base=2) | |
| res = e_sum - ((e1 + e2) / 2.0) | |
| res2 = _jsdiv(P_, Q_) | |
| if not np.allclose(res, res2, atol=10e-5, rtol=0): | |
| warnings.warn('Numerical values of two JSD methods don\'t agree.') | |
| return res | |
| def _jsdiv(P, Q): | |
| '''another way of computing JSD''' | |
| def _kldiv(A, B): | |
| a = A.copy() | |
| b = B.copy() | |
| idx = np.logical_and(a > 0, b > 0) | |
| a = a[idx] | |
| b = b[idx] | |
| return np.sum([v for v in a * np.log2(a / b)]) | |
| P_ = P / np.sum(P) | |
| Q_ = Q / np.sum(Q) | |
| M = 0.5 * (P_ + Q_) | |
| return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) | |
| def downsample_pc(points, n): | |
| sample_idx = random.sample(list(range(points.shape[0])), n) | |
| return points[sample_idx] | |
| def normalize_pc(points): | |
| # normalize | |
| mean = np.mean(points, axis=0) | |
| points = (points - mean) | |
| # fit to unit cube | |
| scale = np.max(np.abs(points)) | |
| points = points / scale | |
| return points | |
| def align_pc(points): | |
| # 1. Center the point cloud | |
| centroid = np.mean(points, axis=0) | |
| centered_points = points - centroid | |
| # 2. Calculate the three edge lengths of bbox | |
| min_coords = np.min(centered_points, axis=0) | |
| max_coords = np.max(centered_points, axis=0) | |
| dimensions = max_coords - min_coords | |
| # 3. Sort axes by dimension length to get axis order | |
| axis_order = np.argsort(dimensions)[::-1] # sort from longest to shortest | |
| # 4. Create permutation matrix (align longest edge to x, shortest to y) | |
| perm_matrix = np.zeros((3, 3)) | |
| perm_matrix[0, axis_order[0]] = 1 # longest edge -> x | |
| perm_matrix[1, axis_order[2]] = 1 # shortest edge -> y | |
| perm_matrix[2, axis_order[1]] = 1 # medium edge -> z | |
| # 5. Apply transformation | |
| aligned_points = np.dot(centered_points, perm_matrix.T) | |
| # 6. Ensure same centroid faces direction | |
| if np.mean(aligned_points[:, 2]) < 0: | |
| aligned_points[:, 2] *= -1 | |
| return aligned_points | |
| def collect_pc(cad_folder): | |
| pc_path = find_files(os.path.join(cad_folder, 'pcd'), 'final_pcd.ply') | |
| if len(pc_path) == 0: | |
| return [] | |
| pc_path = pc_path[-1] # final pcd | |
| pc = read_ply(pc_path) | |
| if pc.shape[0] > N_POINTS: | |
| pc = downsample_pc(pc, N_POINTS) | |
| pc = normalize_pc(pc) | |
| return pc | |
| def collect_pc2(cad_folder): | |
| pc = read_ply(cad_folder) | |
| if pc.shape[0] > N_POINTS: | |
| pc = downsample_pc(pc, N_POINTS) | |
| pc = normalize_pc(pc) | |
| pc = align_pc(pc) | |
| return pc | |
| theta_x = np.radians(90) # Rotation angle around X-axis | |
| theta_y = np.radians(90) # Rotation angle around Y-axis | |
| theta_z = np.radians(180) # Rotation angle around Z-axis | |
| # Create individual rotation matrices | |
| Rx = np.array([[1, 0, 0], | |
| [0, np.cos(theta_x), -np.sin(theta_x)], | |
| [0, np.sin(theta_x), np.cos(theta_x)]]) | |
| Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], | |
| [0, 1, 0], | |
| [-np.sin(theta_y), 0, np.cos(theta_y)]]) | |
| Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], | |
| [np.sin(theta_z), np.cos(theta_z), 0], | |
| [0, 0, 1]]) | |
| rotation_matrix = np.dot(np.dot(Rz, Ry), Rx) | |
| def collect_pc3(cad_folder): | |
| pc = read_ply(cad_folder) | |
| if pc.shape[0] > N_POINTS: | |
| pc = downsample_pc(pc, N_POINTS) | |
| pc = normalize_pc(pc) | |
| rotated_point_cloud = np.dot(pc, rotation_matrix.T).astype(np.float32) # Transpose the rotation matrix to apply it correctly | |
| return rotated_point_cloud | |
| def load_data_with_prefix(root_folder, prefix): | |
| data_files = [] | |
| # Walk through the directory tree starting from the root folder | |
| for root, dirs, files in os.walk(root_folder): | |
| for filename in files: | |
| # Check if the file ends with the specified prefix | |
| if filename.endswith(prefix): | |
| file_path = os.path.join(root, filename) | |
| data_files.append(file_path) | |
| data_files.sort() | |
| return data_files | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--fake", type=str) | |
| parser.add_argument("--real", type=str) | |
| parser.add_argument("--n_test", type=int, default=1000) | |
| parser.add_argument("--multi", type=float, default=3) | |
| parser.add_argument("--times", type=int, default=10) | |
| parser.add_argument("--batch_size", type=int, default=64) | |
| args = parser.parse_args() | |
| seed_everything(0) | |
| print("n_test: {}, multiplier: {}, repeat times: {}".format(args.n_test, args.multi, args.times)) | |
| args.output = args.fake + '_results.txt' | |
| seed_everything(0) | |
| # Load reference pcd | |
| num_cpus = multiprocessing.cpu_count() | |
| ref_pcs = [] | |
| gt_shape_paths = load_data_with_prefix(args.real, '.ply') | |
| load_iter = multiprocessing.Pool(num_cpus).imap(collect_pc2, gt_shape_paths) | |
| for pc in tqdm(load_iter, total=len(gt_shape_paths)): | |
| if len(pc) > 0: | |
| ref_pcs.append(pc) | |
| ref_pcs = np.stack(ref_pcs, axis=0) | |
| print("real point clouds: {}".format(ref_pcs.shape)) | |
| # Load fake pcd | |
| sample_pcs = [] | |
| shape_paths = load_data_with_prefix(args.fake, '.ply') | |
| load_iter = multiprocessing.Pool(num_cpus).imap(collect_pc2, shape_paths) | |
| for pc in tqdm(load_iter, total=len(shape_paths)): | |
| if len(pc) > 0: | |
| sample_pcs.append(pc) | |
| sample_pcs = np.stack(sample_pcs, axis=0) | |
| print("fake point clouds: {}".format(sample_pcs.shape)) | |
| # Testing | |
| cov_on_gt = [] | |
| fp = open(args.output, "w") | |
| result_list = [] | |
| for i in range(args.times): | |
| print("iteration {}...".format(i)) | |
| select_idx1 = random.sample(list(range(len(sample_pcs))), int(args.multi * args.n_test)) | |
| rand_sample_pcs = sample_pcs[select_idx1] | |
| select_idx2 = random.sample(list(range(len(ref_pcs))), args.n_test) | |
| rand_ref_pcs = ref_pcs[select_idx2] | |
| jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False) | |
| with torch.no_grad(): | |
| rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda().float() | |
| rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda().float() | |
| result, idx = compute_cov_mmd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size) | |
| result.update({"JSD": jsd}) | |
| cov_on_gt.extend(list(np.array(select_idx2)[np.unique(idx)])) | |
| if False: | |
| unique_idx = np.unique(idx, return_counts=True) | |
| id_gts = unique_idx[0][np.argsort(unique_idx[1])[::-1][:100]] | |
| gt_prefixes = [os.path.basename(gt_shape_paths[i])[:8] for i in select_idx2] | |
| pred_prefixes = [os.path.basename(shape_paths[i])[:8] for i in select_idx1] | |
| gt_prefixes[403] | |
| print(result) | |
| print(result, file=fp) | |
| result_list.append(result) | |
| avg_result = {} | |
| for k in result_list[0].keys(): | |
| avg_result.update({"avg-" + k: np.mean([x[k] for x in result_list])}) | |
| print("average result:") | |
| print(avg_result) | |
| print(avg_result, file=fp) | |
| fp.close() | |
| cov_on_gt = list(set(cov_on_gt)) | |
| cov_on_gt = [gt_shape_paths[i] for i in cov_on_gt] | |
| np.save(args.fake + '_cov_on_gt.npy', cov_on_gt) | |
| if __name__ == '__main__': | |
| main() | |