shagatoo's picture
Update app.py
4010e76 verified
import gradio as gr
import numpy as np
import pandas as pd
import yfinance as yf
import joblib
from tensorflow.keras.models import load_model
import plotly.graph_objects as go
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')
class StockPredictorApp:
def __init__(self, arima_path='arima_model.pkl',
lstm_path='lstm_model.h5',
scaler_path='scaler.pkl'):
"""
Initialize the stock predictor with pre-trained models
"""
try:
# Load models
self.arima_model = joblib.load(arima_path)
self.lstm_model = load_model(lstm_path)
self.scaler = joblib.load(scaler_path)
self.lookback = 60 # Default lookback period for LSTM
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
self.arima_model = None
self.lstm_model = None
self.scaler = None
def fetch_stock_data(self, ticker, days_back=365):
"""
Fetch recent stock data for prediction
"""
try:
end_date = datetime.now()
start_date = end_date - timedelta(days=days_back)
# Download stock data
stock_data = yf.download(ticker,
start=start_date.strftime('%Y-%m-%d'),
end=end_date.strftime('%Y-%m-%d'),
progress=False)
if stock_data.empty:
return None, "No data found for this ticker"
# Extract closing prices
prices = stock_data[['Close']].copy()
prices.columns = ['price']
return prices, None
except Exception as e:
return None, f"Error fetching data: {str(e)}"
def prepare_lstm_input(self, data):
"""
Prepare data for LSTM prediction
"""
# Scale the data
scaled_data = self.scaler.transform(data[['price']])
# Create sequences
if len(scaled_data) < self.lookback:
# Pad with the first value if not enough data
padding = np.tile(scaled_data[0], (self.lookback - len(scaled_data), 1))
scaled_data = np.vstack([padding, scaled_data])
# Take the last lookback values
sequence = scaled_data[-self.lookback:].reshape(1, self.lookback, 1)
return sequence
def predict_next_days(self, ticker, num_days):
"""
Predict stock prices for the next n days
"""
if not all([self.arima_model, self.lstm_model, self.scaler]):
return None, None, "Models not loaded properly. Please check model files."
# Fetch historical data
historical_data, error = self.fetch_stock_data(ticker, days_back=365)
if error:
return None, None, error
try:
# ARIMA Predictions
arima_forecast = self.arima_model.forecast(steps=num_days)
# LSTM Predictions
lstm_predictions = []
current_data = historical_data.copy()
for _ in range(num_days):
# Prepare input
lstm_input = self.prepare_lstm_input(current_data)
# Make prediction
scaled_pred = self.lstm_model.predict(lstm_input, verbose=0)
pred = self.scaler.inverse_transform(scaled_pred)[0, 0]
lstm_predictions.append(pred)
# Add prediction to data for next iteration
next_date = current_data.index[-1] + timedelta(days=1)
new_row = pd.DataFrame({'price': [pred]}, index=[next_date])
current_data = pd.concat([current_data, new_row])
# Create future dates
last_date = historical_data.index[-1]
future_dates = pd.date_range(start=last_date + timedelta(days=1),
periods=num_days, freq='D')
# Create prediction DataFrames
arima_df = pd.DataFrame({
'Date': future_dates,
'ARIMA_Prediction': arima_forecast.values
})
lstm_df = pd.DataFrame({
'Date': future_dates,
'LSTM_Prediction': lstm_predictions
})
# Combine predictions
predictions_df = pd.merge(arima_df, lstm_df, on='Date')
predictions_df['Average_Prediction'] = (predictions_df['ARIMA_Prediction'] +
predictions_df['LSTM_Prediction']) / 2
return historical_data, predictions_df, None
except Exception as e:
return None, None, f"Prediction error: {str(e)}"
def create_plot(self, historical_data, predictions_df, ticker):
"""
Create an interactive plot using Plotly
"""
fig = go.Figure()
# Plot historical data
fig.add_trace(go.Scatter(
x=historical_data.index,
y=historical_data['price'],
mode='lines',
name='Historical Price',
line=dict(color='black', width=2)
))
# Plot ARIMA predictions
fig.add_trace(go.Scatter(
x=predictions_df['Date'],
y=predictions_df['ARIMA_Prediction'],
mode='lines+markers',
name='ARIMA Forecast',
line=dict(color='blue', width=2, dash='dash'),
marker=dict(size=6)
))
# Plot LSTM predictions
fig.add_trace(go.Scatter(
x=predictions_df['Date'],
y=predictions_df['LSTM_Prediction'],
mode='lines+markers',
name='LSTM Forecast',
line=dict(color='red', width=2, dash='dash'),
marker=dict(size=6)
))
# Plot average predictions
fig.add_trace(go.Scatter(
x=predictions_df['Date'],
y=predictions_df['Average_Prediction'],
mode='lines+markers',
name='Ensemble (Average)',
line=dict(color='green', width=2, dash='dot'),
marker=dict(size=8)
))
# Update layout
fig.update_layout(
title=f'{ticker} Stock Price Forecast',
xaxis_title='Date',
yaxis_title='Price ($)',
hovermode='x unified',
showlegend=True,
template='plotly_white',
height=600
)
# Add a vertical line to separate historical and predicted
# Convert timestamp to string to avoid Plotly issues
last_date = historical_data.index[-1]
if hasattr(last_date, 'strftime'):
last_date = last_date.strftime('%Y-%m-%d')
fig.add_vline(x=last_date,
line_dash="solid",
line_color="gray",
annotation_text="Forecast Start")
return fig
# Initialize the app
predictor = StockPredictorApp()
def predict_stock_price(ticker, num_days):
"""
Main prediction function for Gradio interface
"""
# Create empty dataframe for error cases
empty_df = pd.DataFrame()
if not ticker:
return None, "Please enter a stock ticker symbol", empty_df
# Convert ticker to uppercase
ticker = ticker.upper()
# Validate number of days
if num_days < 1 or num_days > 90:
return None, "Please enter a number of days between 1 and 90", empty_df
# Get predictions
historical_data, predictions_df, error = predictor.predict_next_days(ticker, num_days)
if error:
return None, error, empty_df
# Create plot
fig = predictor.create_plot(historical_data, predictions_df, ticker)
# Format predictions table
predictions_display = predictions_df.copy()
predictions_display['Date'] = predictions_display['Date'].dt.strftime('%Y-%m-%d')
predictions_display = predictions_display.round(2)
# Calculate summary statistics
summary = f"""
### Prediction Summary for {ticker}
**Forecast Period**: {num_days} days
**ARIMA Model**:
- First Day: ${predictions_df['ARIMA_Prediction'].iloc[0]:.2f}
- Last Day: ${predictions_df['ARIMA_Prediction'].iloc[-1]:.2f}
- Average: ${predictions_df['ARIMA_Prediction'].mean():.2f}
- Trend: {'πŸ“ˆ Upward' if predictions_df['ARIMA_Prediction'].iloc[-1] > predictions_df['ARIMA_Prediction'].iloc[0] else 'πŸ“‰ Downward'}
**LSTM Model**:
- First Day: ${predictions_df['LSTM_Prediction'].iloc[0]:.2f}
- Last Day: ${predictions_df['LSTM_Prediction'].iloc[-1]:.2f}
- Average: ${predictions_df['LSTM_Prediction'].mean():.2f}
- Trend: {'πŸ“ˆ Upward' if predictions_df['LSTM_Prediction'].iloc[-1] > predictions_df['LSTM_Prediction'].iloc[0] else 'πŸ“‰ Downward'}
**Ensemble (Average)**:
- First Day: ${predictions_df['Average_Prediction'].iloc[0]:.2f}
- Last Day: ${predictions_df['Average_Prediction'].iloc[-1]:.2f}
- Average: ${predictions_df['Average_Prediction'].mean():.2f}
**Current Price**: ${historical_data['price'].iloc[-1]:.2f}
**Expected Change**: {'+' if predictions_df['Average_Prediction'].iloc[-1] > historical_data['price'].iloc[-1] else ''}{((predictions_df['Average_Prediction'].iloc[-1] / historical_data['price'].iloc[-1] - 1) * 100):.2f}%
"""
return fig, summary, predictions_display
# Create demo mode for when models aren't available
def create_demo_predictions(ticker, num_days):
"""
Create demo predictions when models aren't loaded
"""
# Create fake historical data
dates = pd.date_range(end=datetime.now(), periods=100, freq='D')
base_price = 150.0
historical_data = pd.DataFrame({
'price': base_price + np.cumsum(np.random.randn(100) * 2)
}, index=dates)
# Create fake predictions
future_dates = pd.date_range(start=dates[-1] + timedelta(days=1),
periods=num_days, freq='D')
last_price = historical_data['price'].iloc[-1]
arima_pred = last_price + np.cumsum(np.random.randn(num_days) * 1.5)
lstm_pred = last_price + np.cumsum(np.random.randn(num_days) * 1.5)
predictions_df = pd.DataFrame({
'Date': future_dates,
'ARIMA_Prediction': arima_pred,
'LSTM_Prediction': lstm_pred,
'Average_Prediction': (arima_pred + lstm_pred) / 2
})
return historical_data, predictions_df
# Modified predict function with fallback to demo mode
def predict_stock_price_safe(ticker, num_days):
"""
Safe prediction function with demo fallback
"""
empty_df = pd.DataFrame()
if not ticker:
return None, "Please enter a stock ticker symbol", empty_df
ticker = ticker.upper()
if num_days < 1 or num_days > 90:
return None, "Please enter a number of days between 1 and 90", empty_df
# Check if models are loaded
if not all([predictor.arima_model, predictor.lstm_model, predictor.scaler]):
# Use demo mode
demo_msg = f"""
### ⚠️ Demo Mode Active
**Note**: Pre-trained models are not available. Showing demo predictions with random data.
To use real predictions, ensure you have:
1. `arima_model.pkl` - ARIMA model file
2. `lstm_model.h5` - LSTM model file
3. `scaler.pkl` - Data scaler file
Place these files in the same directory as the app.
"""
try:
historical_data, predictions_df = create_demo_predictions(ticker, num_days)
fig = predictor.create_plot(historical_data, predictions_df, f"{ticker} (DEMO)")
predictions_display = predictions_df.copy()
predictions_display['Date'] = predictions_display['Date'].dt.strftime('%Y-%m-%d')
predictions_display = predictions_display.round(2)
return fig, demo_msg, predictions_display
except Exception as e:
error_msg = f"Error creating demo predictions: {str(e)}"
return None, error_msg, empty_df
# Normal prediction flow
return predict_stock_price(ticker, num_days)
# Create Gradio interface
with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# πŸ“ˆ Stock Price Forecaster
This app uses pre-trained ARIMA and LSTM models to predict stock prices.
Enter a stock ticker symbol and the number of days to forecast.
**Models:**
- πŸ”΅ ARIMA: Statistical time series model
- πŸ”΄ LSTM: Deep learning sequential model
- 🟒 Ensemble: Average of both models
"""
)
with gr.Row():
with gr.Column(scale=1):
ticker_input = gr.Textbox(
label="Stock Ticker Symbol",
placeholder="AAPL",
value="AAPL"
)
days_input = gr.Slider(
minimum=1,
maximum=30,
value=7,
step=1,
label="Number of Days to Forecast"
)
predict_button = gr.Button("πŸš€ Generate Forecast", variant="primary")
with gr.Row():
with gr.Column(scale=2):
plot_output = gr.Plot(label="Price Forecast Chart")
with gr.Row():
summary_output = gr.Markdown(label="Forecast Summary")
with gr.Row():
predictions_table = gr.Dataframe(
label="Detailed Predictions",
headers=["Date", "ARIMA_Prediction", "LSTM_Prediction", "Average_Prediction"],
datatype=["str", "number", "number", "number"]
)
# Add examples (use safe function)
gr.Examples(
examples=[
["AAPL", 7],
["AAPL", 15]
],
inputs=[ticker_input, days_input],
outputs=[plot_output, summary_output, predictions_table],
fn=predict_stock_price_safe,
cache_examples=False
)
# Connect the safe prediction function
predict_button.click(
fn=predict_stock_price_safe,
inputs=[ticker_input, days_input],
outputs=[plot_output, summary_output, predictions_table]
)
gr.Markdown(
"""
---
### πŸ“Š About the Models
- **ARIMA**: Auto-Regressive Integrated Moving Average model trained on historical price data
- **LSTM**: Long Short-Term Memory neural network with 3 layers and dropout regularization
- **Training Data**: Historical stock prices from Yahoo Finance
"""
)
# Launch the app
if __name__ == "__main__":
app.launch(share=True)