Novoo5 commited on
Commit
77bc432
Β·
verified Β·
1 Parent(s): 0aac165

Add CardioQA system files - API, database, and medical data

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
37
+ data/raw/medquad_full.csv filter=lfs diff=lfs merge=lfs -text
38
+ data/raw/medquad_raw.csv filter=lfs diff=lfs merge=lfs -text
chroma_db/c3ee0465-20cc-4bb1-bbae-cdf17ec4df3f/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0afd4a59a7369cd3375711f2768a3e7e41228bc63733e8658bcb04108952750c
3
+ size 167600
chroma_db/c3ee0465-20cc-4bb1-bbae-cdf17ec4df3f/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0e81c3b22454233bc12d0762f06dcca48261a75231cf87c79b75e69a6c00150
3
+ size 100
chroma_db/c3ee0465-20cc-4bb1-bbae-cdf17ec4df3f/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3f53d18cb2b932fc4ce3fe13761180f9983fc2c12f6da2fd8ac805cd4505ae5
3
+ size 400
chroma_db/c3ee0465-20cc-4bb1-bbae-cdf17ec4df3f/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
3
+ size 0
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:941a0e853a6b74b9d429be0f4bb877be8da948cd9b9cddf93025c7508b50bdb3
3
+ size 11190272
data/processed/cardiac_qa.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/processed/cardioqa_system_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "created_date": "2025-10-02T18:00:23.978687",
3
+ "total_documents": 364,
4
+ "embedding_model": "all-MiniLM-L6-v2",
5
+ "llm_model": "gemini-2.0-flash",
6
+ "vector_db_path": "../chroma_db",
7
+ "safety_features": [
8
+ "emergency_detection",
9
+ "professional_consultation",
10
+ "medical_disclaimers"
11
+ ],
12
+ "performance_metrics": {
13
+ "avg_response_time": "2-3 seconds",
14
+ "safety_validation": "enabled",
15
+ "confidence_scoring": "enabled"
16
+ }
17
+ }
data/processed/rag_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vector_db_path": "../chroma_db",
3
+ "collection_name": "cardiac_knowledge",
4
+ "embedding_model": "all-MiniLM-L6-v2",
5
+ "embedding_dimension": 384,
6
+ "total_documents": 364,
7
+ "data_source": "MedQuAD",
8
+ "specialty": "cardiology",
9
+ "created_date": "2025-10-02T17:33:34.943956"
10
+ }
data/raw/.gitkeep ADDED
File without changes
data/raw/dataset_statistics.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_pairs": 16407,
3
+ "columns": [
4
+ "qtype",
5
+ "Question",
6
+ "Answer"
7
+ ],
8
+ "missing_values": {
9
+ "qtype": 0,
10
+ "Question": 0,
11
+ "Answer": 0
12
+ },
13
+ "data_types": {
14
+ "qtype": "object",
15
+ "Question": "object",
16
+ "Answer": "object"
17
+ }
18
+ }
data/raw/medquad_cardiac.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/raw/medquad_full.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fd20f0d2e946398b648b4cc56f3dc6111eb76b00a8eec6f4669ebebf1b701c1
3
+ size 22483298
data/raw/medquad_raw.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fd20f0d2e946398b648b4cc56f3dc6111eb76b00a8eec6f4669ebebf1b701c1
3
+ size 22483298
src/api/main.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CardioQA FastAPI Backend - PRODUCTION VERSION
3
+ AI-powered cardiac diagnostic assistant with RAG
4
+ Author: Novonil Basak
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import time
10
+ from pathlib import Path
11
+ from typing import List, Optional
12
+ from contextlib import asynccontextmanager
13
+
14
+ from fastapi import FastAPI, HTTPException
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from pydantic import BaseModel, Field
17
+ import chromadb
18
+ from sentence_transformers import SentenceTransformer
19
+ import google.generativeai as genai
20
+
21
+ # Setup logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Global variables
26
+ collection = None
27
+ embedding_model = None
28
+ gemini_model = None
29
+ safety_validator = None
30
+
31
+ # Pydantic models
32
+ class QueryRequest(BaseModel):
33
+ query: str = Field(..., min_length=5, max_length=500)
34
+ include_metadata: bool = Field(default=True)
35
+
36
+ class QueryResponse(BaseModel):
37
+ response: str
38
+ safety_score: int
39
+ confidence: str
40
+ knowledge_sources: int
41
+ top_similarity: float
42
+ warnings: List[str]
43
+ response_time: float
44
+
45
+ class MedicalSafetyValidator:
46
+ """Medical safety validation system"""
47
+
48
+ def __init__(self):
49
+ self.emergency_keywords = [
50
+ 'heart attack', 'chest pain', 'shortness of breath', 'stroke',
51
+ 'severe pain', 'bleeding', 'unconscious', 'emergency', 'crushing pain'
52
+ ]
53
+
54
+ def validate_response(self, response_text: str, user_query: str) -> dict:
55
+ """Validate medical safety of AI response"""
56
+ safety_score = 85
57
+ warnings = []
58
+
59
+ # Check for emergency situations
60
+ if any(keyword in user_query.lower() for keyword in self.emergency_keywords):
61
+ if 'seek immediate medical attention' not in response_text.lower():
62
+ warnings.append("CRITICAL: Emergency situation detected")
63
+ safety_score -= 20
64
+ else:
65
+ safety_score += 10
66
+
67
+ # Check for professional consultation recommendation
68
+ consult_phrases = ['consult', 'doctor', 'physician', 'healthcare provider']
69
+ if any(phrase in response_text.lower() for phrase in consult_phrases):
70
+ safety_score += 10
71
+ else:
72
+ warnings.append("Added professional consultation recommendation")
73
+ safety_score -= 15
74
+
75
+ # Check response quality
76
+ if len(response_text) > 200:
77
+ safety_score += 5
78
+
79
+ # Check for dangerous statements
80
+ dangerous_phrases = ['you definitely have', 'this is certainly', 'never see a doctor']
81
+ if any(phrase in response_text.lower() for phrase in dangerous_phrases):
82
+ warnings.append("Contains potentially dangerous medical statements")
83
+ safety_score -= 25
84
+
85
+ safety_score = min(100, max(50, safety_score))
86
+
87
+ return {
88
+ 'safety_score': safety_score,
89
+ 'warnings': warnings,
90
+ 'is_safe': safety_score >= 70
91
+ }
92
+
93
+ def add_safety_disclaimers(self, response_text: str, safety_check: dict) -> str:
94
+ """Add medical disclaimers"""
95
+ disclaimers = "\n\n⚠️ **MEDICAL DISCLAIMER**: Educational purposes only.\nπŸ‘¨β€βš•οΈ **RECOMMENDATION**: Consult healthcare professionals."
96
+
97
+ if safety_check['safety_score'] < 80:
98
+ disclaimers += "\n🚨 **IMPORTANT**: For severe symptoms, seek immediate medical attention."
99
+
100
+ return response_text + disclaimers
101
+
102
+ @asynccontextmanager
103
+ async def lifespan(app: FastAPI):
104
+ """Initialize and cleanup application resources"""
105
+ global collection, embedding_model, gemini_model, safety_validator
106
+
107
+ logger.info("πŸ«€ Starting CardioQA API...")
108
+
109
+ try:
110
+ # FIXED: Force ChromaDB to create new compatible database
111
+ possible_paths = [
112
+ "./chroma_db",
113
+ "chroma_db",
114
+ "/opt/render/project/src/chroma_db",
115
+ Path.cwd() / "chroma_db",
116
+ Path(__file__).parent.parent.parent / "chroma_db"
117
+ ]
118
+
119
+ db_path = None
120
+ for path in possible_paths:
121
+ path_obj = Path(path)
122
+ logger.info(f"πŸ” Checking: {path_obj.absolute()}")
123
+ if path_obj.exists() and path_obj.is_dir():
124
+ db_path = str(path_obj)
125
+ logger.info(f"βœ… Found ChromaDB at: {db_path}")
126
+ break
127
+
128
+ if not db_path:
129
+ # Create new ChromaDB if not found
130
+ logger.info("πŸ“ Creating new ChromaDB...")
131
+ db_path = "./chroma_db_render"
132
+
133
+ # Initialize new ChromaDB and recreate collection
134
+ client = chromadb.PersistentClient(path=db_path)
135
+ try:
136
+ collection = client.get_collection(name="cardiac_knowledge")
137
+ logger.info(f"βœ… Using existing collection: {collection.count()} documents")
138
+ except:
139
+ logger.info("Creating new collection with sample data...")
140
+ collection = client.create_collection(name="cardiac_knowledge")
141
+
142
+ # Add sample cardiac Q&A data for demo
143
+ sample_data = [
144
+ {
145
+ "question": "What are the symptoms of heart attack?",
146
+ "answer": "Common heart attack symptoms include chest pain or discomfort, shortness of breath, pain in arms/back/neck/jaw, cold sweat, nausea, and lightheadedness. Seek immediate medical attention if experiencing these symptoms.",
147
+ "qtype": "symptoms"
148
+ },
149
+ {
150
+ "question": "How can I prevent heart disease?",
151
+ "answer": "Heart disease prevention includes regular exercise, healthy diet low in saturated fats, not smoking, limiting alcohol, managing stress, controlling blood pressure and cholesterol, and regular medical checkups.",
152
+ "qtype": "prevention"
153
+ },
154
+ {
155
+ "question": "What causes high blood pressure?",
156
+ "answer": "High blood pressure can be caused by genetics, age, diet high in sodium, lack of exercise, obesity, excessive alcohol consumption, stress, and certain medical conditions. Regular monitoring is important.",
157
+ "qtype": "causes"
158
+ }
159
+ ]
160
+
161
+ for i, item in enumerate(sample_data):
162
+ collection.add(
163
+ documents=[item["answer"]],
164
+ metadatas=[{
165
+ "question": item["question"],
166
+ "answer": item["answer"],
167
+ "qtype": item["qtype"]
168
+ }],
169
+ ids=[f"cardiac_{i}"]
170
+ )
171
+
172
+ logger.info(f"βœ… Created collection with {len(sample_data)} sample documents")
173
+ else:
174
+ # Try to use existing database
175
+ try:
176
+ client = chromadb.PersistentClient(path=db_path)
177
+ collection = client.get_collection(name="cardiac_knowledge")
178
+ logger.info(f"βœ… Loaded existing database: {collection.count()} documents")
179
+ except Exception as e:
180
+ logger.error(f"❌ ChromaDB compatibility issue: {e}")
181
+ # Fallback: create new database
182
+ logger.info("Creating fallback database...")
183
+ client = chromadb.PersistentClient(path="./chroma_db_fallback")
184
+ collection = client.create_collection(name="cardiac_knowledge")
185
+ # Add sample data (same as above)
186
+ sample_data = [
187
+ {
188
+ "question": "What are the symptoms of heart attack?",
189
+ "answer": "Common heart attack symptoms include chest pain or discomfort, shortness of breath, pain in arms/back/neck/jaw, cold sweat, nausea, and lightheadedness. Seek immediate medical attention.",
190
+ "qtype": "symptoms"
191
+ }
192
+ ]
193
+ collection.add(
194
+ documents=[sample_data[0]["answer"]],
195
+ metadatas=[sample_data[0]],
196
+ ids=["cardiac_0"]
197
+ )
198
+ logger.info("βœ… Created fallback database")
199
+
200
+ # Load embedding model
201
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
202
+ logger.info("βœ… Loaded embedding model")
203
+
204
+ # Configure Gemini API
205
+ api_key = os.getenv("GEMINI_API_KEY")
206
+ if not api_key:
207
+ raise Exception("❌ GEMINI_API_KEY environment variable not set")
208
+
209
+ genai.configure(api_key=api_key)
210
+ gemini_model = genai.GenerativeModel('gemini-2.0-flash')
211
+
212
+ # Test Gemini connection
213
+ test_response = gemini_model.generate_content("Say 'CardioQA ready!'")
214
+ logger.info(f"βœ… Gemini test: {test_response.text}")
215
+
216
+ # Initialize safety validator
217
+ safety_validator = MedicalSafetyValidator()
218
+ logger.info("βœ… Safety validator ready")
219
+
220
+ logger.info("πŸŽ‰ CardioQA API fully initialized!")
221
+
222
+ yield
223
+
224
+ except Exception as e:
225
+ logger.error(f"❌ Startup failed: {str(e)}")
226
+ raise
227
+
228
+ # Cleanup
229
+ logger.info("πŸ”„ Shutting down CardioQA API...")
230
+
231
+ # Initialize FastAPI with lifespan
232
+ app = FastAPI(
233
+ title="CardioQA API",
234
+ description="AI-powered cardiac diagnostic assistant with RAG",
235
+ version="1.0.0",
236
+ lifespan=lifespan
237
+ )
238
+
239
+ # Add CORS middleware
240
+ app.add_middleware(
241
+ CORSMiddleware,
242
+ allow_origins=["*"],
243
+ allow_credentials=True,
244
+ allow_methods=["GET", "POST"],
245
+ allow_headers=["*"],
246
+ )
247
+
248
+ @app.get("/")
249
+ async def root():
250
+ """API root endpoint"""
251
+ return {
252
+ "message": "CardioQA API - AI-Powered Cardiac Diagnostic Assistant",
253
+ "version": "1.0.0",
254
+ "status": "running",
255
+ "endpoints": {
256
+ "health": "/health",
257
+ "query": "/query",
258
+ "docs": "/docs",
259
+ "stats": "/stats"
260
+ }
261
+ }
262
+
263
+ @app.get("/health")
264
+ async def health_check():
265
+ """Health check endpoint"""
266
+ try:
267
+ db_count = collection.count() if collection else 0
268
+ model_status = "ready" if gemini_model else "not loaded"
269
+
270
+ return {
271
+ "status": "healthy",
272
+ "database_count": db_count,
273
+ "model_status": model_status,
274
+ "api_version": "1.0.0",
275
+ "deployment": "render-production"
276
+ }
277
+ except Exception as e:
278
+ raise HTTPException(status_code=500, detail=str(e))
279
+
280
+ @app.post("/query", response_model=QueryResponse)
281
+ async def query_cardioqa(request: QueryRequest):
282
+ """Main CardioQA query endpoint"""
283
+ start_time = time.time()
284
+
285
+ try:
286
+ if not collection or not gemini_model or not safety_validator:
287
+ raise HTTPException(status_code=503, detail="System not fully initialized")
288
+
289
+ logger.info(f"Processing query: {request.query[:100]}...")
290
+
291
+ # Search knowledge base
292
+ results = collection.query(
293
+ query_texts=[request.query],
294
+ n_results=3
295
+ )
296
+
297
+ if not results['documents'][0]:
298
+ raise HTTPException(status_code=404, detail="No relevant cardiac information found")
299
+
300
+ # Format knowledge context
301
+ knowledge_context = []
302
+ for doc, metadata, distance in zip(
303
+ results['documents'][0],
304
+ results['metadatas'][0],
305
+ results['distances'][0]
306
+ ):
307
+ knowledge_context.append({
308
+ 'question': metadata['question'],
309
+ 'answer': metadata['answer'],
310
+ 'similarity': 1 - distance
311
+ })
312
+
313
+ # Create medical prompt
314
+ context_text = f"Medical Evidence:\nQ: {knowledge_context[0]['question']}\nA: {knowledge_context[0]['answer']}"
315
+
316
+ prompt = f"""You are CardioQA, a specialized cardiac health assistant.
317
+
318
+ MEDICAL RESPONSE RULES:
319
+ - Never provide definitive diagnoses
320
+ - Always recommend consulting healthcare professionals
321
+ - Use **bold** for important medical points
322
+ - Be educational and evidence-based
323
+ - Include appropriate medical caution
324
+
325
+ USER QUESTION: {request.query}
326
+
327
+ {context_text}
328
+
329
+ Provide a helpful, evidence-based response with proper **bold** formatting:"""
330
+
331
+ # Generate AI response
332
+ response = gemini_model.generate_content(
333
+ prompt,
334
+ generation_config={
335
+ 'temperature': 0.1,
336
+ 'max_output_tokens': 800,
337
+ }
338
+ )
339
+ ai_response = response.text
340
+
341
+ # Apply safety validation
342
+ safety_check = safety_validator.validate_response(ai_response, request.query)
343
+ safe_response = safety_validator.add_safety_disclaimers(ai_response, safety_check)
344
+
345
+ # Calculate confidence level
346
+ similarity = knowledge_context[0]['similarity']
347
+ if similarity > 0.6:
348
+ confidence = 'High'
349
+ elif similarity > 0.4:
350
+ confidence = 'Medium'
351
+ elif similarity > 0.2:
352
+ confidence = 'Low'
353
+ else:
354
+ confidence = 'Very Low'
355
+
356
+ response_time = time.time() - start_time
357
+
358
+ return QueryResponse(
359
+ response=safe_response,
360
+ safety_score=safety_check['safety_score'],
361
+ confidence=confidence,
362
+ knowledge_sources=len(knowledge_context),
363
+ top_similarity=knowledge_context[0]['similarity'],
364
+ warnings=safety_check['warnings'],
365
+ response_time=round(response_time, 2)
366
+ )
367
+
368
+ except HTTPException:
369
+ raise
370
+ except Exception as e:
371
+ logger.error(f"Query processing error: {str(e)}")
372
+ raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
373
+
374
+ @app.get("/stats")
375
+ async def get_system_stats():
376
+ """System statistics endpoint"""
377
+ try:
378
+ return {
379
+ "total_documents": collection.count() if collection else 0,
380
+ "embedding_model": "all-MiniLM-L6-v2",
381
+ "llm_model": "gemini-2.0-flash",
382
+ "specialty": "cardiology",
383
+ "safety_features": [
384
+ "emergency_detection",
385
+ "professional_consultation",
386
+ "medical_disclaimers",
387
+ "confidence_scoring"
388
+ ],
389
+ "deployment": "render-production",
390
+ "chromadb_version": "compatible"
391
+ }
392
+ except Exception as e:
393
+ raise HTTPException(status_code=500, detail=str(e))
394
+
395
+ # FIXED: Proper port binding for Render deployment
396
+ if __name__ == "__main__":
397
+ import uvicorn
398
+ # Railway uses PORT environment variable
399
+ port = int(os.environ.get("PORT", 8000))
400
+ logger.info(f"πŸš€ Starting CardioQA on port {port}")
401
+ uvicorn.run(
402
+ app,
403
+ host="0.0.0.0",
404
+ port=port,
405
+ log_level="info"
406
+ )
407
+
src/data_pipeline/collect_medquad.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CardioQA Data Collection Module
3
+ Collects and processes medical Q&A data from MedQuAD dataset
4
+ Author: Novonil Basak
5
+ Date: October 2, 2025
6
+ """
7
+
8
+ import os
9
+ import pandas as pd
10
+ import requests
11
+ from datasets import load_dataset
12
+ from pathlib import Path
13
+ import json
14
+ from tqdm import tqdm
15
+ import logging
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class MedicalDataCollector:
22
+ """Collect and process medical datasets for CardioQA RAG system"""
23
+
24
+ def __init__(self, data_dir="data/raw"):
25
+ self.data_dir = Path(data_dir)
26
+ self.data_dir.mkdir(parents=True, exist_ok=True)
27
+
28
+ def collect_medquad_dataset(self):
29
+ """Collect MedQuAD dataset from HuggingFace"""
30
+ logger.info("Starting MedQuAD dataset collection...")
31
+
32
+ try:
33
+ # Load MedQuAD dataset from HuggingFace
34
+ logger.info("Loading MedQuAD from HuggingFace...")
35
+ dataset = load_dataset("keivalya/MedQuad-MedicalQnADataset")
36
+
37
+ # Convert to pandas DataFrame
38
+ df = pd.DataFrame(dataset['train'])
39
+ logger.info(f"Loaded {len(df)} medical Q&A pairs")
40
+
41
+ # Basic data inspection
42
+ logger.info("Dataset columns: " + str(df.columns.tolist()))
43
+ logger.info("Dataset shape: " + str(df.shape))
44
+
45
+ # Save raw dataset
46
+ raw_file_path = self.data_dir / "medquad_raw.csv"
47
+ df.to_csv(raw_file_path, index=False)
48
+ logger.info(f"Saved raw MedQuAD to {raw_file_path}")
49
+
50
+ return df
51
+
52
+ except Exception as e:
53
+ logger.error(f"Error collecting MedQuAD dataset: {str(e)}")
54
+ return None
55
+
56
+ def filter_cardiac_data(self, df):
57
+ """Filter dataset for cardiology-related content"""
58
+ logger.info("Filtering for cardiology-related content...")
59
+
60
+ # Cardiac-related keywords
61
+ cardiac_keywords = [
62
+ 'heart', 'cardiac', 'cardiology', 'cardiovascular', 'coronary',
63
+ 'arrhythmia', 'hypertension', 'blood pressure', 'chest pain',
64
+ 'heart attack', 'myocardial', 'atrial', 'ventricular', 'valve',
65
+ 'pacemaker', 'ECG', 'EKG', 'angina', 'stroke', 'circulation'
66
+ ]
67
+
68
+ # Create cardiac filter mask
69
+ cardiac_mask = df.apply(
70
+ lambda row: any(
71
+ keyword.lower() in str(row).lower()
72
+ for keyword in cardiac_keywords
73
+ ), axis=1
74
+ )
75
+
76
+ cardiac_df = df[cardiac_mask].copy()
77
+ logger.info(f"Found {len(cardiac_df)} cardiac-related Q&A pairs")
78
+
79
+ # Save filtered cardiac data
80
+ cardiac_file_path = self.data_dir / "medquad_cardiac.csv"
81
+ cardiac_df.to_csv(cardiac_file_path, index=False)
82
+ logger.info(f"Saved cardiac data to {cardiac_file_path}")
83
+
84
+ return cardiac_df
85
+
86
+ def display_sample_data(self, df, n_samples=3):
87
+ """Display sample Q&A pairs"""
88
+ logger.info(f"Sample {n_samples} Q&A pairs:")
89
+ print("\n" + "="*80)
90
+
91
+ for i, row in df.head(n_samples).iterrows():
92
+ print(f"Q{i+1}: {row.iloc[0] if len(row) > 0 else 'No question'}")
93
+ print(f"A{i+1}: {row.iloc[1] if len(row) > 1 else 'No answer'}")
94
+ print("-" * 60)
95
+
96
+ def get_dataset_statistics(self, df):
97
+ """Generate basic statistics about the dataset"""
98
+ stats = {
99
+ 'total_pairs': len(df),
100
+ 'columns': df.columns.tolist(),
101
+ 'missing_values': df.isnull().sum().to_dict(),
102
+ 'data_types': df.dtypes.to_dict()
103
+ }
104
+
105
+ # Save statistics
106
+ stats_file = self.data_dir / "dataset_statistics.json"
107
+ with open(stats_file, 'w') as f:
108
+ json.dump(stats, f, indent=2, default=str)
109
+
110
+ logger.info("Dataset Statistics:")
111
+ logger.info(f"- Total Q&A pairs: {stats['total_pairs']}")
112
+ logger.info(f"- Columns: {stats['columns']}")
113
+ logger.info(f"- Statistics saved to {stats_file}")
114
+
115
+ return stats
116
+
117
+ def main():
118
+ """Main execution function"""
119
+ print("πŸ«€ CardioQA Data Collection Pipeline")
120
+ print("=" * 50)
121
+
122
+ # Initialize collector
123
+ collector = MedicalDataCollector()
124
+
125
+ # Step 1: Collect MedQuAD dataset
126
+ print("\nπŸ“Š Step 1: Collecting MedQuAD Dataset...")
127
+ medquad_df = collector.collect_medquad_dataset()
128
+
129
+ if medquad_df is not None:
130
+ # Step 2: Generate statistics
131
+ print("\nπŸ“ˆ Step 2: Analyzing Dataset...")
132
+ stats = collector.get_dataset_statistics(medquad_df)
133
+
134
+ # Step 3: Display samples
135
+ print("\nπŸ‘€ Step 3: Sample Data Preview...")
136
+ collector.display_sample_data(medquad_df, n_samples=3)
137
+
138
+ # Step 4: Filter cardiac data
139
+ print("\nπŸ«€ Step 4: Filtering Cardiac Data...")
140
+ cardiac_df = collector.filter_cardiac_data(medquad_df)
141
+
142
+ # Step 5: Display cardiac samples
143
+ if len(cardiac_df) > 0:
144
+ print("\nπŸ’“ Step 5: Cardiac Data Preview...")
145
+ collector.display_sample_data(cardiac_df, n_samples=2)
146
+
147
+ print("\nβœ… Data collection completed successfully!")
148
+ print(f"πŸ“ Files saved in: {collector.data_dir}")
149
+
150
+ else:
151
+ print("❌ Data collection failed!")
152
+
153
+ if __name__ == "__main__":
154
+ main()