Spaces:
Paused
Paused
| from pymilvus import MilvusClient as Client | |
| from pymilvus import FieldSchema, DataType | |
| import json | |
| import logging | |
| from typing import Optional | |
| from open_webui.retrieval.vector.main import ( | |
| VectorDBBase, | |
| VectorItem, | |
| SearchResult, | |
| GetResult, | |
| ) | |
| from open_webui.config import ( | |
| MILVUS_URI, | |
| MILVUS_DB, | |
| MILVUS_TOKEN, | |
| ) | |
| from open_webui.env import SRC_LOG_LEVELS | |
| log = logging.getLogger(__name__) | |
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | |
| class MilvusClient(VectorDBBase): | |
| def __init__(self): | |
| self.collection_prefix = "open_webui" | |
| if MILVUS_TOKEN is None: | |
| self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB) | |
| else: | |
| self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN) | |
| def _result_to_get_result(self, result) -> GetResult: | |
| ids = [] | |
| documents = [] | |
| metadatas = [] | |
| for match in result: | |
| _ids = [] | |
| _documents = [] | |
| _metadatas = [] | |
| for item in match: | |
| _ids.append(item.get("id")) | |
| _documents.append(item.get("data", {}).get("text")) | |
| _metadatas.append(item.get("metadata")) | |
| ids.append(_ids) | |
| documents.append(_documents) | |
| metadatas.append(_metadatas) | |
| return GetResult( | |
| **{ | |
| "ids": ids, | |
| "documents": documents, | |
| "metadatas": metadatas, | |
| } | |
| ) | |
| def _result_to_search_result(self, result) -> SearchResult: | |
| ids = [] | |
| distances = [] | |
| documents = [] | |
| metadatas = [] | |
| for match in result: | |
| _ids = [] | |
| _distances = [] | |
| _documents = [] | |
| _metadatas = [] | |
| for item in match: | |
| _ids.append(item.get("id")) | |
| # normalize milvus score from [-1, 1] to [0, 1] range | |
| # https://milvus.io/docs/de/metric.md | |
| _dist = (item.get("distance") + 1.0) / 2.0 | |
| _distances.append(_dist) | |
| _documents.append(item.get("entity", {}).get("data", {}).get("text")) | |
| _metadatas.append(item.get("entity", {}).get("metadata")) | |
| ids.append(_ids) | |
| distances.append(_distances) | |
| documents.append(_documents) | |
| metadatas.append(_metadatas) | |
| return SearchResult( | |
| **{ | |
| "ids": ids, | |
| "distances": distances, | |
| "documents": documents, | |
| "metadatas": metadatas, | |
| } | |
| ) | |
| def _create_collection(self, collection_name: str, dimension: int): | |
| schema = self.client.create_schema( | |
| auto_id=False, | |
| enable_dynamic_field=True, | |
| ) | |
| schema.add_field( | |
| field_name="id", | |
| datatype=DataType.VARCHAR, | |
| is_primary=True, | |
| max_length=65535, | |
| ) | |
| schema.add_field( | |
| field_name="vector", | |
| datatype=DataType.FLOAT_VECTOR, | |
| dim=dimension, | |
| description="vector", | |
| ) | |
| schema.add_field(field_name="data", datatype=DataType.JSON, description="data") | |
| schema.add_field( | |
| field_name="metadata", datatype=DataType.JSON, description="metadata" | |
| ) | |
| index_params = self.client.prepare_index_params() | |
| index_params.add_index( | |
| field_name="vector", | |
| index_type="HNSW", | |
| metric_type="COSINE", | |
| params={"M": 16, "efConstruction": 100}, | |
| ) | |
| self.client.create_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| schema=schema, | |
| index_params=index_params, | |
| ) | |
| def has_collection(self, collection_name: str) -> bool: | |
| # Check if the collection exists based on the collection name. | |
| collection_name = collection_name.replace("-", "_") | |
| return self.client.has_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ) | |
| def delete_collection(self, collection_name: str): | |
| # Delete the collection based on the collection name. | |
| collection_name = collection_name.replace("-", "_") | |
| return self.client.drop_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ) | |
| def search( | |
| self, collection_name: str, vectors: list[list[float | int]], limit: int | |
| ) -> Optional[SearchResult]: | |
| # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. | |
| collection_name = collection_name.replace("-", "_") | |
| result = self.client.search( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| data=vectors, | |
| limit=limit, | |
| output_fields=["data", "metadata"], | |
| ) | |
| return self._result_to_search_result(result) | |
| def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): | |
| # Construct the filter string for querying | |
| collection_name = collection_name.replace("-", "_") | |
| if not self.has_collection(collection_name): | |
| return None | |
| filter_string = " && ".join( | |
| [ | |
| f'metadata["{key}"] == {json.dumps(value)}' | |
| for key, value in filter.items() | |
| ] | |
| ) | |
| max_limit = 16383 # The maximum number of records per request | |
| all_results = [] | |
| if limit is None: | |
| limit = float("inf") # Use infinity as a placeholder for no limit | |
| # Initialize offset and remaining to handle pagination | |
| offset = 0 | |
| remaining = limit | |
| try: | |
| # Loop until there are no more items to fetch or the desired limit is reached | |
| while remaining > 0: | |
| log.info(f"remaining: {remaining}") | |
| current_fetch = min( | |
| max_limit, remaining | |
| ) # Determine how many items to fetch in this iteration | |
| results = self.client.query( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| filter=filter_string, | |
| output_fields=["*"], | |
| limit=current_fetch, | |
| offset=offset, | |
| ) | |
| if not results: | |
| break | |
| all_results.extend(results) | |
| results_count = len(results) | |
| remaining -= ( | |
| results_count # Decrease remaining by the number of items fetched | |
| ) | |
| offset += results_count | |
| # Break the loop if the results returned are less than the requested fetch count | |
| if results_count < current_fetch: | |
| break | |
| log.debug(all_results) | |
| return self._result_to_get_result([all_results]) | |
| except Exception as e: | |
| log.exception( | |
| f"Error querying collection {collection_name} with limit {limit}: {e}" | |
| ) | |
| return None | |
| def get(self, collection_name: str) -> Optional[GetResult]: | |
| # Get all the items in the collection. | |
| collection_name = collection_name.replace("-", "_") | |
| result = self.client.query( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| filter='id != ""', | |
| ) | |
| return self._result_to_get_result([result]) | |
| def insert(self, collection_name: str, items: list[VectorItem]): | |
| # Insert the items into the collection, if the collection does not exist, it will be created. | |
| collection_name = collection_name.replace("-", "_") | |
| if not self.client.has_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ): | |
| self._create_collection( | |
| collection_name=collection_name, dimension=len(items[0]["vector"]) | |
| ) | |
| return self.client.insert( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| data=[ | |
| { | |
| "id": item["id"], | |
| "vector": item["vector"], | |
| "data": {"text": item["text"]}, | |
| "metadata": item["metadata"], | |
| } | |
| for item in items | |
| ], | |
| ) | |
| def upsert(self, collection_name: str, items: list[VectorItem]): | |
| # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. | |
| collection_name = collection_name.replace("-", "_") | |
| if not self.client.has_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ): | |
| self._create_collection( | |
| collection_name=collection_name, dimension=len(items[0]["vector"]) | |
| ) | |
| return self.client.upsert( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| data=[ | |
| { | |
| "id": item["id"], | |
| "vector": item["vector"], | |
| "data": {"text": item["text"]}, | |
| "metadata": item["metadata"], | |
| } | |
| for item in items | |
| ], | |
| ) | |
| def delete( | |
| self, | |
| collection_name: str, | |
| ids: Optional[list[str]] = None, | |
| filter: Optional[dict] = None, | |
| ): | |
| # Delete the items from the collection based on the ids. | |
| collection_name = collection_name.replace("-", "_") | |
| if ids: | |
| return self.client.delete( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| ids=ids, | |
| ) | |
| elif filter: | |
| # Convert the filter dictionary to a string using JSON_CONTAINS. | |
| filter_string = " && ".join( | |
| [ | |
| f'metadata["{key}"] == {json.dumps(value)}' | |
| for key, value in filter.items() | |
| ] | |
| ) | |
| return self.client.delete( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| filter=filter_string, | |
| ) | |
| def reset(self): | |
| # Resets the database. This will delete all collections and item entries. | |
| collection_names = self.client.list_collections() | |
| for collection_name in collection_names: | |
| if collection_name.startswith(self.collection_prefix): | |
| self.client.drop_collection(collection_name=collection_name) | |