Lefei commited on
Commit
6ea6ac4
Β·
verified Β·
1 Parent(s): 9f4d52c

update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -137
app.py CHANGED
@@ -26,12 +26,11 @@ if not os.path.exists(CKPT_PATH):
26
  snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
27
 
28
  # Load the model
29
- # NOTE: We assume the model was trained to predict these specific quantiles
30
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
31
  model = VisionTSpp(
32
  ARCH,
33
  ckpt_path=CKPT_PATH,
34
- # quantiles=QUANTILES, # Set the quantiles the model should predict
35
  clip_input=True,
36
  complete_no_clip=False,
37
  color=True
@@ -44,77 +43,53 @@ imagenet_std = np.array([0.229, 0.224, 0.225])
44
 
45
 
46
  # ========================
47
- # 2. Preset Datasets
48
  # ========================
 
 
 
 
 
49
  PRESET_DATASETS = {
50
- "ETTm1 (15-min)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv",
51
- "ETTh1 (1-hour)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv",
52
- "Illness": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/illness.csv",
53
- "Weather": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/weather.csv"
 
 
54
  }
55
 
56
- # Local cache path for presets
57
- PRESET_DIR = "./preset_data"
58
- os.makedirs(PRESET_DIR, exist_ok=True)
59
-
60
-
61
  def load_preset_data(name):
62
- """Loads a preset dataset, caching it locally."""
63
- url = PRESET_DATASETS[name]
64
- # Sanitize name for file path
65
- sanitized_name = name.split(' ')[0]
66
- path = os.path.join(PRESET_DIR, f"{sanitized_name}.csv")
67
  if not os.path.exists(path):
68
- print(f"Downloading preset dataset: {name}...")
69
- df = pd.read_csv(url)
70
- df.to_csv(path, index=False)
71
- else:
72
- df = pd.read_csv(path)
73
- return df
74
 
75
 
76
  # ========================
77
- # 3. Visualization Functions
78
  # ========================
79
-
80
  def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
81
- """
82
- Visualizes a tensor as an image, handling un-normalization.
83
- Returns a matplotlib Figure object for Gradio.
84
- """
85
  if image_tensor is None: return None
86
- # image_tensor is [C, H, W] but we expect [H, W, C] for imshow
87
- # The model outputs [1, 1, C, H, W], after indexing it's [C, H, W]
88
- image = image_tensor.permute(1, 2, 0).cpu() # H, W, C
89
-
90
  cur_image = torch.zeros_like(image)
91
  height_per_var = image.shape[0] // cur_nvars
92
-
93
- # Assign colors to variables for visualization
94
  for i in range(cur_nvars):
95
  cur_color_idx = cur_color_list[i]
96
  var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
97
- # Un-normalize only the used color channel
98
  unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
99
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
100
-
101
  cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
102
-
103
  fig, ax = plt.subplots(figsize=(6, 6))
104
  ax.imshow(cur_image)
105
  ax.set_title(title, fontsize=14)
106
  ax.axis('off')
107
  plt.tight_layout()
108
- plt.close(fig) # Close to prevent double display
109
  return fig
110
 
111
-
112
  def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
113
- """
114
- Visualizes time series with multiple quantile bands.
115
- pred_quantiles_list: list of tensors, one for each quantile.
116
- model_quantiles: The list of quantiles values, e.g., [0.1, 0.2, ..., 0.9].
117
- """
118
  if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy()
119
  if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy()
120
  for i, q in enumerate(pred_quantiles_list):
@@ -122,41 +97,25 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
122
  pred_quantiles_list[i] = q.cpu().numpy()
123
 
124
  nvars = true_data.shape[1]
125
- FIG_WIDTH = 15
126
- FIG_HEIGHT_PER_VAR = 2.0
127
-
128
  fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
129
  if nvars == 1: axes = [axes]
130
 
131
- # Combine quantiles and predictions
132
  sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
133
-
134
- # Filter out the median to get pairs for bands
135
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
136
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
137
-
138
  num_bands = len(quantile_preds) // 2
139
- # Colors from light to dark for bands from widest to narrowest
140
  quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1]
141
 
142
  for i, ax in enumerate(axes):
