mtyrrell commited on
Commit
a20c6b6
·
1 Parent(s): 51e7715

updated for qdrant cloud data source

Browse files
Files changed (4) hide show
  1. app/retriever.py +7 -1
  2. app/utils.py +92 -1
  3. app/vectorstore_interface.py +93 -33
  4. params.cfg +10 -10
app/retriever.py CHANGED
@@ -4,7 +4,7 @@ from langchain.schema import Document
4
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
5
  from langchain.retrievers.document_compressors import CrossEncoderReranker
6
  from .utils import getconfig
7
- from .vectorstore_interface import create_vectorstore, VectorStoreInterface
8
  import logging
9
 
10
  # Load configuration
@@ -198,6 +198,12 @@ def get_context(
198
  "model_name": config.get("embeddings", "MODEL_NAME")
199
  }
200
 
 
 
 
 
 
 
201
  # Perform initial retrieval
202
  retrieved_docs = vectorstore.search(query, top_k, **search_kwargs)
203
 
 
4
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
5
  from langchain.retrievers.document_compressors import CrossEncoderReranker
6
  from .utils import getconfig
7
+ from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore
8
  import logging
9
 
10
  # Load configuration
 
198
  "model_name": config.get("embeddings", "MODEL_NAME")
199
  }
200
 
201
+ # filter support for QdrantVectorStore
202
+ if isinstance(vectorstore, QdrantVectorStore):
203
+ filter_obj = create_filter(reports, sources, subtype, year)
204
+ if filter_obj:
205
+ search_kwargs["filter"] = filter_obj
206
+
207
  # Perform initial retrieval
208
  retrieved_docs = vectorstore.search(query, top_k, **search_kwargs)
209
 
app/utils.py CHANGED
@@ -1,5 +1,12 @@
1
  import configparser
2
  import logging
 
 
 
 
 
 
 
3
 
4
  def getconfig(configfile_path: str):
5
  """
@@ -13,4 +20,88 @@ def getconfig(configfile_path: str):
13
  config.read_file(open(configfile_path))
14
  return config
15
  except:
16
- logging.warning("config file not found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import configparser
2
  import logging
3
+ import os
4
+ import ast
5
+ import re
6
+ from dotenv import load_dotenv
7
+
8
+ # Local .env file
9
+ load_dotenv()
10
 
11
  def getconfig(configfile_path: str):
12
  """
 
20
  config.read_file(open(configfile_path))
21
  return config
22
  except:
23
+ logging.warning("config file not found")
24
+
25
+
26
+ def get_auth(provider: str) -> dict:
27
+ """Get authentication configuration for different providers"""
28
+ auth_configs = {
29
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
30
+ "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
31
+ }
32
+
33
+ provider = provider.lower() # Normalize to lowercase
34
+
35
+ if provider not in auth_configs:
36
+ raise ValueError(f"Unsupported provider: {provider}")
37
+
38
+ auth_config = auth_configs[provider]
39
+ api_key = auth_config.get("api_key")
40
+
41
+ if not api_key:
42
+ logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.")
43
+ auth_config["api_key"] = None
44
+
45
+ return auth_config
46
+
47
+
48
+ def process_content(content: str) -> str:
49
+ """
50
+ Process and clean malformed content that may contain stringified nested lists.
51
+ The test DB on qdrant somehow got a bit malformed in the processing - but probably good to have this anyway
52
+
53
+ Args:
54
+ content: Raw content from vector store
55
+
56
+ Returns:
57
+ Cleaned, readable text content
58
+ """
59
+ if not content:
60
+ return content
61
+
62
+ # Check if content looks like a stringified list/nested structure
63
+ content_stripped = content.strip()
64
+ if content_stripped.startswith('[') and content_stripped.endswith(']'):
65
+ try:
66
+ # Parse as literal list structure
67
+ parsed_content = ast.literal_eval(content_stripped)
68
+
69
+ if isinstance(parsed_content, list):
70
+ # Flatten nested lists and extract meaningful text
71
+ def extract_text_from_nested(obj):
72
+ if isinstance(obj, list):
73
+ text_items = []
74
+ for item in obj:
75
+ extracted = extract_text_from_nested(item)
76
+ if extracted and extracted.strip():
77
+ text_items.append(extracted)
78
+ return ' '.join(text_items)
79
+ elif isinstance(obj, str) and obj.strip():
80
+ return obj.strip()
81
+ elif isinstance(obj, dict):
82
+ # Handle dict structures if present
83
+ text_items = []
84
+ for key, value in obj.items():
85
+ if isinstance(value, str) and value.strip():
86
+ text_items.append(f"{key}: {value}")
87
+ return ' '.join(text_items)
88
+ else:
89
+ return ''
90
+
91
+ extracted_text = extract_text_from_nested(parsed_content)
92
+
93
+ if extracted_text and len(extracted_text.strip()) > 0:
94
+ # Clean up extra whitespace and format nicely
95
+ cleaned_text = re.sub(r'\s+', ' ', extracted_text).strip()
96
+ logging.debug(f"Successfully processed nested list content: {len(cleaned_text)} chars")
97
+ return cleaned_text
98
+ else:
99
+ logging.warning("Parsed list content but no meaningful text found")
100
+ return content # Return original if no meaningful text extracted
101
+
102
+ except (ValueError, SyntaxError) as e:
103
+ logging.debug(f"Content looks like list but failed to parse: {e}")
104
+ # Fall through to return original content
105
+
106
+ # For regular text content, just clean up whitespace
107
+ return re.sub(r'\s+', ' ', content).strip()
app/vectorstore_interface.py CHANGED
@@ -2,11 +2,11 @@ from abc import ABC, abstractmethod
2
  from typing import List, Dict, Any, Optional
3
  from gradio_client import Client
4
  import logging
5
- import os
6
- import time
7
  from dotenv import load_dotenv
 
8
  load_dotenv()
9
 
 
10
  class VectorStoreInterface(ABC):
11
  """Abstract interface for different vector store implementations."""
12
 
@@ -15,17 +15,17 @@ class VectorStoreInterface(ABC):
15
  """Search for similar documents."""
16
  pass
17
 
 
18
  class HuggingFaceSpacesVectorStore(VectorStoreInterface):
19
  """Vector store implementation for Hugging Face Spaces with MCP endpoints."""
20
 
21
- def __init__(self, space_url: str, collection_name: str, hf_token: Optional[str] = None):
22
- token = os.getenv("HF_TOKEN")
23
- repo_id = space_url
24
 
25
  logging.info(f"Connecting to Hugging Face Space: {repo_id}")
26
 
27
- if token:
28
- self.client = Client(repo_id, hf_token=token)
29
  else:
30
  self.client = Client(repo_id)
31
 
@@ -50,42 +50,102 @@ class HuggingFaceSpacesVectorStore(VectorStoreInterface):
50
  logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
51
  raise e
52
 
53
- # class QdrantVectorStore(VectorStoreInterface):
54
- # """Vector store implementation for direct Qdrant connection."""
55
- # # needs to be generalized for other vector stores (or add a new class for each vector store)
56
- # def __init__(self, host: str, port: int, collection_name: str, api_key: Optional[str] = None):
57
- # from qdrant_client import QdrantClient
58
- # from langchain_community.vectorstores import Qdrant
 
 
 
 
 
 
 
 
 
59
 
60
- # self.client = QdrantClient(
61
- # host=host,
62
- # port=port,
63
- # api_key=api_key
64
- # )
65
- # self.collection_name = collection_name
66
- # # Embedding model not implemented
 
 
 
 
 
 
 
67
 
68
- # def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
69
- # """Search using direct Qdrant connection."""
70
- # # Embedding model not implemented
71
- # raise NotImplementedError("Direct Qdrant search needs embedding model configuration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def create_vectorstore(config: Any) -> VectorStoreInterface:
74
  """Factory function to create appropriate vector store based on configuration."""
75
- vectorstore_type = config.get("vectorstore", "TYPE")
 
 
 
76
 
77
- if vectorstore_type.lower() == "huggingface_spaces":
78
- space_url = config.get("vectorstore", "SPACE_URL")
79
  collection_name = config.get("vectorstore", "COLLECTION_NAME")
80
- hf_token = config.get("vectorstore", "HF_TOKEN", fallback=None)
81
- return HuggingFaceSpacesVectorStore(space_url, collection_name, hf_token)
82
 
83
  elif vectorstore_type.lower() == "qdrant":
84
- host = config.get("vectorstore", "HOST")
85
- port = int(config.get("vectorstore", "PORT"))
86
  collection_name = config.get("vectorstore", "COLLECTION_NAME")
87
- api_key = config.get("vectorstore", "API_KEY", fallback=None)
88
- return QdrantVectorStore(host, port, collection_name, api_key)
 
89
 
90
  else:
91
  raise ValueError(f"Unsupported vector store type: {vectorstore_type}")
 
2
  from typing import List, Dict, Any, Optional
3
  from gradio_client import Client
4
  import logging
 
 
5
  from dotenv import load_dotenv
6
+ from .utils import get_auth, process_content
7
  load_dotenv()
8
 
9
+
10
  class VectorStoreInterface(ABC):
11
  """Abstract interface for different vector store implementations."""
12
 
 
15
  """Search for similar documents."""
16
  pass
17
 
18
+
19
  class HuggingFaceSpacesVectorStore(VectorStoreInterface):
20
  """Vector store implementation for Hugging Face Spaces with MCP endpoints."""
21
 
22
+ def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
23
+ repo_id = url
 
24
 
25
  logging.info(f"Connecting to Hugging Face Space: {repo_id}")
26
 
27
+ if api_key:
28
+ self.client = Client(repo_id, hf_token=api_key)
29
  else:
30
  self.client = Client(repo_id)
31
 
 
50
  logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
51
  raise e
52
 
53
+ class QdrantVectorStore(VectorStoreInterface):
54
+ """Vector store implementation for direct Qdrant connection."""
55
+
56
+ def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
57
+ from qdrant_client import QdrantClient
58
+ from sentence_transformers import SentenceTransformer
59
+
60
+ self.client = QdrantClient(
61
+ url=url, # Use url parameter which handles full URLs with protocol
62
+ api_key=api_key
63
+ )
64
+ self.collection_name = collection_name
65
+ # Initialize embedding model as None - will be loaded on first use
66
+ self._embedding_model = None
67
+ self._current_model_name = None
68
 
