shagatoo commited on
Commit
43595e4
·
verified ·
1 Parent(s): 2f520fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -20
app.py CHANGED
@@ -204,21 +204,24 @@ def predict_stock_price(ticker, num_days):
204
  """
205
  Main prediction function for Gradio interface
206
  """
 
 
 
207
  if not ticker:
208
- return None, "Please enter a stock ticker symbol"
209
 
210
  # Convert ticker to uppercase
211
  ticker = ticker.upper()
212
 
213
  # Validate number of days
214
  if num_days < 1 or num_days > 90:
215
- return None, "Please enter a number of days between 1 and 90"
216
 
217
  # Get predictions
218
  historical_data, predictions_df, error = predictor.predict_next_days(ticker, num_days)
219
 
220
  if error:
221
- return None, error
222
 
223
  # Create plot
224
  fig = predictor.create_plot(historical_data, predictions_df, ticker)
@@ -257,6 +260,78 @@ def predict_stock_price(ticker, num_days):
257
 
258
  return fig, summary, predictions_display
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  # Create Gradio interface
261
  with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
262
  gr.Markdown(
@@ -277,7 +352,7 @@ with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
277
  with gr.Column(scale=1):
278
  ticker_input = gr.Textbox(
279
  label="Stock Ticker Symbol",
280
- placeholder="e.g., AAPL, GOOGL, MSFT",
281
  value="AAPL"
282
  )
283
 
@@ -291,15 +366,7 @@ with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
291
 
292
  predict_button = gr.Button("🚀 Generate Forecast", variant="primary")
293
 
294
- gr.Markdown(
295
- """
296
- ### Popular Tickers:
297
- - **Tech**: AAPL, GOOGL, MSFT, AMZN, TSLA
298
- - **Finance**: JPM, BAC, V, MA
299
- - **Healthcare**: JNJ, UNH, PFE
300
- - **Energy**: XOM, CVX
301
- """
302
- )
303
 
304
  with gr.Row():
305
  with gr.Column(scale=2):
@@ -315,23 +382,20 @@ with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
315
  datatype=["str", "number", "number", "number"]
316
  )
317
 
318
- # Add examples
319
  gr.Examples(
320
  examples=[
321
  ["AAPL", 7],
322
- ["GOOGL", 14],
323
- ["TSLA", 5],
324
- ["MSFT", 10]
325
  ],
326
  inputs=[ticker_input, days_input],
327
  outputs=[plot_output, summary_output, predictions_table],
328
- fn=predict_stock_price,
329
  cache_examples=False
330
  )
331
 
332
- # Connect the prediction function
333
  predict_button.click(
334
- fn=predict_stock_price,
335
  inputs=[ticker_input, days_input],
336
  outputs=[plot_output, summary_output, predictions_table]
337
  )
@@ -343,6 +407,9 @@ with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
343
  - **ARIMA**: Auto-Regressive Integrated Moving Average model trained on historical price data
344
  - **LSTM**: Long Short-Term Memory neural network with 3 layers and dropout regularization
345
  - **Training Data**: Historical stock prices from Yahoo Finance
 
 
 
346
  """
347
  )
348
 
 
204
  """
205
  Main prediction function for Gradio interface