143
- # Plot ground truth and median prediction
144
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
145
  pred_range = np.arange(context_len, context_len + pred_len)
146
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
147
 
148
- # Plot quantile bands
149
  for j in range(num_bands):
150
- lower_quantile_pred = quantile_preds[j][:, i]
151
- upper_quantile_pred = quantile_preds[-(j+1)][:, i]
152
- q_low = quantile_vals[j]
153
- q_high = quantile_vals[-(j+1)]
154
-
155
- ax.fill_between(
156
- pred_range, lower_quantile_pred, upper_quantile_pred,
157
- color=quantile_colors[j], alpha=0.7,
158
- label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile'
159
- )
160
 
161
  y_min, y_max = ax.get_ylim()
162
  ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
@@ -165,7 +124,6 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
165
  ax.margins(x=0)
166
 
167
  handles, labels = axes[0].get_legend_handles_labels()
168
- # Create a unique legend
169
  unique_labels = dict(zip(labels, handles))
170
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
171
  plt.tight_layout(rect=[0, 0, 1, 0.95])
@@ -177,27 +135,33 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
177
  # 4. Prediction Logic
178
  # ========================
179
  class PredictionResult:
180
- """A data class to hold prediction results for easier handling."""
181
- def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples):
182
  self.ts_fig = ts_fig
183
  self.input_img_fig = input_img_fig
184
  self.recon_img_fig = recon_img_fig
185
  self.csv_path = csv_path
186
  self.total_samples = total_samples
 
187
 
188
-
189
- def predict_at_index(df, index, context_len, pred_len, freq):
190
- """Performs a full prediction cycle for a given sample index."""
191
- # === Data Validation ===
192
  if 'date' not in df.columns:
193
  raise gr.Error("❌ Input CSV must contain a 'date' column.")
194
 
195
  try:
196
  df['date'] = pd.to_datetime(df['date'])
197
- except Exception:
198
- raise gr.Error("❌ The 'date' column could not be parsed. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
 
 
 
 
 
 
 
 
 
199
 
200
- df = df.sort_values('date').set_index('date')
201
  data = df.select_dtypes(include=np.number).values
202
  nvars = data.shape[1]
203
 
@@ -205,23 +169,20 @@ def predict_at_index(df, index, context_len, pred_len, freq):
205
  if total_samples <= 0:
206
  raise gr.Error(f"Data is too short. It needs at least {context_len + pred_len} rows, but has {len(data)}.")
207
 
208
- # Clamp index to valid range, defaulting to the last sample
209
  index = max(0, min(index, total_samples - 1))
210
 
211
- # Normalize data (simple train/test split for mean/std)
212
  train_len = int(len(data) * 0.7)
213
  x_mean = data[:train_len].mean(axis=0, keepdims=True)
214
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
215
  data_norm = (data - x_mean) / x_std
216
 
217
- # Get data for the selected sample
218
  start_idx = index
219
  x_norm = data_norm[start_idx : start_idx + context_len]
220
  y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len]
221
  x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE)
222
 
223
- # Configure model and run prediction
224
- periodicity_list = freq_to_seasonality_list(freq)
225
  periodicity = periodicity_list[0] if periodicity_list else 1
226
  color_list = [i % 3 for i in range(nvars)]
227
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
@@ -230,40 +191,28 @@ def predict_at_index(df, index, context_len, pred_len, freq):
230
  y_pred, input_image, reconstructed_image, _, _ = model.forward(
231
  x_tensor, export_image=True, color_list=color_list
232
  )
233
- # The model returns a list of all quantile predictions including the median
234
- # The order depends on the model's internal quantile list
235
- # Let's separate median (0.5) from other quantiles
236
  all_preds = dict(zip(model.quantiles, y_pred))
237
- pred_median_norm = all_preds.pop(0.5)[0] # Shape [pred_len, nvars]
238
- pred_quantiles_norm = list(all_preds.values())
239
- pred_quantiles_norm = [q[0] for q in pred_quantiles_norm] # List of [pred_len, nvars]
240
 
241
- # Un-normalize results
242
  y_true = y_true_norm * x_std + x_mean
243
  pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean
244
  pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm]
245
 
