File size: 5,680 Bytes
77bc432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""

CardioQA Data Collection Module

Collects and processes medical Q&A data from MedQuAD dataset

Author: Novonil Basak

Date: October 2, 2025

"""

import os
import pandas as pd
import requests
from datasets import load_dataset
from pathlib import Path
import json
from tqdm import tqdm
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MedicalDataCollector:
    """Collect and process medical datasets for CardioQA RAG system"""
    
    def __init__(self, data_dir="data/raw"):
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(parents=True, exist_ok=True)
        
    def collect_medquad_dataset(self):
        """Collect MedQuAD dataset from HuggingFace"""
        logger.info("Starting MedQuAD dataset collection...")
        
        try:
            # Load MedQuAD dataset from HuggingFace
            logger.info("Loading MedQuAD from HuggingFace...")
            dataset = load_dataset("keivalya/MedQuad-MedicalQnADataset")
            
            # Convert to pandas DataFrame
            df = pd.DataFrame(dataset['train'])
            logger.info(f"Loaded {len(df)} medical Q&A pairs")
            
            # Basic data inspection
            logger.info("Dataset columns: " + str(df.columns.tolist()))
            logger.info("Dataset shape: " + str(df.shape))
            
            # Save raw dataset
            raw_file_path = self.data_dir / "medquad_raw.csv"
            df.to_csv(raw_file_path, index=False)
            logger.info(f"Saved raw MedQuAD to {raw_file_path}")
            
            return df
            
        except Exception as e:
            logger.error(f"Error collecting MedQuAD dataset: {str(e)}")
            return None
            
    def filter_cardiac_data(self, df):
        """Filter dataset for cardiology-related content"""
        logger.info("Filtering for cardiology-related content...")
        
        # Cardiac-related keywords
        cardiac_keywords = [
            'heart', 'cardiac', 'cardiology', 'cardiovascular', 'coronary',
            'arrhythmia', 'hypertension', 'blood pressure', 'chest pain',
            'heart attack', 'myocardial', 'atrial', 'ventricular', 'valve',
            'pacemaker', 'ECG', 'EKG', 'angina', 'stroke', 'circulation'
        ]
        
        # Create cardiac filter mask
        cardiac_mask = df.apply(
            lambda row: any(
                keyword.lower() in str(row).lower() 
                for keyword in cardiac_keywords
            ), axis=1
        )
        
        cardiac_df = df[cardiac_mask].copy()
        logger.info(f"Found {len(cardiac_df)} cardiac-related Q&A pairs")
        
        # Save filtered cardiac data
        cardiac_file_path = self.data_dir / "medquad_cardiac.csv"
        cardiac_df.to_csv(cardiac_file_path, index=False)
        logger.info(f"Saved cardiac data to {cardiac_file_path}")
        
        return cardiac_df
        
    def display_sample_data(self, df, n_samples=3):
        """Display sample Q&A pairs"""
        logger.info(f"Sample {n_samples} Q&A pairs:")
        print("\n" + "="*80)
        
        for i, row in df.head(n_samples).iterrows():
            print(f"Q{i+1}: {row.iloc[0] if len(row) > 0 else 'No question'}")
            print(f"A{i+1}: {row.iloc[1] if len(row) > 1 else 'No answer'}")
            print("-" * 60)
        
    def get_dataset_statistics(self, df):
        """Generate basic statistics about the dataset"""
        stats = {
            'total_pairs': len(df),
            'columns': df.columns.tolist(),
            'missing_values': df.isnull().sum().to_dict(),
            'data_types': df.dtypes.to_dict()
        }
        
        # Save statistics
        stats_file = self.data_dir / "dataset_statistics.json"
        with open(stats_file, 'w') as f:
            json.dump(stats, f, indent=2, default=str)
            
        logger.info("Dataset Statistics:")
        logger.info(f"- Total Q&A pairs: {stats['total_pairs']}")
        logger.info(f"- Columns: {stats['columns']}")
        logger.info(f"- Statistics saved to {stats_file}")
        
        return stats

def main():
    """Main execution function"""
    print("πŸ«€ CardioQA Data Collection Pipeline")
    print("=" * 50)
    
    # Initialize collector
    collector = MedicalDataCollector()
    
    # Step 1: Collect MedQuAD dataset
    print("\nπŸ“Š Step 1: Collecting MedQuAD Dataset...")
    medquad_df = collector.collect_medquad_dataset()
    
    if medquad_df is not None:
        # Step 2: Generate statistics
        print("\nπŸ“ˆ Step 2: Analyzing Dataset...")
        stats = collector.get_dataset_statistics(medquad_df)
        
        # Step 3: Display samples
        print("\nπŸ‘€ Step 3: Sample Data Preview...")
        collector.display_sample_data(medquad_df, n_samples=3)
        
        # Step 4: Filter cardiac data
        print("\nπŸ«€ Step 4: Filtering Cardiac Data...")
        cardiac_df = collector.filter_cardiac_data(medquad_df)
        
        # Step 5: Display cardiac samples
        if len(cardiac_df) > 0:
            print("\nπŸ’“ Step 5: Cardiac Data Preview...")
            collector.display_sample_data(cardiac_df, n_samples=2)
        
        print("\nβœ… Data collection completed successfully!")
        print(f"πŸ“ Files saved in: {collector.data_dir}")
        
    else:
        print("❌ Data collection failed!")

if __name__ == "__main__":
    main()