Lefei commited on
Commit
f76da22
·
verified ·
1 Parent(s): f741a5c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import einops
9
+
10
+ from huggingface_hub import snapshot_download
11
+ from visionts import VisionTSpp, freq_to_seasonality_list
12
+
13
+ # ========================
14
+ # 配置
15
+ # ========================
16
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+ REPO_ID = "Lefei/VisionTSpp"
18
+ LOCAL_DIR = "./hf_models/VisionTSpp"
19
+ CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
20
+
21
+ ARCH = 'mae_base' # 可选: 'mae_base', 'mae_large', 'mae_huge'
22
+
23
+ # 下载模型(Space 构建时执行一次)
24
+ if not os.path.exists(CKPT_PATH):
25
+ os.makedirs(LOCAL_DIR, exist_ok=True)
26
+ print("Downloading model from Hugging Face Hub...")
27
+ snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
28
+
29
+ # 加载模型(全局加载一次)
30
+ model = VisionTSpp(
31
+ ARCH,
32
+ ckpt_path=CKPT_PATH,
33
+ quantile=True,
34
+ clip_input=True,
35
+ complete_no_clip=False,
36
+ color=True
37
+ ).to(DEVICE)
38
+ print(f"Model loaded on {DEVICE}")
39
+
40
+ # ========================
41
+ # 核心预测与可视化函数
42
+ # ========================
43
+
44
+ def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
45
+ """
46
+ 可视化真实值 vs 预测值
47
+ true: [T, nvars]
48
+ preds: [T, nvars],与 true 对齐
49
+ """
50
+ if isinstance(true, torch.Tensor):
51
+ true = true.cpu().numpy()
52
+ if isinstance(preds, torch.Tensor):
53
+ preds = preds.cpu().numpy()
54
+
55
+ nvars = true.shape[1]
56
+
57
+ FIG_WIDTH = 12
58
+ FIG_HEIGHT_PER_VAR = 1.8
59
+ FONT_S = 10
60
+
61
+ fig, axes = plt.subplots(
62
+ nrows=nvars, ncols=1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True,
63
+ gridspec_kw={'height_ratios': [1] * nvars}
64
+ )
65
+ if nvars == 1:
66
+ axes = [axes]
67
+
68
+ lookback_len = true.shape[0] - pred_len
69
+
70
+ for i, ax in enumerate(axes):
71
+ ax.plot(true[:, i], label='Ground Truth', color='gray', linewidth=1.8)
72
+ if preds is not None:
73
+ ax.plot(np.arange(lookback_len, len(true)), preds[lookback_len:, i],
74
+ label='Prediction (Median)', color='blue', linewidth=1.8)
75
+
76
+ # 分隔线
77
+ y_min, y_max = ax.get_ylim()
78
+ ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max,
79
+ colors='gray', linestyles='--', alpha=0.7, linewidth=1)
80
+
81
+ ax.set_yticks([])
82
+ ax.set_xticks([])
83
+ ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold')
84
+
85
+ # 图例
86
+ if preds is not None:
87
+ handles, labels = axes[0].get_legend_handles_labels()
88
+ fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S})
89
+
90
+ # 计算 MSE/MAE
91
+ if preds is not None:
92
+ true_eval = true[-pred_len:]
93
+ pred_eval = preds[-pred_len:]
94
+ mse = np.mean((true_eval - pred_eval) ** 2)
95
+ mae = np.mean(np.abs(true_eval - pred_eval))
96
+ fig.suptitle(f'MSE: {mse:.4f}, MAE: {mae:.4f}', fontsize=12, y=0.95)
97
+
98
+ plt.subplots_adjust(hspace=0)
99
+ return fig # 返回 matplotlib figure
100
+
101
+
102
+ def predict_and_visualize(df, context_len=960, pred_len=394, freq="15Min"):
103
+ """
104
+ 输入: df (pandas.DataFrame),必须包含 'date' 列和其他数值列
105
+ 输出: matplotlib 图像
106
+ """
107
+ if 'date' in df.columns:
108
+ df['date'] = pd.to_datetime(df['date'])
109
+ df = df.set_index('date')
110
+ else:
111
+ # 如果没有 date 列,假设是纯数值序列
112
+ df = df.copy()
113
+
114
+ data = df.values # [T, nvars]
115
+ nvars = data.shape[1]
116
+
117
+ if data.shape[0] < context_len + pred_len:
118
+ raise ValueError(f"数据太短,至少需要 {context_len + pred_len} 行,当前只有 {data.shape[0]} 行。")
119
+
120
+ # 归一化(使用训练集前 70% 的统计量)
121
+ train_len = int(len(data) * 0.7)
122
+ x_mean = data[:train_len].mean(axis=0, keepdims=True)
123
+ x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
124
+ data_norm = (data - x_mean) / x_std
125
+
126
+ # 取最后一段作为测试窗口
127
+ end_idx = len(data_norm)
128
+ start_idx = end_idx - (context_len + pred_len)
129
+ x = data_norm[start_idx:start_idx + context_len] # [context_len, nvars]
130
+ y_true = data_norm[start_idx + context_len:end_idx] # [pred_len, nvars]
131
+
132
+ # 设置周期性
133
+ periodicity_list = freq_to_seasonality_list(freq)
134
+ periodicity = periodicity_list[0] if periodicity_list else 1
135
+ color_list = [i % 3 for i in range(nvars)] # RGB 循环着色
136
+
137
+ # 更新模型配置
138
+ model.update_config(
139
+ context_len=context_len,
140
+ pred_len=pred_len,
141
+ periodicity=periodicity,
142
+ num_patch_input=7,
143
+ padding_mode='constant'
144
+ )
145
+
146
+ # 转为 tensor
147
+ x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE) # [1, T, N]
148
+ y_true_tensor = torch.FloatTensor(y_true).unsqueeze(0).to(DEVICE)
149
+
150
+ # 预测
151
+ with torch.no_grad():
152
+ y_pred, _, _, _, _ = model.forward(x_tensor, export_image=True, color_list=color_list)
153
+ y_pred_median = y_pred[0] # median prediction
154
+
155
+ # 反归一化
156
+ y_true_original = y_true * x_std + x_mean
157
+ y_pred_original = y_pred_median[0].cpu().numpy() * x_std + x_mean
158
+
159
+ # 构造完整序列用于可视化
160
+ full_true = np.concatenate([x * x_std + x_mean, y_true_original], axis=0)
161
+ full_pred = np.concatenate([x * x_std + x_mean, y_pred_original], axis=0)
162
+
163
+ # 可视化
164
+ fig = visual_ts(true=full_true, preds=full_pred, lookback_len_visual=context_len, pred_len=pred_len)
165
+ return fig
166
+
167
+
168
+ # ========================
169
+ # 默认数据加载
170
+ # ========================
171
+ def load_default_data():
172
+ data_path = "./datasets/ETTm1.csv"
173
+ if not os.path.exists(data_path):
174
+ os.makedirs("./datasets", exist_ok=True)
175
+ url = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv"
176
+ df = pd.read_csv(url)
177
+ df.to_csv(data_path, index=False)
178
+ else:
179
+ df = pd.read_csv(data_path)
180
+ return df
181
+
182
+
183
+ # ========================
184
+ # Gradio 界面
185
+ # ========================
186
+ def run_forecast(file_input, context_len, pred_len, freq):
187
+ if file_input is not None:
188
+ df = pd.read_csv(file_input.name)
189
+ title = "Uploaded Data Prediction"
190
+ else:
191
+ df = load_default_data()
192
+ title = "Default ETTm1 Dataset Prediction"
193
+
194
+ try:
195
+ fig = predict_and_visualize(df, context_len=int(context_len), pred_len=int(pred_len), freq=freq)
196
+ fig.suptitle(title, fontsize=14, y=0.98)
197
+ plt.close(fig) # 防止重复显示
198
+ return fig
199
+ except Exception as e:
200
+ # 返回错误信息图像
201
+ fig, ax = plt.subplots()
202
+ ax.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', wrap=True)
203
+ ax.axis('off')
204
+ plt.close(fig)
205
+ return fig
206
+
207
+
208
+ # Gradio UI
209
+ with gr.Blocks(title="VisionTS++ 时间序列预测") as demo:
210
+ gr.Markdown("# 🕰️ VisionTS++ 时间序列预测平台")
211
+ gr.Markdown("上传你的多变量时间序列 CSV 文件,或使用默认 ETTm1 数据进行预测。")
212
+
213
+ with gr.Row():
214
+ file_input = gr.File(label="上传 CSV 文件(含 date 列或纯数值)", file_types=['.csv'])
215
+ with gr.Column():
216
+ context_len = gr.Number(label="历史长度 (context_len)", value=960)
217
+ pred_len = gr.Number(label="预测长度 (pred_len)", value=394)
218
+ freq = gr.Textbox(label="时间频率 (如 15Min, H)", value="15Min")
219
+
220
+ btn = gr.Button("🚀 开始预测")
221
+
222
+ output_plot = gr.Plot(label="预测结果")
223
+
224
+ btn.click(
225
+ fn=run_forecast,
226
+ inputs=[file_input, context_len, pred_len, freq],
227
+ outputs=output_plot
228
+ )
229
+
230
+ # 示例:使用默认数据
231
+ gr.Examples(
232
+ examples=[
233
+ [None, 960, 394, "15Min"]
234
+ ],
235
+ inputs=[file_input, context_len, pred_len, freq],
236
+ outputs=output_plot,
237
+ fn=run_forecast,
238
+ label="点击运行默认示例"
239
+ )
240
+
241
+ # 启动
242
+ if __name__ == "__main__":
243
+ demo.launch()