Lefei commited on
Commit
c1f4164
·
verified ·
1 Parent(s): 4d14e5a

update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -116
app.py CHANGED
@@ -37,28 +37,48 @@ model = VisionTSpp(
37
  ).to(DEVICE)
38
  print(f"Model loaded on {DEVICE}")
39
 
40
- # Image normalization constants
41
  imagenet_mean = np.array([0.485, 0.456, 0.406])
42
  imagenet_std = np.array([0.229, 0.224, 0.225])
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # ========================
46
  # 可视化函数
47
  # ========================
48
 
49
  def show_image_tensor(image, title='', cur_nvars=1, cur_color_list=None):
50
- """
51
- image: [H, W, 3] tensor
52
- 返回 matplotlib figure
53
- """
54
  cur_image = torch.zeros_like(image)
55
  height_per_var = image.shape[0] // cur_nvars
56
-
57
  for i in range(cur_nvars):
58
  cur_color = cur_color_list[i]
59
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color] = \
60
  (image[i*height_per_var:(i+1)*height_per_var, :, cur_color] * imagenet_std[cur_color] + imagenet_mean[cur_color]) * 255
61
-
62
  cur_image = torch.clamp(cur_image, 0, 255).cpu().int()
63
 
64
  fig, ax = plt.subplots(figsize=(6, 6))
@@ -69,14 +89,18 @@ def show_image_tensor(image, title='', cur_nvars=1, cur_color_list=None):
69
  return fig
70
 
71
 
72
- def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
73
  """
74
- 绘制时间序列预测图(多变量)
 
75
  """
76
  if isinstance(true, torch.Tensor):
77
  true = true.cpu().numpy()
78
- if isinstance(preds, torch.Tensor):
79
- preds = preds.cpu().numpy()
 
 
 
80
 
81
  nvars = true.shape[1]
82
  FIG_WIDTH = 12
@@ -92,55 +116,69 @@ def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
92
 
93
  lookback_len = true.shape[0] - pred_len
94
 
 
 
 
 
95
  for i, ax in enumerate(axes):
96
- ax.plot(true[:, i], label='Ground Truth', color='gray', linewidth=1.8)
97
- if preds is not None:
98
- ax.plot(np.arange(lookback_len, len(true)), preds[lookback_len:, i],
99
- label='Prediction (Median)', color='blue', linewidth=1.8)
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  y_min, y_max = ax.get_ylim()
102
- ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max,
103
- colors='gray', linestyles='--', alpha=0.7, linewidth=1)
104
-
105
  ax.set_yticks([])
106
  ax.set_xticks([])
107
  ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold')
108
 
109
- if preds is not None:
110
- handles, labels = axes[0].get_legend_handles_labels()
111
- fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S})
112
-
113
- if preds is not None:
114
- true_eval = true[-pred_len:]
115
- pred_eval = preds[-pred_len:]
116
- mse = np.mean((true_eval - pred_eval) ** 2)
117
- mae = np.mean(np.abs(true_eval - pred_eval))
118
- fig.suptitle(f'MSE: {mse:.4f}, MAE: {mae:.4f}', fontsize=12, y=0.95)
119
-
120
  plt.subplots_adjust(hspace=0)
121
  plt.close(fig)
122
  return fig
123
 
124
 
125
  # ========================
126
- # 数据预处理与预测
127
  # ========================
 
 
 
 
 
 
 
 
128
 
129
  def predict_at_index(df, index, context_len=960, pred_len=394, freq="15Min"):
130
- """
131
- 在指定 index 处预测
132
- index: index 个样本(从 0 开始)
133
- 返回: (ts_fig, input_img_fig, recon_img_fig)
134
- """
135
- if 'date' in df.columns:
136
- df = df.set_index(pd.to_datetime(df['date'])).drop(columns=['date'])
137
 
138
- data = df.values # [T, nvars]
 
 
 
 
 
 
139
  nvars = data.shape[1]
140
- total_samples = len(data) - context_len - pred_len + 1
141
 
 
142
  if total_samples <= 0:
143
- raise ValueError(f"数据太短,无法构造任何样本(需要至少 {context_len + pred_len} 行)")
144
  if index >= total_samples:
145
  raise ValueError(f"索引越界,最大允许索引为 {total_samples - 1}")
146
 
@@ -150,141 +188,149 @@ def predict_at_index(df, index, context_len=960, pred_len=394, freq="15Min"):
150
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
151
  data_norm = (data - x_mean) / x_std
152
 
153
- # 提取当前样本
154
  start_idx = index
