Lefei commited on
Commit
ead362b
·
verified ·
1 Parent(s): c1f4164

update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -160
app.py CHANGED
@@ -11,39 +11,40 @@ from huggingface_hub import snapshot_download
11
  from visionts import VisionTSpp, freq_to_seasonality_list
12
 
13
  # ========================
14
- # 配置
15
  # ========================
16
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  REPO_ID = "Lefei/VisionTSpp"
18
  LOCAL_DIR = "./hf_models/VisionTSpp"
19
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
20
-
21
  ARCH = 'mae_base'
22
 
23
- # 下载模型
24
  if not os.path.exists(CKPT_PATH):
25
  os.makedirs(LOCAL_DIR, exist_ok=True)
26
  print("Downloading model from Hugging Face Hub...")
27
  snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
28
 
29
- # 加载模型
 
 
30
  model = VisionTSpp(
31
  ARCH,
32
  ckpt_path=CKPT_PATH,
33
- quantile=True,
34
  clip_input=True,
35
  complete_no_clip=False,
36
  color=True
37
  ).to(DEVICE)
38
  print(f"Model loaded on {DEVICE}")
39
 
40
- # Image normalization
41
  imagenet_mean = np.array([0.485, 0.456, 0.406])
42
  imagenet_std = np.array([0.229, 0.224, 0.225])
43
 
44
 
45
  # ========================
46
- # 预设数据集
47
  # ========================
48
  PRESET_DATASETS = {
49
  "ETTm1 (15-min)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv",
@@ -52,15 +53,19 @@ PRESET_DATASETS = {
52
  "Weather": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/weather.csv"
53
  }
54
 
55
- # 本地缓存路径
56
  PRESET_DIR = "./preset_data"
57
  os.makedirs(PRESET_DIR, exist_ok=True)
58
 
59
 
60
  def load_preset_data(name):
 
61
  url = PRESET_DATASETS[name]
62
- path = os.path.join(PRESET_DIR, f"{name.split(' ')[0]}.csv")
 
 
63
  if not os.path.exists(path):
 
64
  df = pd.read_csv(url)
65
  df.to_csv(path, index=False)
66
  else:
@@ -69,91 +74,110 @@ def load_preset_data(name):
69
 
70
 
71
  # ========================
72
- # 可视化函数
73
  # ========================
74
 
75
- def show_image_tensor(image, title='', cur_nvars=1, cur_color_list=None):
 
 
 
 
 
 
 
 
 
76
  cur_image = torch.zeros_like(image)
77
  height_per_var = image.shape[0] // cur_nvars
 
 
78
  for i in range(cur_nvars):
79
- cur_color = cur_color_list[i]
80
- cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color] = \
81
- (image[i*height_per_var:(i+1)*height_per_var, :, cur_color] * imagenet_std[cur_color] + imagenet_mean[cur_color]) * 255
82
- cur_image = torch.clamp(cur_image, 0, 255).cpu().int()
 
 
 
83
 
84
  fig, ax = plt.subplots(figsize=(6, 6))
85
- ax.imshow(cur_image.numpy())
86
  ax.set_title(title, fontsize=14)
87
  ax.axis('off')
88
- plt.close(fig)
 
89
  return fig
90
 
91
 
92
- def visual_ts_with_quantiles(true, pred_median, pred_quantiles, lookback_len_visual=300, pred_len=96, quantile_colors=None):
93
  """
94
- 可视化中叠加多个 quantile 区间
95
- pred_quantiles: list of [pred_len, nvars] tensors
 
96
  """
97
- if isinstance(true, torch.Tensor):
98
- true = true.cpu().numpy()
99
- if isinstance(pred_median, torch.Tensor):
100
- pred_median = pred_median.cpu().numpy()
101
- for i, q in enumerate(pred_quantiles):
102
  if isinstance(q, torch.Tensor):
103
- pred_quantiles[i] = q.cpu().numpy()
104
-
105
- nvars = true.shape[1]
106
- FIG_WIDTH = 12
107
- FIG_HEIGHT_PER_VAR = 1.8
108
- FONT_S = 10
109
-
110
- fig, axes = plt.subplots(
111
- nrows=nvars, ncols=1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True,
112
- gridspec_kw={'height_ratios': [1] * nvars}
113
- )
114
- if nvars == 1:
115
- axes = [axes]
116
-
117
- lookback_len = true.shape[0] - pred_len
118
-
119
- # Quantile 颜色(从外到内)
120
- if quantile_colors is None:
121
- quantile_colors = ['lightblue', 'skyblue', 'deepskyblue']
122
 
123
  for i, ax in enumerate(axes):
124
- ax.plot(true[:, i], label='Ground Truth', color='gray', linewidth=2)
125
- ax.plot(np.arange(lookback_len, len(true)), pred_median[lookback_len:, i],
126
- label='Prediction (Median)', color='blue', linewidth=2)
127
-
128
- # 绘制 quantile 区间(从外到内)
129
- base = pred_median[lookback_len:]
130
- quantiles_sorted = sorted(zip(PREDS.quantiles, pred_quantiles), key=lambda x: x[0])
131
- for (q, pred_q), color in zip(quantiles_sorted, quantile_colors):
132
- upper = pred_q[lookback_len:]
133
- lower = 2 * base - upper # 对称假设
 
 
134
  ax.fill_between(
135
- np.arange(lookback_len, len(true)),
136
- lower[:, i], upper[:, i],
137
- color=color, alpha=0.5, label=f'Quantile {q:.1f}'
138
  )
139
 
140
  y_min, y_max = ax.get_ylim()
141
- ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
142
- ax.set_yticks([])
143
- ax.set_xticks([])
144
- ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold')
145
 
146
  handles, labels = axes[0].get_legend_handles_labels()
147
- fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S})
148
- plt.subplots_adjust(hspace=0)
 
 
149
  plt.close(fig)
150
  return fig
151
 
152
 
153
  # ========================
154
- # 预测类封装(便于复用)
155
  # ========================
156
  class PredictionResult:
 
157
  def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples):
158
  self.ts_fig = ts_fig
159
  self.input_img_fig = input_img_fig
@@ -162,175 +186,213 @@ class PredictionResult:
162
  self.total_samples = total_samples
163
 
164
 
165
- def predict_at_index(df, index, context_len=960, pred_len=394, freq="15Min"):
166
- # === 数据校验 ===
 
167
  if 'date' not in df.columns:
168
- raise ValueError("❌ 数据集必须包含名为 'date' 的时间列。")
169
 
170
  try:
171
  df['date'] = pd.to_datetime(df['date'])
172
  except Exception:
173
- raise ValueError("❌ 'date' 列格式无法解析为时间,请检查日期格式。")
174
 
175
  df = df.sort_values('date').set_index('date')
176
- data = df.values
177
  nvars = data.shape[1]
178
 
179
  total_samples = len(data) - context_len - pred_len + 1
180
  if total_samples <= 0:
181
- raise ValueError(f"数据太短,至少需要 {context_len + pred_len} 行,当前只有 {len(data)} 行。")
182
- if index >= total_samples:
183
- raise ValueError(f"索引越界,最大允许索引为 {total_samples - 1}")
 
184
 
185
- # 归一化
186
  train_len = int(len(data) * 0.7)
187
  x_mean = data[:train_len].mean(axis=0, keepdims=True)
188
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
189
  data_norm = (data - x_mean) / x_std
190
 
 
191
  start_idx = index
192
- x = data_norm[start_idx:start_idx + context_len]
193
- y_true = data_norm[start_idx + context_len:start_idx + context_len + pred_len]
 
194
 
 
195
  periodicity_list = freq_to_seasonality_list(freq)
196
  periodicity = periodicity_list[0] if periodicity_list else 1
197
  color_list = [i % 3 for i in range(nvars)]
198
-
199
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
200
 
201
- x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE)
202
-
203
  with torch.no_grad():
204
- y_pred, input_image, reconstructed_image, nvars_out, color_list_out = model.forward(
205
  x_tensor, export_image=True, color_list=color_list
206
  )
207
- pred_median, pred_quantiles = y_pred # list of quantiles
208
-
209
- # 反归一化
210
- y_true_orig = y_true * x_std + x_mean
211
- pred_med_orig = pred_median[0].cpu().numpy() * x_std + x_mean
212
- pred_quants_orig = [q[0].cpu().numpy() * x_std + x_mean for q in pred_quantiles]
213
-
214
- # 完整序列
215
- full_true = np.concatenate([x * x_std + x_mean, y_true_orig], axis=0)
216
- full_pred_med = np.concatenate([x * x_std + x_mean, pred_med_orig], axis=0)
217
-
218
- # === 可视化 ===
 
 
 
 
 
 
219
  ts_fig = visual_ts_with_quantiles(
220
- true=full_true,
221
- pred_median=full_pred_med,
222
- pred_quantiles=pred_quants_orig,
223
- lookback_len_visual=context_len,
 
224
  pred_len=pred_len
225
  )
226
-
227
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
228
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
229
 
230
- # === 保存 CSV ===
231
  os.makedirs("outputs", exist_ok=True)
232
  csv_path = "outputs/prediction_result.csv"
233
- time_index = df.index[start_idx:start_idx + context_len + pred_len]
234
- combined = np.concatenate([full_true, full_pred_med], axis=1) # [T, 2*nvars]
235
- col_names = [f"True_Var{i+1}" for i in range(nvars)] + [f"Pred_Var{i+1}" for i in range(nvars)]
236
- result_df = pd.DataFrame(combined, index=time_index, columns=col_names)
237
- result_df.to_csv(csv_path)
 
 
 
238
 
239
  return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples)
240
 
241
 
242
  # ========================
243
- # Gradio 接口函数
244
  # ========================
245
  def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
 
246
  if data_source == "Upload CSV":
247
  if upload_file is None:
248
- raise ValueError("请上传一个 CSV 文件")
249
  df = pd.read_csv(upload_file.name)
250
  else:
251
  df = load_preset_data(data_source)
252
 
253
  try:
254
- result = predict_at_index(df, int(index), context_len=int(context_len), pred_len=int(pred_len), freq=freq)
 
 
 
 
 
 
 
 
 
 
255
  return (
256
  result.ts_fig,
257
  result.input_img_fig,
258
  result.recon_img_fig,
259
  result.csv_path,
260
- gr.update(maximum=result.total_samples - 1, value=min(index, result.total_samples - 1))
261
  )
262
  except Exception as e:
263
- fig_err = plt.figure(figsize=(6, 4))
264
- plt.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', wrap=True)
 
265
  plt.axis('off')
266
- plt.close(fig_err)
267
- return fig_err, fig_err, fig_err, None, gr.update()
268
-
269
-
270
- # ========================
271
- # Gradio UI
272
- # ========================
273
- with gr.Blocks(title="VisionTS++ 高级预测平台") as demo:
274
- gr.Markdown("# 🕰️ VisionTS++ 多变量时间序列预测平台")
275
- gr.Markdown("""
276
- - 支持预设数据集或本地上传
277
- - ✅ 上传规则:必须是 `.csv`,且包含 `date`
278
- - ✅ 显示多分位数预测区间
279
- - ✅ 支持下载预测结果
280
- - ✅ 滑动样本实时预测
281
- """)
282
-
 
 
283
  with gr.Row():
284
- with gr.Column(scale=2):
 
285
  data_source = gr.Dropdown(
286
- label="选择数据源",
287
  choices=["ETTm1 (15-min)", "ETTh1 (1-hour)", "Illness", "Weather", "Upload CSV"],
288
  value="ETTm1 (15-min)"
289
  )
