Dhahlan2000 commited on
Commit
12e25ff
·
1 Parent(s): 83a2dd2

Refactor app.py to simplify stock price prediction logic and remove unused imports. Update requirements.txt to maintain statsmodels dependency. Streamline the return value of predict_stock_price to only include forecast values.

Browse files
Files changed (2) hide show
  1. app.py +5 -80
  2. requirements.txt +1 -3
app.py CHANGED
@@ -15,9 +15,6 @@ from statsmodels.tsa.arima.model import ARIMA
15
  import torch
16
  import re
17
  from typing import List, Union
18
- import plotly.graph_objects as go
19
- import pandas as pd
20
- from datetime import datetime, timedelta
21
 
22
  # Load environment variables from .env
23
  load_dotenv()
@@ -80,20 +77,12 @@ def predict_stock_price(ticker, days=5):
80
  hist = stock.history(period="6mo")
81
  if hist.empty:
82
  return {"error": f"No data found for ticker {ticker}"}
 
83
 
84
  model = ARIMA(hist["Close"], order=(5, 1, 0))
85
  model_fit = model.fit()
86
  forecast = model_fit.forecast(steps=days)
87
-
88
- # Create future dates for the forecast
89
- last_date = hist.index[-1]
90
- future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=days, freq='B')
91
-
92
- return {
93
- "historical_data": hist[["Close"]].to_dict(),
94
- "forecast_dates": future_dates.strftime('%Y-%m-%d').tolist(),
95
- "forecast_values": forecast.tolist()
96
- }
97
 
98
  def compare_stocks(ticker1, ticker2):
99
  data1 = fetch_stock_data(ticker1)
@@ -245,77 +234,13 @@ if st.button("Submit"):
245
  st.write("Debug: User Query ->", query)
246
  with st.spinner("Processing..."):
247
  try:
248
- response = agent_executor.run(query)
 
249
  st.success("Response:")
250
  st.write(response)
251
-
252
- # Extract ticker from query (basic extraction, you might want to make this more robust)
253
- possible_tickers = re.findall(r'[A-Z]{1,5}', query.upper())
254
- if possible_tickers:
255
- ticker = possible_tickers[0]
256
-
257
- # Create tabs for different visualizations
258
- tab1, tab2, tab3 = st.tabs(["Price History", "Price Prediction", "Technical Indicators"])
259
-
260
- with tab1:
261
- st.subheader(f"{ticker} Price History")
262
- stock = yf.Ticker(ticker)
263
- hist = stock.history(period="1y")
264
-
265
- fig = go.Figure()
266
- fig.add_trace(go.Candlestick(
267
- x=hist.index,
268
- open=hist['Open'],
269
- high=hist['High'],
270
- low=hist['Low'],
271
- close=hist['Close'],
272
- name='OHLC'
273
- ))
274
- fig.update_layout(title=f"{ticker} Stock Price", xaxis_title="Date", yaxis_title="Price")
275
- st.plotly_chart(fig)
276
-
277
- with tab2:
278
- st.subheader(f"{ticker} Price Prediction")
279
- prediction_data = predict_stock_price(ticker)
280
-
281
- if "error" not in prediction_data:
282
- hist_df = pd.DataFrame(prediction_data["historical_data"])
283
-
284
- fig = go.Figure()
285
- # Plot historical data
286
- fig.add_trace(go.Scatter(
287
- x=hist_df.index,
288
- y=hist_df['Close'],
289
- name='Historical',
290
- line=dict(color='blue')
291
- ))
292
- # Plot predicted data
293
- fig.add_trace(go.Scatter(
294
- x=prediction_data["forecast_dates"],
295
- y=prediction_data["forecast_values"],
296
- name='Predicted',
297
- line=dict(color='red', dash='dash')
298
- ))
299
- fig.update_layout(title=f"{ticker} Price Prediction", xaxis_title="Date", yaxis_title="Price")
300
- st.plotly_chart(fig)
301
-
302
- with tab3:
303
- st.subheader(f"{ticker} Technical Indicators")
304
- # Calculate and plot moving averages
305
- hist['MA5'] = hist['Close'].rolling(window=5).mean()
306
- hist['MA20'] = hist['Close'].rolling(window=20).mean()
307
- hist['MA50'] = hist['Close'].rolling(window=50).mean()
308
-
309
- fig = go.Figure()
310
- fig.add_trace(go.Scatter(x=hist.index, y=hist['Close'], name='Price'))
311
- fig.add_trace(go.Scatter(x=hist.index, y=hist['MA5'], name='5-day MA'))
312
- fig.add_trace(go.Scatter(x=hist.index, y=hist['MA20'], name='20-day MA'))
313
- fig.add_trace(go.Scatter(x=hist.index, y=hist['MA50'], name='50-day MA'))
314
- fig.update_layout(title=f"{ticker} Technical Indicators", xaxis_title="Date", yaxis_title="Price")
315
- st.plotly_chart(fig)
316
-
317
  except Exception as e:
318
  st.error(f"An error occurred: {e}")
 
319
  if hasattr(e, "output"):
320
  st.write("Raw Output:", e.output)
321
  else:
 
15
  import torch
16
  import re
17
  from typing import List, Union
 
 
 
18
 
19
  # Load environment variables from .env
20
  load_dotenv()
 
77
  hist = stock.history(period="6mo")
78
  if hist.empty:
79
  return {"error": f"No data found for ticker {ticker}"}
80
+
81
 
82
  model = ARIMA(hist["Close"], order=(5, 1, 0))
83
  model_fit = model.fit()
84
  forecast = model_fit.forecast(steps=days)
85
+ return forecast.tolist()
 
 
 
 
 
 
 
 
 
86
 
87
  def compare_stocks(ticker1, ticker2):
88
  data1 = fetch_stock_data(ticker1)
 
234
  st.write("Debug: User Query ->", query)
235
  with st.spinner("Processing..."):
236
  try:
237
+ # Run the agent and get the response
238
+ response = agent_executor.run(query) # Correct method is `run()`
239
  st.success("Response:")
240
  st.write(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  except Exception as e:
242
  st.error(f"An error occurred: {e}")
243
+ # Log the full LLM output for debugging
244
  if hasattr(e, "output"):
245
  st.write("Raw Output:", e.output)
246
  else:
requirements.txt CHANGED
@@ -5,6 +5,4 @@ pandas
5
  langchain
6
  langchain_huggingface
7
  python-dotenv
8
- statsmodels
9
- plotly
10
- datetime
 
5
  langchain
6
  langchain_huggingface
7
  python-dotenv
8
+ statsmodels