155
- x = data_norm[start_idx:start_idx + context_len] # [context_len, nvars]
156
- y_true = data_norm[start_idx + context_len:start_idx + context_len + pred_len] # [pred_len, nvars]
157
 
158
- # 周期性
159
  periodicity_list = freq_to_seasonality_list(freq)
160
  periodicity = periodicity_list[0] if periodicity_list else 1
161
  color_list = [i % 3 for i in range(nvars)]
162
 
163
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
164
 
165
- x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE) # [1, T, N]
166
 
167
  with torch.no_grad():
168
  y_pred, input_image, reconstructed_image, nvars_out, color_list_out = model.forward(
169
  x_tensor, export_image=True, color_list=color_list
170
  )
171
- y_pred_median = y_pred[0] # median prediction
172
 
173
  # 反归一化
174
- y_true_original = y_true * x_std + x_mean
175
- y_pred_original = y_pred_median[0].cpu().numpy() * x_std + x_mean
 
176
 
177
- # 完整序列(用于可视化)
178
- full_true = np.concatenate([x * x_std + x_mean, y_true_original], axis=0)
179
- full_pred = np.concatenate([x * x_std + x_mean, y_pred_original], axis=0)
180
 
181
  # === 可视化 ===
182
- ts_fig = visual_ts(true=full_true, preds=full_pred, lookback_len_visual=context_len, pred_len=pred_len)
183
-
184
- input_img_fig = show_image_tensor(
185
- input_image[0, 0], title=f'Input Image (Sample {index})', cur_nvars=nvars, cur_color_list=color_list
186
- )
187
- recon_img_fig = show_image_tensor(
188
- reconstructed_image[0, 0], title=f'Reconstructed Image', cur_nvars=nvars, cur_color_list=color_list
189
  )
190
 
191
- return ts_fig, input_img_fig, recon_img_fig, total_samples
 
192
 
 
 
 
 
 
 
 
 
193
 
194
- # ========================
195
- # 默认数据
196
- # ========================
197
- def load_default_data():
198
- data_path = "./datasets/ETTm1.csv"
199
- if not os.path.exists(data_path):
200
- os.makedirs("./datasets", exist_ok=True)
201
- url = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv"
202
- df = pd.read_csv(url)
203
- df.to_csv(data_path, index=False)
204
- else:
205
- df = pd.read_csv(data_path)
206
- return df
207
 
208
 
209
  # ========================
210
- # Gradio 接口
211
  # ========================
212
- def run_forecast(file_input, sample_index, context_len, pred_len, freq):
213
- if file_input is not None:
214
- df = pd.read_csv(file_input.name)
215
- title_prefix = "Uploaded Data"
 
216
  else:
217
- df = load_default_data()
218
- title_prefix = "ETTm1 Dataset"
219
 
220
  try:
221
- ts_fig, input_img_fig, recon_img_fig, total_samples = predict_at_index(
222
- df, int(sample_index), context_len=int(context_len), pred_len=int(pred_len), freq=freq
 
 
 
 
 
223
  )
224
-
225
- # 修改标题
226
- ts_fig.suptitle(f"{title_prefix} - Sample {int(sample_index)}", fontsize=14, y=0.98)
227
-
228
- return ts_fig, input_img_fig, recon_img_fig, gr.update(maximum=total_samples - 1, value=total_samples - 1)
229
  except Exception as e:
230
- # 错误图
231
- def error_fig(msg):
232
- fig, ax = plt.subplots()
233
- ax.text(0.5, 0.5, msg, ha='center', va='center', wrap=True)
234
- ax.axis('off')
235
- plt.close(fig)
236
- return fig
237
-
238
- return error_fig("Error"), error_fig("Error"), error_fig("Error"), gr.Number()
239
 
240
 
241
  # ========================
242
  # Gradio UI
243
  # ========================
244
- with gr.Blocks(title="VisionTS++ 多变量预测") as demo:
245
- gr.Markdown("# 🕰️ VisionTS++ 时间序列预测平台")
246
- gr.Markdown("上传 CSV 或使用默认 ETTm1 数据。滑动选择不同样本进行预测,并查看原始图像表示。")
 
 
 
 
 
 
247
 
248
  with gr.Row():
249
  with gr.Column(scale=2):
250
- file_input = gr.File(label="上传 CSV 文件", file_types=['.csv'])
 
 
 
 
 
251
  context_len = gr.Number(label="历史长度", value=960)
252
  pred_len = gr.Number(label="预测长度", value=394)
253
- freq = gr.Textbox(label="频率 (如 15Min)", value="15Min")
254
  sample_index = gr.Slider(label="样本索引", minimum=0, maximum=100, step=1, value=0)
255
 
