Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from multiprocessing.pool import ThreadPool | |
| import faiss | |
| from typing import List, Tuple | |
| from . import rpc | |
| ############################################################ | |
| # Server implementation | |
| ############################################################ | |
| class SearchServer(rpc.Server): | |
| """ Assign version that can be exposed via RPC """ | |
| def __init__(self, s: int, index: faiss.Index): | |
| rpc.Server.__init__(self, s) | |
| self.index = index | |
| self.index_ivf = faiss.extract_index_ivf(index) | |
| def set_nprobe(self, nprobe: int) -> int: | |
| """ set nprobe field """ | |
| self.index_ivf.nprobe = nprobe | |
| def get_ntotal(self) -> int: | |
| return self.index.ntotal | |
| def __getattr__(self, f): | |
| # all other functions get forwarded to the index | |
| return getattr(self.index, f) | |
| def run_index_server(index: faiss.Index, port: int, v6: bool = False): | |
| """ serve requests for that index forerver """ | |
| rpc.run_server( | |
| lambda s: SearchServer(s, index), | |
| port, v6=v6) | |
| ############################################################ | |
| # Client implementation | |
| ############################################################ | |
| class ClientIndex: | |
| """manages a set of distance sub-indexes. The sub_indexes search a | |
| subset of the inverted lists. Searches are merged afterwards | |
| """ | |
| def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False): | |
| """ connect to a series of (host, port) pairs """ | |
| self.sub_indexes = [] | |
| for machine, port in machine_ports: | |
| self.sub_indexes.append(rpc.Client(machine, port, v6)) | |
| self.ni = len(self.sub_indexes) | |
| # pool of threads. Each thread manages one sub-index. | |
| self.pool = ThreadPool(self.ni) | |
| # test connection... | |
| self.ntotal = self.get_ntotal() | |
| self.verbose = False | |
| def set_nprobe(self, nprobe: int) -> None: | |
| self.pool.map( | |
| lambda idx: idx.set_nprobe(nprobe), | |
| self.sub_indexes | |
| ) | |
| def set_omp_num_threads(self, nt: int) -> None: | |
| self.pool.map( | |
| lambda idx: idx.set_omp_num_threads(nt), | |
| self.sub_indexes | |
| ) | |
| def get_ntotal(self) -> None: | |
| return sum(self.pool.map( | |
| lambda idx: idx.get_ntotal(), | |
| self.sub_indexes | |
| )) | |
| def search(self, x, k: int): | |
| rh = faiss.ResultHeap(x.shape[0], k) | |
| for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes): | |
| rh.add_result(Di, Ii) | |
| rh.finalize() | |
| return rh.D, rh.I | |