206
  """
207
+ # Create empty dataframe for error cases
208
+ empty_df = pd.DataFrame()
209
+
210
  if not ticker:
211
+ return None, "Please enter a stock ticker symbol", empty_df
212
 
213
  # Convert ticker to uppercase
214
  ticker = ticker.upper()
215
 
216
  # Validate number of days
217
  if num_days < 1 or num_days > 90:
218
+ return None, "Please enter a number of days between 1 and 90", empty_df
219
 
220
  # Get predictions
221
  historical_data, predictions_df, error = predictor.predict_next_days(ticker, num_days)
222
 
223
  if error:
224
+ return None, error, empty_df
225
 
226
  # Create plot
227
  fig = predictor.create_plot(historical_data, predictions_df, ticker)
 
260
 
261
  return fig, summary, predictions_display
262
 
263
+ # Create demo mode for when models aren't available
264
+ def create_demo_predictions(ticker, num_days):
265
+ """
266
+ Create demo predictions when models aren't loaded
267
+ """
268
+ # Create fake historical data
269
+ dates = pd.date_range(end=datetime.now(), periods=100, freq='D')
270
+ base_price = 150.0
271
+ historical_data = pd.DataFrame({
272
+ 'price': base_price + np.cumsum(np.random.randn(100) * 2)
273
+ }, index=dates)
274
+
275
+ # Create fake predictions
276
+ future_dates = pd.date_range(start=dates[-1] + timedelta(days=1),
277
+ periods=num_days, freq='D')
278
+
279
+ last_price = historical_data['price'].iloc[-1]
280
+ arima_pred = last_price + np.cumsum(np.random.randn(num_days) * 1.5)
281
+ lstm_pred = last_price + np.cumsum(np.random.randn(num_days) * 1.5)
282
+
283
+ predictions_df = pd.DataFrame({
284
+ 'Date': future_dates,
285
+ 'ARIMA_Prediction': arima_pred,
286
+ 'LSTM_Prediction': lstm_pred,
287
+ 'Average_Prediction': (arima_pred + lstm_pred) / 2
288
+ })
289
+
290
+ return historical_data, predictions_df
291
+
292
+ # Modified predict function with fallback to demo mode
293
+ def predict_stock_price_safe(ticker, num_days):
294
+ """
295
+ Safe prediction function with demo fallback
296
+ """
297
+ empty_df = pd.DataFrame()
298
+
299
+ if not ticker:
300
+ return None, "Please enter a stock ticker symbol", empty_df
301
+
302
+ ticker = ticker.upper()
303
+
304
+ if num_days < 1 or num_days > 90:
305
+ return None, "Please enter a number of days between 1 and 90", empty_df
306
+
307
+ # Check if models are loaded
308
+ if not all([predictor.arima_model, predictor.lstm_model, predictor.scaler]):
309
+ # Use demo mode
310
+ demo_msg = f"""
311
+ ### ⚠️ Demo Mode Active
312
+
313
+ **Note**: Pre-trained models are not available. Showing demo predictions with random data.
314
+
315
+ To use real predictions, ensure you have:
316
+ 1. `arima_model.pkl` - ARIMA model file
317
+ 2. `lstm_model.h5` - LSTM model file
318
+ 3. `scaler.pkl` - Data scaler file
319
+
320
+ Place these files in the same directory as the app.
321
+ """
322
+
323
+ historical_data, predictions_df = create_demo_predictions(ticker, num_days)
324
+ fig = predictor.create_plot(historical_data, predictions_df, f"{ticker} (DEMO)")
325
+
326
+ predictions_display = predictions_df.copy()
327
+ predictions_display['Date'] = predictions_display['Date'].dt.strftime('%Y-%m-%d')
328
+ predictions_display = predictions_display.round(2)
329
+
330
+ return fig, demo_msg, predictions_display
331
+
332
+ # Normal prediction flow
333
+ return predict_stock_price(ticker, num_days)
334
+
335
  # Create Gradio interface
336
  with gr.Blocks(title="Stock Price Forecaster", theme=gr.themes.Soft()) as app:
337
  gr.Markdown(
 
352
  with gr.Column(scale=1):
353
  ticker_input = gr.Textbox(
354
  label="Stock Ticker Symbol",
355
+ placeholder="AAPL",
356
  value="AAPL"
357
  )
358
 
 
366
 
367
  predict_button = gr.Button("🚀 Generate Forecast", variant="primary")
368
 
369
+
 
 
 
 
 
 
 
 
370
 
371
  with gr.Row():
372
  with gr.Column(scale=2):
 
382
  datatype=["str", "number", "number", "number"]
383
  )
384
 
385
+ # Add examples (use safe function)
386
  gr.Examples(
387
  examples=[
388
  ["AAPL", 7],
 
 
 
389
  ],
390
  inputs=[ticker_input, days_input],
391
  outputs=[plot_output, summary_output, predictions_table],
392
+ fn=predict_stock_price_safe,
393
  cache_examples=False
394
  )
395
 
396
+ # Connect the safe prediction function
397
  predict_button.click(
398
+ fn=predict_stock_price_safe,
399
  inputs=[ticker_input, days_input],
400
  outputs=[plot_output, summary_output, predictions_table]
401
  )
 
407
  - **ARIMA**: Auto-Regressive Integrated Moving Average model trained on historical price data
408
  - **LSTM**: Long Short-Term Memory neural network with 3 layers and dropout regularization
409
  - **Training Data**: Historical stock prices from Yahoo Finance
410
+
411
+ ### ⚠️ Disclaimer
412
+ This is for educational purposes only. Stock predictions are inherently uncertain and should not be used as the sole basis for investment decisions.
413
  """
414
  )
415