Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import time | |
| import pickle | |
| import os | |
| import logging | |
| from multiprocessing.pool import ThreadPool | |
| import threading | |
| import _thread | |
| from queue import Queue | |
| import traceback | |
| import datetime | |
| import numpy as np | |
| import faiss | |
| from faiss.contrib.inspect_tools import get_invlist | |
| class BigBatchSearcher: | |
| """ | |
| Object that manages all the data related to the computation | |
| except the actual within-bucket matching and the organization of the | |
| computation (parallel or not) | |
| """ | |
| def __init__( | |
| self, | |
| index, xq, k, | |
| verbose=0, | |
| use_float16=False): | |
| # verbosity | |
| self.verbose = verbose | |
| self.tictoc = [] | |
| self.xq = xq | |
| self.index = index | |
| self.use_float16 = use_float16 | |
| keep_max = faiss.is_similarity_metric(index.metric_type) | |
| self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max) | |
| self.t_accu = [0] * 6 | |
| self.t_display = self.t0 = time.time() | |
| def start_t_accu(self): | |
| self.t_accu_t0 = time.time() | |
| def stop_t_accu(self, n): | |
| self.t_accu[n] += time.time() - self.t_accu_t0 | |
| def tic(self, name): | |
| self.tictoc = (name, time.time()) | |
| if self.verbose > 0: | |
| print(name, end="\r", flush=True) | |
| def toc(self): | |
| name, t0 = self.tictoc | |
| dt = time.time() - t0 | |
| if self.verbose > 0: | |
| print(f"{name}: {dt:.3f} s") | |
| return dt | |
| def report(self, l): | |
| if self.verbose == 1 or ( | |
| self.verbose == 2 and ( | |
| l > 1000 and time.time() < self.t_display + 1.0 | |
| ) | |
| ): | |
| return | |
| t = time.time() - self.t0 | |
| print( | |
| f"[{t:.1f} s] list {l}/{self.index.nlist} " | |
| f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} " | |
| f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} " | |
| f"wait in {self.t_accu[4]:.3f} " | |
| f"wait out {self.t_accu[5]:.3f} " | |
| f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} " | |
| f"mem {faiss.get_mem_usage_kb()}", | |
| end="\r" if self.verbose <= 2 else "\n", | |
| flush=True, | |
| ) | |
| self.t_display = time.time() | |
| def coarse_quantization(self): | |
| self.tic("coarse quantization") | |
| bs = 65536 | |
| nq = len(self.xq) | |
| q_assign = np.empty((nq, self.index.nprobe), dtype='int32') | |
| for i0 in range(0, nq, bs): | |
| i1 = min(nq, i0 + bs) | |
| q_dis_i, q_assign_i = self.index.quantizer.search( | |
| self.xq[i0:i1], self.index.nprobe) | |
| # q_dis[i0:i1] = q_dis_i | |
| q_assign[i0:i1] = q_assign_i | |
| self.toc() | |
| self.q_assign = q_assign | |
| def reorder_assign(self): | |
| self.tic("bucket sort") | |
| q_assign = self.q_assign | |
| q_assign += 1 # move -1 -> 0 | |
| self.bucket_lims = faiss.matrix_bucket_sort_inplace( | |
| self.q_assign, nbucket=self.index.nlist + 1, nt=16) | |
| self.query_ids = self.q_assign.ravel() | |
| if self.verbose > 0: | |
| print(' number of -1s:', self.bucket_lims[1]) | |
| self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s | |
| del self.q_assign # inplace so let's forget about the old version... | |
| self.toc() | |
| def prepare_bucket(self, l): | |
| """ prepare the queries and database items for bucket l""" | |
| t0 = time.time() | |
| index = self.index | |
| # prepare queries | |
| i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1] | |
| q_subset = self.query_ids[i0:i1] | |
| xq_l = self.xq[q_subset] | |
| if self.by_residual: | |
| xq_l = xq_l - index.quantizer.reconstruct(l) | |
| t1 = time.time() | |
| # prepare database side | |
| list_ids, xb_l = get_invlist(index.invlists, l) | |
| if self.decode_func is None: | |
| xb_l = xb_l.ravel() | |
| else: | |
| xb_l = self.decode_func(xb_l) | |
| if self.use_float16: | |
| xb_l = xb_l.astype('float16') | |
| xq_l = xq_l.astype('float16') | |
| t2 = time.time() | |
| self.t_accu[0] += t1 - t0 | |
| self.t_accu[1] += t2 - t1 | |
| return q_subset, xq_l, list_ids, xb_l | |
| def add_results_to_heap(self, q_subset, D, list_ids, I): | |
| """add the bucket results to the heap structure""" | |
| if D is None: | |
| return | |
| t0 = time.time() | |
| if I is None: | |
| I = list_ids | |
| else: | |
| I = list_ids[I] | |
| self.rh.add_result_subset(q_subset, D, I) | |
| self.t_accu[3] += time.time() - t0 | |
| def sizes_in_checkpoint(self): | |
| return (self.xq.shape, self.index.nprobe, self.index.nlist) | |
| def write_checkpoint(self, fname, completed): | |
| # write to temp file then move to final file | |
| tmpname = fname + ".tmp" | |
| with open(tmpname, "wb") as f: | |
| pickle.dump( | |
| { | |
| "sizes": self.sizes_in_checkpoint(), | |
| "completed": completed, | |
| "rh": (self.rh.D, self.rh.I), | |
| }, f, -1) | |
| os.replace(tmpname, fname) | |
| def read_checkpoint(self, fname): | |
| with open(fname, "rb") as f: | |
| ckp = pickle.load(f) | |
| assert ckp["sizes"] == self.sizes_in_checkpoint() | |
| self.rh.D[:] = ckp["rh"][0] | |
| self.rh.I[:] = ckp["rh"][1] | |
| return ckp["completed"] | |
| class BlockComputer: | |
| """ computation within one bucket """ | |
| def __init__( | |
| self, | |
| index, | |
| method="knn_function", | |
| pairwise_distances=faiss.pairwise_distances, | |
| knn=faiss.knn): | |
| self.index = index | |
| if index.__class__ == faiss.IndexIVFFlat: | |
| index_help = faiss.IndexFlat(index.d, index.metric_type) | |
| decode_func = lambda x: x.view("float32") | |
| by_residual = False | |
| elif index.__class__ == faiss.IndexIVFPQ: | |
| index_help = faiss.IndexPQ( | |
| index.d, index.pq.M, index.pq.nbits, index.metric_type) | |
| index_help.pq = index.pq | |
| decode_func = index_help.pq.decode | |
| index_help.is_trained = True | |
| by_residual = index.by_residual | |
| elif index.__class__ == faiss.IndexIVFScalarQuantizer: | |
| index_help = faiss.IndexScalarQuantizer( | |
| index.d, index.sq.qtype, index.metric_type) | |
| index_help.sq = index.sq | |
| decode_func = index_help.sq.decode | |
| index_help.is_trained = True | |
| by_residual = index.by_residual | |
| else: | |
| raise RuntimeError(f"index type {index.__class__} not supported") | |
| self.index_help = index_help | |
| self.decode_func = None if method == "index" else decode_func | |
| self.by_residual = by_residual | |
| self.method = method | |
| self.pairwise_distances = pairwise_distances | |
| self.knn = knn | |
| def block_search(self, xq_l, xb_l, list_ids, k, **extra_args): | |
| metric_type = self.index.metric_type | |
| if xq_l.size == 0 or xb_l.size == 0: | |
| D = I = None | |
| elif self.method == "index": | |
| faiss.copy_array_to_vector(xb_l, self.index_help.codes) | |
| self.index_help.ntotal = len(list_ids) | |
| D, I = self.index_help.search(xq_l, k) | |
| elif self.method == "pairwise_distances": | |
| # TODO implement blockwise to avoid mem blowup | |
| D = self.pairwise_distances(xq_l, xb_l, metric=metric_type) | |
| I = None | |
| elif self.method == "knn_function": | |
| D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args) | |
| return D, I | |
| def big_batch_search( | |
| index, xq, k, | |
| method="knn_function", | |
| pairwise_distances=faiss.pairwise_distances, | |
| knn=faiss.knn, | |
| verbose=0, | |
| threaded=0, | |
| use_float16=False, | |
| prefetch_threads=1, | |
| computation_threads=1, | |
| q_assign=None, | |
| checkpoint=None, | |
| checkpoint_freq=7200, | |
| start_list=0, | |
| end_list=None, | |
| crash_at=-1 | |
| ): | |
| """ | |
| Search queries xq in the IVF index, with a search function that collects | |
| batches of query vectors per inverted list. This can be faster than the | |
| regular search indexes. | |
| Supports IVFFlat, IVFPQ and IVFScalarQuantizer. | |
| Supports three computation methods: | |
| method = "index": | |
| build a flat index and populate it separately for each index | |
| method = "pairwise_distances": | |
| decompress codes and compute all pairwise distances for the queries | |
| and index and add result to heap | |
| method = "knn_function": | |
| decompress codes and compute knn results for the queries | |
| threaded=0: sequential execution | |
| threaded=1: prefetch next bucket while computing the current one | |
| threaded=2: prefetch prefetch_threads buckets at a time. | |
| compute_threads>1: the knn function will get an additional thread_no that | |
| tells which worker should handle this. | |
| In threaded mode, the computation is tiled with the bucket perparation and | |
| the writeback of results (useful to maximize GPU utilization). | |
| use_float16: convert all matrices to float16 (faster for GPU gemm) | |
| q_assign: override coarse assignment, should be a matrix of size nq * nprobe | |
| checkpointing (only for threaded > 1): | |
| checkpoint: file where the checkpoints are stored | |
| checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded | |
| start_list, end_list: process only a subset of invlists | |
| """ | |
| nprobe = index.nprobe | |
| assert method in ("index", "pairwise_distances", "knn_function") | |
| mem_queries = xq.nbytes | |
| mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize | |
| mem_res = len(xq) * k * ( | |
| np.dtype('int64').itemsize | |
| + np.dtype('float32').itemsize | |
| ) | |
| mem_tot = mem_queries + mem_assign + mem_res | |
| if verbose > 0: | |
| logging.info( | |
| f"memory: queries {mem_queries} assign {mem_assign} " | |
| f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB" | |
| ) | |
| bbs = BigBatchSearcher( | |
| index, xq, k, | |
| verbose=verbose, | |
| use_float16=use_float16 | |
| ) | |
| comp = BlockComputer( | |
| index, | |
| method=method, | |
| pairwise_distances=pairwise_distances, | |
| knn=knn | |
| ) | |
| bbs.decode_func = comp.decode_func | |
| bbs.by_residual = comp.by_residual | |
| if q_assign is None: | |
| bbs.coarse_quantization() | |
| else: | |
| bbs.q_assign = q_assign | |
| bbs.reorder_assign() | |
| if end_list is None: | |
| end_list = index.nlist | |
| completed = set() | |
| if checkpoint is not None: | |
| assert (start_list, end_list) == (0, index.nlist) | |
| if os.path.exists(checkpoint): | |
| logging.info(f"recovering checkpoint: {checkpoint}") | |
| completed = bbs.read_checkpoint(checkpoint) | |
| logging.info(f" already completed: {len(completed)}") | |
| else: | |
| logging.info("no checkpoint: starting from scratch") | |
| if threaded == 0: | |
| # simple sequential version | |
| for l in range(start_list, end_list): | |
| bbs.report(l) | |
| q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l) | |
| t0i = time.time() | |
| D, I = comp.block_search(xq_l, xb_l, list_ids, k) | |
| bbs.t_accu[2] += time.time() - t0i | |
| bbs.add_results_to_heap(q_subset, D, list_ids, I) | |
| elif threaded == 1: | |
| # parallel version with granularity 1 | |
| def add_results_and_prefetch(to_add, l): | |
| """ perform the addition for the previous bucket and | |
| prefetch the next (if applicable) """ | |
| if to_add is not None: | |
| bbs.add_results_to_heap(*to_add) | |
| if l < index.nlist: | |
| return bbs.prepare_bucket(l) | |
| prefetched_bucket = bbs.prepare_bucket(start_list) | |
| to_add = None | |
| pool = ThreadPool(1) | |
| for l in range(start_list, end_list): | |
| bbs.report(l) | |
| prefetched_bucket_a = pool.apply_async( | |
| add_results_and_prefetch, (to_add, l + 1)) | |
| q_subset, xq_l, list_ids, xb_l = prefetched_bucket | |
| bbs.start_t_accu() | |
| D, I = comp.block_search(xq_l, xb_l, list_ids, k) | |
| bbs.stop_t_accu(2) | |
| to_add = q_subset, D, list_ids, I | |
| bbs.start_t_accu() | |
| prefetched_bucket = prefetched_bucket_a.get() | |
| bbs.stop_t_accu(4) | |
| bbs.add_results_to_heap(*to_add) | |
| pool.close() | |
| else: | |
| def task_manager_thread( | |
| task, | |
| pool_size, | |
| start_task, | |
| end_task, | |
| completed, | |
| output_queue, | |
| input_queue, | |
| ): | |
| try: | |
| with ThreadPool(pool_size) as pool: | |
| res = [pool.apply_async( | |
| task, | |
| args=(i, output_queue, input_queue)) | |
| for i in range(start_task, end_task) | |
| if i not in completed] | |
| for r in res: | |
| r.get() | |
| pool.close() | |
| pool.join() | |
| output_queue.put(None) | |
| except: | |
| traceback.print_exc() | |
| _thread.interrupt_main() | |
| raise | |
| def task_manager(*args): | |
| task_manager = threading.Thread( | |
| target=task_manager_thread, | |
| args=args, | |
| ) | |
| task_manager.daemon = True | |
| task_manager.start() | |
| return task_manager | |
| def prepare_task(task_id, output_queue, input_queue=None): | |
| try: | |
| logging.info(f"Prepare start: {task_id}") | |
| q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id) | |
| output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l)) | |
| logging.info(f"Prepare end: {task_id}") | |
| except: | |
| traceback.print_exc() | |
| _thread.interrupt_main() | |
| raise | |
| def compute_task(task_id, output_queue, input_queue): | |
| try: | |
| logging.info(f"Compute start: {task_id}") | |
| t_wait_out = 0 | |
| while True: | |
| t0 = time.time() | |
| logging.info(f'Compute input: task {task_id}') | |
| input_value = input_queue.get() | |
| t_wait_in = time.time() - t0 | |
| if input_value is None: | |
| # signal for other compute tasks | |
| input_queue.put(None) | |
| break | |
| centroid, q_subset, xq_l, list_ids, xb_l = input_value | |
| logging.info(f'Compute work: task {task_id}, centroid {centroid}') | |
| t0 = time.time() | |
| if computation_threads > 1: | |
| D, I = comp.block_search( | |
| xq_l, xb_l, list_ids, k, thread_id=task_id | |
| ) | |
| else: | |
| D, I = comp.block_search(xq_l, xb_l, list_ids, k) | |
| t_compute = time.time() - t0 | |
| logging.info(f'Compute output: task {task_id}, centroid {centroid}') | |
| t0 = time.time() | |
| output_queue.put( | |
| (centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I) | |
| ) | |
| t_wait_out = time.time() - t0 | |
| logging.info(f"Compute end: {task_id}") | |
| except: | |
| traceback.print_exc() | |
| _thread.interrupt_main() | |
| raise | |
| prepare_to_compute_queue = Queue(2) | |
| compute_to_main_queue = Queue(2) | |
| compute_task_manager = task_manager( | |
| compute_task, | |
| computation_threads, | |
| 0, | |
| computation_threads, | |
| set(), | |
| compute_to_main_queue, | |
| prepare_to_compute_queue, | |
| ) | |
| prepare_task_manager = task_manager( | |
| prepare_task, | |
| prefetch_threads, | |
| start_list, | |
| end_list, | |
| completed, | |
| prepare_to_compute_queue, | |
| None, | |
| ) | |
| t_checkpoint = time.time() | |
| while True: | |
| logging.info("Waiting for result") | |
| value = compute_to_main_queue.get() | |
| if not value: | |
| break | |
| centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value | |
| # to test checkpointing | |
| if centroid == crash_at: | |
| 1 / 0 | |
| bbs.t_accu[2] += t_compute | |
| bbs.t_accu[4] += t_wait_in | |
| bbs.t_accu[5] += t_wait_out | |
| logging.info(f"Adding to heap start: centroid {centroid}") | |
| bbs.add_results_to_heap(q_subset, D, list_ids, I) | |
| logging.info(f"Adding to heap end: centroid {centroid}") | |
| completed.add(centroid) | |
| bbs.report(centroid) | |
| if checkpoint is not None: | |
| if time.time() - t_checkpoint > checkpoint_freq: | |
| logging.info("writing checkpoint") | |
| bbs.write_checkpoint(checkpoint, completed) | |
| t_checkpoint = time.time() | |
| prepare_task_manager.join() | |
| compute_task_manager.join() | |
| bbs.tic("finalize heap") | |
| bbs.rh.finalize() | |
| bbs.toc() | |
| return bbs.rh.D, bbs.rh.I | |