cesparzaf commited on
Commit
a8ce7ec
·
1 Parent(s): 09a0967

App de pronóstico con Chronos-Bolt (UI+API)

Browse files
Files changed (2) hide show
  1. app.py +104 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+
7
+ # 🔽 Import del pipeline de Chronos (lib 'chronos-forecasting')
8
+ from chronos import ChronosPipeline
9
+
10
+ # Modelo recomendado para CPU free tier (rápido y estable)
11
+ MODEL_ID = "amazon/chronos-bolt-base"
12
+
13
+ # Cargar el modelo UNA sola vez
14
+ PIPELINE = ChronosPipeline.from_pretrained(
15
+ MODEL_ID,
16
+ device_map="auto",
17
+ torch_dtype=torch.float32, # en CPU va bien
18
+ )
19
+
20
+ def _prepare_series(df: pd.DataFrame, freq: str | None):
21
+ """
22
+ Espera columnas: date,value
23
+ - Ordena por fecha
24
+ - Infere o aplica frecuencia
25
+ - Interpola huecos
26
+ """
27
+ if "date" not in df.columns or "value" not in df.columns:
28
+ raise gr.Error("El CSV debe tener columnas: date,value")
29
+
30
+ df = df.copy()
31
+ df["date"] = pd.to_datetime(df["date"])
32
+ df = df.sort_values("date")
33
+
34
+ if freq and freq.strip():
35
+ df = df.set_index("date").asfreq(freq).reset_index()
36
+ else:
37
+ inferred = pd.infer_freq(df["date"])
38
+ if inferred is None:
39
+ # fallback: tamaño de paso por mediana en días
40
+ step = max(int((df["date"].diff().median() / pd.Timedelta(days=1)) or 1), 1)
41
+ df = df.set_index("date").asfreq(f"{step}D").reset_index()
42
+ else:
43
+ df = df.set_index("date").asfreq(inferred).reset_index()
44
+
45
+ # Rellenar faltantes
46
+ df["value"] = pd.to_numeric(df["value"], errors="coerce")
47
+ df["value"] = df["value"].interpolate("linear").bfill().ffill()
48
+ return df
49
+
50
+ def forecast_fn(file, horizon: int = 12, freq: str = "MS"):
51
+ if file is None:
52
+ raise gr.Error("Sube un CSV con columnas: date,value")
53
+ df = pd.read_csv(file.name)
54
+ df = _prepare_series(df, freq.strip() or None)
55
+
56
+ # Serie a tensor
57
+ y = torch.tensor(df["value"].values, dtype=torch.float32)
58
+
59
+ # Predicción probabilística (múltiples trayectorias -> cuantiles)
60
+ samples = PIPELINE.predict(y, prediction_length=horizon, num_samples=200) # [1, N, H]
61
+ samples = samples[0].numpy() # [N, H]
62
+ p10, p50, p90 = np.quantile(samples, [0.10, 0.50, 0.90], axis=0)
63
+
64
+ # Fechas futuras
65
+ inferred = pd.infer_freq(df["date"])
66
+ if inferred is None:
67
+ step = max(int((df["date"].diff().median() / pd.Timedelta(days=1)) or 1), 1)
68
+ future_index = pd.date_range(df["date"].iloc[-1], periods=horizon+1, freq=f"{step}D")[1:]
69
+ else:
70
+ future_index = pd.date_range(df["date"].iloc[-1], periods=horizon+1, freq=inferred)[1:]
71
+
72
+ out = pd.DataFrame({
73
+ "date": future_index,
74
+ "p10": np.round(p10, 4),
75
+ "p50": np.round(p50, 4),
76
+ "p90": np.round(p90, 4),
77
+ })
78
+
79
+ # Gráfica
80
+ fig = plt.figure(figsize=(8, 4))
81
+ plt.plot(df["date"], df["value"], label="Histórico")
82
+ plt.plot(out["date"], out["p50"], label="Pronóstico (P50)")
83
+ plt.fill_between(out["date"], out["p10"], out["p90"], alpha=0.3, label="Banda P10–P90")
84
+ plt.title("Pronóstico con Chronos-Bolt (P10 / P50 / P90)")
85
+ plt.xlabel("Fecha"); plt.ylabel("Valor")
86
+ plt.legend()
87
+
88
+ return out, fig
89
+
90
+ with gr.Blocks(title="Pronóstico de Demanda (Chronos-Bolt)") as demo:
91
+ gr.Markdown("## Análisis predictivo de mercado (Hugging Face + Chronos-Bolt)\nSube un CSV con columnas **date,value**. Elige horizonte y frecuencia.")
92
+ with gr.Row():
93
+ file = gr.File(label="CSV: date,value", file_types=[".csv"])
94
+ horizon = gr.Slider(1, 36, value=12, step=1, label="Horizonte (pasos)")
95
+ freq = gr.Dropdown(choices=["", "D", "W", "MS", "M"], value="MS",
96
+ label="Frecuencia (opcional). ''=inferir, MS=mensual")
97
+ btn = gr.Button("Generar pronóstico")
98
+ out_table = gr.Dataframe(label="Tabla de pronóstico")
99
+ out_plot = gr.Plot(label="Gráfica")
100
+ # api_name → te da un endpoint gratis para consumir como API
101
+ btn.click(forecast_fn, inputs=[file, horizon, freq], outputs=[out_table, out_plot], api_name="/forecast")
102
+
103
+ if __name__ == "__main__":
104
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ pandas>=2.0.0
3
+ numpy>=1.24.0
4
+ torch>=2.2.0
5
+ chronos-forecasting>=1.2.0
6
+ matplotlib>=3.7.0