290
- upload_file = gr.File(label="上传 CSV 文件", file_types=['.csv'], visible=False)
291
- context_len = gr.Number(label="历史长度", value=960)
292
- pred_len = gr.Number(label="预测长度", value=394)
293
- freq = gr.Textbox(label="频率 (如 15Min, H)", value="15Min")
294
- sample_index = gr.Slider(label="样本索引", minimum=0, maximum=100, step=1, value=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  with gr.Column(scale=3):
297
- ts_plot = gr.Plot(label="时间序列预测(含分位数区间)")
 
298
  with gr.Row():
299
- input_img_plot = gr.Plot(label="Input Image")
300
  recon_img_plot = gr.Plot(label="Reconstructed Image")
301
- download_csv = gr.File(label="下载预测结果")
302
 
303
- btn = gr.Button("🚀 初始运行")
 
 
 
 
304
 
305
- # 上传切换
306
- def toggle_upload(choice):
307
- return gr.update(visible=choice == "Upload CSV")
308
 
309
- data_source.change(fn=toggle_upload, inputs=data_source, outputs=upload_file)
 
 
310
 
311
- # 初始运行 + 滑动条变化都触发
312
- btn.click(
313
- fn=run_forecast,
314
- inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
315
- outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
316
- )
317
 
318
- # 【关键】滑动条变化时重新预测
319
- sample_index.change(
320
- fn=run_forecast,
321
- inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
322
- outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
323
- )
324
-
325
- # 示例
326
  gr.Examples(
327
  examples=[
328
- ["ETTm1 (15-min)", None, 960, 394, "15Min"],
329
- ["Illness", None, 36, 24, "D"]
 
330
  ],
331
- inputs=[data_source, upload_file, context_len, pred_len, freq],
332
- fn=lambda a,b,c,d,e: run_forecast(a,b,0,c,d,e),
333
- label="点击运行示例"
 
334
  )
335
 
336
- demo.launch()
 
11
  from visionts import VisionTSpp, freq_to_seasonality_list
12
 
13
  # ========================
14
+ # 1. Configuration
15
  # ========================
16
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  REPO_ID = "Lefei/VisionTSpp"
18
  LOCAL_DIR = "./hf_models/VisionTSpp"
19
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
 
20
  ARCH = 'mae_base'
21
 
22
+ # Download the model from Hugging Face Hub
23
  if not os.path.exists(CKPT_PATH):
24
  os.makedirs(LOCAL_DIR, exist_ok=True)
25
  print("Downloading model from Hugging Face Hub...")
26
  snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
27
 
28
+ # Load the model
29
+ # NOTE: We assume the model was trained to predict these specific quantiles
30
+ QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
31
  model = VisionTSpp(
32
  ARCH,
33
  ckpt_path=CKPT_PATH,
34
+ quantiles=QUANTILES, # Set the quantiles the model should predict
35
  clip_input=True,
36
  complete_no_clip=False,
37
  color=True
38
  ).to(DEVICE)
39
  print(f"Model loaded on {DEVICE}")
40
 
41
+ # Image normalization constants
42
  imagenet_mean = np.array([0.485, 0.456, 0.406])
43
  imagenet_std = np.array([0.229, 0.224, 0.225])
44
 
45
 
46
  # ========================
47
+ # 2. Preset Datasets
48
  # ========================
49
  PRESET_DATASETS = {
50
  "ETTm1 (15-min)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv",
 
53
  "Weather": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/weather.csv"
54
  }
55
 
56
+ # Local cache path for presets
57
  PRESET_DIR = "./preset_data"
58
  os.makedirs(PRESET_DIR, exist_ok=True)
59
 
60
 
61
  def load_preset_data(name):
62
+ """Loads a preset dataset, caching it locally."""
63
  url = PRESET_DATASETS[name]
64
+ # Sanitize name for file path
65
+ sanitized_name = name.split(' ')[0]
66
+ path = os.path.join(PRESET_DIR, f"{sanitized_name}.csv")
67
  if not os.path.exists(path):
68
+ print(f"Downloading preset dataset: {name}...")
69
  df = pd.read_csv(url)
70
  df.to_csv(path, index=False)
71
  else:
 
74
 
75
 
76
  # ========================
77
+ # 3. Visualization Functions
78
  # ========================
79
 
80
+ def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
81
+ """
82
+ Visualizes a tensor as an image, handling un-normalization.
83
+ Returns a matplotlib Figure object for Gradio.
84
+ """
85
+ if image_tensor is None: return None
86
+ # image_tensor is [C, H, W] but we expect [H, W, C] for imshow
87
+ # The model outputs [1, 1, C, H, W], after indexing it's [C, H, W]
88
+ image = image_tensor.permute(1, 2, 0).cpu() # H, W, C
89
+
90
  cur_image = torch.zeros_like(image)
91
  height_per_var = image.shape[0] // cur_nvars
92
+
93
+ # Assign colors to variables for visualization
94
  for i in range(cur_nvars):
95
+ cur_color_idx = cur_color_list[i]
96
+ var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
97
+ # Un-normalize only the used color channel
98
+ unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
99
+ cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
100
+
101
+ cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
102
 
103
  fig, ax = plt.subplots(figsize=(6, 6))
104
+ ax.imshow(cur_image)
105
  ax.set_title(title, fontsize=14)
106
  ax.axis('off')
107
+ plt.tight_layout()
108
+ plt.close(fig) # Close to prevent double display
109
  return fig
110
 
111
 
112
+ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
113
  """
114
+ Visualizes time series with multiple quantile bands.
115
+ pred_quantiles_list: list of tensors, one for each quantile.
116
+ model_quantiles: The list of quantiles values, e.g., [0.1, 0.2, ..., 0.9].
117
  """
118
+ if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy()
119
+ if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy()
120
+ for i, q in enumerate(pred_quantiles_list):
 
 
121
  if isinstance(q, torch.Tensor):
122
+ pred_quantiles_list[i] = q.cpu().numpy()
123
+
124
+ nvars = true_data.shape[1]
125
+ FIG_WIDTH = 15
126
+ FIG_HEIGHT_PER_VAR = 2.0
127
+
128
+ fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
129
+ if nvars == 1: axes = [axes]
130
+
131
+ # Combine quantiles and predictions
132
+ sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
133
+
134
+ # Filter out the median to get pairs for bands
135
+ quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
136
+ quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
137
+
138
+ num_bands = len(quantile_preds) // 2
139
+ # Colors from light to dark for bands from widest to narrowest
140
+ quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1]
141
 
