Dhahlan2000 commited on
Commit
b607f48
·
1 Parent(s): 19bc86f

Enhance stock prediction and visualization features in app.py:

Browse files

- Add Plotly for interactive charts and pandas for data manipulation.
- Update predict_stock_price to return historical data, forecast dates, and values.
- Implement tabs for displaying price history, predictions, and technical indicators (moving averages).
- Improve error handling and user feedback during stock analysis.

Files changed (1) hide show
  1. app.py +80 -5
app.py CHANGED
@@ -15,6 +15,9 @@ from statsmodels.tsa.arima.model import ARIMA
15
  import torch
16
  import re
17
  from typing import List, Union
 
 
 
18
 
19
  # Load environment variables from .env
20
  load_dotenv()
@@ -77,12 +80,20 @@ def predict_stock_price(ticker, days=5):
77
  hist = stock.history(period="6mo")
78
  if hist.empty:
79
  return {"error": f"No data found for ticker {ticker}"}
80
-
81
 
82
  model = ARIMA(hist["Close"], order=(5, 1, 0))
83
  model_fit = model.fit()
84
  forecast = model_fit.forecast(steps=days)
85
- return forecast.tolist()
 
 
 
 
 
 
 
 
 
86
 
87
  def compare_stocks(ticker1, ticker2):
88
  data1 = fetch_stock_data(ticker1)
@@ -234,13 +245,77 @@ if st.button("Submit"):
234
  st.write("Debug: User Query ->", query)
235
  with st.spinner("Processing..."):
236
  try:
237
- # Run the agent and get the response
238
- response = agent_executor.run(query) # Correct method is `run()`
239
  st.success("Response:")
240
  st.write(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  except Exception as e:
242
  st.error(f"An error occurred: {e}")
243
- # Log the full LLM output for debugging
244
  if hasattr(e, "output"):
245
  st.write("Raw Output:", e.output)
246
  else:
 
15
  import torch
16
  import re
17
  from typing import List, Union
18
+ import plotly.graph_objects as go
19
+ import pandas as pd
20
+ from datetime import datetime, timedelta
21
 
22
  # Load environment variables from .env
23
  load_dotenv()
 
80
  hist = stock.history(period="6mo")
81
  if hist.empty:
82
  return {"error": f"No data found for ticker {ticker}"}
 
83
 
84
  model = ARIMA(hist["Close"], order=(5, 1, 0))
85
  model_fit = model.fit()
86
  forecast = model_fit.forecast(steps=days)
87
+
88
+ # Create future dates for the forecast
89
+ last_date = hist.index[-1]
90
+ future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=days, freq='B')
91
+
92
+ return {
93
+ "historical_data": hist[["Close"]].to_dict(),
94
+ "forecast_dates": future_dates.strftime('%Y-%m-%d').tolist(),
95
+ "forecast_values": forecast.tolist()
96
+ }
97
 
98
  def compare_stocks(ticker1, ticker2):
99
  data1 = fetch_stock_data(ticker1)
 
245
  st.write("Debug: User Query ->", query)
246
  with st.spinner("Processing..."):
247
  try:
248
+ response = agent_executor.run(query)
 
249
  st.success("Response:")
250
  st.write(response)
251
+
252
+ # Extract ticker from query (basic extraction, you might want to make this more robust)
253
+ possible_tickers = re.findall(r'[A-Z]{1,5}', query.upper())
254
+ if possible_tickers:
255
+ ticker = possible_tickers[0]
256
+
257
+ # Create tabs for different visualizations
258
+ tab1, tab2, tab3 = st.tabs(["Price History", "Price Prediction", "Technical Indicators"])
259
+
260
+ with tab1:
261
+ st.subheader(f"{ticker} Price History")
262
+ stock = yf.Ticker(ticker)
263
+ hist = stock.history(period="1y")
264
+
265
+ fig = go.Figure()
266
+ fig.add_trace(go.Candlestick(
267
+ x=hist.index,
268
+ open=hist['Open'],
269
+ high=hist['High'],
270
+ low=hist['Low'],
271
+ close=hist['Close'],
272
+ name='OHLC'
273
+ ))
274
+ fig.update_layout(title=f"{ticker} Stock Price", xaxis_title="Date", yaxis_title="Price")
275
+ st.plotly_chart(fig)
276
+
277
+ with tab2:
278
+ st.subheader(f"{ticker} Price Prediction")
279
+ prediction_data = predict_stock_price(ticker)
280
+
281
+ if "error" not in prediction_data:
282
+ hist_df = pd.DataFrame(prediction_data["historical_data"])
283
+
284
+ fig = go.Figure()
285
+ # Plot historical data
286
+ fig.add_trace(go.Scatter(
287
+ x=hist_df.index,
288
+ y=hist_df['Close'],
289
+ name='Historical',
290
+ line=dict(color='blue')
291
+ ))
292
+ # Plot predicted data
293
+ fig.add_trace(go.Scatter(
294
+ x=prediction_data["forecast_dates"],
295
+ y=prediction_data["forecast_values"],
296
+ name='Predicted',
297
+ line=dict(color='red', dash='dash')
298
+ ))
299
+ fig.update_layout(title=f"{ticker} Price Prediction", xaxis_title="Date", yaxis_title="Price")
300
+ st.plotly_chart(fig)
301
+
302
+ with tab3:
303
+ st.subheader(f"{ticker} Technical Indicators")
304
+ # Calculate and plot moving averages
305
+ hist['MA5'] = hist['Close'].rolling(window=5).mean()
306
+ hist['MA20'] = hist['Close'].rolling(window=20).mean()
307
+ hist['MA50'] = hist['Close'].rolling(window=50).mean()
308
+
309
+ fig = go.Figure()
310
+ fig.add_trace(go.Scatter(x=hist.index, y=hist['Close'], name='Price'))
311
+ fig.add_trace(go.Scatter(x=hist.index, y=hist['MA5'], name='5-day MA'))
312
+ fig.add_trace(go.Scatter(x=hist.index, y=hist['MA20'], name='20-day MA'))
313
+ fig.add_trace(go.Scatter(x=hist.index, y=hist['MA50'], name='50-day MA'))
314
+ fig.update_layout(title=f"{ticker} Technical Indicators", xaxis_title="Date", yaxis_title="Price")
315
+ st.plotly_chart(fig)
316
+
317
  except Exception as e:
318
  st.error(f"An error occurred: {e}")
 
319
  if hasattr(e, "output"):
320
  st.write("Raw Output:", e.output)
321
  else: