Spaces:
Running
Running
Commit
·
12e25ff
1
Parent(s):
83a2dd2
Refactor app.py to simplify stock price prediction logic and remove unused imports. Update requirements.txt to maintain statsmodels dependency. Streamline the return value of predict_stock_price to only include forecast values.
Browse files- app.py +5 -80
- requirements.txt +1 -3
app.py
CHANGED
@@ -15,9 +15,6 @@ from statsmodels.tsa.arima.model import ARIMA
|
|
15 |
import torch
|
16 |
import re
|
17 |
from typing import List, Union
|
18 |
-
import plotly.graph_objects as go
|
19 |
-
import pandas as pd
|
20 |
-
from datetime import datetime, timedelta
|
21 |
|
22 |
# Load environment variables from .env
|
23 |
load_dotenv()
|
@@ -80,20 +77,12 @@ def predict_stock_price(ticker, days=5):
|
|
80 |
hist = stock.history(period="6mo")
|
81 |
if hist.empty:
|
82 |
return {"error": f"No data found for ticker {ticker}"}
|
|
|
83 |
|
84 |
model = ARIMA(hist["Close"], order=(5, 1, 0))
|
85 |
model_fit = model.fit()
|
86 |
forecast = model_fit.forecast(steps=days)
|
87 |
-
|
88 |
-
# Create future dates for the forecast
|
89 |
-
last_date = hist.index[-1]
|
90 |
-
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=days, freq='B')
|
91 |
-
|
92 |
-
return {
|
93 |
-
"historical_data": hist[["Close"]].to_dict(),
|
94 |
-
"forecast_dates": future_dates.strftime('%Y-%m-%d').tolist(),
|
95 |
-
"forecast_values": forecast.tolist()
|
96 |
-
}
|
97 |
|
98 |
def compare_stocks(ticker1, ticker2):
|
99 |
data1 = fetch_stock_data(ticker1)
|
@@ -245,77 +234,13 @@ if st.button("Submit"):
|
|
245 |
st.write("Debug: User Query ->", query)
|
246 |
with st.spinner("Processing..."):
|
247 |
try:
|
248 |
-
|
|
|
249 |
st.success("Response:")
|
250 |
st.write(response)
|
251 |
-
|
252 |
-
# Extract ticker from query (basic extraction, you might want to make this more robust)
|
253 |
-
possible_tickers = re.findall(r'[A-Z]{1,5}', query.upper())
|
254 |
-
if possible_tickers:
|
255 |
-
ticker = possible_tickers[0]
|
256 |
-
|
257 |
-
# Create tabs for different visualizations
|
258 |
-
tab1, tab2, tab3 = st.tabs(["Price History", "Price Prediction", "Technical Indicators"])
|
259 |
-
|
260 |
-
with tab1:
|
261 |
-
st.subheader(f"{ticker} Price History")
|
262 |
-
stock = yf.Ticker(ticker)
|
263 |
-
hist = stock.history(period="1y")
|
264 |
-
|
265 |
-
fig = go.Figure()
|
266 |
-
fig.add_trace(go.Candlestick(
|
267 |
-
x=hist.index,
|
268 |
-
open=hist['Open'],
|
269 |
-
high=hist['High'],
|
270 |
-
low=hist['Low'],
|
271 |
-
close=hist['Close'],
|
272 |
-
name='OHLC'
|
273 |
-
))
|
274 |
-
fig.update_layout(title=f"{ticker} Stock Price", xaxis_title="Date", yaxis_title="Price")
|
275 |
-
st.plotly_chart(fig)
|
276 |
-
|
277 |
-
with tab2:
|
278 |
-
st.subheader(f"{ticker} Price Prediction")
|
279 |
-
prediction_data = predict_stock_price(ticker)
|
280 |
-
|
281 |
-
if "error" not in prediction_data:
|
282 |
-
hist_df = pd.DataFrame(prediction_data["historical_data"])
|
283 |
-
|
284 |
-
fig = go.Figure()
|
285 |
-
# Plot historical data
|
286 |
-
fig.add_trace(go.Scatter(
|
287 |
-
x=hist_df.index,
|
288 |
-
y=hist_df['Close'],
|
289 |
-
name='Historical',
|
290 |
-
line=dict(color='blue')
|
291 |
-
))
|
292 |
-
# Plot predicted data
|
293 |
-
fig.add_trace(go.Scatter(
|
294 |
-
x=prediction_data["forecast_dates"],
|
295 |
-
y=prediction_data["forecast_values"],
|
296 |
-
name='Predicted',
|
297 |
-
line=dict(color='red', dash='dash')
|
298 |
-
))
|
299 |
-
fig.update_layout(title=f"{ticker} Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
300 |
-
st.plotly_chart(fig)
|
301 |
-
|
302 |
-
with tab3:
|
303 |
-
st.subheader(f"{ticker} Technical Indicators")
|
304 |
-
# Calculate and plot moving averages
|
305 |
-
hist['MA5'] = hist['Close'].rolling(window=5).mean()
|
306 |
-
hist['MA20'] = hist['Close'].rolling(window=20).mean()
|
307 |
-
hist['MA50'] = hist['Close'].rolling(window=50).mean()
|
308 |
-
|
309 |
-
fig = go.Figure()
|
310 |
-
fig.add_trace(go.Scatter(x=hist.index, y=hist['Close'], name='Price'))
|
311 |
-
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA5'], name='5-day MA'))
|
312 |
-
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA20'], name='20-day MA'))
|
313 |
-
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA50'], name='50-day MA'))
|
314 |
-
fig.update_layout(title=f"{ticker} Technical Indicators", xaxis_title="Date", yaxis_title="Price")
|
315 |
-
st.plotly_chart(fig)
|
316 |
-
|
317 |
except Exception as e:
|
318 |
st.error(f"An error occurred: {e}")
|
|
|
319 |
if hasattr(e, "output"):
|
320 |
st.write("Raw Output:", e.output)
|
321 |
else:
|
|
|
15 |
import torch
|
16 |
import re
|
17 |
from typing import List, Union
|
|
|
|
|
|
|
18 |
|
19 |
# Load environment variables from .env
|
20 |
load_dotenv()
|
|
|
77 |
hist = stock.history(period="6mo")
|
78 |
if hist.empty:
|
79 |
return {"error": f"No data found for ticker {ticker}"}
|
80 |
+
|
81 |
|
82 |
model = ARIMA(hist["Close"], order=(5, 1, 0))
|
83 |
model_fit = model.fit()
|
84 |
forecast = model_fit.forecast(steps=days)
|
85 |
+
return forecast.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
def compare_stocks(ticker1, ticker2):
|
88 |
data1 = fetch_stock_data(ticker1)
|
|
|
234 |
st.write("Debug: User Query ->", query)
|
235 |
with st.spinner("Processing..."):
|
236 |
try:
|
237 |
+
# Run the agent and get the response
|
238 |
+
response = agent_executor.run(query) # Correct method is `run()`
|
239 |
st.success("Response:")
|
240 |
st.write(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
except Exception as e:
|
242 |
st.error(f"An error occurred: {e}")
|
243 |
+
# Log the full LLM output for debugging
|
244 |
if hasattr(e, "output"):
|
245 |
st.write("Raw Output:", e.output)
|
246 |
else:
|
requirements.txt
CHANGED
@@ -5,6 +5,4 @@ pandas
|
|
5 |
langchain
|
6 |
langchain_huggingface
|
7 |
python-dotenv
|
8 |
-
statsmodels
|
9 |
-
plotly
|
10 |
-
datetime
|
|
|
5 |
langchain
|
6 |
langchain_huggingface
|
7 |
python-dotenv
|
8 |
+
statsmodels
|
|
|
|