Spaces:
Running
Running
| # app.py | |
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import einops | |
| import copy | |
| import uuid | |
| import shutil | |
| import time | |
| import threading # <-- NEW: Import for background tasks | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from visionts import VisionTSpp, freq_to_seasonality_list | |
| # ======================== | |
| # 0. Environment & Cleanup Configuration | |
| # ======================== | |
| # --- Configuration for Session Cleanup --- | |
| SESSION_DIR_ROOT = Path("user_sessions") | |
| SESSION_DIR_ROOT.mkdir(exist_ok=True) | |
| MAX_FILE_AGE_SECONDS = 24 * 60 * 60 # 24 hours | |
| CLEANUP_INTERVAL_SECONDS = 60 * 60 # Run cleanup check every 1 hour | |
| # set the gradio tmp dir | |
| # os.environ["GRADIO_TEMP_DIR"] = "./user_sessions" | |
| def cleanup_old_sessions(): | |
| """Deletes session folders older than MAX_FILE_AGE_SECONDS.""" | |
| print(f"Running periodic cleanup of old session directories, with periodicity of {CLEANUP_INTERVAL_SECONDS} seconds...") | |
| now = time.time() | |
| deleted_count = 0 | |
| for session_dir in SESSION_DIR_ROOT.iterdir(): | |
| if session_dir.is_dir(): | |
| try: | |
| # Use modification time of the directory as an indicator of last activity | |
| dir_mod_time = session_dir.stat().st_mtime | |
| if (now - dir_mod_time) > MAX_FILE_AGE_SECONDS: | |
| print(f"Cleaning up old session directory: {session_dir}, over {MAX_FILE_AGE_SECONDS} seconds.") | |
| shutil.rmtree(session_dir) | |
| deleted_count += 1 | |
| except Exception as e: | |
| print(f"Error cleaning up directory {session_dir}: {e}") | |
| if deleted_count > 0: | |
| print(f"Cleanup complete. Removed {deleted_count} old session(s).") | |
| else: | |
| print("Cleanup complete. No old sessions found.") | |
| # --- NEW: Function to run the cleanup periodically in the background --- | |
| def periodic_cleanup_task(): | |
| """Wrapper function to run cleanup in a loop with a sleep interval.""" | |
| print("Starting background thread for periodic cleanup.") | |
| while True: | |
| cleanup_old_sessions() | |
| time.sleep(CLEANUP_INTERVAL_SECONDS) | |
| # ======================== | |
| # 1. Model Configuration | |
| # ======================== | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| REPO_ID = "Lefei/VisionTSpp" | |
| LOCAL_DIR = "./hf_models/VisionTSpp" | |
| CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt") | |
| ARCH = 'mae_base' | |
| if not os.path.exists(CKPT_PATH): | |
| from huggingface_hub import snapshot_download | |
| print("Downloading model from Hugging Face Hub...") | |
| snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False, resume_download=True) | |
| QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] | |
| # Assuming VisionTSpp is defined in a separate file or installed package | |
| # from visionts import VisionTSpp, freq_to_seasonality_list # Placeholder for your model import | |
| model = VisionTSpp(ARCH, ckpt_path=CKPT_PATH, quantile=True, clip_input=True, complete_no_clip=False, color=True).to(DEVICE) | |
| print(f"Model loaded on {DEVICE}") | |
| imagenet_mean = np.array([0.485, 0.456, 0.406]) | |
| imagenet_std = np.array([0.229, 0.224, 0.225]) | |
| # ======================== | |
| # 2. Preset Datasets (Now Loaded Locally) | |
| # ======================== | |
| data_dir = "./datasets/" | |
| PRESET_DATASETS = { | |
| "ETTm1": data_dir + "ETTm1.csv", | |
| "ETTm2": data_dir + "ETTm2.csv", | |
| "ETTh1": data_dir + "ETTh1.csv", | |
| "ETTh2": data_dir + "ETTh2.csv", | |
| "Illness": data_dir + "Illness.csv", | |
| "Weather": data_dir + "Weather.csv", | |
| } | |
| def load_preset_data(name): | |
| """Loads a preset dataset from a local path.""" | |
| path = PRESET_DATASETS[name] | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Preset dataset file not found: {path}. Make sure it's uploaded to the 'datasets' folder.") | |
| return pd.read_csv(path) | |
| # ======================== | |
| # 3. Visualization Functions (No changes needed) | |
| # ======================== | |
| def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None): | |
| if image_tensor is None: | |
| return None | |
| image = image_tensor.cpu() | |
| cur_image = torch.zeros_like(image) | |
| height_per_var = image.shape[0] // cur_nvars | |
| for i in range(cur_nvars): | |
| cur_color_idx = cur_color_list[i] | |
| var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :] | |
| unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx] | |
| cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255 | |
| cur_image = torch.clamp(cur_image, 0, 255).int().numpy() | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| ax.imshow(cur_image) | |
| ax.set_title(title, fontsize=14) | |
| ax.axis('off') | |
| plt.tight_layout() | |
| plt.close(fig) | |
| return fig | |
| def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len): | |
| if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy() | |
| if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy() | |
| for i, q in enumerate(pred_quantiles_list): | |
| if isinstance(q, torch.Tensor): | |
| pred_quantiles_list[i] = q.cpu().numpy() | |
| nvars = true_data.shape[1] | |
| FIG_WIDTH, FIG_HEIGHT_PER_VAR = 15, 2.0 | |
| fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True) | |
| if nvars == 1: axes = [axes] | |
| pred_quantiles_list.insert(len(QUANTILES)//2, pred_median) | |
| sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0]) | |
| quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5] | |
| quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5] | |
| num_bands = len(quantile_preds) // 2 | |
| quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1] | |
| for i, ax in enumerate(axes): | |
| ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5) | |
| pred_range = np.arange(context_len, context_len + pred_len) | |
| ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5) | |
| for j in range(num_bands): | |
| lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i] | |
| q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)] | |
| ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile') | |
| y_min, y_max = ax.get_ylim() | |
| ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7) | |
| ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center') | |
| ax.grid(True, which='both', linestyle='--', linewidth=0.5) | |
| ax.margins(x=0) | |
| handles, labels = axes[0].get_legend_handles_labels() | |
| unique_labels = dict(zip(labels, handles)) | |
| fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2) | |
| plt.tight_layout(rect=[0, 0, 1, 0.95]) | |
| plt.close(fig) | |
| return fig | |
| # ======================== | |
| # 4. Prediction Logic | |
| # ======================== | |
| class PredictionResult: | |
| def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples, inferred_freq): | |
| self.ts_fig = ts_fig | |
| self.input_img_fig = input_img_fig | |
| self.recon_img_fig = recon_img_fig | |
| self.csv_path = csv_path | |
| self.total_samples = total_samples | |
| self.inferred_freq = inferred_freq | |
| def predict_at_index(df, index, context_len, pred_len, session_dir): | |
| if 'date' not in df.columns: | |
| raise gr.Error("β Input CSV must contain a 'date' column.") | |
| try: | |
| df['date'] = pd.to_datetime(df['date']) | |
| df = df.sort_values('date').set_index('date') | |
| inferred_freq = pd.infer_freq(df.index) | |
| if inferred_freq is None: | |
| time_diff = df.index[1] - df.index[0] | |
| inferred_freq = pd.tseries.frequencies.to_offset(time_diff).freqstr | |
| gr.Warning(f"Could not reliably infer frequency. Using fallback based on first two timestamps: {inferred_freq}") | |
| print(f"Inferred frequency: {inferred_freq}") | |
| except Exception as e: | |
| raise gr.Error(f"β Date processing failed: {e}. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).") | |
| data = df.select_dtypes(include=np.number).values | |
| nvars = data.shape[1] | |
| total_samples = len(data) - context_len - pred_len + 1 | |
| if total_samples <= 0: | |
| raise gr.Error(f"Data is too short. It needs at least context_len + pred_len = {context_len + pred_len} rows, but has {len(data)}.") | |
| index = max(0, min(index, total_samples - 1)) | |
| train_len = int(len(data) * 0.7) | |
| x_mean = data[:train_len].mean(axis=0, keepdims=True) | |
| x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8 | |
| data_norm = (data - x_mean) / x_std | |
| start_idx = index | |
| x_norm = data_norm[start_idx : start_idx + context_len] | |
| y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len] | |
| x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE) | |
| periodicity_list = freq_to_seasonality_list(inferred_freq) | |
| periodicity = periodicity_list[0] if periodicity_list else 1 | |
| color_list = [i % 3 for i in range(nvars)] | |
| model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity, | |
| num_patch_input=7, padding_mode='constant') | |
| with torch.no_grad(): | |
| y_pred, input_image, reconstructed_image, _, _ = model.forward( | |
| x_tensor, export_image=True, color_list=color_list | |
| ) | |
| y_pred, y_pred_quantile_list = y_pred | |
| all_y_pred_list = copy.deepcopy(y_pred_quantile_list) | |
| all_y_pred_list.insert(len(QUANTILES)//2, y_pred) | |
| all_preds = dict(zip(QUANTILES, all_y_pred_list)) | |
| pred_median_norm = all_preds.pop(0.5)[0] | |
| pred_quantiles_norm = [q[0] for q in list(all_preds.values())] | |
| y_true = y_true_norm * x_std + x_mean | |
| pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean | |
| pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm] | |
| full_true_context = data[start_idx : start_idx + context_len] | |
| full_true_series = np.concatenate([full_true_context, y_true], axis=0) | |
| ts_fig = visual_ts_with_quantiles( | |
| true_data=full_true_series, pred_median=pred_median, | |
| pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()), | |
| context_len=context_len, pred_len=pred_len | |
| ) | |
| input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list) | |
| recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list) | |
| csv_path = Path(session_dir) / "prediction_result.csv" | |
| time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len] | |
| result_data = {'date': time_index} | |
| for i in range(nvars): | |
| result_data[f'True_Var{i+1}'] = y_true[:, i] | |
| result_data[f'Pred_Median_Var{i+1}'] = pred_median[:, i] | |
| result_df = pd.DataFrame(result_data) | |
| result_df.to_csv(csv_path, index=False) | |
| return PredictionResult(ts_fig, input_img_fig, recon_img_fig, str(csv_path), total_samples, inferred_freq) | |
| # ======================== | |
| # 5. Gradio Interface | |
| # ======================== | |
| def get_session_dir(session_id: gr.State): | |
| """Creates and returns a unique directory for the user session.""" | |
| if session_id is None or not Path(session_id).exists(): | |
| session_uuid = str(uuid.uuid4()) | |
| session_dir = Path(SESSION_DIR_ROOT) / session_uuid | |
| session_dir.mkdir(exist_ok=True, parents=True) | |
| session_id = str(session_dir) | |
| return session_id | |
| def run_forecast(data_source, upload_file, index, context_len, pred_len, session_id: gr.State): | |
| session_dir = get_session_dir(session_id) | |
| try: | |
| if data_source == "Upload CSV": | |
| if upload_file is None: | |
| raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.") | |
| uploaded_file_path = Path(session_dir) / Path(upload_file.name).name | |
| shutil.copy(upload_file.name, uploaded_file_path) | |
| df = pd.read_csv(uploaded_file_path) | |
| else: | |
| df = load_preset_data(data_source) | |
| index, context_len, pred_len = int(index), int(context_len), int(pred_len) | |
| result = predict_at_index(df, index, context_len, pred_len, session_dir) | |
| final_index = min(index, result.total_samples - 1) | |
| return ( | |
| result.ts_fig, | |
| result.input_img_fig, | |
| result.recon_img_fig, | |
| result.csv_path, | |
| gr.update(maximum=result.total_samples - 1, value=final_index), | |
| gr.update(value=result.inferred_freq), | |
| session_dir | |
| ) | |
| except Exception as e: | |
| print(f"Error during forecast: {e}") | |
| error_fig = plt.figure(figsize=(10, 5)) | |
| plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12) | |
| plt.axis('off') | |
| plt.close(error_fig) | |
| return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id | |
| with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo: | |
| session_id_state = gr.State(None) | |
| gr.Markdown("# π°οΈ VisionTS++: Multivariate Time Series Forecasting") | |
| gr.Markdown( | |
| """ | |
| An interactive platform to explore time series forecasting using the VisionTS++ model. | |
| - β **Select** from local preset datasets or **upload** your own. | |
| - β **Frequency is auto-detected** from the 'date' column. | |
| - β **Visualize** predictions with multiple **quantile uncertainty bands**. | |
| - β **Slide** through different samples of the dataset for real-time forecasting. | |
| - β **Download** the prediction results as a CSV file. | |
| - β **User Isolation**: Each user session has its own temporary storage to prevent file conflicts. Old files are automatically cleaned up. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### 1. Data & Model Configuration") | |
| data_source = gr.Dropdown( | |
| label="Select Data Source", | |
| choices=list(PRESET_DATASETS.keys()) + ["Upload CSV"], | |
| value="ETTh1" | |
| ) | |
| upload_file = gr.File(label="Upload CSV File", file_types=['.csv'], visible=False) | |
| gr.Markdown( | |
| """ | |
| **Upload Rules:** | |
| 1. Must be a `.csv` file. | |
| 2. Must contain a time column named `date` with a consistent frequency. | |
| """ | |
| ) | |
| context_len = gr.Number(label="Context Length (History)", value=336) | |
| pred_len = gr.Number(label="Prediction Length (Future)", value=96) | |
| freq_display = gr.Textbox(label="Detected Frequency", interactive=False) | |
| run_btn = gr.Button("π Run Forecast", variant="primary") | |
| gr.Markdown("### 2. Sample Selection") | |
| sample_index = gr.Slider( | |
| label="Sample Index", | |
| minimum=0, | |
| maximum=1000000, | |
| step=1, | |
| value=1000000, | |
| info="Drag the slider to select different starting points from the dataset for prediction." | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("### 3. Prediction Results") | |
| ts_plot = gr.Plot(label="Time Series Forecast with Quantile Bands") | |
| gr.Markdown( | |
| """ | |
| **Plot Explanation:** | |
| - **β« Black Line:** Ground truth data. The left side is the input context, and the right side is the actual future value. | |
| - **π΄ Red Line:** The model's median prediction for the future. | |
| - **π΅ Blue Shaded Areas:** Represent the model's uncertainty. The darker the blue, the wider the prediction interval, indicating more uncertainty. | |
| """ | |
| ) | |
| with gr.Row(): | |
| input_img_plot = gr.Plot(label="Input as Image") | |
| recon_img_plot = gr.Plot(label="Reconstructed Image") | |
| gr.Markdown( | |
| """ | |
| **Image Explanation:** | |
| - **Input as Image:** The historical time series data (look-back window) transformed into an image format that the VisionTS++ model uses as input. | |
| - **Reconstructed Image:** The model's internal reconstruction of the input image. This helps to visualize what features the model is focusing on. | |
| """ | |
| ) | |
| download_csv = gr.File(label="Download Prediction CSV") | |
| gr.Markdown( | |
| """ | |
| **Download Prediction CSV:** | |
| - You can download the prediction results of VisionTS++ here! | |
| """ | |
| ) | |
| def toggle_upload_visibility(choice): | |
| return gr.update(visible=(choice == "Upload CSV")) | |
| data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file) | |
| inputs = [data_source, upload_file, sample_index, context_len, pred_len, session_id_state] | |
| outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display, session_id_state] | |
| run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast") | |
| sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide") | |
| # ======================== | |
| # 6. Main Execution Block | |
| # ======================== | |
| if __name__ == "__main__": | |
| # --- Run initial cleanup on startup --- | |
| cleanup_old_sessions() | |
| # --- NEW: Start the periodic cleanup in a background daemon thread --- | |
| cleanup_thread = threading.Thread(target=periodic_cleanup_task, daemon=True) | |
| cleanup_thread.start() | |
| # --- Launch the Gradio app --- | |
| demo.launch(debug=True) | |