69
+ def _get_embedding_model(self, model_name: str = None):
70
+ """Lazy load embedding model to avoid loading if not needed."""
71
+ if model_name is None:
72
+ model_name = "BAAI/bge-m3" # Default from config
73
+
74
+ # Only reload if model name changed
75
+ if self._embedding_model is None or self._current_model_name != model_name:
76
+ logging.info(f"Loading embedding model: {model_name}")
77
+ from sentence_transformers import SentenceTransformer
78
+ self._embedding_model = SentenceTransformer(model_name)
79
+ self._current_model_name = model_name
80
+ logging.info(f"Successfully loaded embedding model: {model_name}")
81
+
82
+ return self._embedding_model
83
 
84
+ def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
85
+ """Search using direct Qdrant connection."""
86
+ try:
87
+ # Get embedding model
88
+ model_name = kwargs.get('model_name')
89
+ embedding_model = self._get_embedding_model(model_name)
90
+
91
+ # Convert query to embedding
92
+ logging.info(f"Converting query to embedding using model: {self._current_model_name}")
93
+ query_embedding = embedding_model.encode(query).tolist()
94
+
95
+ # Get filter from kwargs if provided
96
+ filter_obj = kwargs.get('filter', None)
97
+
98
+ # Perform vector search
99
+ logging.info(f"Searching Qdrant collection '{self.collection_name}' for top {top_k} results")
100
+ search_result = self.client.search(
101
+ collection_name=self.collection_name,
102
+ query_vector=query_embedding,
103
+ query_filter=filter_obj, # Add filter support
104
+ limit=top_k,
105
+ with_payload=True,
106
+ with_vectors=False
107
+ )
108
+
109
+ # Format results to match expected output format
110
+ results = []
111
+ for hit in search_result:
112
+ raw_content = hit.payload.get('page_content', '')
113
+ # Process content to handle malformed nested list structures
114
+ processed_content = process_content(raw_content)
115
+
116
+ result_dict = {
117
+ 'answer': processed_content,
118
+ 'answer_metadata': hit.payload.get('metadata', {}),
119
+ 'score': hit.score
120
+ }
121
+ results.append(result_dict)
122
+
123
+ logging.info(f"Successfully retrieved {len(results)} documents from Qdrant")
124
+ return results
125
+
126
+ except Exception as e:
127
+ logging.error(f"Error searching Qdrant: {str(e)}")
128
+ raise e
129
 
130
  def create_vectorstore(config: Any) -> VectorStoreInterface:
131
  """Factory function to create appropriate vector store based on configuration."""
132
+ vectorstore_type = config.get("vectorstore", "PROVIDER")
133
+
134
+ # Get authentication config based on provider
135
+ auth_config = get_auth(vectorstore_type.lower())
136
 
137
+ if vectorstore_type.lower() == "huggingface":
138
+ url = config.get("vectorstore", "URL")
139
  collection_name = config.get("vectorstore", "COLLECTION_NAME")
140
+ api_key = auth_config["api_key"]
141
+ return HuggingFaceSpacesVectorStore(url, collection_name, api_key)
142
 
143
  elif vectorstore_type.lower() == "qdrant":
144
+ url = config.get("vectorstore", "URL") # Use the full URL
 
145
  collection_name = config.get("vectorstore", "COLLECTION_NAME")
146
+ api_key = auth_config["api_key"]
147
+ # Remove port parameter since it's included in the URL
148
+ return QdrantVectorStore(url, collection_name, api_key)
149
 
150
  else:
151
  raise ValueError(f"Unsupported vector store type: {vectorstore_type}")
params.cfg CHANGED
@@ -1,16 +1,16 @@
1
  [vectorstore]
2
- TYPE = huggingface_spaces
3
- SPACE_URL = GIZ/audit_data
4
- COLLECTION_NAME = docling
5
- # For future direct Qdrant usage:
6
- # TYPE = qdrant
7
- # HOST = ip address
8
- # PORT = 6333
9
- # COLLECTION_NAME = "collection name"
10
- # API_KEY = api key for source
11
 
12
  [embeddings]
13
- MODEL_NAME = BAAI/bge-m3
14
  # DEVICE = cpu
15
 
16
  [retriever]
 
1
  [vectorstore]
2
+ # huggingface_spaces usage:
3
+ # PROVIDER = huggingface
4
+ # URL = GIZ/audit_data
5
+ # COLLECTION_NAME = docling
6
+
7
+ # direct Qdrant usage:
8
+ PROVIDER = qdrant
9
+ URL = https://de438521-e2dd-43d9-b41b-b2e18299a2c0.europe-west3-0.gcp.cloud.qdrant.io:6333
10
+ COLLECTION_NAME = allreports
11
 
12
  [embeddings]
13
+ MODEL_NAME = BAAI/bge-large-en-v1.5
14
  # DEVICE = cpu
15
 
16
  [retriever]