Spaces:
Running
Running
Gil Stetler
commited on
Commit
·
b83a684
1
Parent(s):
218f038
rollback
Browse files- app.py +268 -67
- requirements.txt +3 -4
app.py
CHANGED
|
@@ -1,69 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os, random
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
import gradio as gr
|
| 8 |
+
import matplotlib
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
+
from chronos import ChronosPipeline
|
| 12 |
+
|
| 13 |
+
# our data pipeline
|
| 14 |
+
import pipeline_v2 as pipe2 # update_ticker_csv(...)
|
| 15 |
+
|
| 16 |
+
# --------------------
|
| 17 |
+
# Config
|
| 18 |
+
# --------------------
|
| 19 |
+
MODEL_ID = "amazon/chronos-t5-large"
|
| 20 |
+
PREDICTION_LENGTH = 30 # forecast last 30 days
|
| 21 |
+
NUM_SAMPLES = 1 # single path -> day-by-day point prediction
|
| 22 |
+
RV_WINDOW = 20 # realized vol window (trading days)
|
| 23 |
+
ANNUALIZE = True # annualize by sqrt(252)
|
| 24 |
+
EPS = 1e-8
|
| 25 |
+
|
| 26 |
+
# --------------------
|
| 27 |
+
# Model load (once)
|
| 28 |
+
# --------------------
|
| 29 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 31 |
+
|
| 32 |
+
pipe = ChronosPipeline.from_pretrained(
|
| 33 |
+
MODEL_ID,
|
| 34 |
+
device_map="auto",
|
| 35 |
+
torch_dtype=dtype,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# --------------------
|
| 39 |
+
# Helpers
|
| 40 |
+
# --------------------
|
| 41 |
+
def _extract_close(df: pd.DataFrame) -> pd.Series:
|
| 42 |
+
"""
|
| 43 |
+
Robustly extract the close or adjusted close price as a numeric Series.
|
| 44 |
+
Handles both flat and MultiIndex columns (yfinance often returns MultiIndex
|
| 45 |
+
when multiple tickers or suffixes are used).
|
| 46 |
+
"""
|
| 47 |
+
# --- Case 1: MultiIndex (e.g., ('Adj Close', 'BMW.DE')) ---
|
| 48 |
+
if isinstance(df.columns, pd.MultiIndex):
|
| 49 |
+
# Try Adj Close first
|
| 50 |
+
for name in ["Adj Close", "Adj_Close", "adj close", "adj_close"]:
|
| 51 |
+
if name in df.columns.get_level_values(0):
|
| 52 |
+
sub = df.xs(name, axis=1, level=0)
|
| 53 |
+
# If multiple tickers, pick first column
|
| 54 |
+
if sub.shape[1] > 1:
|
| 55 |
+
sub = sub.iloc[:, 0]
|
| 56 |
+
return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
|
| 57 |
+
# Fallback to Close
|
| 58 |
+
for name in ["Close", "close", "Price", "price"]:
|
| 59 |
+
if name in df.columns.get_level_values(0):
|
| 60 |
+
sub = df.xs(name, axis=1, level=0)
|
| 61 |
+
if sub.shape[1] > 1:
|
| 62 |
+
sub = sub.iloc[:, 0]
|
| 63 |
+
return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
|
| 64 |
+
|
| 65 |
+
# --- Case 2: Flat columns ---
|
| 66 |
+
mapping = {c.lower(): c for c in df.columns}
|
| 67 |
+
for name in ["adj close", "adj_close", "close", "price"]:
|
| 68 |
+
if name in mapping:
|
| 69 |
+
col = df[mapping[name]]
|
| 70 |
+
return pd.to_numeric(col, errors="coerce").dropna()
|
| 71 |
+
|
| 72 |
+
# --- Fallback: last numeric column ---
|
| 73 |
+
num_cols = df.select_dtypes(include=[np.number]).columns
|
| 74 |
+
if len(num_cols) == 0:
|
| 75 |
+
raise gr.Error("No numeric price column found in downloaded data.")
|
| 76 |
+
return pd.Series(df[num_cols[-1]]).astype(float)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _extract_dates(df: pd.DataFrame):
|
| 80 |
+
# If index is DatetimeIndex, use it
|
| 81 |
+
if isinstance(df.index, pd.DatetimeIndex):
|
| 82 |
+
return df.index.to_numpy()
|
| 83 |
+
# Else try a date-like column
|
| 84 |
+
mapping = {c.lower(): c for c in df.columns}
|
| 85 |
+
for name in ["date", "time", "timestamp"]:
|
| 86 |
+
if name in mapping:
|
| 87 |
+
try:
|
| 88 |
+
return pd.to_datetime(df[mapping[name]]).to_numpy()
|
| 89 |
+
except Exception:
|
| 90 |
+
pass
|
| 91 |
+
# Fallback to a simple range
|
| 92 |
+
return np.arange(len(df))
|
| 93 |
+
|
| 94 |
+
def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
|
| 95 |
+
r = np.log(close).diff().dropna()
|
| 96 |
+
rv = r.rolling(window, min_periods=window).std()
|
| 97 |
+
if annualize:
|
| 98 |
+
rv = rv * np.sqrt(252.0)
|
| 99 |
+
return rv.dropna().reset_index(drop=True)
|
| 100 |
+
|
| 101 |
+
def bias_scale_calibration(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, np.ndarray]:
|
| 102 |
+
alpha = float(np.sum(y_true * y_pred) / (np.sum(y_pred**2) + EPS))
|
| 103 |
+
return alpha, alpha * y_pred
|
| 104 |
+
|
| 105 |
+
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
|
| 106 |
+
err = y_pred - y_true
|
| 107 |
+
denom = np.maximum(EPS, np.abs(y_true))
|
| 108 |
+
mape = float((np.abs(err) / denom).mean() * 100)
|
| 109 |
+
mpe = float((err / np.maximum(EPS, y_true)).mean() * 100)
|
| 110 |
+
rmse = float(np.sqrt(np.mean(err**2)))
|
| 111 |
+
return {"MAPE": mape, "MPE": mpe, "RMSE": rmse}
|
| 112 |
+
|
| 113 |
+
# --------------------
|
| 114 |
+
# Core routine
|
| 115 |
+
# --------------------
|
| 116 |
+
def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: bool):
|
| 117 |
+
"""
|
| 118 |
+
tickers: comma/space separated; we use the FIRST for plotting/eval.
|
| 119 |
+
start: YYYY-MM-DD
|
| 120 |
+
interval: '1d', '1wk', '1mo'
|
| 121 |
+
"""
|
| 122 |
+
# Parse first ticker (keep dots and dashes!)
|
| 123 |
+
tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
|
| 124 |
+
if not tick_list:
|
| 125 |
+
raise gr.Error("Please enter at least one ticker, e.g. AAPL or NESN.SW")
|
| 126 |
+
|
| 127 |
+
ticker = tick_list[0] # keep original form; pipeline handles uppercasing
|
| 128 |
+
|
| 129 |
+
# 1) Fetch/update CSV via pipeline
|
| 130 |
+
try:
|
| 131 |
+
csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
raise gr.Error(
|
| 134 |
+
f"Data fetch failed for '{ticker}'. Tip: ensure exchange suffixes (e.g., NESN.SW, BMW.DE, VOD.L).\n{e}"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# 2) Load CSV and build realized vol
|
| 138 |
+
try:
|
| 139 |
+
df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
|
| 140 |
+
if not isinstance(df.index, pd.DatetimeIndex):
|
| 141 |
+
# last fallback
|
| 142 |
+
df = pd.read_csv(csv_path)
|
| 143 |
+
except Exception:
|
| 144 |
+
df = pd.read_csv(csv_path)
|
| 145 |
+
|
| 146 |
+
dates = _extract_dates(df)
|
| 147 |
+
close = _extract_close(df)
|
| 148 |
+
|
| 149 |
+
rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE).to_numpy()
|
| 150 |
+
n = len(rv); H = PREDICTION_LENGTH
|
| 151 |
+
if n <= H + 5:
|
| 152 |
+
raise gr.Error(f"Vol series too short after rolling window. Need > {H+5}, got {n}.")
|
| 153 |
+
|
| 154 |
+
rv_train = rv[: n - H]
|
| 155 |
+
rv_test = rv[n - H :]
|
| 156 |
+
|
| 157 |
+
# 3) Forecast a single sample path (deterministic via seed)
|
| 158 |
+
random.seed(0); np.random.seed(0); torch.manual_seed(0)
|
| 159 |
+
if torch.cuda.is_available():
|
| 160 |
+
torch.cuda.manual_seed_all(0)
|
| 161 |
+
|
| 162 |
+
context = torch.tensor(rv_train, dtype=torch.float32)
|
| 163 |
+
fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, 1, H]
|
| 164 |
+
samples = fcst[0].cpu().numpy() # (1, H)
|
| 165 |
+
path_pred = samples[0] # (H,)
|
| 166 |
+
|
| 167 |
+
# 4) Optional bias/scale calibration
|
| 168 |
+
alpha = None
|
| 169 |
+
if use_calibration:
|
| 170 |
+
alpha, path_pred_cal = bias_scale_calibration(rv_test, path_pred)
|
| 171 |
+
metrics_raw = compute_metrics(rv_test, path_pred)
|
| 172 |
+
metrics_cal = compute_metrics(rv_test, path_pred_cal)
|
| 173 |
+
else:
|
| 174 |
+
metrics_raw = compute_metrics(rv_test, path_pred)
|
| 175 |
+
metrics_cal = None
|
| 176 |
+
path_pred_cal = None
|
| 177 |
+
|
| 178 |
+
# 5) Plot
|
| 179 |
+
fig = plt.figure(figsize=(10, 4))
|
| 180 |
+
H0 = len(rv_train)
|
| 181 |
+
|
| 182 |
+
if isinstance(dates, np.ndarray) and len(dates) >= len(close):
|
| 183 |
+
dates_rv = np.array(dates[-len(rv):])
|
| 184 |
+
x_hist = dates_rv[:H0]
|
| 185 |
+
x_fcst = dates_rv[H0:]
|
| 186 |
+
x_lbl = "date"
|
| 187 |
+
else:
|
| 188 |
+
x_hist = np.arange(H0)
|
| 189 |
+
x_fcst = np.arange(H0, H0 + H)
|
| 190 |
+
x_lbl = "time index"
|
| 191 |
+
|
| 192 |
+
plt.plot(x_hist, rv_train, label="realized vol (history)")
|
| 193 |
+
plt.plot(x_fcst, rv_test, label="realized vol (actual last 30)")
|
| 194 |
+
plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (raw path)")
|
| 195 |
+
if use_calibration:
|
| 196 |
+
plt.plot(x_fcst, path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})")
|
| 197 |
+
|
| 198 |
+
plt.title(f"{ticker.upper()} — Volatility Forecast (RV={RV_WINDOW}, H={H}, interval={interval})")
|
| 199 |
+
plt.xlabel(x_lbl); plt.ylabel("realized volatility")
|
| 200 |
+
plt.legend(loc="best"); plt.tight_layout()
|
| 201 |
+
|
| 202 |
+
# 6) Per-day table
|
| 203 |
+
last_dates = x_fcst
|
| 204 |
+
df_days = pd.DataFrame({
|
| 205 |
+
"date": last_dates,
|
| 206 |
+
"actual_vol": rv_test,
|
| 207 |
+
"forecast_raw": path_pred,
|
| 208 |
+
})
|
| 209 |
+
if use_calibration:
|
| 210 |
+
df_days["forecast_calibrated"] = path_pred_cal
|
| 211 |
+
df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
|
| 212 |
+
df_days["abs_pct_error_cal_%"] = np.abs((path_pred_cal - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
|
| 213 |
+
else:
|
| 214 |
+
df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
|
| 215 |
+
|
| 216 |
+
# 7) JSON + metrics text
|
| 217 |
+
out = {
|
| 218 |
+
"ticker": ticker.upper(),
|
| 219 |
+
"csv_path": csv_path,
|
| 220 |
+
"config": {
|
| 221 |
+
"start": start,
|
| 222 |
+
"interval": interval,
|
| 223 |
+
"rv_window": RV_WINDOW,
|
| 224 |
+
"prediction_length": H,
|
| 225 |
+
"num_samples": NUM_SAMPLES,
|
| 226 |
+
"annualized": ANNUALIZE,
|
| 227 |
+
"point_forecast": "single_sample_path",
|
| 228 |
+
},
|
| 229 |
+
"metrics_raw": {k: round(v, 4) for k, v in metrics_raw.items()},
|
| 230 |
+
}
|
| 231 |
+
metrics_md = f"**RAW** — MAPE {metrics_raw['MAPE']:.2f}% | MPE {metrics_raw['MPE']:.2f}% | RMSE {metrics_raw['RMSE']:.5f}"
|
| 232 |
+
|
| 233 |
+
if use_calibration and metrics_cal is not None:
|
| 234 |
+
out["alpha"] = alpha
|
| 235 |
+
out["metrics_calibrated"] = {k: round(v, 4) for k, v in metrics_cal.items()}
|
| 236 |
+
metrics_md += f"\n**CALIBRATED** — MAPE {metrics_cal['MAPE']:.2f}% | MPE {metrics_cal['MPE']:.2f}% | RMSE {metrics_cal['RMSE']:.5f}"
|
| 237 |
+
|
| 238 |
+
return fig, out, df_days, metrics_md
|
| 239 |
+
|
| 240 |
+
# --------------------
|
| 241 |
+
# UI
|
| 242 |
+
# --------------------
|
| 243 |
+
with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
|
| 244 |
+
gr.Markdown(
|
| 245 |
+
"### Predict last 30 days of realized volatility for any ticker\n"
|
| 246 |
+
"- Works with symbols like `AAPL`, `NESN.SW`, `BMW.DE`, `VOD.L`, `BRK-B`, `BTC-USD`.\n"
|
| 247 |
+
"- Data fetched via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n"
|
| 248 |
+
"- Forecast uses **Chronos-T5-Large** (single path, deterministic seed).\n"
|
| 249 |
+
"- Day-by-day comparison with **MAPE/MPE/RMSE**.\n"
|
| 250 |
+
"- Optional **Bias/Scale Calibration (α)**."
|
| 251 |
+
)
|
| 252 |
+
with gr.Row():
|
| 253 |
+
tickers_in = gr.Textbox(value="AAPL", label="Ticker (you can use suffixes like NESN.SW, BMW.DE)")
|
| 254 |
+
with gr.Row():
|
| 255 |
+
start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
|
| 256 |
+
interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
|
| 257 |
+
calib_in = gr.Checkbox(value=True, label="Apply bias/scale calibration (α)")
|
| 258 |
+
run_btn = gr.Button("Run", variant="primary")
|
| 259 |
+
|
| 260 |
+
plot = gr.Plot(label="Forecast vs Actual (last 30 days)")
|
| 261 |
+
meta = gr.JSON(label="Run config & metrics")
|
| 262 |
+
table = gr.Dataframe(label="Per-day comparison", wrap=True)
|
| 263 |
+
metrics = gr.Markdown(label="Summary")
|
| 264 |
+
|
| 265 |
+
run_btn.click(run_for_ticker, inputs=[tickers_in, start_in, interval_in, calib_in],
|
| 266 |
+
outputs=[plot, meta, table, metrics])
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
demo.launch()
|
| 270 |
+
|
requirements.txt
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
chronos-forecasting>=
|
| 3 |
torch>=2.2
|
| 4 |
-
numpy>=1.26
|
| 5 |
pandas>=2.0
|
| 6 |
-
|
| 7 |
matplotlib>=3.8
|
| 8 |
yfinance>=0.2.40
|
|
|
|
| 1 |
+
gradio>=4.0
|
| 2 |
+
chronos-forecasting>=1.5
|
| 3 |
torch>=2.2
|
|
|
|
| 4 |
pandas>=2.0
|
| 5 |
+
numpy>=1.26
|
| 6 |
matplotlib>=3.8
|
| 7 |
yfinance>=0.2.40
|