246
- # Create full series for plotting
247
  full_true_context = data[start_idx : start_idx + context_len]
248
  full_true_series = np.concatenate([full_true_context, y_true], axis=0)
249
 
250
- # === Visualization ===
251
  ts_fig = visual_ts_with_quantiles(
252
- true_data=full_true_series,
253
- pred_median=pred_median,
254
- pred_quantiles_list=pred_quantiles,
255
- model_quantiles=list(all_preds.keys()), # Quantiles without median
256
- context_len=context_len,
257
- pred_len=pred_len
258
  )
259
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
260
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
261
 
262
- # === Save CSV ===
263
  os.makedirs("outputs", exist_ok=True)
264
  csv_path = "outputs/prediction_result.csv"
265
  time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
266
-
267
  result_data = {'date': time_index}
268
  for i in range(nvars):
269
  result_data[f'True_Var{i+1}'] = y_true[:, i]
@@ -271,14 +220,13 @@ def predict_at_index(df, index, context_len, pred_len, freq):
271
  result_df = pd.DataFrame(result_data)
272
  result_df.to_csv(csv_path, index=False)
273
 
274
- return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples)
275
 
276
 
277
  # ========================
278
  # 5. Gradio Interface
279
  # ========================
280
- def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
281
- """Wrapper function for the Gradio interface."""
282
  if data_source == "Upload CSV":
283
  if upload_file is None:
284
  raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
@@ -287,12 +235,9 @@ def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
287
  df = load_preset_data(data_source)
288
 
289
  try:
290
- # Cast inputs to correct types
291
  index, context_len, pred_len = int(index), int(context_len), int(pred_len)
292
-
293
- result = predict_at_index(df, index, context_len, pred_len, freq)
294
 
295
- # On the first run, set the slider to the last sample
296
  if index >= result.total_samples:
297
  final_index = result.total_samples - 1
298
  else:
@@ -303,17 +248,15 @@ def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
303
  result.input_img_fig,
304
  result.recon_img_fig,
305
  result.csv_path,
306
- gr.update(maximum=result.total_samples - 1, value=final_index) # Update slider
 
307
  )
308
  except Exception as e:
309
- # Handle errors gracefully by displaying them
310
  error_fig = plt.figure(figsize=(10, 5))
311
  plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
312
  plt.axis('off')
313
  plt.close(error_fig)
314
- # Return empty plots and no file
315
- return error_fig, None, None, None, gr.update()
316
-
317
 
318
  # UI Layout
