Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import numpy as np | |
| import faiss | |
| from .vecs_io import fvecs_read, ivecs_read, bvecs_mmap, fvecs_mmap | |
| from .exhaustive_search import knn | |
| class Dataset: | |
| """ Generic abstract class for a test dataset """ | |
| def __init__(self): | |
| """ the constructor should set the following fields: """ | |
| self.d = -1 | |
| self.metric = 'L2' # or IP | |
| self.nq = -1 | |
| self.nb = -1 | |
| self.nt = -1 | |
| def get_queries(self): | |
| """ return the queries as a (nq, d) array """ | |
| raise NotImplementedError() | |
| def get_train(self, maxtrain=None): | |
| """ return the queries as a (nt, d) array """ | |
| raise NotImplementedError() | |
| def get_database(self): | |
| """ return the queries as a (nb, d) array """ | |
| raise NotImplementedError() | |
| def database_iterator(self, bs=128, split=(1, 0)): | |
| """returns an iterator on database vectors. | |
| bs is the number of vectors per batch | |
| split = (nsplit, rank) means the dataset is split in nsplit | |
| shards and we want shard number rank | |
| The default implementation just iterates over the full matrix | |
| returned by get_dataset. | |
| """ | |
| xb = self.get_database() | |
| nsplit, rank = split | |
| i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit | |
| for j0 in range(i0, i1, bs): | |
| yield xb[j0: min(j0 + bs, i1)] | |
| def get_groundtruth(self, k=None): | |
| """ return the ground truth for k-nearest neighbor search """ | |
| raise NotImplementedError() | |
| def get_groundtruth_range(self, thresh=None): | |
| """ return the ground truth for range search """ | |
| raise NotImplementedError() | |
| def __str__(self): | |
| return (f"dataset in dimension {self.d}, with metric {self.metric}, " | |
| f"size: Q {self.nq} B {self.nb} T {self.nt}") | |
| def check_sizes(self): | |
| """ runs the previous and checks the sizes of the matrices """ | |
| assert self.get_queries().shape == (self.nq, self.d) | |
| if self.nt > 0: | |
| xt = self.get_train(maxtrain=123) | |
| assert xt.shape == (123, self.d), "shape=%s" % (xt.shape, ) | |
| assert self.get_database().shape == (self.nb, self.d) | |
| assert self.get_groundtruth(k=13).shape == (self.nq, 13) | |
| class SyntheticDataset(Dataset): | |
| """A dataset that is not completely random but still challenging to | |
| index | |
| """ | |
| def __init__(self, d, nt, nb, nq, metric='L2', seed=1338): | |
| Dataset.__init__(self) | |
| self.d, self.nt, self.nb, self.nq = d, nt, nb, nq | |
| d1 = 10 # intrinsic dimension (more or less) | |
| n = nb + nt + nq | |
| rs = np.random.RandomState(seed) | |
| x = rs.normal(size=(n, d1)) | |
| x = np.dot(x, rs.rand(d1, d)) | |
| # now we have a d1-dim ellipsoid in d-dimensional space | |
| # higher factor (>4) -> higher frequency -> less linear | |
| x = x * (rs.rand(d) * 4 + 0.1) | |
| x = np.sin(x) | |
| x = x.astype('float32') | |
| self.metric = metric | |
| self.xt = x[:nt] | |
| self.xb = x[nt:nt + nb] | |
| self.xq = x[nt + nb:] | |
| def get_queries(self): | |
| return self.xq | |
| def get_train(self, maxtrain=None): | |
| maxtrain = maxtrain if maxtrain is not None else self.nt | |
| return self.xt[:maxtrain] | |
| def get_database(self): | |
| return self.xb | |
| def get_groundtruth(self, k=100): | |
| return knn( | |
| self.xq, self.xb, k, | |
| faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT | |
| )[1] | |
| ############################################################################ | |
| # The following datasets are a few standard open-source datasets | |
| # they should be stored in a directory, and we start by guessing where | |
| # that directory is | |
| ############################################################################ | |
| for dataset_basedir in ( | |
| '/datasets01/simsearch/041218/', | |
| '/mnt/vol/gfsai-flash3-east/ai-group/datasets/simsearch/'): | |
| if os.path.exists(dataset_basedir): | |
| break | |
| else: | |
| # users can link their data directory to `./data` | |
| dataset_basedir = 'data/' | |
| class DatasetSIFT1M(Dataset): | |
| """ | |
| The original dataset is available at: http://corpus-texmex.irisa.fr/ | |
| (ANN_SIFT1M) | |
| """ | |
| def __init__(self): | |
| Dataset.__init__(self) | |
| self.d, self.nt, self.nb, self.nq = 128, 100000, 1000000, 10000 | |
| self.basedir = dataset_basedir + 'sift1M/' | |
| def get_queries(self): | |
| return fvecs_read(self.basedir + "sift_query.fvecs") | |
| def get_train(self, maxtrain=None): | |
| maxtrain = maxtrain if maxtrain is not None else self.nt | |
| return fvecs_read(self.basedir + "sift_learn.fvecs")[:maxtrain] | |
| def get_database(self): | |
| return fvecs_read(self.basedir + "sift_base.fvecs") | |
| def get_groundtruth(self, k=None): | |
| gt = ivecs_read(self.basedir + "sift_groundtruth.ivecs") | |
| if k is not None: | |
| assert k <= 100 | |
| gt = gt[:, :k] | |
| return gt | |
| def sanitize(x): | |
| return np.ascontiguousarray(x, dtype='float32') | |
| class DatasetBigANN(Dataset): | |
| """ | |
| The original dataset is available at: http://corpus-texmex.irisa.fr/ | |
| (ANN_SIFT1B) | |
| """ | |
| def __init__(self, nb_M=1000): | |
| Dataset.__init__(self) | |
| assert nb_M in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000) | |
| self.nb_M = nb_M | |
| nb = nb_M * 10**6 | |
| self.d, self.nt, self.nb, self.nq = 128, 10**8, nb, 10000 | |
| self.basedir = dataset_basedir + 'bigann/' | |
| def get_queries(self): | |
| return sanitize(bvecs_mmap(self.basedir + 'bigann_query.bvecs')[:]) | |
| def get_train(self, maxtrain=None): | |
| maxtrain = maxtrain if maxtrain is not None else self.nt | |
| return sanitize(bvecs_mmap(self.basedir + 'bigann_learn.bvecs')[:maxtrain]) | |
| def get_groundtruth(self, k=None): | |
| gt = ivecs_read(self.basedir + 'gnd/idx_%dM.ivecs' % self.nb_M) | |
| if k is not None: | |
| assert k <= 100 | |
| gt = gt[:, :k] | |
| return gt | |
| def get_database(self): | |
| assert self.nb_M < 100, "dataset too large, use iterator" | |
| return sanitize(bvecs_mmap(self.basedir + 'bigann_base.bvecs')[:self.nb]) | |
| def database_iterator(self, bs=128, split=(1, 0)): | |
| xb = bvecs_mmap(self.basedir + 'bigann_base.bvecs') | |
| nsplit, rank = split | |
| i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit | |
| for j0 in range(i0, i1, bs): | |
| yield sanitize(xb[j0: min(j0 + bs, i1)]) | |
| class DatasetDeep1B(Dataset): | |
| """ | |
| See | |
| https://github.com/facebookresearch/faiss/tree/main/benchs#getting-deep1b | |
| on how to get the data | |
| """ | |
| def __init__(self, nb=10**9): | |
| Dataset.__init__(self) | |
| nb_to_name = { | |
| 10**5: '100k', | |
| 10**6: '1M', | |
| 10**7: '10M', | |
| 10**8: '100M', | |
| 10**9: '1B' | |
| } | |
| assert nb in nb_to_name | |
| self.d, self.nt, self.nb, self.nq = 96, 358480000, nb, 10000 | |
| self.basedir = dataset_basedir + 'deep1b/' | |
| self.gt_fname = "%sdeep%s_groundtruth.ivecs" % ( | |
| self.basedir, nb_to_name[self.nb]) | |
| def get_queries(self): | |
| return sanitize(fvecs_read(self.basedir + "deep1B_queries.fvecs")) | |
| def get_train(self, maxtrain=None): | |
| maxtrain = maxtrain if maxtrain is not None else self.nt | |
| return sanitize(fvecs_mmap(self.basedir + "learn.fvecs")[:maxtrain]) | |
| def get_groundtruth(self, k=None): | |
| gt = ivecs_read(self.gt_fname) | |
| if k is not None: | |
| assert k <= 100 | |
| gt = gt[:, :k] | |
| return gt | |
| def get_database(self): | |
| assert self.nb <= 10**8, "dataset too large, use iterator" | |
| return sanitize(fvecs_mmap(self.basedir + "base.fvecs")[:self.nb]) | |
| def database_iterator(self, bs=128, split=(1, 0)): | |
| xb = fvecs_mmap(self.basedir + "base.fvecs") | |
| nsplit, rank = split | |
| i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit | |
| for j0 in range(i0, i1, bs): | |
| yield sanitize(xb[j0: min(j0 + bs, i1)]) | |
| class DatasetGlove(Dataset): | |
| """ | |
| Data from http://ann-benchmarks.com/glove-100-angular.hdf5 | |
| """ | |
| def __init__(self, loc=None, download=False): | |
| import h5py | |
| assert not download, "not implemented" | |
| if not loc: | |
| loc = dataset_basedir + 'glove/glove-100-angular.hdf5' | |
| self.glove_h5py = h5py.File(loc, 'r') | |
| # IP and L2 are equivalent in this case, but it is traditionally seen as an IP dataset | |
| self.metric = 'IP' | |
| self.d, self.nt = 100, 0 | |
| self.nb = self.glove_h5py['train'].shape[0] | |
| self.nq = self.glove_h5py['test'].shape[0] | |
| def get_queries(self): | |
| xq = np.array(self.glove_h5py['test']) | |
| faiss.normalize_L2(xq) | |
| return xq | |
| def get_database(self): | |
| xb = np.array(self.glove_h5py['train']) | |
| faiss.normalize_L2(xb) | |
| return xb | |
| def get_groundtruth(self, k=None): | |
| gt = self.glove_h5py['neighbors'] | |
| if k is not None: | |
| assert k <= 100 | |
| gt = gt[:, :k] | |
| return gt | |
| class DatasetMusic100(Dataset): | |
| """ | |
| get dataset from | |
| https://github.com/stanis-morozov/ip-nsw#dataset | |
| """ | |
| def __init__(self): | |
| Dataset.__init__(self) | |
| self.d, self.nt, self.nb, self.nq = 100, 0, 10**6, 10000 | |
| self.metric = 'IP' | |
| self.basedir = dataset_basedir + 'music-100/' | |
| def get_queries(self): | |
| xq = np.fromfile(self.basedir + 'query_music100.bin', dtype='float32') | |
| xq = xq.reshape(-1, 100) | |
| return xq | |
| def get_database(self): | |
| xb = np.fromfile(self.basedir + 'database_music100.bin', dtype='float32') | |
| xb = xb.reshape(-1, 100) | |
| return xb | |
| def get_groundtruth(self, k=None): | |
| gt = np.load(self.basedir + 'gt.npy') | |
| if k is not None: | |
| assert k <= 100 | |
| gt = gt[:, :k] | |
| return gt | |
| class DatasetGIST1M(Dataset): | |
| """ | |
| The original dataset is available at: http://corpus-texmex.irisa.fr/ | |
| (ANN_SIFT1M) | |
| """ | |
| def __init__(self): | |
| Dataset.__init__(self) | |
| self.d, self.nt, self.nb, self.nq = 960, 100000, 1000000, 10000 | |
| self.basedir = dataset_basedir + 'gist1M/' | |
| def get_queries(self): | |
| return fvecs_read(self.basedir + "gist_query.fvecs") | |
| def get_train(self, maxtrain=None): | |
| maxtrain = maxtrain if maxtrain is not None else self.nt | |
| return fvecs_read(self.basedir + "gist_learn.fvecs")[:maxtrain] | |
| def get_database(self): | |
| return fvecs_read(self.basedir + "gist_base.fvecs") | |
| def get_groundtruth(self, k=None): | |
| gt = ivecs_read(self.basedir + "gist_groundtruth.ivecs") | |
| if k is not None: | |
| assert k <= 100 | |
| gt = gt[:, :k] | |
| return gt | |
| def dataset_from_name(dataset='deep1M', download=False): | |
| """ converts a string describing a dataset to a Dataset object | |
| Supports sift1M, bigann1M..bigann1B, deep1M..deep1B, music-100 and glove | |
| """ | |
| if dataset == 'sift1M': | |
| return DatasetSIFT1M() | |
| elif dataset == 'gist1M': | |
| return DatasetGIST1M() | |
| elif dataset.startswith('bigann'): | |
| dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1]) | |
| return DatasetBigANN(nb_M=dbsize) | |
| elif dataset.startswith("deep"): | |
| szsuf = dataset[4:] | |
| if szsuf[-1] == 'M': | |
| dbsize = 10 ** 6 * int(szsuf[:-1]) | |
| elif szsuf == '1B': | |
| dbsize = 10 ** 9 | |
| elif szsuf[-1] == 'k': | |
| dbsize = 1000 * int(szsuf[:-1]) | |
| else: | |
| assert False, "did not recognize suffix " + szsuf | |
| return DatasetDeep1B(nb=dbsize) | |
| elif dataset == "music-100": | |
| return DatasetMusic100() | |
| elif dataset == "glove": | |
| return DatasetGlove(download=download) | |
| else: | |
| raise RuntimeError("unknown dataset " + dataset) | |