256
  with gr.Column(scale=3):
257
- ts_plot = gr.Plot(label="时间序列预测")
258
  with gr.Row():
259
  input_img_plot = gr.Plot(label="Input Image")
260
  recon_img_plot = gr.Plot(label="Reconstructed Image")
 
 
 
261
 
262
- btn = gr.Button("🚀 更新预测")
 
 
263
 
264
- # 点击按钮或滑动条变化时更新
 
 
265
  btn.click(
266
  fn=run_forecast,
267
- inputs=[file_input, sample_index, context_len, pred_len, freq],
268
- outputs=[ts_plot, input_img_plot, recon_img_plot, sample_index]
269
  )
270
 
271
- # 滑动条变化时也触发(但只在点击后才允许滑动)
272
- # 我们用 sample_index.change 依赖于前一次运行的结果
273
- demo.load(
274
- fn=lambda: gr.update(maximum=100, value=0),
275
- outputs=sample_index
276
  )
277
 
278
  # 示例
279
  gr.Examples(
280
  examples=[
281
- [None, 960, 394, "15Min"]
 
282
  ],
283
- inputs=[file_input, context_len, pred_len, freq],
284
- outputs=[ts_plot, input_img_plot, recon_img_plot, sample_index],
285
- fn=lambda f, i, c, p, fr: run_forecast(f, 0, c, p, fr), # 默认 index=0
286
- label="运行默认示例"
287
  )
288
 
289
- # 启动
290
  demo.launch()
 
37
  ).to(DEVICE)
38
  print(f"Model loaded on {DEVICE}")
39
 
40
+ # Image normalization
41
  imagenet_mean = np.array([0.485, 0.456, 0.406])
42
  imagenet_std = np.array([0.229, 0.224, 0.225])
43
 
44
 
45
+ # ========================
46
+ # 预设数据集
47
+ # ========================
48
+ PRESET_DATASETS = {
49
+ "ETTm1 (15-min)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv",
50
+ "ETTh1 (1-hour)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv",
51
+ "Illness": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/illness.csv",
52
+ "Weather": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/weather.csv"
53
+ }
54
+
55
+ # 本地缓存路径
56
+ PRESET_DIR = "./preset_data"
57
+ os.makedirs(PRESET_DIR, exist_ok=True)
58
+
59
+
60
+ def load_preset_data(name):
61
+ url = PRESET_DATASETS[name]
62
+ path = os.path.join(PRESET_DIR, f"{name.split(' ')[0]}.csv")
63
+ if not os.path.exists(path):
64
+ df = pd.read_csv(url)
65
+ df.to_csv(path, index=False)
66
+ else:
67
+ df = pd.read_csv(path)
68
+ return df
69
+
70
+
71
  # ========================
72
  # 可视化函数
73
  # ========================
74
 
75
  def show_image_tensor(image, title='', cur_nvars=1, cur_color_list=None):
 
 
 
 
76
  cur_image = torch.zeros_like(image)
77
  height_per_var = image.shape[0] // cur_nvars
 
78
  for i in range(cur_nvars):
79
  cur_color = cur_color_list[i]
80
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color] = \
81
  (image[i*height_per_var:(i+1)*height_per_var, :, cur_color] * imagenet_std[cur_color] + imagenet_mean[cur_color]) * 255
 
82
  cur_image = torch.clamp(cur_image, 0, 255).cpu().int()
83
 
84
  fig, ax = plt.subplots(figsize=(6, 6))
 
89
  return fig
90
 
91
 
92
+ def visual_ts_with_quantiles(true, pred_median, pred_quantiles, lookback_len_visual=300, pred_len=96, quantile_colors=None):
93
  """
94
+ 可视化中叠加多个 quantile 区间
95
+ pred_quantiles: list of [pred_len, nvars] tensors
96
  """
97
  if isinstance(true, torch.Tensor):
98
  true = true.cpu().numpy()
99
+ if isinstance(pred_median, torch.Tensor):
100
+ pred_median = pred_median.cpu().numpy()
101
+ for i, q in enumerate(pred_quantiles):
102
+ if isinstance(q, torch.Tensor):
103
+ pred_quantiles[i] = q.cpu().numpy()
104
 
105
  nvars = true.shape[1]
106
  FIG_WIDTH = 12
 
116
 
117
  lookback_len = true.shape[0] - pred_len
118
 
119
+ # Quantile 颜色(从外到内)
120
+ if quantile_colors is None:
121
+ quantile_colors = ['lightblue', 'skyblue', 'deepskyblue']
122
+
123
  for i, ax in enumerate(axes):
