Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import numpy as np | |
| import unittest | |
| import time | |
| import faiss | |
| from multiprocessing.pool import ThreadPool | |
| ############################################################### | |
| # Simple functions to evaluate knn results | |
| def knn_intersection_measure(I1, I2): | |
| """ computes the intersection measure of two result tables | |
| """ | |
| nq, rank = I1.shape | |
| assert I2.shape == (nq, rank) | |
| ninter = sum( | |
| np.intersect1d(I1[i], I2[i]).size | |
| for i in range(nq) | |
| ) | |
| return ninter / I1.size | |
| ############################################################### | |
| # Range search results can be compared with Precision-Recall | |
| def filter_range_results(lims, D, I, thresh): | |
| """ select a set of results """ | |
| nq = lims.size - 1 | |
| mask = D < thresh | |
| new_lims = np.zeros_like(lims) | |
| for i in range(nq): | |
| new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum() | |
| return new_lims, D[mask], I[mask] | |
| def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"): | |
| """compute the precision and recall of range search results. The | |
| function does not take the distances into account. """ | |
| def ref_result_for(i): | |
| return Iref[lims_ref[i]:lims_ref[i + 1]] | |
| def new_result_for(i): | |
| return Inew[lims_new[i]:lims_new[i + 1]] | |
| nq = lims_ref.size - 1 | |
| assert lims_new.size - 1 == nq | |
| ninter = np.zeros(nq, dtype="int64") | |
| def compute_PR_for(q): | |
| # ground truth results for this query | |
| gt_ids = ref_result_for(q) | |
| # results for this query | |
| new_ids = new_result_for(q) | |
| # there are no set functions in numpy so let's do this | |
| inter = np.intersect1d(gt_ids, new_ids) | |
| ninter[q] = len(inter) | |
| # run in a thread pool, which helps in spite of the GIL | |
| pool = ThreadPool(20) | |
| pool.map(compute_PR_for, range(nq)) | |
| return counts_to_PR( | |
| lims_ref[1:] - lims_ref[:-1], | |
| lims_new[1:] - lims_new[:-1], | |
| ninter, | |
| mode=mode | |
| ) | |
| def counts_to_PR(ngt, nres, ninter, mode="overall"): | |
| """ computes a precision-recall for a ser of queries. | |
| ngt = nb of GT results per query | |
| nres = nb of found results per query | |
| ninter = nb of correct results per query (smaller than nres of course) | |
| """ | |
| if mode == "overall": | |
| ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum() | |
| if nres > 0: | |
| precision = ninter / nres | |
| else: | |
| precision = 1.0 | |
| if ngt > 0: | |
| recall = ninter / ngt | |
| elif nres == 0: | |
| recall = 1.0 | |
| else: | |
| recall = 0.0 | |
| return precision, recall | |
| elif mode == "average": | |
| # average precision and recall over queries | |
| mask = ngt == 0 | |
| ngt[mask] = 1 | |
| recalls = ninter / ngt | |
| recalls[mask] = (nres[mask] == 0).astype(float) | |
| # avoid division by 0 | |
| mask = nres == 0 | |
| assert np.all(ninter[mask] == 0) | |
| ninter[mask] = 1 | |
| nres[mask] = 1 | |
| precisions = ninter / nres | |
| return precisions.mean(), recalls.mean() | |
| else: | |
| raise AssertionError() | |
| def sort_range_res_2(lims, D, I): | |
| """ sort 2 arrays using the first as key """ | |
| I2 = np.empty_like(I) | |
| D2 = np.empty_like(D) | |
| nq = len(lims) - 1 | |
| for i in range(nq): | |
| l0, l1 = lims[i], lims[i + 1] | |
| ii = I[l0:l1] | |
| di = D[l0:l1] | |
| o = di.argsort() | |
| I2[l0:l1] = ii[o] | |
| D2[l0:l1] = di[o] | |
| return I2, D2 | |
| def sort_range_res_1(lims, I): | |
| I2 = np.empty_like(I) | |
| nq = len(lims) - 1 | |
| for i in range(nq): | |
| l0, l1 = lims[i], lims[i + 1] | |
| I2[l0:l1] = I[l0:l1] | |
| I2[l0:l1].sort() | |
| return I2 | |
| def range_PR_multiple_thresholds( | |
| lims_ref, Iref, | |
| lims_new, Dnew, Inew, | |
| thresholds, | |
| mode="overall", do_sort="ref,new" | |
| ): | |
| """ compute precision-recall values for range search results | |
| for several thresholds on the "new" results. | |
| This is to plot PR curves | |
| """ | |
| # ref should be sorted by ids | |
| if "ref" in do_sort: | |
| Iref = sort_range_res_1(lims_ref, Iref) | |
| # new should be sorted by distances | |
| if "new" in do_sort: | |
| Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew) | |
| def ref_result_for(i): | |
| return Iref[lims_ref[i]:lims_ref[i + 1]] | |
| def new_result_for(i): | |
| l0, l1 = lims_new[i], lims_new[i + 1] | |
| return Inew[l0:l1], Dnew[l0:l1] | |
| nq = lims_ref.size - 1 | |
| assert lims_new.size - 1 == nq | |
| nt = len(thresholds) | |
| counts = np.zeros((nq, nt, 3), dtype="int64") | |
| def compute_PR_for(q): | |
| gt_ids = ref_result_for(q) | |
| res_ids, res_dis = new_result_for(q) | |
| counts[q, :, 0] = len(gt_ids) | |
| if res_dis.size == 0: | |
| # the rest remains at 0 | |
| return | |
| # which offsets we are interested in | |
| nres= np.searchsorted(res_dis, thresholds) | |
| counts[q, :, 1] = nres | |
| if gt_ids.size == 0: | |
| return | |
| # find number of TPs at each stage in the result list | |
| ii = np.searchsorted(gt_ids, res_ids) | |
| ii[ii == len(gt_ids)] = -1 | |
| n_ok = np.cumsum(gt_ids[ii] == res_ids) | |
| # focus on threshold points | |
| n_ok = np.hstack(([0], n_ok)) | |
| counts[q, :, 2] = n_ok[nres] | |
| pool = ThreadPool(20) | |
| pool.map(compute_PR_for, range(nq)) | |
| # print(counts.transpose(2, 1, 0)) | |
| precisions = np.zeros(nt) | |
| recalls = np.zeros(nt) | |
| for t in range(nt): | |
| p, r = counts_to_PR( | |
| counts[:, t, 0], counts[:, t, 1], counts[:, t, 2], | |
| mode=mode | |
| ) | |
| precisions[t] = p | |
| recalls[t] = r | |
| return precisions, recalls | |
| ############################################################### | |
| # Functions that compare search results with a reference result. | |
| # They are intended for use in tests | |
| def _cluster_tables_with_tolerance(tab1, tab2, thr): | |
| """ for two tables, cluster them by merging values closer than thr. | |
| Returns the cluster ids for each table element """ | |
| tab = np.hstack([tab1, tab2]) | |
| tab.sort() | |
| n = len(tab) | |
| diffs = np.ones(n) | |
| diffs[1:] = tab[1:] - tab[:-1] | |
| unique_vals = tab[diffs > thr] | |
| idx1 = np.searchsorted(unique_vals, tab1, side='right') - 1 | |
| idx2 = np.searchsorted(unique_vals, tab2, side='right') - 1 | |
| return idx1, idx2 | |
| def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5): | |
| """ test that knn search results are identical, with possible ties. | |
| Raise if not. """ | |
| np.testing.assert_allclose(Dref, Dnew, rtol=rtol) | |
| # here we have to be careful because of draws | |
| testcase = unittest.TestCase() # because it makes nice error messages | |
| for i in range(len(Iref)): | |
| if np.all(Iref[i] == Inew[i]): # easy case | |
| continue | |
| # otherwise collect elements per distance | |
| r = rtol * Dref[i].max() | |
| DrefC, DnewC = _cluster_tables_with_tolerance(Dref[i], Dnew[i], r) | |
| for dis in np.unique(DrefC): | |
| if dis == DrefC[-1]: | |
| continue | |
| mask = DrefC == dis | |
| testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask])) | |
| def check_ref_range_results(Lref, Dref, Iref, | |
| Lnew, Dnew, Inew): | |
| """ compare range search results wrt. a reference result, | |
| throw if it fails """ | |
| np.testing.assert_array_equal(Lref, Lnew) | |
| nq = len(Lref) - 1 | |
| for i in range(nq): | |
| l0, l1 = Lref[i], Lref[i + 1] | |
| Ii_ref = Iref[l0:l1] | |
| Ii_new = Inew[l0:l1] | |
| Di_ref = Dref[l0:l1] | |
| Di_new = Dnew[l0:l1] | |
| if np.all(Ii_ref == Ii_new): # easy | |
| pass | |
| else: | |
| def sort_by_ids(I, D): | |
| o = I.argsort() | |
| return I[o], D[o] | |
| # sort both | |
| (Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref) | |
| (Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new) | |
| np.testing.assert_array_equal(Ii_ref, Ii_new) | |
| np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5) | |
| ############################################################### | |
| # OperatingPoints functions | |
| # this is the Python version of the AutoTune object in C++ | |
| class OperatingPoints: | |
| """ | |
| Manages a set of search parameters with associated performance and time. | |
| Keeps the Pareto optimal points. | |
| """ | |
| def __init__(self): | |
| # list of (key, perf, t) | |
| self.operating_points = [ | |
| # (self.do_nothing_key(), 0.0, 0.0) | |
| ] | |
| self.suboptimal_points = [] | |
| def compare_keys(self, k1, k2): | |
| """ return -1 if k1 > k2, 1 if k2 > k1, 0 otherwise """ | |
| raise NotImplemented | |
| def do_nothing_key(self): | |
| """ parameters to say we do noting, takes 0 time and has 0 performance""" | |
| raise NotImplemented | |
| def is_pareto_optimal(self, perf_new, t_new): | |
| for _, perf, t in self.operating_points: | |
| if perf >= perf_new and t <= t_new: | |
| return False | |
| return True | |
| def predict_bounds(self, key): | |
| """ predicts the bound on time and performance """ | |
| min_time = 0.0 | |
| max_perf = 1.0 | |
| for key2, perf, t in self.operating_points + self.suboptimal_points: | |
| cmp = self.compare_keys(key, key2) | |
| if cmp > 0: # key2 > key | |
| if t > min_time: | |
| min_time = t | |
| if cmp < 0: # key2 < key | |
| if perf < max_perf: | |
| max_perf = perf | |
| return max_perf, min_time | |
| def should_run_experiment(self, key): | |
| (max_perf, min_time) = self.predict_bounds(key) | |
| return self.is_pareto_optimal(max_perf, min_time) | |
| def add_operating_point(self, key, perf, t): | |
| if self.is_pareto_optimal(perf, t): | |
| i = 0 | |
| # maybe it shadows some other operating point completely? | |
| while i < len(self.operating_points): | |
| op_Ls, perf2, t2 = self.operating_points[i] | |
| if perf >= perf2 and t < t2: | |
| self.suboptimal_points.append( | |
| self.operating_points.pop(i)) | |
| else: | |
| i += 1 | |
| self.operating_points.append((key, perf, t)) | |
| return True | |
| else: | |
| self.suboptimal_points.append((key, perf, t)) | |
| return False | |
| class OperatingPointsWithRanges(OperatingPoints): | |
| """ | |
| Set of parameters that are each picked from a discrete range of values. | |
| An increase of each parameter is assumed to make the operation slower | |
| and more accurate. | |
| A key = int array of indices in the ordered set of parameters. | |
| """ | |
| def __init__(self): | |
| OperatingPoints.__init__(self) | |
| # list of (name, values) | |
| self.ranges = [] | |
| def add_range(self, name, values): | |
| self.ranges.append((name, values)) | |
| def compare_keys(self, k1, k2): | |
| if np.all(k1 >= k2): | |
| return 1 | |
| if np.all(k2 >= k1): | |
| return -1 | |
| return 0 | |
| def do_nothing_key(self): | |
| return np.zeros(len(self.ranges), dtype=int) | |
| def num_experiments(self): | |
| return int(np.prod([len(values) for name, values in self.ranges])) | |
| def sample_experiments(self, n_autotune, rs=np.random): | |
| """ sample a set of experiments of max size n_autotune | |
| (run all experiments in random order if n_autotune is 0) | |
| """ | |
| assert n_autotune == 0 or n_autotune >= 2 | |
| totex = self.num_experiments() | |
| rs = np.random.RandomState(123) | |
| if n_autotune == 0 or totex < n_autotune: | |
| experiments = rs.permutation(totex - 2) | |
| else: | |
| experiments = rs.choice( | |
| totex - 2, size=n_autotune - 2, replace=False) | |
| experiments = [0, totex - 1] + [int(cno) + 1 for cno in experiments] | |
| return experiments | |
| def cno_to_key(self, cno): | |
| """Convert a sequential experiment number to a key""" | |
| k = np.zeros(len(self.ranges), dtype=int) | |
| for i, (name, values) in enumerate(self.ranges): | |
| k[i] = cno % len(values) | |
| cno //= len(values) | |
| assert cno == 0 | |
| return k | |
| def get_parameters(self, k): | |
| """Convert a key to a dictionary with parameter values""" | |
| return { | |
| name: values[k[i]] | |
| for i, (name, values) in enumerate(self.ranges) | |
| } | |
| def restrict_range(self, name, max_val): | |
| """ remove too large values from a range""" | |
| for name2, values in self.ranges: | |
| if name == name2: | |
| val2 = [v for v in values if v < max_val] | |
| values[:] = val2 | |
| return | |
| raise RuntimeError(f"parameter {name} not found") | |
| ############################################################### | |
| # Timer object | |
| class TimerIter: | |
| def __init__(self, timer): | |
| self.ts = [] | |
| self.runs = timer.runs | |
| self.timer = timer | |
| if timer.nt >= 0: | |
| faiss.omp_set_num_threads(timer.nt) | |
| def __next__(self): | |
| timer = self.timer | |
| self.runs -= 1 | |
| self.ts.append(time.time()) | |
| total_time = self.ts[-1] - self.ts[0] if len(self.ts) >= 2 else 0 | |
| if self.runs == -1 or total_time > timer.max_secs: | |
| if timer.nt >= 0: | |
| faiss.omp_set_num_threads(timer.remember_nt) | |
| ts = np.array(self.ts) | |
| times = ts[1:] - ts[:-1] | |
| if len(times) == timer.runs: | |
| timer.times = times[timer.warmup :] | |
| else: | |
| # if timeout, we use all the runs | |
| timer.times = times[:] | |
| raise StopIteration | |
| class RepeatTimer: | |
| """ | |
| This is yet another timer object. It is adapted to Faiss by | |
| taking a number of openmp threads to set on input. It should be called | |
| in an explicit loop as: | |
| timer = RepeatTimer(warmup=1, nt=1, runs=6) | |
| for _ in timer: | |
| # perform operation | |
| print(f"time={timer.get_ms():.1f} ± {timer.get_ms_std():.1f} ms") | |
| the same timer can be re-used. In that case it is reset each time it | |
| enters a loop. It focuses on ms-scale times because for second scale | |
| it's usually less relevant to repeat the operation. | |
| """ | |
| def __init__(self, warmup=0, nt=-1, runs=1, max_secs=np.inf): | |
| assert warmup < runs | |
| self.warmup = warmup | |
| self.nt = nt | |
| self.runs = runs | |
| self.max_secs = max_secs | |
| self.remember_nt = faiss.omp_get_max_threads() | |
| def __iter__(self): | |
| return TimerIter(self) | |
| def ms(self): | |
| return np.mean(self.times) * 1000 | |
| def ms_std(self): | |
| return np.std(self.times) * 1000 if len(self.times) > 1 else 0.0 | |
| def nruns(self): | |
| """ effective number of runs (may be lower than runs - warmup due to timeout)""" | |
| return len(self.times) | |