142
  for i, ax in enumerate(axes):
143
+ # Plot ground truth and median prediction
144
+ ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
145
+ pred_range = np.arange(context_len, context_len + pred_len)
146
+ ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
147
+
148
+ # Plot quantile bands
149
+ for j in range(num_bands):
150
+ lower_quantile_pred = quantile_preds[j][:, i]
151
+ upper_quantile_pred = quantile_preds[-(j+1)][:, i]
152
+ q_low = quantile_vals[j]
153
+ q_high = quantile_vals[-(j+1)]
154
+
155
  ax.fill_between(
156
+ pred_range, lower_quantile_pred, upper_quantile_pred,
157
+ color=quantile_colors[j], alpha=0.7,
158
+ label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile'
159
  )
160
 
161
  y_min, y_max = ax.get_ylim()
162
+ ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
163
+ ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
164
+ ax.grid(True, which='both', linestyle='--', linewidth=0.5)
165
+ ax.margins(x=0)
166
 
167
  handles, labels = axes[0].get_legend_handles_labels()
168
+ # Create a unique legend
169
+ unique_labels = dict(zip(labels, handles))
170
+ fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
171
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
172
  plt.close(fig)
173
  return fig
174
 
175
 
176
  # ========================
177
+ # 4. Prediction Logic
178
  # ========================
179
  class PredictionResult:
180
+ """A data class to hold prediction results for easier handling."""
181
  def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples):
182
  self.ts_fig = ts_fig
183
  self.input_img_fig = input_img_fig
 
186
  self.total_samples = total_samples
187
 
188
 
189
+ def predict_at_index(df, index, context_len, pred_len, freq):
190
+ """Performs a full prediction cycle for a given sample index."""
191
+ # === Data Validation ===
192
  if 'date' not in df.columns:
193
+ raise gr.Error("❌ Input CSV must contain a 'date' column.")
194
 
195
  try:
196
  df['date'] = pd.to_datetime(df['date'])
197
  except Exception:
198
+ raise gr.Error("❌ The 'date' column could not be parsed. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
199
 
200
  df = df.sort_values('date').set_index('date')
201
+ data = df.select_dtypes(include=np.number).values
202
  nvars = data.shape[1]
203
 
204
  total_samples = len(data) - context_len - pred_len + 1
205
  if total_samples <= 0:
206
+ raise gr.Error(f"Data is too short. It needs at least {context_len + pred_len} rows, but has {len(data)}.")
207
+
208
+ # Clamp index to valid range, defaulting to the last sample
209
+ index = max(0, min(index, total_samples - 1))
210
 
211
+ # Normalize data (simple train/test split for mean/std)
212
  train_len = int(len(data) * 0.7)
213
  x_mean = data[:train_len].mean(axis=0, keepdims=True)
214
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
215
  data_norm = (data - x_mean) / x_std
216
 
217
+ # Get data for the selected sample
218
  start_idx = index
219
+ x_norm = data_norm[start_idx : start_idx + context_len]
220
+ y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len]
221
+ x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE)
222
 
223
+ # Configure model and run prediction
224
  periodicity_list = freq_to_seasonality_list(freq)
225
  periodicity = periodicity_list[0] if periodicity_list else 1
226
  color_list = [i % 3 for i in range(nvars)]
 
227
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
228
 
 
 
229
  with torch.no_grad():
230
+ y_pred, input_image, reconstructed_image, _, _ = model.forward(
231
  x_tensor, export_image=True, color_list=color_list
232
  )
233
+ # The model returns a list of all quantile predictions including the median
234
+ # The order depends on the model's internal quantile list
235
+ # Let's separate median (0.5) from other quantiles
236
+ all_preds = dict(zip(model.quantiles, y_pred))
237
+ pred_median_norm = all_preds.pop(0.5)[0] # Shape [pred_len, nvars]
238
+ pred_quantiles_norm = list(all_preds.values())
239
+ pred_quantiles_norm = [q[0] for q in pred_quantiles_norm] # List of [pred_len, nvars]
240
+
241
+ # Un-normalize results
242
+ y_true = y_true_norm * x_std + x_mean
243
+ pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean
244
+ pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm]
245
+
246
+ # Create full series for plotting
247
+ full_true_context = data[start_idx : start_idx + context_len]
248
+ full_true_series = np.concatenate([full_true_context, y_true], axis=0)
249
+
250
+ # === Visualization ===
251
  ts_fig = visual_ts_with_quantiles(
252
+ true_data=full_true_series,
253
+ pred_median=pred_median,
254
+ pred_quantiles_list=pred_quantiles,
255
+ model_quantiles=list(all_preds.keys()), # Quantiles without median
256
+ context_len=context_len,
257
  pred_len=pred_len
258
  )
 
259
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
260
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
261
 
262
+ # === Save CSV ===
263
  os.makedirs("outputs", exist_ok=True)
264
  csv_path = "outputs/prediction_result.csv"
265
+ time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
266
+
267
+ result_data = {'date': time_index}
268
+ for i in range(nvars):
269
+ result_data[f'True_Var{i+1}'] = y_true[:, i]
270
+ result_data[f'Pred_Median_Var{i+1}'] = pred_median[:, i]
271
+ result_df = pd.DataFrame(result_data)
272
+ result_df.to_csv(csv_path, index=False)
273
 
274
  return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples)
275
 
276
 
277
  # ========================
278
+ # 5. Gradio Interface
279
  # ========================
280
  def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
281
+ """Wrapper function for the Gradio interface."""
282
  if data_source == "Upload CSV":
283
  if upload_file is None:
284
+ raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
285
  df = pd.read_csv(upload_file.name)
286
  else:
287
  df = load_preset_data(data_source)
288
 
289
  try:
290
+ # Cast inputs to correct types
291
+ index, context_len, pred_len = int(index), int(context_len), int(pred_len)
292
+
293
+ result = predict_at_index(df, index, context_len, pred_len, freq)
294
+
295
+ # On the first run, set the slider to the last sample
296
+ if index >= result.total_samples:
297
+ final_index = result.total_samples - 1
298
+ else:
299
+ final_index = index
300
+
301
  return (
302
  result.ts_fig,
303
  result.input_img_fig,
304
  result.recon_img_fig,
305
  result.csv_path,
306
+ gr.update(maximum=result.total_samples - 1, value=final_index) # Update slider
307
  )
308
  except Exception as e:
309
+ # Handle errors gracefully by displaying them
310
+ error_fig = plt.figure(figsize=(10, 5))
311
+ plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
312
  plt.axis('off')
313
+ plt.close(error_fig)
314
+ # Return empty plots and no file
315
+ return error_fig, None, None, None, gr.update()
316
+
317
+
318
+ # UI Layout
319
+ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
320
+ gr.Markdown("# 🕰️ VisionTS++: Multivariate Time Series Forecasting")
321
+ gr.Markdown(
322
+ """
323
+ An interactive platform to explore time series forecasting using the VisionTS++ model.
324
+ - ✅ **Select** from preset datasets or **upload** your own.
325
+ - ✅ **Visualize** predictions with multiple **quantile uncertainty bands**.
326
+ - ✅ **Inspect** the model's internal "image" representation of the time series.
327
+ - ✅ **Slide** through different samples of the dataset for real-time forecasting.
328
+ - ✅ **Download** the prediction results as a CSV file.
329
+ """
330
+ )
331
+
332
  with gr.Row():
333
+ with gr.Column(scale=1, min_width=300):
334
+ gr.Markdown("### 1. Data & Model Configuration")
335
  data_source = gr.Dropdown(
336
+ label="Select Data Source",
337
  choices=["ETTm1 (15-min)", "ETTh1 (1-hour)", "Illness", "Weather", "Upload CSV"],
338
  value="ETTm1 (15-min)"
339
  )
340
+ upload_file = gr.File(label="Upload CSV File", file_types=['.csv'], visible=False)
341
+ gr.Markdown(
342
+ """
343
+ **Upload Rules:**
344
+ 1. Must be a `.csv` file.
345
+ 2. Must contain a time column named `date`.
346
+ """
347
+ )
348
+
349
+ context_len = gr.Number(label="Context Length (History)", value=336)
350
+ pred_len = gr.Number(label="Prediction Length (Future)", value=96)
351
+ freq = gr.Textbox(label="Frequency (e.g., 15Min, H, D)", value="15Min")
352
+
353
+ run_btn = gr.Button("🚀 Run Forecast", variant="primary")
354
+
355
+ gr.Markdown("### 2. Sample Selection")
356
+ # Set a high initial value to default to the last sample on first run.
357
+ sample_index = gr.Slider(label="Sample Index", minimum=0, maximum=1000, step=1, value=10000)
358
 
359
  with gr.Column(scale=3):
360
+ gr.Markdown("### 3. Prediction Results")
361
+ ts_plot = gr.Plot(label="Time Series Forecast with Quantile Bands")
362
  with gr.Row():
363
+ input_img_plot = gr.Plot(label="Input as Image")
364
  recon_img_plot = gr.Plot(label="Reconstructed Image")
365
+ download_csv = gr.File(label="Download Prediction CSV")
366
 
367
+ # --- Event Handlers ---
368
+
369
+ # Show/hide upload button based on data source
370
+ def toggle_upload_visibility(choice):
371
+ return gr.update(visible=(choice == "Upload CSV"))
372
 
373
+ data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
 
 
374
 
375
+ # Define the inputs and outputs for the forecast function
376
+ inputs = [data_source, upload_file, sample_index, context_len, pred_len, freq]
377
+ outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
378
 
379
+ # Trigger forecast on button click
380
+ run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
 
 
 
 
381
 
382
+ # Trigger forecast when the slider value changes
383
+ sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide")
384
+
385
+ # Examples
 
 
 
 
386
  gr.Examples(
387
  examples=[
388
+ ["ETTm1 (15-min)", None, 0, 336, 96, "15Min"],
389
+ ["Illness", None, 0, 36, 24, "D"],
390
+ ["Weather", None, 0, 96, 192, "H"]
391
  ],
392
+ inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
393
+ fn=run_forecast, # The button click will trigger the run
394
+ outputs=outputs,
395
+ label="Click an example to load configuration, then click 'Run Forecast'"
396
  )
397
 
398
+ demo.launch(debug=True)