Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import mne | |
| import numpy as np | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import os | |
| model_name = "tiiuae/falcon-7b-instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| def compute_band_power(psd, freqs, fmin, fmax): | |
| freq_mask = (freqs >= fmin) & (freqs <= fmax) | |
| band_psd = psd[:, freq_mask].mean() | |
| return float(band_psd) | |
| def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'): | |
| """ | |
| Load EEG data from a file with flexible CSV handling. | |
| - If FIF: Use read_raw_fif. | |
| - If CSV: | |
| * If `time_col` is present, use it as time. | |
| * Otherwise, assume a default sfreq and treat all columns as channels. | |
| """ | |
| _, file_ext = os.path.splitext(file_path) | |
| file_ext = file_ext.lower() | |
| if file_ext == '.fif': | |
| raw = mne.io.read_raw_fif(file_path, preload=True) | |
| elif file_ext == '.csv': | |
| df = pd.read_csv(file_path) | |
| # Remove non-numeric columns except time_col | |
| for col in df.columns: | |
| if col != time_col: | |
| # Drop non-numeric columns if any | |
| if not pd.api.types.is_numeric_dtype(df[col]): | |
| df = df.drop(columns=[col]) | |
| if time_col in df.columns: | |
| # Use the provided time column | |
| time = df[time_col].values | |
| data_df = df.drop(columns=[time_col]) | |
| if len(time) < 2: | |
| raise ValueError("Not enough time points to estimate sampling frequency.") | |
| sfreq = 1.0 / np.mean(np.diff(time)) | |
| else: | |
| # No explicit time column, assume uniform sampling at default_sfreq | |
| sfreq = default_sfreq | |
| data_df = df | |
| # Channels are all remaining columns | |
| ch_names = list(data_df.columns) | |
| data = data_df.values.T # shape: (n_channels, n_samples) | |
| # Create MNE info | |
| ch_types = ['eeg'] * len(ch_names) | |
| info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) | |
| raw = mne.io.RawArray(data, info) | |
| else: | |
| raise ValueError("Unsupported file format. Please provide a FIF or CSV file.") | |
| return raw | |
| def process_eeg(file, default_sfreq, time_col): | |
| raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col) | |
| psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40) | |
| alpha_power = compute_band_power(psd, freqs, 8, 12) | |
| beta_power = compute_band_power(psd, freqs, 13, 30) | |
| data_summary = ( | |
| f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. " | |
| f"The EEG shows stable alpha rhythms and slightly elevated beta activity." | |
| ) | |
| prompt = f"""You are a neuroscientist analyzing EEG features. | |
| Data Summary: {data_summary} | |
| Provide a concise, user-friendly interpretation of these findings in simple terms. | |
| """ | |
| inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95 | |
| ) | |
| summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return summary | |
| iface = gr.Interface( | |
| fn=process_eeg, | |
| inputs=[ | |
| gr.File(label="Upload your EEG data (FIF or CSV)"), | |
| gr.Textbox(label="Default Sampling Frequency if no time column (Hz)", value="256"), | |
| gr.Textbox(label="Time column name (if exists)", value="time") | |
| ], | |
| outputs="text", | |
| title="NeuroNarrative-Lite: EEG Summary (Flexible CSV Handling)", | |
| description=( | |
| "Upload EEG data in FIF or CSV format. " | |
| "If CSV, either include a 'time' column or specify a default sampling frequency. " | |
| "Non-numeric columns will be removed (except the chosen time column)." | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |