Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import yfinance as yf | |
| import joblib | |
| from tensorflow.keras.models import load_model | |
| import plotly.graph_objects as go | |
| from datetime import datetime, timedelta | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| class StockPredictorApp: | |
| def __init__(self, arima_path='arima_model.pkl', | |
| lstm_path='lstm_model.h5', | |
| scaler_path='scaler.pkl'): | |
| """ | |
| Initialize the stock predictor with pre-trained models | |
| """ | |
| try: | |
| # Load models | |
| self.arima_model = joblib.load(arima_path) | |
| self.lstm_model = load_model(lstm_path) | |
| self.scaler = joblib.load(scaler_path) | |
| self.lookback = 60 # Default lookback period for LSTM | |
| print("Models loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| self.arima_model = None | |
| self.lstm_model = None | |
| self.scaler = None | |
| def fetch_stock_data(self, ticker, days_back=365): | |
| """ | |
| Fetch recent stock data for prediction | |
| """ | |
| try: | |
| end_date = datetime.now() | |
| start_date = end_date - timedelta(days=days_back) | |
| # Download stock data | |
| stock_data = yf.download(ticker, | |
| start=start_date.strftime('%Y-%m-%d'), | |
| end=end_date.strftime('%Y-%m-%d'), | |
| progress=False) | |
| if stock_data.empty: | |
| return None, "No data found for this ticker" | |
| # Extract closing prices | |
| prices = stock_data[['Close']].copy() | |
| prices.columns = ['price'] | |
| return prices, None | |
| except Exception as e: | |
| return None, f"Error fetching data: {str(e)}" | |
| def prepare_lstm_input(self, data): | |
| """ | |
| Prepare data for LSTM prediction | |
| """ | |
| # Scale the data | |
| scaled_data = self.scaler.transform(data[['price']]) | |
| # Create sequences | |
| if len(scaled_data) < self.lookback: | |
| # Pad with the first value if not enough data | |
| padding = np.tile(scaled_data[0], (self.lookback - len(scaled_data), 1)) | |
| scaled_data = np.vstack([padding, scaled_data]) | |
| # Take the last lookback values | |
| sequence = scaled_data[-self.lookback:].reshape(1, self.lookback, 1) | |
| return sequence | |
| def predict_next_days(self, ticker, num_days): | |
| """ | |
| Predict stock prices for the next n days | |
| """ | |
| if not all([self.arima_model, self.lstm_model, self.scaler]): | |
| return None, None, "Models not loaded properly. Please check model files." | |
| # Fetch historical data | |
| historical_data, error = self.fetch_stock_data(ticker, days_back=365) | |
| if error: | |
| return None, None, error | |
| try: | |
| # ARIMA Predictions | |
| arima_forecast = self.arima_model.forecast(steps=num_days) | |
| # LSTM Predictions | |
| lstm_predictions = [] | |
| current_data = historical_data.copy() | |
| for _ in range(num_days): | |
| # Prepare input | |
| lstm_input = self.prepare_lstm_input(current_data) | |
| # Make prediction | |
| scaled_pred = self.lstm_model.predict(lstm_input, verbose=0) | |
| pred = self.scaler.inverse_transform(scaled_pred)[0, 0] | |
| lstm_predictions.append(pred) | |
| # Add prediction to data for next iteration | |
| next_date = current_data.index[-1] + timedelta(days=1) | |
| new_row = pd.DataFrame({'price': [pred]}, index=[next_date]) | |
| current_data = pd.concat([current_data, new_row]) | |
| # Create future dates | |
| last_date = historical_data.index[-1] | |
| future_dates = pd.date_range(start=last_date + timedelta(days=1), | |
| periods=num_days, freq='D') | |
| # Create prediction DataFrames | |
| arima_df = pd.DataFrame({ | |
| 'Date': future_dates, | |
| 'ARIMA_Prediction': arima_forecast.values | |
| }) | |
| lstm_df = pd.DataFrame({ | |
| 'Date': future_dates, | |
| 'LSTM_Prediction': lstm_predictions | |
| }) | |
| # Combine predictions | |
| predictions_df = pd.merge(arima_df, lstm_df, on='Date') | |
| predictions_df['Average_Prediction'] = (predictions_df['ARIMA_Prediction'] + | |
| predictions_df['LSTM_Prediction']) / 2 | |
| return historical_data, predictions_df, None | |
| except Exception as e: | |
| return None, None, f"Prediction error: {str(e)}" | |
| def create_plot(self, historical_data, predictions_df, ticker): | |
| """ | |
| Create an interactive plot using Plotly | |
| """ | |
| fig = go.Figure() | |
| # Plot historical data | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['price'], | |
| mode='lines', | |
| name='Historical Price', | |
| line=dict(color='black', width=2) | |
| )) | |
| # Plot ARIMA predictions | |
| fig.add_trace(go.Scatter( | |
| x=predictions_df['Date'], | |
| y=predictions_df['ARIMA_Prediction'], | |
| mode='lines+markers', | |
| name='ARIMA Forecast', | |
| line=dict(color='blue', width=2, dash='dash'), | |
| marker=dict(size=6) | |
| )) | |
| # Plot LSTM predictions | |
| fig.add_trace(go.Scatter( | |
| x=predictions_df['Date'], | |
| y=predictions_df['LSTM_Prediction'], | |
| mode='lines+markers', | |
| name='LSTM Forecast', | |
| line=dict(color='red', width=2, dash='dash'), | |
| marker=dict(size=6) | |
| )) | |
| # Plot average predictions | |
| fig.add_trace(go.Scatter( | |
| x=predictions_df['Date'], | |
| y=predictions_df['Average_Prediction'], | |
| mode='lines+markers', | |
| name='Ensemble (Average)', | |
| line=dict(color='green', width=2, dash='dot'), | |
| marker=dict(size=8) | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=f'{ticker} Stock Price Forecast', | |
| xaxis_title='Date', | |
| yaxis_title='Price ($)', | |
| hovermode='x unified', | |
| showlegend=True, | |
| template='plotly_white', | |
| height=600 | |
| ) | |
| # Add a vertical line to separate historical and predicted | |
| # Convert timestamp to string to avoid Plotly issues | |
| last_date = historical_data.index[-1] | |
| if hasattr(last_date, 'strftime'): | |
| last_date = last_date.strftime('%Y-%m-%d') | |
| fig.add_vline(x=last_date, | |
| line_dash="solid", | |
| line_color="gray", | |
| annotation_text="Forecast Start") | |
| return fig | |
| # Initialize the app | |
| predictor = StockPredictorApp() | |
| def predict_stock_price(ticker, num_days): | |
| """ | |
| Main prediction function for Gradio interface | |
| """ | |
| # Create empty dataframe for error cases | |
| empty_df = pd.DataFrame() | |
| if not ticker: | |
| return None, "Please enter a stock ticker symbol", empty_df | |
| # Convert ticker to uppercase | |
| ticker = ticker.upper() | |
| # Validate number of days | |
| if num_days < 1 or num_days > 90: | |
| return None, "Please enter a number of days between 1 and 90", empty_df | |
| # Get predictions | |
| historical_data, predictions_df, error = predictor.predict_next_days(ticker, num_days) | |
| if error: | |
| return None, error, empty_df | |
| # Create plot | |
| fig = predictor.create_plot(historical_data, predictions_df, ticker) | |
| # Format predictions table | |
| predictions_display = predictions_df.copy() | |
| predictions_display['Date'] = predictions_display['Date'].dt.strftime('%Y-%m-%d') | |
| predictions_display = predictions_display.round(2) | |
| # Calculate summary statistics | |
| summary = f""" | |
| ### Prediction Summary for {ticker} | |
| **Forecast Period**: {num_days} days | |
| **ARIMA Model**: | |
| - First Day: ${predictions_df['ARIMA_Prediction'].iloc[0]:.2f} | |
| - Last Day: ${predictions_df['ARIMA_Prediction'].iloc[-1]:.2f} | |
| - Average: ${predictions_df['ARIMA_Prediction'].mean():.2f} | |
| - Trend: {'π Upward' if predictions_df['ARIMA_Prediction'].iloc[-1] > predictions_df['ARIMA_Prediction'].iloc[0] else 'π Downward'} | |
| **LSTM Model**: | |
| - First Day: ${predictions_df['LSTM_Prediction'].iloc[0]:.2f} | |
| - Last Day: ${predictions_df['LSTM_Prediction'].iloc[-1]:.2f} | |
| - Average: ${predictions_df['LSTM_Prediction'].mean():.2f} | |
| - Trend: {'π Upward' if predictions_df['LSTM_Prediction'].iloc[-1] > predictions_df['LSTM_Prediction'].iloc[0] else 'π Downward'} | |
| **Ensemble (Average)**: | |
| - First Day: ${predictions_df['Average_Prediction'].iloc[0]:.2f} | |
| - Last Day: ${predictions_df['Average_Prediction'].iloc[-1]:.2f} | |
| - Average: ${predictions_df['Average_Prediction'].mean():.2f} | |
| **Current Price**: ${historical_data['price'].iloc[-1]:.2f} | |
| **Expected Change**: {'+' if predictions_df['Average_Prediction'].iloc[-1] > historical_data['price'].iloc[-1] else ''}{((predictions_df['Average_Prediction'].iloc[-1] / historical_data['price'].iloc[-1] - 1) * 100):.2f}% | |
| """ | |
| return fig, summary, predictions_display | |
| # Create demo mode for when models aren't available | |
| def create_demo_predictions(ticker, num_days): | |
| """ | |
| Create demo predictions when models aren't loaded | |
| """ | |
| # Create fake historical data | |
| dates = pd.date_range(end=datetime.now(), periods=100, freq='D') | |
| base_price = 150.0 | |
| historical_data = pd.DataFrame({ | |
| 'price': base_price + np.cumsum(np.random.randn(100) * 2) | |
| }, index=dates) | |
| # Create fake predictions | |
| future_dates = pd.date_range(start=dates[-1] + timedelta(days=1), | |
| periods=num_days, freq='D') | |
| last_price = historical_data['price'].iloc[-1] | |
| arima_pred = last_price + np.cumsum(np.random.randn(num_days) * 1.5) | |
| lstm_pred = last_price + np.cumsum(np.random.randn(num_days) * 1.5) | |
| predictions_df = pd.DataFrame({ | |
| 'Date': future_dates, | |
| 'ARIMA_Prediction': arima_pred, | |
| 'LSTM_Prediction': lstm_pred, | |
| 'Average_Prediction': (arima_pred + lstm_pred) / 2 | |
| }) | |
| return historical_data, predictions_df | |
| # Modified predict function with fallback to demo mode | |
| def predict_stock_price_safe(ticker, num_days): | |
| """ | |
| Safe prediction function with demo fallback | |
| """ | |
| empty_df = pd.DataFrame() | |
| if not ticker: | |
| return None, "Please enter a stock ticker symbol", empty_df | |
| ticker = ticker.upper() | |
| if num_days < 1 or num_days > 90: | |
| return None, "Please enter a number of days between 1 and 90", empty_df | |
| # Check if models are loaded | |
| if not all([predictor.arima_model, predictor.lstm_model, predictor.scaler]): | |
| # Use demo mode | |
| demo_msg = f""" | |
| ### β οΈ Demo Mode Active | |
| **Note**: Pre-trained models are not available. Showing demo predictions with random data. | |
| To use real predictions, ensure you have: | |
| 1. `arima_model.pkl` - ARIMA model file | |
| 2. `lstm_model.h5` - LSTM model file | |
| 3. `scaler.pkl` - Data scaler file | |
| Place these files in the same directory as the app. | |
| """ | |
| try: | |
| historical_data, predictions_df = create_demo_predictions(ticker, num_days) | |
| fig = predictor.create_plot(historical_data, predictions_df, f"{ticker} (DEMO)") | |
| predictions_display = predictions_df.copy() | |
| predictions_display['Date'] = predictions_display['Date'].dt.strftime('%Y-%m-%d') | |
| predictions_display = predictions_display.round(2) | |
| return fig, demo_msg, predictions_display | |
| except Exception as e: | |
| error_msg = f"Error creating demo predictions: {str(e)}" | |
| return None, error_msg, empty_df | |
| # Normal prediction flow | |
| return predict_stock_price(ticker, num_days) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app: | |
| gr.Markdown( | |
| """ | |
| # π Stock Price Forecaster | |
| This app uses pre-trained ARIMA and LSTM models to predict stock prices. | |
| Enter a stock ticker symbol and the number of days to forecast. | |
| **Models:** | |
| - π΅ ARIMA: Statistical time series model | |
| - π΄ LSTM: Deep learning sequential model | |
| - π’ Ensemble: Average of both models | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ticker_input = gr.Textbox( | |
| label="Stock Ticker Symbol", | |
| placeholder="AAPL", | |
| value="AAPL" | |
| ) | |
| days_input = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| value=7, | |
| step=1, | |
| label="Number of Days to Forecast" | |
| ) | |
| predict_button = gr.Button("π Generate Forecast", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| plot_output = gr.Plot(label="Price Forecast Chart") | |
| with gr.Row(): | |
| summary_output = gr.Markdown(label="Forecast Summary") | |
| with gr.Row(): | |
| predictions_table = gr.Dataframe( | |
| label="Detailed Predictions", | |
| headers=["Date", "ARIMA_Prediction", "LSTM_Prediction", "Average_Prediction"], | |
| datatype=["str", "number", "number", "number"] | |
| ) | |
| # Add examples (use safe function) | |
| gr.Examples( | |
| examples=[ | |
| ["AAPL", 7], | |
| ["AAPL", 15] | |
| ], | |
| inputs=[ticker_input, days_input], | |
| outputs=[plot_output, summary_output, predictions_table], | |
| fn=predict_stock_price_safe, | |
| cache_examples=False | |
| ) | |
| # Connect the safe prediction function | |
| predict_button.click( | |
| fn=predict_stock_price_safe, | |
| inputs=[ticker_input, days_input], | |
| outputs=[plot_output, summary_output, predictions_table] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π About the Models | |
| - **ARIMA**: Auto-Regressive Integrated Moving Average model trained on historical price data | |
| - **LSTM**: Long Short-Term Memory neural network with 3 layers and dropout regularization | |
| - **Training Data**: Historical stock prices from Yahoo Finance | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch(share=True) |