# app.py import os os.environ["GRADIO_TEMP_DIR"] = "/home/mouxiangchen/VisionTSpp/gradio_tmp" import gradio as gr import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt import einops import copy from huggingface_hub import snapshot_download from visionts import VisionTSpp, freq_to_seasonality_list # ======================== # 1. Configuration # ======================== DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' REPO_ID = "Lefei/VisionTSpp" LOCAL_DIR = "./hf_models/VisionTSpp" CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt") ARCH = 'mae_base' # Download the model from Hugging Face Hub if not os.path.exists(CKPT_PATH): os.makedirs(LOCAL_DIR, exist_ok=True) print("Downloading model from Hugging Face Hub...") snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False) # Load the model QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] model = VisionTSpp( ARCH, ckpt_path=CKPT_PATH, # quantiles=QUANTILES, quantile=True, clip_input=True, complete_no_clip=False, color=True ).to(DEVICE) print(f"Model loaded on {DEVICE}") # Image normalization constants imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) # ======================== # 2. Preset Datasets (Now Loaded Locally) # ======================== # This dictionary maps user-friendly names to local file paths # ASSUMPTION: These files exist in a 'datasets' subfolder data_dir = "./datasets/" # data_dir = "./" PRESET_DATASETS = { "ETTm1": data_dir + "ETTm1.csv", "ETTm2": data_dir + "ETTm2.csv", "ETTh1": data_dir + "ETTh1.csv", "ETTh2": data_dir + "ETTh2.csv", "Illness": data_dir + "Illness.csv", "Weather": data_dir + "Weather.csv", } def load_preset_data(name): """Loads a preset dataset from a local path.""" path = PRESET_DATASETS[name] if not os.path.exists(path): raise FileNotFoundError(f"Preset dataset file not found: {path}. Make sure it's uploaded to the 'datasets' folder.") return pd.read_csv(path) # ======================== # 3. Visualization Functions (No changes needed) # ======================== def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None): if image_tensor is None: return None # no need for permute? # image = image_tensor.permute(1, 2, 0).cpu() image = image_tensor.cpu() cur_image = torch.zeros_like(image) height_per_var = image.shape[0] // cur_nvars for i in range(cur_nvars): cur_color_idx = cur_color_list[i] var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :] unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx] cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255 cur_image = torch.clamp(cur_image, 0, 255).int().numpy() fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(cur_image) ax.set_title(title, fontsize=14) ax.axis('off') plt.tight_layout() plt.close(fig) return fig def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len): if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy() if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy() for i, q in enumerate(pred_quantiles_list): if isinstance(q, torch.Tensor): pred_quantiles_list[i] = q.cpu().numpy() nvars = true_data.shape[1] FIG_WIDTH, FIG_HEIGHT_PER_VAR = 15, 2.0 fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True) if nvars == 1: axes = [axes] print(f"{len(pred_quantiles_list) = }") print(f"{len(model_quantiles) = }") print(f"{pred_quantiles_list[0].shape = }") sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0]) quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5] quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5] num_bands = len(quantile_preds) // 2 quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1] for i, ax in enumerate(axes): ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5) pred_range = np.arange(context_len, context_len + pred_len) ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5) for j in range(num_bands): lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i] q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)] ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile') y_min, y_max = ax.get_ylim() ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7) ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center') ax.grid(True, which='both', linestyle='--', linewidth=0.5) ax.margins(x=0) handles, labels = axes[0].get_legend_handles_labels() unique_labels = dict(zip(labels, handles)) fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2) plt.tight_layout(rect=[0, 0, 1, 0.95]) plt.close(fig) return fig # ======================== # 4. Prediction Logic # ======================== class PredictionResult: def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples, inferred_freq): self.ts_fig = ts_fig self.input_img_fig = input_img_fig self.recon_img_fig = recon_img_fig self.csv_path = csv_path self.total_samples = total_samples self.inferred_freq = inferred_freq def predict_at_index(df, index, context_len, pred_len): # === Data Validation & Frequency Inference === if 'date' not in df.columns: raise gr.Error("❌ Input CSV must contain a 'date' column.") try: df['date'] = pd.to_datetime(df['date']) df = df.sort_values('date').set_index('date') # *** NEW: Infer frequency *** inferred_freq = pd.infer_freq(df.index) if inferred_freq is None: # Fallback if inference fails time_diff = df.index[1] - df.index[0] inferred_freq = pd.tseries.frequencies.to_offset(time_diff).freqstr gr.Warning(f"Could not reliably infer frequency. Using fallback based on first two timestamps: {inferred_freq}") print(f"Inferred frequency: {inferred_freq}") except Exception as e: raise gr.Error(f"❌ Date processing failed: {e}. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).") data = df.select_dtypes(include=np.number).values nvars = data.shape[1] total_samples = len(data) - context_len - pred_len + 1 if total_samples <= 0: raise gr.Error(f"Data is too short. It needs at least context_len + pred_len = {context_len + pred_len} rows, but has {len(data)}.") index = max(0, min(index, total_samples - 1)) train_len = int(len(data) * 0.7) x_mean = data[:train_len].mean(axis=0, keepdims=True) x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8 data_norm = (data - x_mean) / x_std start_idx = index x_norm = data_norm[start_idx : start_idx + context_len] y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len] x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE) # *** Use inferred frequency *** periodicity_list = freq_to_seasonality_list(inferred_freq) periodicity = periodicity_list[0] if periodicity_list else 1 color_list = [i % 3 for i in range(nvars)] model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity) with torch.no_grad(): y_pred, input_image, reconstructed_image, _, _ = model.forward( x_tensor, export_image=True, color_list=color_list ) y_pred, y_pred_quantile_list = y_pred print(f"{x_tensor.shape = }") print(f"{y_pred.shape = }") print(f"{input_image.shape = }") print(f"{reconstructed_image.shape = }") print(f"{len(y_pred_quantile_list) = }") print(f"{input_image = }") print(f"{input_image[0,0,0, :, 0] = }") print(f"{input_image[0,0,0, 50:70, 0] = }") print(f"{input_image[0,0,0, 100:120, 0] = }") # print(f"{input_image[0] = }") # print(f"{reconstructed_image = }") all_y_pred_list = copy.deepcopy(y_pred_quantile_list) # insert in the place of 0.5 quantile, ie:len(QUANTILES)//2 all_y_pred_list.insert(len(QUANTILES)//2, y_pred) print(f"{len(all_y_pred_list) = }") print(f"{all_y_pred_list[0].shape = }") all_preds = dict(zip(QUANTILES, all_y_pred_list)) print(f"{all_preds.keys() = }") pred_median_norm = all_preds.pop(0.5)[0] pred_quantiles_norm = [q[0] for q in list(all_preds.values())] print(f"{pred_median_norm.shape = }") print(f"{len(pred_quantiles_norm) = }") y_true = y_true_norm * x_std + x_mean pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm] full_true_context = data[start_idx : start_idx + context_len] full_true_series = np.concatenate([full_true_context, y_true], axis=0) ts_fig = visual_ts_with_quantiles( true_data=full_true_series, pred_median=pred_median, pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()), context_len=context_len, pred_len=pred_len ) input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list) recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list) os.makedirs("outputs", exist_ok=True) csv_path = "outputs/prediction_result.csv" time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len] result_data = {'date': time_index} for i in range(nvars): result_data[f'True_Var{i+1}'] = y_true[:, i] result_data[f'Pred_Median_Var{i+1}'] = pred_median[:, i] result_df = pd.DataFrame(result_data) result_df.to_csv(csv_path, index=False) return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples, inferred_freq) # ======================== # 5. Gradio Interface # ======================== def run_forecast(data_source, upload_file, index, context_len, pred_len): if data_source == "Upload CSV": if upload_file is None: raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.") df = pd.read_csv(upload_file.name) else: df = load_preset_data(data_source) try: index, context_len, pred_len = int(index), int(context_len), int(pred_len) result = predict_at_index(df, index, context_len, pred_len) if index >= result.total_samples: final_index = result.total_samples - 1 else: final_index = index return ( result.ts_fig, result.input_img_fig, result.recon_img_fig, result.csv_path, gr.update(maximum=result.total_samples - 1, value=final_index), gr.update(value=result.inferred_freq) # *** Update frequency textbox *** ) except Exception as e: error_fig = plt.figure(figsize=(10, 5)) plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12) plt.axis('off') plt.close(error_fig) return error_fig, None, None, None, gr.update(), gr.update(value="Error") # UI Layout with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🕰️ VisionTS++: Multivariate Time Series Forecasting") gr.Markdown( """ An interactive platform to explore time series forecasting using the VisionTS++ model. - ✅ **Select** from local preset datasets or **upload** your own. - ✅ **Frequency is auto-detected** from the 'date' column. - ✅ **Visualize** predictions with multiple **quantile uncertainty bands**. - ✅ **Slide** through different samples of the dataset for real-time forecasting. - ✅ **Download** the prediction results as a CSV file. """ ) with gr.Row(): with gr.Column(scale=1, min_width=300): gr.Markdown("### 1. Data & Model Configuration") data_source = gr.Dropdown( label="Select Data Source", choices=list(PRESET_DATASETS.keys()) + ["Upload CSV"], value="ETTh1" ) upload_file = gr.File(label="Upload CSV File", file_types=['.csv'], visible=False) gr.Markdown( """ **Upload Rules:** 1. Must be a `.csv` file. 2. Must contain a time column named `date` with a consistent frequency. """ ) context_len = gr.Number(label="Context Length (History)", value=336) pred_len = gr.Number(label="Prediction Length (Future)", value=96) # *** Changed to non-interactive textbox to display freq *** freq_display = gr.Textbox(label="Detected Frequency", interactive=True) run_btn = gr.Button("🚀 Run Forecast", variant="primary") gr.Markdown("### 2. Sample Selection") sample_index = gr.Slider(label="Sample Index", minimum=0, maximum=100000, step=1, value=100000) with gr.Column(scale=3): gr.Markdown("### 3. Prediction Results") ts_plot = gr.Plot(label="Time Series Forecast with Quantile Bands") with gr.Row(): input_img_plot = gr.Plot(label="Input as Image") recon_img_plot = gr.Plot(label="Reconstructed Image") download_csv = gr.File(label="Download Prediction CSV") # --- Event Handlers --- def toggle_upload_visibility(choice): return gr.update(visible=(choice == "Upload CSV")) data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file) inputs = [data_source, upload_file, sample_index, context_len, pred_len] outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display] run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast") sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide") # Remove Examples block to avoid startup issues and rely on the button. # If you still want examples, ensure `cache_examples=False`. # For simplicity, we'll remove it as the 'Run' button is clear. demo.launch(debug=True)