Gil Stetler commited on
Commit
89cf40b
·
1 Parent(s): 682cd17
Files changed (2) hide show
  1. app.py +20 -14
  2. pipeline_v2.py +38 -23
app.py CHANGED
@@ -524,7 +524,7 @@ matplotlib.use("Agg")
524
  import matplotlib.pyplot as plt
525
  from chronos import ChronosPipeline
526
 
527
- # --- our data pipeline ---
528
  import pipeline_v2 as pipe2 # update_ticker_csv(...)
529
 
530
  # --------------------
@@ -568,7 +568,7 @@ def _extract_dates(df: pd.DataFrame):
568
  # If index is DatetimeIndex, use it
569
  if isinstance(df.index, pd.DatetimeIndex):
570
  return df.index.to_numpy()
571
- # Else look for a date-like column
572
  mapping = {c.lower(): c for c in df.columns}
573
  for name in ["date", "time", "timestamp"]:
574
  if name in mapping:
@@ -607,23 +607,28 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
607
  start: YYYY-MM-DD
608
  interval: '1d', '1wk', '1mo'
609
  """
610
- # Parse first ticker
611
  tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
612
  if not tick_list:
613
- raise gr.Error("Please enter at least one ticker, e.g. AAPL")
614
- ticker = tick_list[0]
 
615
 
616
  # 1) Fetch/update CSV via pipeline
617
  try:
618
  csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
619
  except Exception as e:
620
- raise gr.Error(f"Data fetch failed for '{ticker}': {e}")
 
 
621
 
622
  # 2) Load CSV and build realized vol
623
  try:
624
- df = pd.read_csv(csv_path, index_col=0, parse_dates=[0])
 
 
 
625
  except Exception:
626
- # Fallback if index parsing fails
627
  df = pd.read_csv(csv_path)
628
 
629
  dates = _extract_dates(df)
@@ -662,7 +667,6 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
662
  fig = plt.figure(figsize=(10, 4))
663
  H0 = len(rv_train)
664
 
665
- # Align dates to rv length if we have real dates
666
  if isinstance(dates, np.ndarray) and len(dates) >= len(close):
667
  dates_rv = np.array(dates[-len(rv):])
668
  x_hist = dates_rv[:H0]
@@ -727,13 +731,14 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
727
  with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
728
  gr.Markdown(
729
  "### Predict last 30 days of realized volatility for any ticker\n"
730
- "- Fetches data via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n"
731
- "- Forecast uses **Chronos-T5-Large** (single path, no mean/median).\n"
732
- "- Compares day-by-day to actual RV and reports **MAPE/MPE/RMSE**.\n"
733
- "- Optional **Bias/Scale Calibration (α)** to remove systematic bias."
 
734
  )
735
  with gr.Row():
736
- tickers_in = gr.Textbox(value="AAPL", label="Tickers (comma-separated; first is evaluated)")
737
  with gr.Row():
738
  start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
739
  interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
@@ -750,3 +755,4 @@ with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as d
750
 
751
  if __name__ == "__main__":
752
  demo.launch()
 
 
524
  import matplotlib.pyplot as plt
525
  from chronos import ChronosPipeline
526
 
527
+ # our data pipeline
528
  import pipeline_v2 as pipe2 # update_ticker_csv(...)
529
 
530
  # --------------------
 
568
  # If index is DatetimeIndex, use it
569
  if isinstance(df.index, pd.DatetimeIndex):
570
  return df.index.to_numpy()
571
+ # Else try a date-like column
572
  mapping = {c.lower(): c for c in df.columns}
573
  for name in ["date", "time", "timestamp"]:
574
  if name in mapping:
 
607
  start: YYYY-MM-DD
608
  interval: '1d', '1wk', '1mo'
609
  """
610
+ # Parse first ticker (keep dots and dashes!)
611
  tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
612
  if not tick_list:
613
+ raise gr.Error("Please enter at least one ticker, e.g. AAPL or NESN.SW")
614
+
615
+ ticker = tick_list[0] # keep original form; pipeline handles uppercasing
616
 
617
  # 1) Fetch/update CSV via pipeline
618
  try:
619
  csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
620
  except Exception as e:
621
+ raise gr.Error(
622
+ f"Data fetch failed for '{ticker}'. Tip: ensure exchange suffixes (e.g., NESN.SW, BMW.DE, VOD.L).\n{e}"
623
+ )
624
 
625
  # 2) Load CSV and build realized vol
626
  try:
627
+ df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
628
+ if not isinstance(df.index, pd.DatetimeIndex):
629
+ # last fallback
630
+ df = pd.read_csv(csv_path)
631
  except Exception:
 
632
  df = pd.read_csv(csv_path)
633
 
634
  dates = _extract_dates(df)
 
667
  fig = plt.figure(figsize=(10, 4))
668
  H0 = len(rv_train)
669
 
 
670
  if isinstance(dates, np.ndarray) and len(dates) >= len(close):
671
  dates_rv = np.array(dates[-len(rv):])
672
  x_hist = dates_rv[:H0]
 
731
  with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