319
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
@@ -321,9 +264,9 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
321
  gr.Markdown(
322
  """
323
  An interactive platform to explore time series forecasting using the VisionTS++ model.
324
- - βœ… **Select** from preset datasets or **upload** your own.
 
325
  - βœ… **Visualize** predictions with multiple **quantile uncertainty bands**.
326
- - βœ… **Inspect** the model's internal "image" representation of the time series.
327
  - βœ… **Slide** through different samples of the dataset for real-time forecasting.
328
  - βœ… **Download** the prediction results as a CSV file.
329
  """
@@ -334,26 +277,26 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
334
  gr.Markdown("### 1. Data & Model Configuration")
335
  data_source = gr.Dropdown(
336
  label="Select Data Source",
337
- choices=["ETTm1 (15-min)", "ETTh1 (1-hour)", "Illness", "Weather", "Upload CSV"],
338
- value="ETTm1 (15-min)"
339
  )
340
  upload_file = gr.File(label="Upload CSV File", file_types=['.csv'], visible=False)
341
  gr.Markdown(
342
  """
343
  **Upload Rules:**
344
  1. Must be a `.csv` file.
345
- 2. Must contain a time column named `date`.
346
  """
347
  )
348
 
349
  context_len = gr.Number(label="Context Length (History)", value=336)
350
  pred_len = gr.Number(label="Prediction Length (Future)", value=96)
351
- freq = gr.Textbox(label="Frequency (e.g., 15Min, H, D)", value="15Min")
 
352
 
353
  run_btn = gr.Button("πŸš€ Run Forecast", variant="primary")
354
 
355
  gr.Markdown("### 2. Sample Selection")
356
- # Set a high initial value to default to the last sample on first run.
357
  sample_index = gr.Slider(label="Sample Index", minimum=0, maximum=1000, step=1, value=10000)
358
 
359
  with gr.Column(scale=3):
@@ -365,35 +308,19 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
365
  download_csv = gr.File(label="Download Prediction CSV")
366
 
367
  # --- Event Handlers ---
368
-
369
- # Show/hide upload button based on data source
370
  def toggle_upload_visibility(choice):
371
  return gr.update(visible=(choice == "Upload CSV"))
372
 
373
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
374
 
375
- # Define the inputs and outputs for the forecast function
376
- inputs = [data_source, upload_file, sample_index, context_len, pred_len, freq]
377
- outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
378
 
379
- # Trigger forecast on button click
380
  run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
381
-
382
- # Trigger forecast when the slider value changes
383
  sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide")
384
 
385
- # Examples
386
- gr.Examples(
387
- examples=[
388
- ["ETTm1 (15-min)", None, 0, 336, 96, "15Min"],
389
- ["Illness", None, 0, 36, 24, "D"],
390
- ["Weather", None, 0, 96, 192, "H"]
391
- ],
392
- inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
393
- fn=run_forecast, # The button click will trigger the run
394
- outputs=outputs,
395
- label="Click an example to load configuration, then click 'Run Forecast'"
396
- )
397
 
398
  demo.launch(debug=True)
399
-
 
26
  snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
27
 
28
  # Load the model
 
29
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
30
  model = VisionTSpp(
31
  ARCH,
32
  ckpt_path=CKPT_PATH,
33
+ quantiles=QUANTILES,
34
  clip_input=True,
35
  complete_no_clip=False,
36
  color=True
 
43
 
44
 
45
  # ========================
46
+ # 2. Preset Datasets (Now Loaded Locally)
47
  # ========================
48
+ # This dictionary maps user-friendly names to local file paths
49
+ # ASSUMPTION: These files exist in a 'datasets' subfolder
50
+
51
+ # data_dir = "./datasets/"
52
+ data_dir = "./"
53
  PRESET_DATASETS = {
54
+ "ETTm1": data_dir + "ETTm1.csv",
55
+ "ETTm2": data_dir + "ETTm2.csv",
56
+ "ETTh1": data_dir + "ETTh1.csv",
57
+ "ETTh2": data_dir + "ETTh2.csv",
58
+ "Illness": data_dir + "Illness.csv",
59
+ "Weather": data_dir + "Weather.csv",
60
  }
61
 
 
 
 
 
 
62
  def load_preset_data(name):
63
+ """Loads a preset dataset from a local path."""
64
+ path = PRESET_DATASETS[name]
 
 
 
65
  if not os.path.exists(path):
66
+ raise FileNotFoundError(f"Preset dataset file not found: {path}. Make sure it's uploaded to the 'datasets' folder.")
67
+ return pd.read_csv(path)
 
 
 
 
68
 
69
 
70
  # ========================
71
+ # 3. Visualization Functions (No changes needed)
72
  # ========================
 
73
  def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
 
 
 
 
74
  if image_tensor is None: return None
75
+ image = image_tensor.permute(1, 2, 0).cpu()
 
 
 
76
  cur_image = torch.zeros_like(image)
77
  height_per_var = image.shape[0] // cur_nvars
 
 
78
  for i in range(cur_nvars):
79
  cur_color_idx = cur_color_list[i]
80
  var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
 
81
  unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
82
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
 
83
  cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
 
84
  fig, ax = plt.subplots(figsize=(6, 6))
85
  ax.imshow(cur_image)
86
  ax.set_title(title, fontsize=14)
87
  ax.axis('off')
88
  plt.tight_layout()
89
+ plt.close(fig)
90
  return fig
91
 
 
92
  def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
 
 
 
 
 
93
  if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy()
94
  if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy()
95
  for i, q in enumerate(pred_quantiles_list):
 
97
  pred_quantiles_list[i] = q.cpu().numpy()
98
 
99
  nvars = true_data.shape[1]
100
+ FIG_WIDTH, FIG_HEIGHT_PER_VAR = 15, 2.0
 
 
101
  fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
102
  if nvars == 1: axes = [axes]
103
 
 
104
  sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
 
 
105
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
106
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
 
107
  num_bands = len(quantile_preds) // 2
 
108
  quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1]
109
 
110
  for i, ax in enumerate(axes):
 
111
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
112
  pred_range = np.arange(context_len, context_len + pred_len)
113
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
114
 
 
115
  for j in range(num_bands):
116
+ lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
117
+ q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
118
+ ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
 
 
 
 
 
 
 
119
 
120
  y_min, y_max = ax.get_ylim()
121
  ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
 
124
  ax.margins(x=0)
125
 
126
  handles, labels = axes[0].get_legend_handles_labels()
 
127
  unique_labels = dict(zip(labels, handles))
128
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
129
  plt.tight_layout(rect=[0, 0, 1, 0.95])
 
135
  # 4. Prediction Logic
136
  # ========================
137
  class PredictionResult:
138
+ def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples, inferred_freq):
 
