Spaces:
Running
Running
| # app.py | |
| import os, random | |
| from typing import Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from chronos import ChronosPipeline | |
| # our data pipeline | |
| import pipeline_v2 as pipe2 # update_ticker_csv(...) | |
| # -------------------- | |
| # Config | |
| # -------------------- | |
| MODEL_ID = "amazon/chronos-t5-large" | |
| PREDICTION_LENGTH = 30 # forecast last 30 days | |
| NUM_SAMPLES = 1 # single path -> day-by-day point prediction | |
| RV_WINDOW = 20 # realized vol window (trading days) | |
| ANNUALIZE = True # annualize by sqrt(252) | |
| EPS = 1e-8 | |
| # -------------------- | |
| # Model load (once) | |
| # -------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| pipe = ChronosPipeline.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| torch_dtype=dtype, | |
| ) | |
| # -------------------- | |
| # Helpers | |
| # -------------------- | |
| def _extract_close(df: pd.DataFrame) -> pd.Series: | |
| """ | |
| Robustly extract the close or adjusted close price as a numeric Series. | |
| Handles both flat and MultiIndex columns (yfinance often returns MultiIndex | |
| when multiple tickers or suffixes are used). | |
| """ | |
| # --- Case 1: MultiIndex (e.g., ('Adj Close', 'BMW.DE')) --- | |
| if isinstance(df.columns, pd.MultiIndex): | |
| # Try Adj Close first | |
| for name in ["Adj Close", "Adj_Close", "adj close", "adj_close"]: | |
| if name in df.columns.get_level_values(0): | |
| sub = df.xs(name, axis=1, level=0) | |
| # If multiple tickers, pick first column | |
| if sub.shape[1] > 1: | |
| sub = sub.iloc[:, 0] | |
| return pd.to_numeric(sub.squeeze(), errors="coerce").dropna() | |
| # Fallback to Close | |
| for name in ["Close", "close", "Price", "price"]: | |
| if name in df.columns.get_level_values(0): | |
| sub = df.xs(name, axis=1, level=0) | |
| if sub.shape[1] > 1: | |
| sub = sub.iloc[:, 0] | |
| return pd.to_numeric(sub.squeeze(), errors="coerce").dropna() | |
| # --- Case 2: Flat columns --- | |
| mapping = {c.lower(): c for c in df.columns} | |
| for name in ["adj close", "adj_close", "close", "price"]: | |
| if name in mapping: | |
| col = df[mapping[name]] | |
| return pd.to_numeric(col, errors="coerce").dropna() | |
| # --- Fallback: last numeric column --- | |
| num_cols = df.select_dtypes(include=[np.number]).columns | |
| if len(num_cols) == 0: | |
| raise gr.Error("No numeric price column found in downloaded data.") | |
| return pd.Series(df[num_cols[-1]]).astype(float) | |
| def _extract_dates(df: pd.DataFrame): | |
| # If index is DatetimeIndex, use it | |
| if isinstance(df.index, pd.DatetimeIndex): | |
| return df.index.to_numpy() | |
| # Else try a date-like column | |
| mapping = {c.lower(): c for c in df.columns} | |
| for name in ["date", "time", "timestamp"]: | |
| if name in mapping: | |
| try: | |
| return pd.to_datetime(df[mapping[name]]).to_numpy() | |
| except Exception: | |
| pass | |
| # Fallback to a simple range | |
| return np.arange(len(df)) | |
| def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series: | |
| r = np.log(close).diff().dropna() | |
| rv = r.rolling(window, min_periods=window).std() | |
| if annualize: | |
| rv = rv * np.sqrt(252.0) | |
| return rv.dropna().reset_index(drop=True) | |
| def bias_scale_calibration(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, np.ndarray]: | |
| alpha = float(np.sum(y_true * y_pred) / (np.sum(y_pred**2) + EPS)) | |
| return alpha, alpha * y_pred | |
| def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict: | |
| err = y_pred - y_true | |
| denom = np.maximum(EPS, np.abs(y_true)) | |
| mape = float((np.abs(err) / denom).mean() * 100) | |
| mpe = float((err / np.maximum(EPS, y_true)).mean() * 100) | |
| rmse = float(np.sqrt(np.mean(err**2))) | |
| return {"MAPE": mape, "MPE": mpe, "RMSE": rmse} | |
| # -------------------- | |
| # Core routine | |
| # -------------------- | |
| def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: bool): | |
| """ | |
| tickers: comma/space separated; we use the FIRST for plotting/eval. | |
| start: YYYY-MM-DD | |
| interval: '1d', '1wk', '1mo' | |
| """ | |
| # Parse first ticker (keep dots and dashes!) | |
| tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()] | |
| if not tick_list: | |
| raise gr.Error("Please enter at least one ticker, e.g. AAPL or NESN.SW") | |
| ticker = tick_list[0] # keep original form; pipeline handles uppercasing | |
| # 1) Fetch/update CSV via pipeline | |
| try: | |
| csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval) | |
| except Exception as e: | |
| raise gr.Error( | |
| f"Data fetch failed for '{ticker}'. Tip: ensure exchange suffixes (e.g., NESN.SW, BMW.DE, VOD.L).\n{e}" | |
| ) | |
| # 2) Load CSV and build realized vol | |
| try: | |
| df = pd.read_csv(csv_path, index_col=0, parse_dates=True) | |
| if not isinstance(df.index, pd.DatetimeIndex): | |
| # last fallback | |
| df = pd.read_csv(csv_path) | |
| except Exception: | |
| df = pd.read_csv(csv_path) | |
| dates = _extract_dates(df) | |
| close = _extract_close(df) | |
| rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE).to_numpy() | |
| n = len(rv); H = PREDICTION_LENGTH | |
| if n <= H + 5: | |
| raise gr.Error(f"Vol series too short after rolling window. Need > {H+5}, got {n}.") | |
| rv_train = rv[: n - H] | |
| rv_test = rv[n - H :] | |
| # 3) Forecast a single sample path (deterministic via seed) | |
| random.seed(0); np.random.seed(0); torch.manual_seed(0) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(0) | |
| context = torch.tensor(rv_train, dtype=torch.float32) | |
| fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, 1, H] | |
| samples = fcst[0].cpu().numpy() # (1, H) | |
| path_pred = samples[0] # (H,) | |
| # 4) Optional bias/scale calibration | |
| alpha = None | |
| if use_calibration: | |
| alpha, path_pred_cal = bias_scale_calibration(rv_test, path_pred) | |
| metrics_raw = compute_metrics(rv_test, path_pred) | |
| metrics_cal = compute_metrics(rv_test, path_pred_cal) | |
| else: | |
| metrics_raw = compute_metrics(rv_test, path_pred) | |
| metrics_cal = None | |
| path_pred_cal = None | |
| # 5) Plot | |
| fig = plt.figure(figsize=(10, 4)) | |
| H0 = len(rv_train) | |
| if isinstance(dates, np.ndarray) and len(dates) >= len(close): | |
| dates_rv = np.array(dates[-len(rv):]) | |
| x_hist = dates_rv[:H0] | |
| x_fcst = dates_rv[H0:] | |
| x_lbl = "date" | |
| else: | |
| x_hist = np.arange(H0) | |
| x_fcst = np.arange(H0, H0 + H) | |
| x_lbl = "time index" | |
| plt.plot(x_hist, rv_train, label="realized vol (history)") | |
| plt.plot(x_fcst, rv_test, label="realized vol (actual last 30)") | |
| plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (raw path)") | |
| if use_calibration: | |
| plt.plot(x_fcst, path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})") | |
| plt.title(f"{ticker.upper()} — Volatility Forecast (RV={RV_WINDOW}, H={H}, interval={interval})") | |
| plt.xlabel(x_lbl); plt.ylabel("realized volatility") | |
| plt.legend(loc="best"); plt.tight_layout() | |
| # 6) Per-day table | |
| last_dates = x_fcst | |
| df_days = pd.DataFrame({ | |
| "date": last_dates, | |
| "actual_vol": rv_test, | |
| "forecast_raw": path_pred, | |
| }) | |
| if use_calibration: | |
| df_days["forecast_calibrated"] = path_pred_cal | |
| df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100 | |
| df_days["abs_pct_error_cal_%"] = np.abs((path_pred_cal - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100 | |
| else: | |
| df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100 | |
| # 7) JSON + metrics text | |
| out = { | |
| "ticker": ticker.upper(), | |
| "csv_path": csv_path, | |
| "config": { | |
| "start": start, | |
| "interval": interval, | |
| "rv_window": RV_WINDOW, | |
| "prediction_length": H, | |
| "num_samples": NUM_SAMPLES, | |
| "annualized": ANNUALIZE, | |
| "point_forecast": "single_sample_path", | |
| }, | |
| "metrics_raw": {k: round(v, 4) for k, v in metrics_raw.items()}, | |
| } | |
| metrics_md = f"**RAW** — MAPE {metrics_raw['MAPE']:.2f}% | MPE {metrics_raw['MPE']:.2f}% | RMSE {metrics_raw['RMSE']:.5f}" | |
| if use_calibration and metrics_cal is not None: | |
| out["alpha"] = alpha | |
| out["metrics_calibrated"] = {k: round(v, 4) for k, v in metrics_cal.items()} | |
| metrics_md += f"\n**CALIBRATED** — MAPE {metrics_cal['MAPE']:.2f}% | MPE {metrics_cal['MPE']:.2f}% | RMSE {metrics_cal['RMSE']:.5f}" | |
| return fig, out, df_days, metrics_md | |
| # -------------------- | |
| # UI | |
| # -------------------- | |
| with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo: | |
| gr.Markdown( | |
| "### Predict last 30 days of realized volatility for any ticker\n" | |
| "- Works with symbols like `AAPL`, `NESN.SW`, `BMW.DE`, `VOD.L`, `BRK-B`, `BTC-USD`.\n" | |
| "- Data fetched via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n" | |
| "- Forecast uses **Chronos-T5-Large** (single path, deterministic seed).\n" | |
| "- Day-by-day comparison with **MAPE/MPE/RMSE**.\n" | |
| "- Optional **Bias/Scale Calibration (α)**." | |
| ) | |
| with gr.Row(): | |
| tickers_in = gr.Textbox(value="AAPL", label="Ticker (you can use suffixes like NESN.SW, BMW.DE)") | |
| with gr.Row(): | |
| start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)") | |
| interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval") | |
| calib_in = gr.Checkbox(value=True, label="Apply bias/scale calibration (α)") | |
| run_btn = gr.Button("Run", variant="primary") | |
| plot = gr.Plot(label="Forecast vs Actual (last 30 days)") | |
| meta = gr.JSON(label="Run config & metrics") | |
| table = gr.Dataframe(label="Per-day comparison", wrap=True) | |
| metrics = gr.Markdown(label="Summary") | |
| run_btn.click(run_for_ticker, inputs=[tickers_in, start_in, interval_in, calib_in], | |
| outputs=[plot, meta, table, metrics]) | |
| if __name__ == "__main__": | |
| demo.launch() | |