124
+ ax.plot(true[:, i], label='Ground Truth', color='gray', linewidth=2)
125
+ ax.plot(np.arange(lookback_len, len(true)), pred_median[lookback_len:, i],
126
+ label='Prediction (Median)', color='blue', linewidth=2)
127
+
128
+ # 绘制 quantile 区间(从外到内)
129
+ base = pred_median[lookback_len:]
130
+ quantiles_sorted = sorted(zip(PREDS.quantiles, pred_quantiles), key=lambda x: x[0])
131
+ for (q, pred_q), color in zip(quantiles_sorted, quantile_colors):
132
+ upper = pred_q[lookback_len:]
133
+ lower = 2 * base - upper # 对称假设
134
+ ax.fill_between(
135
+ np.arange(lookback_len, len(true)),
136
+ lower[:, i], upper[:, i],
137
+ color=color, alpha=0.5, label=f'Quantile {q:.1f}'
138
+ )
139
 
140
  y_min, y_max = ax.get_ylim()
141
+ ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
 
 
142
  ax.set_yticks([])
143
  ax.set_xticks([])
144
  ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold')
145
 
146
+ handles, labels = axes[0].get_legend_handles_labels()
147
+ fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S})
 
 
 
 
 
 
 
 
 
148
  plt.subplots_adjust(hspace=0)
149
  plt.close(fig)
150
  return fig
151
 
152
 
153
  # ========================
154
+ # 预测类封装(便于复用)
155
  # ========================
156
+ class PredictionResult:
157
+ def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples):
158
+ self.ts_fig = ts_fig
159
+ self.input_img_fig = input_img_fig
160
+ self.recon_img_fig = recon_img_fig
161
+ self.csv_path = csv_path
162
+ self.total_samples = total_samples
163
+
164
 
165
  def predict_at_index(df, index, context_len=960, pred_len=394, freq="15Min"):
166
+ # === 数据校验 ===
167
+ if 'date' not in df.columns:
168
+ raise ValueError("❌ 数据集必须包含名为 'date' 的时间列。")
 
 
 
 
169
 
170
+ try:
171
+ df['date'] = pd.to_datetime(df['date'])
172
+ except Exception:
173
+ raise ValueError("❌ 'date' 列格式无法解析为时间,请检查日期格式。")
174
+
175
+ df = df.sort_values('date').set_index('date')
176
+ data = df.values
177
  nvars = data.shape[1]
 
178
 
179
+ total_samples = len(data) - context_len - pred_len + 1
180
  if total_samples <= 0:
181
+ raise ValueError(f"数据太短,至少需要 {context_len + pred_len} 行,当前只有 {len(data)} 行。")
182
  if index >= total_samples:
183
  raise ValueError(f"索引越界,最大允许索引为 {total_samples - 1}")
184
 
 
188
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
189
  data_norm = (data - x_mean) / x_std
190
 
 
191
  start_idx = index
192
+ x = data_norm[start_idx:start_idx + context_len]
193
+ y_true = data_norm[start_idx + context_len:start_idx + context_len + pred_len]
194
 
 
195
  periodicity_list = freq_to_seasonality_list(freq)
196
  periodicity = periodicity_list[0] if periodicity_list else 1
197
  color_list = [i % 3 for i in range(nvars)]
198
 
199
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
200
 
201
+ x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE)
202
 
203
  with torch.no_grad():
204
  y_pred, input_image, reconstructed_image, nvars_out, color_list_out = model.forward(
205
  x_tensor, export_image=True, color_list=color_list
206
  )
207
+ pred_median, pred_quantiles = y_pred # list of quantiles
208
 
209
  # 反归一化
210
+ y_true_orig = y_true * x_std + x_mean
211
+ pred_med_orig = pred_median[0].cpu().numpy() * x_std + x_mean
212
+ pred_quants_orig = [q[0].cpu().numpy() * x_std + x_mean for q in pred_quantiles]
213
 
214
+ # 完整序列
215
+ full_true = np.concatenate([x * x_std + x_mean, y_true_orig], axis=0)
216
+ full_pred_med = np.concatenate([x * x_std + x_mean, pred_med_orig], axis=0)
217
 
218
  # === 可视化 ===
219
+ ts_fig = visual_ts_with_quantiles(
220
+ true=full_true,
221
+ pred_median=full_pred_med,
222
+ pred_quantiles=pred_quants_orig,
223
+ lookback_len_visual=context_len,
224
+ pred_len=pred_len
 
225
  )
226
 
227
+ input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
228
+ recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
229
 
