Gil Stetler commited on
Commit
b83a684
·
1 Parent(s): 218f038
Files changed (2) hide show
  1. app.py +268 -67
  2. requirements.txt +3 -4
app.py CHANGED
@@ -1,69 +1,270 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  import matplotlib.pyplot as plt
3
- import pandas as pd
4
- from utils_vol import fetch_close_series, realized_vol
5
- from autogluon.timeseries import TimeSeriesPredictor
6
- from train_autogluon import train_bolt_small
7
- import os
8
-
9
- MODEL_DIR = "/mnt/data/AutogluonChronosBoltSmall"
10
-
11
- # ---------- Handlers ----------
12
-
13
- def predict_vol(ticker, start, interval):
14
- if not os.path.isdir(MODEL_DIR):
15
- raise gr.Error("Kein trainiertes Modell gefunden. Bitte zuerst trainieren.")
16
- predictor = TimeSeriesPredictor.load(MODEL_DIR)
17
- close = fetch_close_series(ticker, start=start, interval=interval)
18
- rv = realized_vol(close)
19
- df = pd.DataFrame({"timestamp": rv.index, "target": rv.values, "item_id": "series_1"})
20
- forecast = predictor.predict(df)
21
- f = forecast.to_pandas()
22
- plt.figure(figsize=(8,4))
23
- plt.plot(rv.index, rv.values, label="Historie")
24
- plt.plot(f.index, f["0.5"], "--", label="Forecast (Median)")
25
- plt.legend()
26
- plt.title(f"{ticker} – Volatilitätsprognose (Chronos-Bolt-Small)")
27
- return plt
28
-
29
- def train_model(ticker, start, interval):
30
- train_bolt_small(ticker=ticker, start=start, interval=interval)
31
- return f"Training abgeschlossen und unter {MODEL_DIR} gespeichert."
32
-
33
- def clear_model():
34
- import shutil
35
- if os.path.isdir(MODEL_DIR):
36
- shutil.rmtree(MODEL_DIR)
37
- return "Modell gelöscht."
38
- return "Kein Modell zum Löschen gefunden."
39
-
40
- # ---------- UI ----------
41
- with gr.Blocks(title="Chronos-Bolt-Small (CPU) Fine-Tuning App") as demo:
42
- gr.Markdown("## Chronos-Bolt-Small Volatilitäts-Vorhersage\n"
43
- "Trainiert auf CPU innerhalb von ~10 Minuten über AutoGluon.\n"
44
- "• Tab **Train**: neues Modell fine-tunen\n"
45
- "• Tab **Predict**: Vorhersage anzeigen\n"
46
- "• Tab **Manage**: Modell löschen")
47
-
48
- with gr.Tab("Predict"):
49
- t1 = gr.Textbox(label="Ticker", value="AAPL")
50
- s1 = gr.Textbox(label="Startdatum", value="2015-01-01")
51
- i1 = gr.Dropdown(["1d","1wk","1mo"], value="1d", label="Intervall")
52
- btn_p = gr.Button("Vorhersagen")
53
- out_p = gr.Plot()
54
- btn_p.click(predict_vol, inputs=[t1, s1, i1], outputs=[out_p])
55
-
56
- with gr.Tab("Train"):
57
- t2 = gr.Textbox(label="Ticker", value="AAPL")
58
- s2 = gr.Textbox(label="Startdatum", value="2015-01-01")
59
- i2 = gr.Dropdown(["1d","1wk","1mo"], value="1d", label="Intervall")
60
- btn_t = gr.Button("Train (AutoGluon Chronos-Bolt-Small)")
61
- out_t = gr.Textbox(label="Train-Log", lines=8)
62
- btn_t.click(train_model, inputs=[t2, s2, i2], outputs=[out_t])
63
-
64
- with gr.Tab("Manage"):
65
- btn_c = gr.Button("Modell löschen")
66
- out_c = gr.Textbox(label="Status")
67
- btn_c.click(clear_model, outputs=[out_c])
68
-
69
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os, random
3
+ from typing import Tuple
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
  import gradio as gr
8
+ import matplotlib
9
+ matplotlib.use("Agg")
10
  import matplotlib.pyplot as plt