139
  self.ts_fig = ts_fig
140
  self.input_img_fig = input_img_fig
141
  self.recon_img_fig = recon_img_fig
142
  self.csv_path = csv_path
143
  self.total_samples = total_samples
144
+ self.inferred_freq = inferred_freq
145
 
146
+ def predict_at_index(df, index, context_len, pred_len):
147
+ # === Data Validation & Frequency Inference ===
 
 
148
  if 'date' not in df.columns:
149
  raise gr.Error("❌ Input CSV must contain a 'date' column.")
150
 
151
  try:
152
  df['date'] = pd.to_datetime(df['date'])
153
+ df = df.sort_values('date').set_index('date')
154
+ # *** NEW: Infer frequency ***
155
+ inferred_freq = pd.infer_freq(df.index)
156
+ if inferred_freq is None:
157
+ # Fallback if inference fails
158
+ time_diff = df.index[1] - df.index[0]
159
+ inferred_freq = pd.tseries.frequencies.to_offset(time_diff).freqstr
160
+ gr.Warning(f"Could not reliably infer frequency. Using fallback based on first two timestamps: {inferred_freq}")
161
+ print(f"Inferred frequency: {inferred_freq}")
162
+ except Exception as e:
163
+ raise gr.Error(f"❌ Date processing failed: {e}. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
164
 
 
165
  data = df.select_dtypes(include=np.number).values
166
  nvars = data.shape[1]
167
 
 
169
  if total_samples <= 0:
170
  raise gr.Error(f"Data is too short. It needs at least {context_len + pred_len} rows, but has {len(data)}.")
171
 
 
172
  index = max(0, min(index, total_samples - 1))
173
 
 
174
  train_len = int(len(data) * 0.7)
175
  x_mean = data[:train_len].mean(axis=0, keepdims=True)
176
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
177
  data_norm = (data - x_mean) / x_std
178
 
 
179
  start_idx = index
180
  x_norm = data_norm[start_idx : start_idx + context_len]
181
  y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len]
182
  x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE)
183
 
184
+ # *** Use inferred frequency ***
185
+ periodicity_list = freq_to_seasonality_list(inferred_freq)
186
  periodicity = periodicity_list[0] if periodicity_list else 1
187
  color_list = [i % 3 for i in range(nvars)]
188
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
 
191
  y_pred, input_image, reconstructed_image, _, _ = model.forward(
192
  x_tensor, export_image=True, color_list=color_list
193
  )
 
 
 
194
  all_preds = dict(zip(model.quantiles, y_pred))
195
+ pred_median_norm = all_preds.pop(0.5)[0]
196
+ pred_quantiles_norm = [q[0] for q in list(all_preds.values())]
 
197
 
 
198
  y_true = y_true_norm * x_std + x_mean
199
  pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean
200
  pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm]
201
 
 
202
  full_true_context = data[start_idx : start_idx + context_len]
203
  full_true_series = np.concatenate([full_true_context, y_true], axis=0)
204
 
 
205
  ts_fig = visual_ts_with_quantiles(
206
+ true_data=full_true_series, pred_median=pred_median,
207
+ pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
208
+ context_len=context_len, pred_len=pred_len
 
 
 
209
  )
