File size: 8,966 Bytes
3718631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
#!/usr/bin/env python3
"""
RAG Database Generation Script for gprMax Documentation
Generates a ChromaDB vector database from gprMax documentation
"""

import os
import sys
import shutil
import argparse
import logging
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any
import json
import hashlib

import chromadb
import git
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class GprMaxDocumentProcessor:
    """Process gprMax documentation files for vectorization"""
    
    SUPPORTED_EXTENSIONS = {'.rst', '.md', '.txt'}
    CHUNK_SIZE = 1000  # Characters per chunk
    CHUNK_OVERLAP = 200  # Overlap between chunks
    
    def __init__(self, repo_path: Path):
        self.repo_path = repo_path
        self.doc_path = repo_path / "docs"
        
    def extract_documents(self) -> List[Dict[str, Any]]:
        """Extract and chunk all documentation files"""
        documents = []
        
        if not self.doc_path.exists():
            logger.warning(f"Documentation path {self.doc_path} does not exist")
            return documents
            
        for file_path in self._find_doc_files():
            try:
                chunks = self._process_file(file_path)
                documents.extend(chunks)
            except Exception as e:
                logger.error(f"Error processing {file_path}: {e}")
                
        logger.info(f"Extracted {len(documents)} document chunks")
        return documents
    
    def _find_doc_files(self) -> List[Path]:
        """Find all documentation files"""
        doc_files = []
        for ext in self.SUPPORTED_EXTENSIONS:
            doc_files.extend(self.doc_path.rglob(f"*{ext}"))
        return doc_files
    
    def _process_file(self, file_path: Path) -> List[Dict[str, Any]]:
        """Process a single file into chunks"""
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            content = f.read()
            
        # Calculate relative path for metadata
        rel_path = file_path.relative_to(self.repo_path)
        
        # Create chunks with overlap
        chunks = []
        for i in range(0, len(content), self.CHUNK_SIZE - self.CHUNK_OVERLAP):
            chunk_text = content[i:i + self.CHUNK_SIZE]
            
            # Skip empty or very small chunks
            if len(chunk_text.strip()) < 50:
                continue
                
            # Generate unique ID for chunk
            chunk_id = hashlib.md5(f"{rel_path}_{i}_{chunk_text[:50]}".encode()).hexdigest()
            
            chunks.append({
                "id": chunk_id,
                "text": chunk_text,
                "metadata": {
                    "source": str(rel_path),
                    "file_type": file_path.suffix,
                    "chunk_index": len(chunks),
                    "char_start": i,
                    "char_end": min(i + self.CHUNK_SIZE, len(content))
                }
            })
            
        return chunks


# Removed custom embedding model - using ChromaDB's default


