shagatoo commited on
Commit
1450fb5
Β·
verified Β·
1 Parent(s): 451a054

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +351 -0
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import yfinance as yf
5
+ import joblib
6
+ from tensorflow.keras.models import load_model
7
+ import plotly.graph_objects as go
8
+ from datetime import datetime, timedelta
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class StockPredictorApp:
13
+ def __init__(self, arima_path='arima_model.pkl',
14
+ lstm_path='lstm_model.h5',
15
+ scaler_path='scaler.pkl'):
16
+ """
17
+ Initialize the stock predictor with pre-trained models
18
+ """
19
+ try:
20
+ # Load models
21
+ self.arima_model = joblib.load(arima_path)
22
+ self.lstm_model = load_model(lstm_path)
23
+ self.scaler = joblib.load(scaler_path)
24
+ self.lookback = 60 # Default lookback period for LSTM
25
+ print("Models loaded successfully!")
26
+ except Exception as e:
27
+ print(f"Error loading models: {e}")
28
+ self.arima_model = None
29
+ self.lstm_model = None
30
+ self.scaler = None
31
+
32
+ def fetch_stock_data(self, ticker, days_back=365):
33
+ """
34
+ Fetch recent stock data for prediction
35
+ """
36
+ try:
37
+ end_date = datetime.now()
38
+ start_date = end_date - timedelta(days=days_back)
39
+
40
+ # Download stock data
41
+ stock_data = yf.download(ticker,
42
+ start=start_date.strftime('%Y-%m-%d'),
43
+ end=end_date.strftime('%Y-%m-%d'),
44
+ progress=False)
45
+
46
+ if stock_data.empty:
47
+ return None, "No data found for this ticker"
48
+
49
+ # Extract closing prices
50
+ prices = stock_data[['Close']].copy()
51
+ prices.columns = ['price']
52
+
53
+ return prices, None
54
+ except Exception as e:
55
+ return None, f"Error fetching data: {str(e)}"
56
+
57
+ def prepare_lstm_input(self, data):
58
+ """
59
+ Prepare data for LSTM prediction
60
+ """
61
+ # Scale the data
62
+ scaled_data = self.scaler.transform(data[['price']])
63
+
64
+ # Create sequences
65
+ if len(scaled_data) < self.lookback:
66
+ # Pad with the first value if not enough data
67
+ padding = np.tile(scaled_data[0], (self.lookback - len(scaled_data), 1))
68
+ scaled_data = np.vstack([padding, scaled_data])
69
+
70
+ # Take the last lookback values
71
+ sequence = scaled_data[-self.lookback:].reshape(1, self.lookback, 1)
72
+
73
+ return sequence
74
+
75
+ def predict_next_days(self, ticker, num_days):
76
+ """
77
+ Predict stock prices for the next n days
78
+ """
79
+ if not all([self.arima_model, self.lstm_model, self.scaler]):
80
+ return None, None, "Models not loaded properly. Please check model files."
81
+
82
+ # Fetch historical data
83
+ historical_data, error = self.fetch_stock_data(ticker, days_back=365)
84
+
85
+ if error:
86
+ return None, None, error
87
+
88
+ try:
89
+ # ARIMA Predictions
90
+ arima_forecast = self.arima_model.forecast(steps=num_days)
91
+
92
+ # LSTM Predictions
93
+ lstm_predictions = []
94
+ current_data = historical_data.copy()
95
+
96
+ for _ in range(num_days):
97
+ # Prepare input
98
+ lstm_input = self.prepare_lstm_input(current_data)
99
+
100
+ # Make prediction
101
+ scaled_pred = self.lstm_model.predict(lstm_input, verbose=0)
102
+ pred = self.scaler.inverse_transform(scaled_pred)[0, 0]
103
+ lstm_predictions.append(pred)
104
+
105
+ # Add prediction to data for next iteration
106
+ next_date = current_data.index[-1] + timedelta(days=1)
107
+ new_row = pd.DataFrame({'price': [pred]}, index=[next_date])
108
+ current_data = pd.concat([current_data, new_row])
109
+
110
+ # Create future dates
111
+ last_date = historical_data.index[-1]
112
+ future_dates = pd.date_range(start=last_date + timedelta(days=1),
113
+ periods=num_days, freq='D')
114
+
115
+ # Create prediction DataFrames
116
+ arima_df = pd.DataFrame({
117
+ 'Date': future_dates,
118
+ 'ARIMA_Prediction': arima_forecast.values
119
+ })
120
+
121
+ lstm_df = pd.DataFrame({
122
+ 'Date': future_dates,
123
+ 'LSTM_Prediction': lstm_predictions
124
+ })
125
+
126
+ # Combine predictions
127
+ predictions_df = pd.merge(arima_df, lstm_df, on='Date')
128
+ predictions_df['Average_Prediction'] = (predictions_df['ARIMA_Prediction'] +
129
+ predictions_df['LSTM_Prediction']) / 2
130
+
131
+ return historical_data, predictions_df, None
132
+
133
+ except Exception as e:
134
+ return None, None, f"Prediction error: {str(e)}"
135
+
136
+ def create_plot(self, historical_data, predictions_df, ticker):
137
+ """
138
+ Create an interactive plot using Plotly
139
+ """
140
+ fig = go.Figure()
141
+
142
+ # Plot historical data
143
+ fig.add_trace(go.Scatter(
144
+ x=historical_data.index,
145
+ y=historical_data['price'],
146
+ mode='lines',
147
+ name='Historical Price',
148
+ line=dict(color='black', width=2)
149
+ ))
150
+
151
+ # Plot ARIMA predictions
152
+ fig.add_trace(go.Scatter(
153
+ x=predictions_df['Date'],
154
+ y=predictions_df['ARIMA_Prediction'],
155
+ mode='lines+markers',
156
+ name='ARIMA Forecast',
157
+ line=dict(color='blue', width=2, dash='dash'),
158
+ marker=dict(size=6)
159
+ ))
160
+
161
+ # Plot LSTM predictions
162
+ fig.add_trace(go.Scatter(
163
+ x=predictions_df['Date'],
164
+ y=predictions_df['LSTM_Prediction'],
165
+ mode='lines+markers',
166
+ name='LSTM Forecast',
167
+ line=dict(color='red', width=2, dash='dash'),
168
+ marker=dict(size=6)
169
+ ))
170
+
171
+ # Plot average predictions
172
+ fig.add_trace(go.Scatter(
173
+ x=predictions_df['Date'],
174
+ y=predictions_df['Average_Prediction'],
175
+ mode='lines+markers',
176
+ name='Ensemble (Average)',
177
+ line=dict(color='green', width=2, dash='dot'),
178
+ marker=dict(size=8)
179
+ ))
180
+
181
+ # Update layout
182
+ fig.update_layout(
183
+ title=f'{ticker} Stock Price Forecast',
184
+ xaxis_title='Date',
185
+ yaxis_title='Price ($)',
186
+ hovermode='x unified',
187
+ showlegend=True,
188
+ template='plotly_white',
189
+ height=600
190
+ )
191
+
192
+ # Add a vertical line to separate historical and predicted
193
+ fig.add_vline(x=historical_data.index[-1],
194
+ line_dash="solid",
195
+ line_color="gray",
196
+ annotation_text="Forecast Start")
197
+
198
+ return fig
199
+
200
+ # Initialize the app
201
+ predictor = StockPredictorApp()
202
+
203
+ def predict_stock_price(ticker, num_days):
204
+ """
205
+ Main prediction function for Gradio interface
206
+ """
207
+ if not ticker:
208
+ return None, "Please enter a stock ticker symbol"
209
+
210
+ # Convert ticker to uppercase
211
+ ticker = ticker.upper()
212
+
213
+ # Validate number of days
214
+ if num_days < 1 or num_days > 90:
215
+ return None, "Please enter a number of days between 1 and 90"
216
+
217
+ # Get predictions
218
+ historical_data, predictions_df, error = predictor.predict_next_days(ticker, num_days)
219
+
220
+ if error:
221
+ return None, error
222
+
223
+ # Create plot
224
+ fig = predictor.create_plot(historical_data, predictions_df, ticker)
225
+
226
+ # Format predictions table
227
+ predictions_display = predictions_df.copy()
228
+ predictions_display['Date'] = predictions_display['Date'].dt.strftime('%Y-%m-%d')
229
+ predictions_display = predictions_display.round(2)
230
+
231
+ # Calculate summary statistics
232
+ summary = f"""
233
+ ### Prediction Summary for {ticker}
234
+
235
+ **Forecast Period**: {num_days} days
236
+
237
+ **ARIMA Model**:
238
+ - First Day: ${predictions_df['ARIMA_Prediction'].iloc[0]:.2f}
239
+ - Last Day: ${predictions_df['ARIMA_Prediction'].iloc[-1]:.2f}
240
+ - Average: ${predictions_df['ARIMA_Prediction'].mean():.2f}
241
+ - Trend: {'πŸ“ˆ Upward' if predictions_df['ARIMA_Prediction'].iloc[-1] > predictions_df['ARIMA_Prediction'].iloc[0] else 'πŸ“‰ Downward'}
242
+
243
+ **LSTM Model**:
244
+ - First Day: ${predictions_df['LSTM_Prediction'].iloc[0]:.2f}
245
+ - Last Day: ${predictions_df['LSTM_Prediction'].iloc[-1]:.2f}
246
+ - Average: ${predictions_df['LSTM_Prediction'].mean():.2f}
247
+ - Trend: {'πŸ“ˆ Upward' if predictions_df['LSTM_Prediction'].iloc[-1] > predictions_df['LSTM_Prediction'].iloc[0] else 'πŸ“‰ Downward'}
248
+
249
+ **Ensemble (Average)**:
250
+ - First Day: ${predictions_df['Average_Prediction'].iloc[0]:.2f}
251
+ - Last Day: ${predictions_df['Average_Prediction'].iloc[-1]:.2f}
252
+ - Average: ${predictions_df['Average_Prediction'].mean():.2f}
253
+
254
+ **Current Price**: ${historical_data['price'].iloc[-1]:.2f}
255
+ **Expected Change**: {'+' if predictions_df['Average_Prediction'].iloc[-1] > historical_data['price'].iloc[-1] else ''}{((predictions_df['Average_Prediction'].iloc[-1] / historical_data['price'].iloc[-1] - 1) * 100):.2f}%
256
+ """
257
+
258
+ return fig, summary, predictions_display
259
+
260
+ # Create Gradio interface
261
+ with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
262
+ gr.Markdown(
263
+ """
264
+ # πŸ“ˆ Stock Price Forecaster
265
+
266
+ This app uses pre-trained ARIMA and LSTM models to predict stock prices.
267
+ Enter a stock ticker symbol and the number of days to forecast.
268
+
269
+ **Models:**
270
+ - πŸ”΅ ARIMA: Statistical time series model
271
+ - πŸ”΄ LSTM: Deep learning sequential model
272
+ - 🟒 Ensemble: Average of both models
273
+ """
274
+ )
275
+
276
+ with gr.Row():
277
+ with gr.Column(scale=1):
278
+ ticker_input = gr.Textbox(
279
+ label="Stock Ticker Symbol",
280
+ placeholder="e.g., AAPL, GOOGL, MSFT",
281
+ value="AAPL"
282
+ )
283
+
284
+ days_input = gr.Slider(
285
+ minimum=1,
286
+ maximum=30,
287
+ value=7,
288
+ step=1,
289
+ label="Number of Days to Forecast"
290
+ )
291
+
292
+ predict_button = gr.Button("πŸš€ Generate Forecast", variant="primary")
293
+
294
+ gr.Markdown(
295
+ """
296
+ ### Popular Tickers:
297
+ - **Tech**: AAPL, GOOGL, MSFT, AMZN, TSLA
298
+ - **Finance**: JPM, BAC, V, MA
299
+ - **Healthcare**: JNJ, UNH, PFE
300
+ - **Energy**: XOM, CVX
301
+ """
302
+ )
303
+
304
+ with gr.Row():
305
+ with gr.Column(scale=2):
306
+ plot_output = gr.Plot(label="Price Forecast Chart")
307
+
308
+ with gr.Row():
309
+ summary_output = gr.Markdown(label="Forecast Summary")
310
+
311
+ with gr.Row():
312
+ predictions_table = gr.Dataframe(
313
+ label="Detailed Predictions",
314
+ headers=["Date", "ARIMA_Prediction", "LSTM_Prediction", "Average_Prediction"],
315
+ datatype=["str", "number", "number", "number"]
316
+ )
317
+
318
+ # Add examples
319
+ gr.Examples(
320
+ examples=[
321
+ ["AAPL", 7],
322
+ ["GOOGL", 14],
323
+ ["TSLA", 5],
324
+ ["MSFT", 10]
325
+ ],
326
+ inputs=[ticker_input, days_input],
327
+ outputs=[plot_output, summary_output, predictions_table],
328
+ fn=predict_stock_price,
329
+ cache_examples=False
330
+ )
331
+
332
+ # Connect the prediction function
333
+ predict_button.click(
334
+ fn=predict_stock_price,
335
+ inputs=[ticker_input, days_input],
336
+ outputs=[plot_output, summary_output, predictions_table]
337
+ )
338
+
339
+ gr.Markdown(
340
+ """
341
+ ---
342
+ ### πŸ“Š About the Models
343
+ - **ARIMA**: Auto-Regressive Integrated Moving Average model trained on historical price data
344
+ - **LSTM**: Long Short-Term Memory neural network with 3 layers and dropout regularization
345
+ - **Training Data**: Historical stock prices from Yahoo Finance
346
+ """
347
+ )
348
+
349
+ # Launch the app
350
+ if __name__ == "__main__":
351
+ app.launch(share=True)