732
  gr.Markdown(
733
  "### Predict last 30 days of realized volatility for any ticker\n"
734
+ "- Works with symbols like `AAPL`, `NESN.SW`, `BMW.DE`, `VOD.L`, `BRK-B`, `BTC-USD`.\n"
735
+ "- Data fetched via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n"
736
+ "- Forecast uses **Chronos-T5-Large** (single path, deterministic seed).\n"
737
+ "- Day-by-day comparison with **MAPE/MPE/RMSE**.\n"
738
+ "- Optional **Bias/Scale Calibration (α)**."
739
  )
740
  with gr.Row():
741
+ tickers_in = gr.Textbox(value="AAPL", label="Ticker (you can use suffixes like NESN.SW, BMW.DE)")
742
  with gr.Row():
743
  start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
744
  interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
 
755
 
756
  if __name__ == "__main__":
757
  demo.launch()
758
+
pipeline_v2.py CHANGED
@@ -1,6 +1,6 @@
1
  # pipeline_v2.py
2
  import os
3
- from typing import Tuple
4
  import pandas as pd
5
 
6
  try:
@@ -15,50 +15,65 @@ def _ensure_dir(path: str) -> None:
15
  os.makedirs(path, exist_ok=True)
16
 
17
 
18
- def _sanitize_ticker(t: str) -> str:
19
- return t.strip().upper().replace(" ", "").replace("/", "-").replace(".", "-")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def update_ticker_csv(
23
  ticker: str,
24
  start: str = "2015-01-01",
25
  interval: str = "1d",
26
- dst_dir: str = "/mnt/data" # HF Spaces writeable path
27
  ) -> str:
28
  """
29
  Download OHLCV for `ticker` using yfinance and save as CSV.
30
  Returns the CSV file path.
31
-
32
- Args:
33
- ticker: e.g. "AAPL"
34
- start: "YYYY-MM-DD"
35
- interval: "1d", "1wk", "1mo"
36
- dst_dir: directory to write CSVs (default: /mnt/data for Spaces)
37
  """
38
  _ensure_dir(dst_dir)
39
- tkr = _sanitize_ticker(ticker)
 
 
40
 
41
  df = yf.download(
42
- tkr,
43
  start=start,
44
  interval=interval,
45
- auto_adjust=False, # keep explicit Adj Close; we’ll pick Close / Adj Close later
46
  progress=False,
47
  threads=True,
48
  )
49
 
50
  if df is None or df.empty:
51
- raise ValueError(f"No data returned for ticker '{tkr}' with start={start}, interval={interval}.")
 
 
 
52
 
53
- # Ensure a clean, single-index Date column
54
- if isinstance(df.index, pd.DatetimeIndex):
55
- df = df.copy()
56
- df.index.name = "Date"
57
- else:
58
- df = df.reset_index().rename(columns={df.columns[0]: "Date"}).set_index("Date")
 
 
59
 
60
- # Save
61
- csv_path = os.path.join(dst_dir, f"{tkr}_{interval}.csv")
62
- df.to_csv(csv_path)
63
 
 
 
64
  return csv_path
 
1
  # pipeline_v2.py
2
  import os
3
+ import re
4
  import pandas as pd
5
 
6
  try:
 
15
  os.makedirs(path, exist_ok=True)
16
 
17
 
18
+ def _ticker_for_query(t: str) -> str:
19
+ """
20
+ Prepare ticker for yfinance:
21
+ - strip spaces
22
+ - uppercase
23
+ - DO NOT alter '.' or '-' (yfinance relies on them, e.g. NESN.SW, BRK-B)
24
+ """
25
+ return t.strip().upper()
26
+
27
+
28
+ def _ticker_for_filename(t: str) -> str:
29
+ """
30
+ Prepare a safe filename:
31
+ - replace any char not [A-Za-z0-9] with '_'
32
+ """
33
+ return re.sub(r"[^A-Za-z0-9]", "_", t)
34
 
35
 
36
  def update_ticker_csv(
37
  ticker: str,
38
  start: str = "2015-01-01",
39
  interval: str = "1d",
40
+ dst_dir: str = "/mnt/data"
41
  ) -> str:
42
  """
43
  Download OHLCV for `ticker` using yfinance and save as CSV.
44
  Returns the CSV file path.
 
 
 
 
 
 
45
  """
46
  _ensure_dir(dst_dir)
47
+
48
+ tkr_query = _ticker_for_query(ticker)
49
+ tkr_file = _ticker_for_filename(tkr_query)
50
 
51
  df = yf.download(
52
+ tkr_query,
53
  start=start,
54
  interval=interval,
55
+ auto_adjust=False,
56
  progress=False,
57
  threads=True,
58
  )
59
 
60
  if df is None or df.empty:
61
+ raise ValueError(
62
+ f"No data returned for ticker '{tkr_query}' (start={start}, interval={interval}). "
63
+ "Check the symbol and exchange suffix (e.g., NESN.SW, BMW.DE, VOD.L)."
64
+ )
65
 
66
+ # Ensure a clean Date index
67
+ if not isinstance(df.index, pd.DatetimeIndex):
68
+ df = df.reset_index()
69
+ if "Date" in df.columns:
70
+ df = df.set_index("Date")
71
+ else:
72
+ df.columns = ["Date"] + list(df.columns[1:])
73
+ df = df.set_index("Date")
74
 
75
+ df.index.name = "Date"
 
 
76
 
77
+ csv_path = os.path.join(dst_dir, f"{tkr_file}_{interval}.csv")
78
+ df.to_csv(csv_path)
79
  return csv_path