class ChromaRAGDatabase:
    """ChromaDB-based RAG database"""
    
    def __init__(self, db_path: Path):
        self.db_path = db_path
        
        # Initialize ChromaDB with persistent storage
        self.client = chromadb.PersistentClient(path=str(db_path))
        
        # Collection name with version for easy updates
        self.collection_name = "gprmax_docs_v1"
        
    def create_collection(self, recreate: bool = False):
        """Create or get the document collection"""
        if recreate:
            try:
                self.client.delete_collection(self.collection_name)
                logger.info(f"Deleted existing collection: {self.collection_name}")
            except:
                pass
                
        # Let ChromaDB use its default embedding function
        self.collection = self.client.create_collection(
            name=self.collection_name,
            metadata={"created_at": datetime.now().isoformat()}
        )
        logger.info(f"Created collection: {self.collection_name}")
        
    def add_documents(self, documents: List[Dict[str, Any]]):
        """Add documents to the collection"""
        if not documents:
            logger.warning("No documents to add")
            return
            
        # Prepare data for ChromaDB
        ids = [doc["id"] for doc in documents]
        texts = [doc["text"] for doc in documents]
        metadatas = [doc["metadata"] for doc in documents]
        
        # Add to collection in batches (ChromaDB will generate embeddings automatically)
        batch_size = 100
        logger.info(f"Adding {len(documents)} documents to database...")
        for i in tqdm(range(0, len(ids), batch_size), desc="Adding to database"):
            end_idx = min(i + batch_size, len(ids))
            self.collection.add(
                ids=ids[i:end_idx],
                documents=texts[i:end_idx],
                metadatas=metadatas[i:end_idx]
                # No embeddings parameter - ChromaDB will generate them
            )
            
        logger.info(f"Added {len(documents)} documents to database")
        
        # Verify documents were added
        actual_count = self.collection.count()
        logger.info(f"Verified collection now contains {actual_count} documents")
        
    def save_metadata(self):
        """Save database metadata for reference"""
        # Get fresh count
        doc_count = self.collection.count()
        
        metadata = {
            "created_at": datetime.now().isoformat(),
            "embedding_model": "ChromaDB Default (all-MiniLM-L6-v2)",
            "collection_name": self.collection_name,
            "chunk_size": GprMaxDocumentProcessor.CHUNK_SIZE,
            "chunk_overlap": GprMaxDocumentProcessor.CHUNK_OVERLAP,
            "total_documents": doc_count
        }
        
        metadata_path = self.db_path / "metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
            
        logger.info(f"Saved metadata to {metadata_path}")


def clone_gprmax_repo(target_dir: Path) -> Path:
    """Clone or update gprMax repository"""
    repo_path = target_dir / "gprMax"
    
    if repo_path.exists():
        logger.info(f"Updating existing repository at {repo_path}")
        repo = git.Repo(repo_path)
        repo.remotes.origin.pull()
    else:
        logger.info(f"Cloning gprMax repository to {repo_path}")
        git.Repo.clone_from(
            "https://github.com/gprMax/gprMax.git",
            repo_path,
            depth=1  # Shallow clone for faster download
        )
        
    return repo_path


def main():
    parser = argparse.ArgumentParser(description="Generate RAG database from gprMax documentation")
    parser.add_argument(
        "--db-path",
        type=Path,
        default=Path(__file__).parent / "chroma_db",
        help="Path to store the ChromaDB database"
    )
    parser.add_argument(
        "--temp-dir",
        type=Path,
        default=Path(__file__).parent / "temp",
        help="Temporary directory for cloning repository"
    )
    parser.add_argument(
        "--recreate",
        action="store_true",
        help="Recreate database from scratch (delete existing)"
    )
    
    args = parser.parse_args()
    
    try:
        # Step 1: Clone/update gprMax repository
        logger.info("Step 1: Fetching gprMax repository...")
        repo_path = clone_gprmax_repo(args.temp_dir)
        
        # Step 2: Process documentation
        logger.info("Step 2: Processing documentation files...")
        processor = GprMaxDocumentProcessor(repo_path)
        documents = processor.extract_documents()
        
        if not documents:
            logger.error("No documents found to process")
            return 1
            
        # Step 3: Create database
        logger.info("Step 3: Creating vector database...")
        db = ChromaRAGDatabase(args.db_path)
        db.create_collection(recreate=args.recreate)
        
        # Step 4: Add documents
        logger.info("Step 4: Adding documents to database...")
        db.add_documents(documents)
        
        # Step 5: Save metadata
        db.save_metadata()
        
        logger.info(f"✅ Database successfully created at {args.db_path}")
        logger.info(f"Total documents: {len(documents)}")
        
        # Cleanup temp files if needed
        if args.temp_dir.exists() and args.temp_dir != args.db_path.parent:
            logger.info("Cleaning up temporary files...")
            shutil.rmtree(args.temp_dir, ignore_errors=True)
            
        return 0
        
    except Exception as e:
        logger.error(f"Failed to generate database: {e}")
        return 1


if __name__ == "__main__":
    sys.exit(main())