Spaces:
Running
Running
Commit
·
b607f48
1
Parent(s):
19bc86f
Enhance stock prediction and visualization features in app.py:
Browse files- Add Plotly for interactive charts and pandas for data manipulation.
- Update predict_stock_price to return historical data, forecast dates, and values.
- Implement tabs for displaying price history, predictions, and technical indicators (moving averages).
- Improve error handling and user feedback during stock analysis.
app.py
CHANGED
@@ -15,6 +15,9 @@ from statsmodels.tsa.arima.model import ARIMA
|
|
15 |
import torch
|
16 |
import re
|
17 |
from typing import List, Union
|
|
|
|
|
|
|
18 |
|
19 |
# Load environment variables from .env
|
20 |
load_dotenv()
|
@@ -77,12 +80,20 @@ def predict_stock_price(ticker, days=5):
|
|
77 |
hist = stock.history(period="6mo")
|
78 |
if hist.empty:
|
79 |
return {"error": f"No data found for ticker {ticker}"}
|
80 |
-
|
81 |
|
82 |
model = ARIMA(hist["Close"], order=(5, 1, 0))
|
83 |
model_fit = model.fit()
|
84 |
forecast = model_fit.forecast(steps=days)
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
def compare_stocks(ticker1, ticker2):
|
88 |
data1 = fetch_stock_data(ticker1)
|
@@ -234,13 +245,77 @@ if st.button("Submit"):
|
|
234 |
st.write("Debug: User Query ->", query)
|
235 |
with st.spinner("Processing..."):
|
236 |
try:
|
237 |
-
|
238 |
-
response = agent_executor.run(query) # Correct method is `run()`
|
239 |
st.success("Response:")
|
240 |
st.write(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
except Exception as e:
|
242 |
st.error(f"An error occurred: {e}")
|
243 |
-
# Log the full LLM output for debugging
|
244 |
if hasattr(e, "output"):
|
245 |
st.write("Raw Output:", e.output)
|
246 |
else:
|
|
|
15 |
import torch
|
16 |
import re
|
17 |
from typing import List, Union
|
18 |
+
import plotly.graph_objects as go
|
19 |
+
import pandas as pd
|
20 |
+
from datetime import datetime, timedelta
|
21 |
|
22 |
# Load environment variables from .env
|
23 |
load_dotenv()
|
|
|
80 |
hist = stock.history(period="6mo")
|
81 |
if hist.empty:
|
82 |
return {"error": f"No data found for ticker {ticker}"}
|
|
|
83 |
|
84 |
model = ARIMA(hist["Close"], order=(5, 1, 0))
|
85 |
model_fit = model.fit()
|
86 |
forecast = model_fit.forecast(steps=days)
|
87 |
+
|
88 |
+
# Create future dates for the forecast
|
89 |
+
last_date = hist.index[-1]
|
90 |
+
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=days, freq='B')
|
91 |
+
|
92 |
+
return {
|
93 |
+
"historical_data": hist[["Close"]].to_dict(),
|
94 |
+
"forecast_dates": future_dates.strftime('%Y-%m-%d').tolist(),
|
95 |
+
"forecast_values": forecast.tolist()
|
96 |
+
}
|
97 |
|
98 |
def compare_stocks(ticker1, ticker2):
|
99 |
data1 = fetch_stock_data(ticker1)
|
|
|
245 |
st.write("Debug: User Query ->", query)
|
246 |
with st.spinner("Processing..."):
|
247 |
try:
|
248 |
+
response = agent_executor.run(query)
|
|
|
249 |
st.success("Response:")
|
250 |
st.write(response)
|
251 |
+
|
252 |
+
# Extract ticker from query (basic extraction, you might want to make this more robust)
|
253 |
+
possible_tickers = re.findall(r'[A-Z]{1,5}', query.upper())
|
254 |
+
if possible_tickers:
|
255 |
+
ticker = possible_tickers[0]
|
256 |
+
|
257 |
+
# Create tabs for different visualizations
|
258 |
+
tab1, tab2, tab3 = st.tabs(["Price History", "Price Prediction", "Technical Indicators"])
|
259 |
+
|
260 |
+
with tab1:
|
261 |
+
st.subheader(f"{ticker} Price History")
|
262 |
+
stock = yf.Ticker(ticker)
|
263 |
+
hist = stock.history(period="1y")
|
264 |
+
|
265 |
+
fig = go.Figure()
|
266 |
+
fig.add_trace(go.Candlestick(
|
267 |
+
x=hist.index,
|
268 |
+
open=hist['Open'],
|
269 |
+
high=hist['High'],
|
270 |
+
low=hist['Low'],
|
271 |
+
close=hist['Close'],
|
272 |
+
name='OHLC'
|
273 |
+
))
|
274 |
+
fig.update_layout(title=f"{ticker} Stock Price", xaxis_title="Date", yaxis_title="Price")
|
275 |
+
st.plotly_chart(fig)
|
276 |
+
|
277 |
+
with tab2:
|
278 |
+
st.subheader(f"{ticker} Price Prediction")
|
279 |
+
prediction_data = predict_stock_price(ticker)
|
280 |
+
|
281 |
+
if "error" not in prediction_data:
|
282 |
+
hist_df = pd.DataFrame(prediction_data["historical_data"])
|
283 |
+
|
284 |
+
fig = go.Figure()
|
285 |
+
# Plot historical data
|
286 |
+
fig.add_trace(go.Scatter(
|
287 |
+
x=hist_df.index,
|
288 |
+
y=hist_df['Close'],
|
289 |
+
name='Historical',
|
290 |
+
line=dict(color='blue')
|
291 |
+
))
|
292 |
+
# Plot predicted data
|
293 |
+
fig.add_trace(go.Scatter(
|
294 |
+
x=prediction_data["forecast_dates"],
|
295 |
+
y=prediction_data["forecast_values"],
|
296 |
+
name='Predicted',
|
297 |
+
line=dict(color='red', dash='dash')
|
298 |
+
))
|
299 |
+
fig.update_layout(title=f"{ticker} Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
300 |
+
st.plotly_chart(fig)
|
301 |
+
|
302 |
+
with tab3:
|
303 |
+
st.subheader(f"{ticker} Technical Indicators")
|
304 |
+
# Calculate and plot moving averages
|
305 |
+
hist['MA5'] = hist['Close'].rolling(window=5).mean()
|
306 |
+
hist['MA20'] = hist['Close'].rolling(window=20).mean()
|
307 |
+
hist['MA50'] = hist['Close'].rolling(window=50).mean()
|
308 |
+
|
309 |
+
fig = go.Figure()
|
310 |
+
fig.add_trace(go.Scatter(x=hist.index, y=hist['Close'], name='Price'))
|
311 |
+
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA5'], name='5-day MA'))
|
312 |
+
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA20'], name='20-day MA'))
|
313 |
+
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA50'], name='50-day MA'))
|
314 |
+
fig.update_layout(title=f"{ticker} Technical Indicators", xaxis_title="Date", yaxis_title="Price")
|
315 |
+
st.plotly_chart(fig)
|
316 |
+
|
317 |
except Exception as e:
|
318 |
st.error(f"An error occurred: {e}")
|
|
|
319 |
if hasattr(e, "output"):
|
320 |
st.write("Raw Output:", e.output)
|
321 |
else:
|