11
+ from chronos import ChronosPipeline
12
+
13
+ # our data pipeline
14
+ import pipeline_v2 as pipe2 # update_ticker_csv(...)
15
+
16
+ # --------------------
17
+ # Config
18
+ # --------------------
19
+ MODEL_ID = "amazon/chronos-t5-large"
20
+ PREDICTION_LENGTH = 30 # forecast last 30 days
21
+ NUM_SAMPLES = 1 # single path -> day-by-day point prediction
22
+ RV_WINDOW = 20 # realized vol window (trading days)
23
+ ANNUALIZE = True # annualize by sqrt(252)
24
+ EPS = 1e-8
25
+
26
+ # --------------------
27
+ # Model load (once)
28
+ # --------------------
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
31
+
32
+ pipe = ChronosPipeline.from_pretrained(
33
+ MODEL_ID,
34
+ device_map="auto",
35
+ torch_dtype=dtype,
36
+ )
37
+
38
+ # --------------------
39
+ # Helpers
40
+ # --------------------
41
+ def _extract_close(df: pd.DataFrame) -> pd.Series:
42
+ """
43
+ Robustly extract the close or adjusted close price as a numeric Series.
44
+ Handles both flat and MultiIndex columns (yfinance often returns MultiIndex
45
+ when multiple tickers or suffixes are used).
46
+ """
47
+ # --- Case 1: MultiIndex (e.g., ('Adj Close', 'BMW.DE')) ---
48
+ if isinstance(df.columns, pd.MultiIndex):
49
+ # Try Adj Close first
50
+ for name in ["Adj Close", "Adj_Close", "adj close", "adj_close"]:
51
+ if name in df.columns.get_level_values(0):
52
+ sub = df.xs(name, axis=1, level=0)
53
+ # If multiple tickers, pick first column
54
+ if sub.shape[1] > 1:
55
+ sub = sub.iloc[:, 0]
56
+ return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
57
+ # Fallback to Close
58
+ for name in ["Close", "close", "Price", "price"]:
59
+ if name in df.columns.get_level_values(0):
60
+ sub = df.xs(name, axis=1, level=0)
61
+ if sub.shape[1] > 1:
62
+ sub = sub.iloc[:, 0]
63
+ return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
64
+
65
+ # --- Case 2: Flat columns ---
66
+ mapping = {c.lower(): c for c in df.columns}
67
+ for name in ["adj close", "adj_close", "close", "price"]:
68
+ if name in mapping:
69
+ col = df[mapping[name]]
70
+ return pd.to_numeric(col, errors="coerce").dropna()
71
+
72
+ # --- Fallback: last numeric column ---
73
+ num_cols = df.select_dtypes(include=[np.number]).columns
74
+ if len(num_cols) == 0:
75
+ raise gr.Error("No numeric price column found in downloaded data.")
76
+ return pd.Series(df[num_cols[-1]]).astype(float)
77
+
78
+
79
+ def _extract_dates(df: pd.DataFrame):
80
+ # If index is DatetimeIndex, use it
81
+ if isinstance(df.index, pd.DatetimeIndex):
82
+ return df.index.to_numpy()
83
+ # Else try a date-like column
84
+ mapping = {c.lower(): c for c in df.columns}
85
+ for name in ["date", "time", "timestamp"]:
86
+ if name in mapping:
87
+ try:
88
+ return pd.to_datetime(df[mapping[name]]).to_numpy()
89
+ except Exception:
90
+ pass
91
+ # Fallback to a simple range
92
+ return np.arange(len(df))
93
+
94
+ def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
95
+ r = np.log(close).diff().dropna()
96
+ rv = r.rolling(window, min_periods=window).std()
97
+ if annualize:
98
+ rv = rv * np.sqrt(252.0)
99
+ return rv.dropna().reset_index(drop=True)
100
+
101
+ def bias_scale_calibration(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, np.ndarray]:
102
+ alpha = float(np.sum(y_true * y_pred) / (np.sum(y_pred**2) + EPS))
103
+ return alpha, alpha * y_pred
104
+
105
+ def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
106
+ err = y_pred - y_true
107
+ denom = np.maximum(EPS, np.abs(y_true))
108
+ mape = float((np.abs(err) / denom).mean() * 100)
109
+ mpe = float((err / np.maximum(EPS, y_true)).mean() * 100)
110
+ rmse = float(np.sqrt(np.mean(err**2)))
111
+ return {"MAPE": mape, "MPE": mpe, "RMSE": rmse}
112
+
113
+ # --------------------
114
+ # Core routine
115
+ # --------------------
116
+ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: bool):
117
+ """
118
+ tickers: comma/space separated; we use the FIRST for plotting/eval.
119
+ start: YYYY-MM-DD
120
+ interval: '1d', '1wk', '1mo'
121
+ """
122
+ # Parse first ticker (keep dots and dashes!)
123
+ tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
124
+ if not tick_list:
125
+ raise gr.Error("Please enter at least one ticker, e.g. AAPL or NESN.SW")
126
+
127
+ ticker = tick_list[0] # keep original form; pipeline handles uppercasing
128
+
129
+ # 1) Fetch/update CSV via pipeline
130
+ try:
131
+ csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
132
+ except Exception as e:
133
+ raise gr.Error(
134
+ f"Data fetch failed for '{ticker}'. Tip: ensure exchange suffixes (e.g., NESN.SW, BMW.DE, VOD.L).\n{e}"
135
+ )
136
+
137
+ # 2) Load CSV and build realized vol
138
+ try:
139
+ df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
140
+ if not isinstance(df.index, pd.DatetimeIndex):
141
+ # last fallback
142
+ df = pd.read_csv(csv_path)
143
+ except Exception:
144
+ df = pd.read_csv(csv_path)
145
+
146
+ dates = _extract_dates(df)
147
+ close = _extract_close(df)
148
+
149
+ rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE).to_numpy()
150
+ n = len(rv); H = PREDICTION_LENGTH
151
+ if n <= H + 5:
152
+ raise gr.Error(f"Vol series too short after rolling window. Need > {H+5}, got {n}.")
153
+
154
+ rv_train = rv[: n - H]
155
+ rv_test = rv[n - H :]
156
+
157
+ # 3) Forecast a single sample path (deterministic via seed)
158
+ random.seed(0); np.random.seed(0); torch.manual_seed(0)
159
+ if torch.cuda.is_available():
160
+ torch.cuda.manual_seed_all(0)
161
+
162
+ context = torch.tensor(rv_train, dtype=torch.float32)
163
+ fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, 1, H]
164
+ samples = fcst[0].cpu().numpy() # (1, H)
165
+ path_pred = samples[0] # (H,)
166
+
167
+ # 4) Optional bias/scale calibration
168
+ alpha = None
169
+ if use_calibration:
170
+ alpha, path_pred_cal = bias_scale_calibration(rv_test, path_pred)
171
+ metrics_raw = compute_metrics(rv_test, path_pred)
172
+ metrics_cal = compute_metrics(rv_test, path_pred_cal)
173
+ else:
174
+ metrics_raw = compute_metrics(rv_test, path_pred)
175
+ metrics_cal = None
176
+ path_pred_cal = None
177
+
178
+ # 5) Plot
179
+ fig = plt.figure(figsize=(10, 4))
180
+ H0 = len(rv_train)
181
+
182
+ if isinstance(dates, np.ndarray) and len(dates) >= len(close):
183
+ dates_rv = np.array(dates[-len(rv):])
184
+ x_hist = dates_rv[:H0]
185
+ x_fcst = dates_rv[H0:]
186
+ x_lbl = "date"
187
+ else:
188
+ x_hist = np.arange(H0)
189
+ x_fcst = np.arange(H0, H0 + H)
190
+ x_lbl = "time index"
191
+
192
+ plt.plot(x_hist, rv_train, label="realized vol (history)")
193
+ plt.plot(x_fcst, rv_test, label="realized vol (actual last 30)")
194
+ plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (raw path)")
195
+ if use_calibration:
196
+ plt.plot(x_fcst, path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})")
197
+
198
+ plt.title(f"{ticker.upper()} — Volatility Forecast (RV={RV_WINDOW}, H={H}, interval={interval})")
199
+ plt.xlabel(x_lbl); plt.ylabel("realized volatility")
200
+ plt.legend(loc="best"); plt.tight_layout()
201
+
202
+ # 6) Per-day table
203
+ last_dates = x_fcst
204
+ df_days = pd.DataFrame({
205
+ "date": last_dates,
206
+ "actual_vol": rv_test,
207
+ "forecast_raw": path_pred,
208
+ })
209
+ if use_calibration:
210
+ df_days["forecast_calibrated"] = path_pred_cal
211
+ df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
212
+ df_days["abs_pct_error_cal_%"] = np.abs((path_pred_cal - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
213
+ else:
214
+ df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
215
+
216
+ # 7) JSON + metrics text
217
+ out = {
218
+ "ticker": ticker.upper(),
219
+ "csv_path": csv_path,
220
+ "config": {
221
+ "start": start,
222
+ "interval": interval,
223
+ "rv_window": RV_WINDOW,
224
+ "prediction_length": H,
225
+ "num_samples": NUM_SAMPLES,
226
+ "annualized": ANNUALIZE,
227
+ "point_forecast": "single_sample_path",
228
+ },
229
+ "metrics_raw": {k: round(v, 4) for k, v in metrics_raw.items()},
230
+ }
231
+ metrics_md = f"**RAW** — MAPE {metrics_raw['MAPE']:.2f}% | MPE {metrics_raw['MPE']:.2f}% | RMSE {metrics_raw['RMSE']:.5f}"
232
+
233
+ if use_calibration and metrics_cal is not None:
234
+ out["alpha"] = alpha
235
+ out["metrics_calibrated"] = {k: round(v, 4) for k, v in metrics_cal.items()}
236
+ metrics_md += f"\n**CALIBRATED** — MAPE {metrics_cal['MAPE']:.2f}% | MPE {metrics_cal['MPE']:.2f}% | RMSE {metrics_cal['RMSE']:.5f}"
237
+
238
+ return fig, out, df_days, metrics_md
239
+
240
+ # --------------------
241
+ # UI
242
+ # --------------------
243
+ with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
244
+ gr.Markdown(
245
+ "### Predict last 30 days of realized volatility for any ticker\n"
246
+ "- Works with symbols like `AAPL`, `NESN.SW`, `BMW.DE`, `VOD.L`, `BRK-B`, `BTC-USD`.\n"
247
+ "- Data fetched via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n"
248
+ "- Forecast uses **Chronos-T5-Large** (single path, deterministic seed).\n"
249
+ "- Day-by-day comparison with **MAPE/MPE/RMSE**.\n"
250
+ "- Optional **Bias/Scale Calibration (α)**."
251
+ )
252
+ with gr.Row():
253
+ tickers_in = gr.Textbox(value="AAPL", label="Ticker (you can use suffixes like NESN.SW, BMW.DE)")
254
+ with gr.Row():
255
+ start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
256
+ interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
257
+ calib_in = gr.Checkbox(value=True, label="Apply bias/scale calibration (α)")
258
+ run_btn = gr.Button("Run", variant="primary")
259
+
260
+ plot = gr.Plot(label="Forecast vs Actual (last 30 days)")
261
+ meta = gr.JSON(label="Run config & metrics")
262
+ table = gr.Dataframe(label="Per-day comparison", wrap=True)
263
+ metrics = gr.Markdown(label="Summary")
264
+
265
+ run_btn.click(run_for_ticker, inputs=[tickers_in, start_in, interval_in, calib_in],
266
+ outputs=[plot, meta, table, metrics])
267
+
268
+ if __name__ == "__main__":
269
+ demo.launch()
270
+
requirements.txt CHANGED
@@ -1,8 +1,7 @@
1
- autogluon.timeseries==1.4.0
2
- chronos-forecasting>=2.0.0
3
  torch>=2.2
4
- numpy>=1.26
5
  pandas>=2.0
6
- gradio>=4.0
7
  matplotlib>=3.8
8
  yfinance>=0.2.40
 
1
+ gradio>=4.0
2
+ chronos-forecasting>=1.5
3
  torch>=2.2
 
4
  pandas>=2.0
5
+ numpy>=1.26
6
  matplotlib>=3.8
7
  yfinance>=0.2.40