210
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
211
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
212
 
 
213
  os.makedirs("outputs", exist_ok=True)
214
  csv_path = "outputs/prediction_result.csv"
215
  time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
 
216
  result_data = {'date': time_index}
217
  for i in range(nvars):
218
  result_data[f'True_Var{i+1}'] = y_true[:, i]
 
220
  result_df = pd.DataFrame(result_data)
221
  result_df.to_csv(csv_path, index=False)
222
 
223
+ return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples, inferred_freq)
224
 
225
 
226
  # ========================
227
  # 5. Gradio Interface
228
  # ========================
229
+ def run_forecast(data_source, upload_file, index, context_len, pred_len):
 
230
  if data_source == "Upload CSV":
231
  if upload_file is None:
232
  raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
 
235
  df = load_preset_data(data_source)
236
 
237
  try:
 
238
  index, context_len, pred_len = int(index), int(context_len), int(pred_len)
239
+ result = predict_at_index(df, index, context_len, pred_len)
 
240
 
 
241
  if index >= result.total_samples:
242
  final_index = result.total_samples - 1
243
  else:
 
248
  result.input_img_fig,
249
  result.recon_img_fig,
250
  result.csv_path,
251
+ gr.update(maximum=result.total_samples - 1, value=final_index),
252
+ gr.update(value=result.inferred_freq) # *** Update frequency textbox ***
253
  )
254
  except Exception as e:
 
255
  error_fig = plt.figure(figsize=(10, 5))
256
  plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
257
  plt.axis('off')
258
  plt.close(error_fig)
259
+ return error_fig, None, None, None, gr.update(), gr.update(value="Error")
 
 
260
 
261
  # UI Layout
262
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
 
264
  gr.Markdown(
265
  """
266
  An interactive platform to explore time series forecasting using the VisionTS++ model.
267
+ - βœ… **Select** from local preset datasets or **upload** your own.
268
+ - βœ… **Frequency is auto-detected** from the 'date' column.
269
  - βœ… **Visualize** predictions with multiple **quantile uncertainty bands**.
 
270
  - βœ… **Slide** through different samples of the dataset for real-time forecasting.
271
  - βœ… **Download** the prediction results as a CSV file.
272
  """
 
277
  gr.Markdown("### 1. Data & Model Configuration")
278
  data_source = gr.Dropdown(
279
  label="Select Data Source",
280
+ choices=list(PRESET_DATASETS.keys()) + ["Upload CSV"],
281
+ value="ETTh1"
282
  )
283
  upload_file = gr.File(label="Upload CSV File", file_types=['.csv'], visible=False)
284
  gr.Markdown(
285
  """
286
  **Upload Rules:**
287
  1. Must be a `.csv` file.
288
+ 2. Must contain a time column named `date` with a consistent frequency.
289
  """
290
  )
291
 
292
  context_len = gr.Number(label="Context Length (History)", value=336)
293
  pred_len = gr.Number(label="Prediction Length (Future)", value=96)
294
+ # *** Changed to non-interactive textbox to display freq ***
295
+ freq_display = gr.Textbox(label="Detected Frequency", interactive=True)
296
 
297
  run_btn = gr.Button("πŸš€ Run Forecast", variant="primary")
298
 
299
  gr.Markdown("### 2. Sample Selection")
 
300
  sample_index = gr.Slider(label="Sample Index", minimum=0, maximum=1000, step=1, value=10000)
301
 
302
  with gr.Column(scale=3):
 
308
  download_csv = gr.File(label="Download Prediction CSV")
309
 
310
  # --- Event Handlers ---
 
 
311
  def toggle_upload_visibility(choice):
312
  return gr.update(visible=(choice == "Upload CSV"))
313
 
314
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
315
 
316
+ inputs = [data_source, upload_file, sample_index, context_len, pred_len]
317
+ outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display]
 
318
 
 
319
  run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
 
 
320
  sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide")
321
 
322
+ # Remove Examples block to avoid startup issues and rely on the button.
323
+ # If you still want examples, ensure `cache_examples=False`.
324
+ # For simplicity, we'll remove it as the 'Run' button is clear.
 
 
 
 
 
 
 
 
 
325
 
326
  demo.launch(debug=True)