230
+ # === 保存 CSV ===
231
+ os.makedirs("outputs", exist_ok=True)
232
+ csv_path = "outputs/prediction_result.csv"
233
+ time_index = df.index[start_idx:start_idx + context_len + pred_len]
234
+ combined = np.concatenate([full_true, full_pred_med], axis=1) # [T, 2*nvars]
235
+ col_names = [f"True_Var{i+1}" for i in range(nvars)] + [f"Pred_Var{i+1}" for i in range(nvars)]
236
+ result_df = pd.DataFrame(combined, index=time_index, columns=col_names)
237
+ result_df.to_csv(csv_path)
238
 
239
+ return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples)
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
 
242
  # ========================
243
+ # Gradio 接口函数
244
  # ========================
245
+ def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
246
+ if data_source == "Upload CSV":
247
+ if upload_file is None:
248
+ raise ValueError("请上传一个 CSV 文件")
249
+ df = pd.read_csv(upload_file.name)
250
  else:
251
+ df = load_preset_data(data_source)
 
252
 
253
  try:
254
+ result = predict_at_index(df, int(index), context_len=int(context_len), pred_len=int(pred_len), freq=freq)
255
+ return (
256
+ result.ts_fig,
257
+ result.input_img_fig,
258
+ result.recon_img_fig,
259
+ result.csv_path,
260
+ gr.update(maximum=result.total_samples - 1, value=min(index, result.total_samples - 1))
261
  )
 
 
 
 
 
262
  except Exception as e:
263
+ fig_err = plt.figure(figsize=(6, 4))
264
+ plt.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', wrap=True)
265
+ plt.axis('off')
266
+ plt.close(fig_err)
267
+ return fig_err, fig_err, fig_err, None, gr.update()
 
 
 
 
268
 
269
 
270
  # ========================
271
  # Gradio UI
272
  # ========================
273
+ with gr.Blocks(title="VisionTS++ 高级预测平台") as demo:
274
+ gr.Markdown("# 🕰️ VisionTS++ 多变量时间序列预测平台")
275
+ gr.Markdown("""
276
+ - ✅ 支持预设数据集或本地上传
277
+ - ✅ 上传规则:必须是 `.csv`,且包含 `date` 列
278
+ - ✅ 显示多分位数预测区间
279
+ - ✅ 支持下载预测结果
280
+ - ✅ 滑动样本实时预测
281
+ """)
282
 
283
  with gr.Row():
284
  with gr.Column(scale=2):
285
+ data_source = gr.Dropdown(
286
+ label="选择数据源",
287
+ choices=["ETTm1 (15-min)", "ETTh1 (1-hour)", "Illness", "Weather", "Upload CSV"],
288
+ value="ETTm1 (15-min)"
289
+ )
290
+ upload_file = gr.File(label="上传 CSV 文件", file_types=['.csv'], visible=False)
291
  context_len = gr.Number(label="历史长度", value=960)
292
  pred_len = gr.Number(label="预测长度", value=394)
293
+ freq = gr.Textbox(label="频率 (如 15Min, H)", value="15Min")
294
  sample_index = gr.Slider(label="样本索引", minimum=0, maximum=100, step=1, value=0)
295
 
296
  with gr.Column(scale=3):
297
+ ts_plot = gr.Plot(label="时间序列预测(含分位数区间)")
298
  with gr.Row():
299
  input_img_plot = gr.Plot(label="Input Image")
300
  recon_img_plot = gr.Plot(label="Reconstructed Image")
301
+ download_csv = gr.File(label="下载预测结果")
302
+
303
+ btn = gr.Button("🚀 初始运行")
304
 
305
+ # 上传切换
306
+ def toggle_upload(choice):
307
+ return gr.update(visible=choice == "Upload CSV")
308
 
309
+ data_source.change(fn=toggle_upload, inputs=data_source, outputs=upload_file)
310
+
311
+ # 初始运行 + 滑动条变化都触发
312
  btn.click(
313
  fn=run_forecast,
314
+ inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
315
+ outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
316
  )
317
 
318
+ # 【关键】滑动条变化时重新预测
319
+ sample_index.change(
320
+ fn=run_forecast,
321
+ inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
322
+ outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
323
  )
324
 
325
  # 示例
326
  gr.Examples(
327
  examples=[
328
+ ["ETTm1 (15-min)", None, 960, 394, "15Min"],
329
+ ["Illness", None, 36, 24, "D"]
330
  ],
331
+ inputs=[data_source, upload_file, context_len, pred_len, freq],
332
+ fn=lambda a,b,c,d,e: run_forecast(a,b,0,c,d,e),
333
+ label="点击运行示例"
 
334
  )
335
 
 
336
  demo.launch()