VisionTSpp / app.py
Lefei's picture
update app.py
ead362b verified
raw
history blame
16.3 kB
# app.py
import os
import gradio as gr
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import einops
from huggingface_hub import snapshot_download
from visionts import VisionTSpp, freq_to_seasonality_list
# ========================
# 1. Configuration
# ========================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
REPO_ID = "Lefei/VisionTSpp"
LOCAL_DIR = "./hf_models/VisionTSpp"
CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
ARCH = 'mae_base'
# Download the model from Hugging Face Hub
if not os.path.exists(CKPT_PATH):
os.makedirs(LOCAL_DIR, exist_ok=True)
print("Downloading model from Hugging Face Hub...")
snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
# Load the model
# NOTE: We assume the model was trained to predict these specific quantiles
QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
model = VisionTSpp(
ARCH,
ckpt_path=CKPT_PATH,
quantiles=QUANTILES, # Set the quantiles the model should predict
clip_input=True,
complete_no_clip=False,
color=True
).to(DEVICE)
print(f"Model loaded on {DEVICE}")
# Image normalization constants
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
# ========================
# 2. Preset Datasets
# ========================
PRESET_DATASETS = {
"ETTm1 (15-min)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv",
"ETTh1 (1-hour)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv",
"Illness": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/illness.csv",
"Weather": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/weather.csv"
}
# Local cache path for presets
PRESET_DIR = "./preset_data"
os.makedirs(PRESET_DIR, exist_ok=True)
def load_preset_data(name):
"""Loads a preset dataset, caching it locally."""
url = PRESET_DATASETS[name]
# Sanitize name for file path
sanitized_name = name.split(' ')[0]
path = os.path.join(PRESET_DIR, f"{sanitized_name}.csv")
if not os.path.exists(path):
print(f"Downloading preset dataset: {name}...")
df = pd.read_csv(url)
df.to_csv(path, index=False)
else:
df = pd.read_csv(path)
return df
# ========================
# 3. Visualization Functions
# ========================
def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
"""
Visualizes a tensor as an image, handling un-normalization.
Returns a matplotlib Figure object for Gradio.
"""
if image_tensor is None: return None
# image_tensor is [C, H, W] but we expect [H, W, C] for imshow
# The model outputs [1, 1, C, H, W], after indexing it's [C, H, W]
image = image_tensor.permute(1, 2, 0).cpu() # H, W, C
cur_image = torch.zeros_like(image)
height_per_var = image.shape[0] // cur_nvars
# Assign colors to variables for visualization
for i in range(cur_nvars):
cur_color_idx = cur_color_list[i]
var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
# Un-normalize only the used color channel
unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(cur_image)
ax.set_title(title, fontsize=14)
ax.axis('off')
plt.tight_layout()
plt.close(fig) # Close to prevent double display
return fig
def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
"""
Visualizes time series with multiple quantile bands.
pred_quantiles_list: list of tensors, one for each quantile.
model_quantiles: The list of quantiles values, e.g., [0.1, 0.2, ..., 0.9].
"""
if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy()
if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy()
for i, q in enumerate(pred_quantiles_list):
if isinstance(q, torch.Tensor):
pred_quantiles_list[i] = q.cpu().numpy()
nvars = true_data.shape[1]
FIG_WIDTH = 15
FIG_HEIGHT_PER_VAR = 2.0
fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
if nvars == 1: axes = [axes]
# Combine quantiles and predictions
sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
# Filter out the median to get pairs for bands
quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
num_bands = len(quantile_preds) // 2
# Colors from light to dark for bands from widest to narrowest
quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1]
for i, ax in enumerate(axes):
# Plot ground truth and median prediction
ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
pred_range = np.arange(context_len, context_len + pred_len)
ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
# Plot quantile bands
for j in range(num_bands):
lower_quantile_pred = quantile_preds[j][:, i]
upper_quantile_pred = quantile_preds[-(j+1)][:, i]
q_low = quantile_vals[j]
q_high = quantile_vals[-(j+1)]
ax.fill_between(
pred_range, lower_quantile_pred, upper_quantile_pred,
color=quantile_colors[j], alpha=0.7,
label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile'
)
y_min, y_max = ax.get_ylim()
ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
ax.margins(x=0)
handles, labels = axes[0].get_legend_handles_labels()
# Create a unique legend
unique_labels = dict(zip(labels, handles))
fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.close(fig)
return fig
# ========================
# 4. Prediction Logic
# ========================
class PredictionResult:
"""A data class to hold prediction results for easier handling."""
def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples):
self.ts_fig = ts_fig
self.input_img_fig = input_img_fig
self.recon_img_fig = recon_img_fig
self.csv_path = csv_path
self.total_samples = total_samples
def predict_at_index(df, index, context_len, pred_len, freq):
"""Performs a full prediction cycle for a given sample index."""
# === Data Validation ===
if 'date' not in df.columns:
raise gr.Error("❌ Input CSV must contain a 'date' column.")
try:
df['date'] = pd.to_datetime(df['date'])
except Exception:
raise gr.Error("❌ The 'date' column could not be parsed. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
df = df.sort_values('date').set_index('date')
data = df.select_dtypes(include=np.number).values
nvars = data.shape[1]
total_samples = len(data) - context_len - pred_len + 1
if total_samples <= 0:
raise gr.Error(f"Data is too short. It needs at least {context_len + pred_len} rows, but has {len(data)}.")
# Clamp index to valid range, defaulting to the last sample
index = max(0, min(index, total_samples - 1))
# Normalize data (simple train/test split for mean/std)
train_len = int(len(data) * 0.7)
x_mean = data[:train_len].mean(axis=0, keepdims=True)
x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
data_norm = (data - x_mean) / x_std
# Get data for the selected sample
start_idx = index
x_norm = data_norm[start_idx : start_idx + context_len]
y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len]
x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE)
# Configure model and run prediction
periodicity_list = freq_to_seasonality_list(freq)
periodicity = periodicity_list[0] if periodicity_list else 1
color_list = [i % 3 for i in range(nvars)]
model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
with torch.no_grad():
y_pred, input_image, reconstructed_image, _, _ = model.forward(
x_tensor, export_image=True, color_list=color_list
)
# The model returns a list of all quantile predictions including the median
# The order depends on the model's internal quantile list
# Let's separate median (0.5) from other quantiles
all_preds = dict(zip(model.quantiles, y_pred))
pred_median_norm = all_preds.pop(0.5)[0] # Shape [pred_len, nvars]
pred_quantiles_norm = list(all_preds.values())
pred_quantiles_norm = [q[0] for q in pred_quantiles_norm] # List of [pred_len, nvars]
# Un-normalize results
y_true = y_true_norm * x_std + x_mean
pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean
pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm]
# Create full series for plotting
full_true_context = data[start_idx : start_idx + context_len]
full_true_series = np.concatenate([full_true_context, y_true], axis=0)
# === Visualization ===
ts_fig = visual_ts_with_quantiles(
true_data=full_true_series,
pred_median=pred_median,
pred_quantiles_list=pred_quantiles,
model_quantiles=list(all_preds.keys()), # Quantiles without median
context_len=context_len,
pred_len=pred_len
)
input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
# === Save CSV ===
os.makedirs("outputs", exist_ok=True)
csv_path = "outputs/prediction_result.csv"
time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
result_data = {'date': time_index}
for i in range(nvars):
result_data[f'True_Var{i+1}'] = y_true[:, i]
result_data[f'Pred_Median_Var{i+1}'] = pred_median[:, i]
result_df = pd.DataFrame(result_data)
result_df.to_csv(csv_path, index=False)
return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples)
# ========================
# 5. Gradio Interface
# ========================
def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
"""Wrapper function for the Gradio interface."""
if data_source == "Upload CSV":
if upload_file is None:
raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
df = pd.read_csv(upload_file.name)
else:
df = load_preset_data(data_source)
try:
# Cast inputs to correct types
index, context_len, pred_len = int(index), int(context_len), int(pred_len)
result = predict_at_index(df, index, context_len, pred_len, freq)
# On the first run, set the slider to the last sample
if index >= result.total_samples:
final_index = result.total_samples - 1
else:
final_index = index
return (
result.ts_fig,
result.input_img_fig,
result.recon_img_fig,
result.csv_path,
gr.update(maximum=result.total_samples - 1, value=final_index) # Update slider
)
except Exception as e:
# Handle errors gracefully by displaying them
error_fig = plt.figure(figsize=(10, 5))
plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
plt.axis('off')
plt.close(error_fig)
# Return empty plots and no file
return error_fig, None, None, None, gr.update()
# UI Layout
with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ•°οΈ VisionTS++: Multivariate Time Series Forecasting")
gr.Markdown(
"""
An interactive platform to explore time series forecasting using the VisionTS++ model.
- βœ… **Select** from preset datasets or **upload** your own.
- βœ… **Visualize** predictions with multiple **quantile uncertainty bands**.
- βœ… **Inspect** the model's internal "image" representation of the time series.
- βœ… **Slide** through different samples of the dataset for real-time forecasting.
- βœ… **Download** the prediction results as a CSV file.
"""
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 1. Data & Model Configuration")
data_source = gr.Dropdown(
label="Select Data Source",
choices=["ETTm1 (15-min)", "ETTh1 (1-hour)", "Illness", "Weather", "Upload CSV"],
value="ETTm1 (15-min)"
)
upload_file = gr.File(label="Upload CSV File", file_types=['.csv'], visible=False)
gr.Markdown(
"""
**Upload Rules:**
1. Must be a `.csv` file.
2. Must contain a time column named `date`.
"""
)
context_len = gr.Number(label="Context Length (History)", value=336)
pred_len = gr.Number(label="Prediction Length (Future)", value=96)
freq = gr.Textbox(label="Frequency (e.g., 15Min, H, D)", value="15Min")
run_btn = gr.Button("πŸš€ Run Forecast", variant="primary")
gr.Markdown("### 2. Sample Selection")
# Set a high initial value to default to the last sample on first run.
sample_index = gr.Slider(label="Sample Index", minimum=0, maximum=1000, step=1, value=10000)
with gr.Column(scale=3):
gr.Markdown("### 3. Prediction Results")
ts_plot = gr.Plot(label="Time Series Forecast with Quantile Bands")
with gr.Row():
input_img_plot = gr.Plot(label="Input as Image")
recon_img_plot = gr.Plot(label="Reconstructed Image")
download_csv = gr.File(label="Download Prediction CSV")
# --- Event Handlers ---
# Show/hide upload button based on data source
def toggle_upload_visibility(choice):
return gr.update(visible=(choice == "Upload CSV"))
data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
# Define the inputs and outputs for the forecast function
inputs = [data_source, upload_file, sample_index, context_len, pred_len, freq]
outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
# Trigger forecast on button click
run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
# Trigger forecast when the slider value changes
sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide")
# Examples
gr.Examples(
examples=[
["ETTm1 (15-min)", None, 0, 336, 96, "15Min"],
["Illness", None, 0, 36, 24, "D"],
["Weather", None, 0, 96, 192, "H"]
],
inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
fn=run_forecast, # The button click will trigger the run
outputs=outputs,
label="Click an example to load configuration, then click 'Run Forecast'"
)
demo.launch(debug=True)