Spaces:
Sleeping
Sleeping
Gil Stetler
commited on
Commit
·
89cf40b
1
Parent(s):
682cd17
fix
Browse files- app.py +20 -14
- pipeline_v2.py +38 -23
app.py
CHANGED
|
@@ -524,7 +524,7 @@ matplotlib.use("Agg")
|
|
| 524 |
import matplotlib.pyplot as plt
|
| 525 |
from chronos import ChronosPipeline
|
| 526 |
|
| 527 |
-
#
|
| 528 |
import pipeline_v2 as pipe2 # update_ticker_csv(...)
|
| 529 |
|
| 530 |
# --------------------
|
|
@@ -568,7 +568,7 @@ def _extract_dates(df: pd.DataFrame):
|
|
| 568 |
# If index is DatetimeIndex, use it
|
| 569 |
if isinstance(df.index, pd.DatetimeIndex):
|
| 570 |
return df.index.to_numpy()
|
| 571 |
-
# Else
|
| 572 |
mapping = {c.lower(): c for c in df.columns}
|
| 573 |
for name in ["date", "time", "timestamp"]:
|
| 574 |
if name in mapping:
|
|
@@ -607,23 +607,28 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
|
|
| 607 |
start: YYYY-MM-DD
|
| 608 |
interval: '1d', '1wk', '1mo'
|
| 609 |
"""
|
| 610 |
-
# Parse first ticker
|
| 611 |
tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
|
| 612 |
if not tick_list:
|
| 613 |
-
raise gr.Error("Please enter at least one ticker, e.g. AAPL")
|
| 614 |
-
|
|
|
|
| 615 |
|
| 616 |
# 1) Fetch/update CSV via pipeline
|
| 617 |
try:
|
| 618 |
csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
|
| 619 |
except Exception as e:
|
| 620 |
-
raise gr.Error(
|
|
|
|
|
|
|
| 621 |
|
| 622 |
# 2) Load CSV and build realized vol
|
| 623 |
try:
|
| 624 |
-
df = pd.read_csv(csv_path, index_col=0, parse_dates=
|
|
|
|
|
|
|
|
|
|
| 625 |
except Exception:
|
| 626 |
-
# Fallback if index parsing fails
|
| 627 |
df = pd.read_csv(csv_path)
|
| 628 |
|
| 629 |
dates = _extract_dates(df)
|
|
@@ -662,7 +667,6 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
|
|
| 662 |
fig = plt.figure(figsize=(10, 4))
|
| 663 |
H0 = len(rv_train)
|
| 664 |
|
| 665 |
-
# Align dates to rv length if we have real dates
|
| 666 |
if isinstance(dates, np.ndarray) and len(dates) >= len(close):
|
| 667 |
dates_rv = np.array(dates[-len(rv):])
|
| 668 |
x_hist = dates_rv[:H0]
|
|
@@ -727,13 +731,14 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
|
|
| 727 |
with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
|
| 728 |
gr.Markdown(
|
| 729 |
"### Predict last 30 days of realized volatility for any ticker\n"
|
| 730 |
-
"-
|
| 731 |
-
"-
|
| 732 |
-
"-
|
| 733 |
-
"-
|
|
|
|
| 734 |
)
|
| 735 |
with gr.Row():
|
| 736 |
-
tickers_in = gr.Textbox(value="AAPL", label="
|
| 737 |
with gr.Row():
|
| 738 |
start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
|
| 739 |
interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
|
|
@@ -750,3 +755,4 @@ with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as d
|
|
| 750 |
|
| 751 |
if __name__ == "__main__":
|
| 752 |
demo.launch()
|
|
|
|
|
|
| 524 |
import matplotlib.pyplot as plt
|
| 525 |
from chronos import ChronosPipeline
|
| 526 |
|
| 527 |
+
# our data pipeline
|
| 528 |
import pipeline_v2 as pipe2 # update_ticker_csv(...)
|
| 529 |
|
| 530 |
# --------------------
|
|
|
|
| 568 |
# If index is DatetimeIndex, use it
|
| 569 |
if isinstance(df.index, pd.DatetimeIndex):
|
| 570 |
return df.index.to_numpy()
|
| 571 |
+
# Else try a date-like column
|
| 572 |
mapping = {c.lower(): c for c in df.columns}
|
| 573 |
for name in ["date", "time", "timestamp"]:
|
| 574 |
if name in mapping:
|
|
|
|
| 607 |
start: YYYY-MM-DD
|
| 608 |
interval: '1d', '1wk', '1mo'
|
| 609 |
"""
|
| 610 |
+
# Parse first ticker (keep dots and dashes!)
|
| 611 |
tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
|
| 612 |
if not tick_list:
|
| 613 |
+
raise gr.Error("Please enter at least one ticker, e.g. AAPL or NESN.SW")
|
| 614 |
+
|
| 615 |
+
ticker = tick_list[0] # keep original form; pipeline handles uppercasing
|
| 616 |
|
| 617 |
# 1) Fetch/update CSV via pipeline
|
| 618 |
try:
|
| 619 |
csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
|
| 620 |
except Exception as e:
|
| 621 |
+
raise gr.Error(
|
| 622 |
+
f"Data fetch failed for '{ticker}'. Tip: ensure exchange suffixes (e.g., NESN.SW, BMW.DE, VOD.L).\n{e}"
|
| 623 |
+
)
|
| 624 |
|
| 625 |
# 2) Load CSV and build realized vol
|
| 626 |
try:
|
| 627 |
+
df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
|
| 628 |
+
if not isinstance(df.index, pd.DatetimeIndex):
|
| 629 |
+
# last fallback
|
| 630 |
+
df = pd.read_csv(csv_path)
|
| 631 |
except Exception:
|
|
|
|
| 632 |
df = pd.read_csv(csv_path)
|
| 633 |
|
| 634 |
dates = _extract_dates(df)
|
|
|
|
| 667 |
fig = plt.figure(figsize=(10, 4))
|
| 668 |
H0 = len(rv_train)
|
| 669 |
|
|
|
|
| 670 |
if isinstance(dates, np.ndarray) and len(dates) >= len(close):
|
| 671 |
dates_rv = np.array(dates[-len(rv):])
|
| 672 |
x_hist = dates_rv[:H0]
|
|
|
|
| 731 |
with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
|
| 732 |
gr.Markdown(
|
| 733 |
"### Predict last 30 days of realized volatility for any ticker\n"
|
| 734 |
+
"- Works with symbols like `AAPL`, `NESN.SW`, `BMW.DE`, `VOD.L`, `BRK-B`, `BTC-USD`.\n"
|
| 735 |
+
"- Data fetched via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n"
|
| 736 |
+
"- Forecast uses **Chronos-T5-Large** (single path, deterministic seed).\n"
|
| 737 |
+
"- Day-by-day comparison with **MAPE/MPE/RMSE**.\n"
|
| 738 |
+
"- Optional **Bias/Scale Calibration (α)**."
|
| 739 |
)
|
| 740 |
with gr.Row():
|
| 741 |
+
tickers_in = gr.Textbox(value="AAPL", label="Ticker (you can use suffixes like NESN.SW, BMW.DE)")
|
| 742 |
with gr.Row():
|
| 743 |
start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
|
| 744 |
interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
|
|
|
|
| 755 |
|
| 756 |
if __name__ == "__main__":
|
| 757 |
demo.launch()
|
| 758 |
+
|
pipeline_v2.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# pipeline_v2.py
|
| 2 |
import os
|
| 3 |
-
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
try:
|
|
@@ -15,50 +15,65 @@ def _ensure_dir(path: str) -> None:
|
|
| 15 |
os.makedirs(path, exist_ok=True)
|
| 16 |
|
| 17 |
|
| 18 |
-
def
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def update_ticker_csv(
|
| 23 |
ticker: str,
|
| 24 |
start: str = "2015-01-01",
|
| 25 |
interval: str = "1d",
|
| 26 |
-
dst_dir: str = "/mnt/data"
|
| 27 |
) -> str:
|
| 28 |
"""
|
| 29 |
Download OHLCV for `ticker` using yfinance and save as CSV.
|
| 30 |
Returns the CSV file path.
|
| 31 |
-
|
| 32 |
-
Args:
|
| 33 |
-
ticker: e.g. "AAPL"
|
| 34 |
-
start: "YYYY-MM-DD"
|
| 35 |
-
interval: "1d", "1wk", "1mo"
|
| 36 |
-
dst_dir: directory to write CSVs (default: /mnt/data for Spaces)
|
| 37 |
"""
|
| 38 |
_ensure_dir(dst_dir)
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
df = yf.download(
|
| 42 |
-
|
| 43 |
start=start,
|
| 44 |
interval=interval,
|
| 45 |
-
auto_adjust=False,
|
| 46 |
progress=False,
|
| 47 |
threads=True,
|
| 48 |
)
|
| 49 |
|
| 50 |
if df is None or df.empty:
|
| 51 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
# Ensure a clean
|
| 54 |
-
if isinstance(df.index, pd.DatetimeIndex):
|
| 55 |
-
df = df.
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
csv_path = os.path.join(dst_dir, f"{tkr}_{interval}.csv")
|
| 62 |
-
df.to_csv(csv_path)
|
| 63 |
|
|
|
|
|
|
|
| 64 |
return csv_path
|
|
|
|
| 1 |
# pipeline_v2.py
|
| 2 |
import os
|
| 3 |
+
import re
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
try:
|
|
|
|
| 15 |
os.makedirs(path, exist_ok=True)
|
| 16 |
|
| 17 |
|
| 18 |
+
def _ticker_for_query(t: str) -> str:
|
| 19 |
+
"""
|
| 20 |
+
Prepare ticker for yfinance:
|
| 21 |
+
- strip spaces
|
| 22 |
+
- uppercase
|
| 23 |
+
- DO NOT alter '.' or '-' (yfinance relies on them, e.g. NESN.SW, BRK-B)
|
| 24 |
+
"""
|
| 25 |
+
return t.strip().upper()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _ticker_for_filename(t: str) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Prepare a safe filename:
|
| 31 |
+
- replace any char not [A-Za-z0-9] with '_'
|
| 32 |
+
"""
|
| 33 |
+
return re.sub(r"[^A-Za-z0-9]", "_", t)
|
| 34 |
|
| 35 |
|
| 36 |
def update_ticker_csv(
|
| 37 |
ticker: str,
|
| 38 |
start: str = "2015-01-01",
|
| 39 |
interval: str = "1d",
|
| 40 |
+
dst_dir: str = "/mnt/data"
|
| 41 |
) -> str:
|
| 42 |
"""
|
| 43 |
Download OHLCV for `ticker` using yfinance and save as CSV.
|
| 44 |
Returns the CSV file path.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
_ensure_dir(dst_dir)
|
| 47 |
+
|
| 48 |
+
tkr_query = _ticker_for_query(ticker)
|
| 49 |
+
tkr_file = _ticker_for_filename(tkr_query)
|
| 50 |
|
| 51 |
df = yf.download(
|
| 52 |
+
tkr_query,
|
| 53 |
start=start,
|
| 54 |
interval=interval,
|
| 55 |
+
auto_adjust=False,
|
| 56 |
progress=False,
|
| 57 |
threads=True,
|
| 58 |
)
|
| 59 |
|
| 60 |
if df is None or df.empty:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f"No data returned for ticker '{tkr_query}' (start={start}, interval={interval}). "
|
| 63 |
+
"Check the symbol and exchange suffix (e.g., NESN.SW, BMW.DE, VOD.L)."
|
| 64 |
+
)
|
| 65 |
|
| 66 |
+
# Ensure a clean Date index
|
| 67 |
+
if not isinstance(df.index, pd.DatetimeIndex):
|
| 68 |
+
df = df.reset_index()
|
| 69 |
+
if "Date" in df.columns:
|
| 70 |
+
df = df.set_index("Date")
|
| 71 |
+
else:
|
| 72 |
+
df.columns = ["Date"] + list(df.columns[1:])
|
| 73 |
+
df = df.set_index("Date")
|
| 74 |
|
| 75 |
+
df.index.name = "Date"
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
csv_path = os.path.join(dst_dir, f"{tkr_file}_{interval}.csv")
|
| 78 |
+
df.to_csv(csv_path)
|
| 79 |
return csv_path
|