Spaces:
Sleeping
Sleeping
File size: 5,680 Bytes
77bc432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""
CardioQA Data Collection Module
Collects and processes medical Q&A data from MedQuAD dataset
Author: Novonil Basak
Date: October 2, 2025
"""
import os
import pandas as pd
import requests
from datasets import load_dataset
from pathlib import Path
import json
from tqdm import tqdm
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MedicalDataCollector:
"""Collect and process medical datasets for CardioQA RAG system"""
def __init__(self, data_dir="data/raw"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
def collect_medquad_dataset(self):
"""Collect MedQuAD dataset from HuggingFace"""
logger.info("Starting MedQuAD dataset collection...")
try:
# Load MedQuAD dataset from HuggingFace
logger.info("Loading MedQuAD from HuggingFace...")
dataset = load_dataset("keivalya/MedQuad-MedicalQnADataset")
# Convert to pandas DataFrame
df = pd.DataFrame(dataset['train'])
logger.info(f"Loaded {len(df)} medical Q&A pairs")
# Basic data inspection
logger.info("Dataset columns: " + str(df.columns.tolist()))
logger.info("Dataset shape: " + str(df.shape))
# Save raw dataset
raw_file_path = self.data_dir / "medquad_raw.csv"
df.to_csv(raw_file_path, index=False)
logger.info(f"Saved raw MedQuAD to {raw_file_path}")
return df
except Exception as e:
logger.error(f"Error collecting MedQuAD dataset: {str(e)}")
return None
def filter_cardiac_data(self, df):
"""Filter dataset for cardiology-related content"""
logger.info("Filtering for cardiology-related content...")
# Cardiac-related keywords
cardiac_keywords = [
'heart', 'cardiac', 'cardiology', 'cardiovascular', 'coronary',
'arrhythmia', 'hypertension', 'blood pressure', 'chest pain',
'heart attack', 'myocardial', 'atrial', 'ventricular', 'valve',
'pacemaker', 'ECG', 'EKG', 'angina', 'stroke', 'circulation'
]
# Create cardiac filter mask
cardiac_mask = df.apply(
lambda row: any(
keyword.lower() in str(row).lower()
for keyword in cardiac_keywords
), axis=1
)
cardiac_df = df[cardiac_mask].copy()
logger.info(f"Found {len(cardiac_df)} cardiac-related Q&A pairs")
# Save filtered cardiac data
cardiac_file_path = self.data_dir / "medquad_cardiac.csv"
cardiac_df.to_csv(cardiac_file_path, index=False)
logger.info(f"Saved cardiac data to {cardiac_file_path}")
return cardiac_df
def display_sample_data(self, df, n_samples=3):
"""Display sample Q&A pairs"""
logger.info(f"Sample {n_samples} Q&A pairs:")
print("\n" + "="*80)
for i, row in df.head(n_samples).iterrows():
print(f"Q{i+1}: {row.iloc[0] if len(row) > 0 else 'No question'}")
print(f"A{i+1}: {row.iloc[1] if len(row) > 1 else 'No answer'}")
print("-" * 60)
def get_dataset_statistics(self, df):
"""Generate basic statistics about the dataset"""
stats = {
'total_pairs': len(df),
'columns': df.columns.tolist(),
'missing_values': df.isnull().sum().to_dict(),
'data_types': df.dtypes.to_dict()
}
# Save statistics
stats_file = self.data_dir / "dataset_statistics.json"
with open(stats_file, 'w') as f:
json.dump(stats, f, indent=2, default=str)
logger.info("Dataset Statistics:")
logger.info(f"- Total Q&A pairs: {stats['total_pairs']}")
logger.info(f"- Columns: {stats['columns']}")
logger.info(f"- Statistics saved to {stats_file}")
return stats
def main():
"""Main execution function"""
print("π« CardioQA Data Collection Pipeline")
print("=" * 50)
# Initialize collector
collector = MedicalDataCollector()
# Step 1: Collect MedQuAD dataset
print("\nπ Step 1: Collecting MedQuAD Dataset...")
medquad_df = collector.collect_medquad_dataset()
if medquad_df is not None:
# Step 2: Generate statistics
print("\nπ Step 2: Analyzing Dataset...")
stats = collector.get_dataset_statistics(medquad_df)
# Step 3: Display samples
print("\nπ Step 3: Sample Data Preview...")
collector.display_sample_data(medquad_df, n_samples=3)
# Step 4: Filter cardiac data
print("\nπ« Step 4: Filtering Cardiac Data...")
cardiac_df = collector.filter_cardiac_data(medquad_df)
# Step 5: Display cardiac samples
if len(cardiac_df) > 0:
print("\nπ Step 5: Cardiac Data Preview...")
collector.display_sample_data(cardiac_df, n_samples=2)
print("\nβ
Data collection completed successfully!")
print(f"π Files saved in: {collector.data_dir}")
else:
print("β Data collection failed!")
if __name__ == "__main__":
main()
|