Spaces:
Running
Running
import streamlit as st | |
import yfinance as yf | |
import requests | |
import os | |
from dotenv import load_dotenv | |
from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser | |
from langchain.schema import AgentAction, AgentFinish, HumanMessage | |
from langchain.prompts import BaseChatPromptTemplate | |
from langchain.tools import Tool | |
from langchain_huggingface import HuggingFacePipeline | |
from langchain import LLMChain | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from langchain.memory import ConversationBufferWindowMemory | |
from statsmodels.tsa.arima.model import ARIMA | |
import torch | |
import re | |
from typing import List, Union | |
import plotly.graph_objects as go | |
import pandas as pd | |
from datetime import datetime, timedelta | |
# Load environment variables from .env | |
load_dotenv() | |
NEWSAPI_KEY = os.getenv("NEWSAPI_KEY") | |
access_token = os.getenv("API_KEY") | |
# Check if the access token and API key are present | |
if not NEWSAPI_KEY or not access_token: | |
raise ValueError("NEWSAPI_KEY or API_KEY not found in .env file.") | |
# Initialize the model and tokenizer for the HuggingFace pipeline | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b-it", | |
torch_dtype=torch.bfloat16, | |
token=access_token | |
) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512) | |
# Define functions for fetching stock data, news, and moving averages | |
def validate_ticker(ticker): | |
return ticker.strip().upper() | |
def fetch_stock_data(ticker): | |
try: | |
ticker = ticker.strip().upper() | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="1mo") | |
if hist.empty: | |
return {"error": f"No data found for ticker {ticker}"} | |
return hist.tail(5).to_dict() | |
except Exception as e: | |
return {"error": str(e)} | |
def fetch_stock_news(ticker, NEWSAPI_KEY): | |
api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}" | |
response = requests.get(api_url) | |
if response.status_code == 200: | |
articles = response.json().get('articles', []) | |
return [{"title": article['title'], "description": article['description']} for article in articles[:5]] | |
else: | |
return [{"error": "Unable to fetch news."}] | |
def calculate_moving_average(ticker, window=5): | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="1mo") | |
hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean() | |
return hist[["Close", f"{window}-day MA"]].tail(5) | |
def analyze_sentiment(news_articles): | |
sentiment_pipeline = pipeline("sentiment-analysis") | |
results = [{"title": article["title"], | |
"sentiment": sentiment_pipeline(article["description"] or article["title"])[0]} | |
for article in news_articles] | |
return results | |
def predict_stock_price(ticker, days=5): | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="6mo") | |
if hist.empty: | |
return {"error": f"No data found for ticker {ticker}"} | |
model = ARIMA(hist["Close"], order=(5, 1, 0)) | |
model_fit = model.fit() | |
forecast = model_fit.forecast(steps=days) | |
# Create future dates for the forecast | |
last_date = hist.index[-1] | |
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=days, freq='B') | |
return { | |
"historical_data": hist[["Close"]].to_dict(), | |
"forecast_dates": future_dates.strftime('%Y-%m-%d').tolist(), | |
"forecast_values": forecast.tolist() | |
} | |
def compare_stocks(ticker1, ticker2): | |
data1 = fetch_stock_data(ticker1) | |
data2 = fetch_stock_data(ticker2) | |
if "error" in data1 or "error" in data2: | |
return {"error": "Could not fetch stock data for comparison."} | |
comparison = { | |
ticker1: {"recent_close": data1["Close"][-1]}, | |
ticker2: {"recent_close": data2["Close"][-1]}, | |
} | |
return comparison | |
# Define LangChain tools | |
stock_data_tool = Tool( | |
name="Stock Data Fetcher", | |
func=fetch_stock_data, | |
description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)." | |
) | |
stock_news_tool = Tool( | |
name="Stock News Fetcher", | |
func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY), | |
description="Fetch recent news articles about a stock ticker." | |
) | |
moving_average_tool = Tool( | |
name="Moving Average Calculator", | |
func=calculate_moving_average, | |
description="Calculate the moving average of a stock over a 5-day window." | |
) | |
sentiment_tool = Tool( | |
name="News Sentiment Analyzer", | |
func=lambda ticker: analyze_sentiment(fetch_stock_news(ticker, NEWSAPI_KEY)), | |
description="Analyze the sentiment of recent news articles about a stock ticker." | |
) | |
stock_prediction_tool = Tool( | |
name="Stock Price Predictor", | |
func=predict_stock_price, | |
description="Predict future stock prices for a given ticker based on historical data." | |
) | |
stock_comparator_tool = Tool( | |
name="Stock Comparator", | |
func=lambda tickers: compare_stocks(*tickers.split(',')), | |
description="Compare the recent performance of two stocks given their tickers, e.g., 'AAPL,MSFT'." | |
) | |
tools = [ | |
stock_data_tool, | |
stock_news_tool, | |
moving_average_tool, | |
sentiment_tool, | |
stock_prediction_tool, | |
stock_comparator_tool | |
] | |
# Set up a prompt template with history | |
template_with_history = """You are SearchGPT, a professional search engine who provides informative answers to users. Answer the following questions as best you can. You have access to the following tools: | |
{tools} | |
Use the following format: | |
Question: the input question you must answer | |
Thought: you should always think about what to do | |
Action: the action to take, should be one of [{tool_names}] | |
Action Input: the input to the action | |
Observation: the result of the action | |
(this Thought/Action/Action Input/Observation can repeat N times) | |
Thought: I now know the final answer | |
Final Answer: the final answer to the original input question | |
Begin! Remember to give detailed, informative answers | |
Previous conversation history: | |
{history} | |
New question: {input} | |
{agent_scratchpad}""" | |
# Set up the prompt template | |
class CustomPromptTemplate(BaseChatPromptTemplate): | |
template: str | |
tools: List[Tool] | |
def format_messages(self, **kwargs) -> str: | |
intermediate_steps = kwargs.pop("intermediate_steps") | |
thoughts = "" | |
for action, observation in intermediate_steps: | |
thoughts += action.log | |
thoughts += f"\nObservation: {observation}\nThought: " | |
kwargs["agent_scratchpad"] = thoughts | |
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) | |
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) | |
formatted = self.template.format(**kwargs) | |
return [HumanMessage(content=formatted)] | |
prompt_with_history = CustomPromptTemplate( | |
template=template_with_history, | |
tools=tools, | |
input_variables=["input", "intermediate_steps", "history"] | |
) | |
# Custom output parser | |
class CustomOutputParser(AgentOutputParser): | |
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: | |
if "Final Answer:" in llm_output: | |
return AgentFinish( | |
return_values={"output": llm_output.split("Final Answer:")[-1].strip()}, | |
log=llm_output, | |
) | |
regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)" | |
match = re.search(regex, llm_output, re.DOTALL) | |
if not match: | |
raise ValueError(f"Could not parse LLM output: `{llm_output}`") | |
action = match.group(1).strip() | |
action_input = match.group(2) | |
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) | |
output_parser = CustomOutputParser() | |
# Initialize HuggingFace pipeline | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# LLM chain | |
llm_chain = LLMChain(llm=llm, prompt=prompt_with_history) | |
tool_names = [tool.name for tool in tools] | |
agent = LLMSingleActionAgent( | |
llm_chain=llm_chain, | |
output_parser=output_parser, | |
stop=["\nObservation:"], | |
allowed_tools=tool_names | |
) | |
memory = ConversationBufferWindowMemory(k=2) | |
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory) | |
# Streamlit app | |
st.title("Trading Helper Agent") | |
query = st.text_input("Enter your query:") | |
if st.button("Submit"): | |
if query: | |
st.write("Debug: User Query ->", query) | |
with st.spinner("Processing..."): | |
try: | |
response = agent_executor.run(query) | |
st.success("Response:") | |
st.write(response) | |
# Extract ticker from query (basic extraction, you might want to make this more robust) | |
possible_tickers = re.findall(r'[A-Z]{1,5}', query.upper()) | |
if possible_tickers: | |
ticker = possible_tickers[0] | |
# Create tabs for different visualizations | |
tab1, tab2, tab3 = st.tabs(["Price History", "Price Prediction", "Technical Indicators"]) | |
with tab1: | |
st.subheader(f"{ticker} Price History") | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="1y") | |
fig = go.Figure() | |
fig.add_trace(go.Candlestick( | |
x=hist.index, | |
open=hist['Open'], | |
high=hist['High'], | |
low=hist['Low'], | |
close=hist['Close'], | |
name='OHLC' | |
)) | |
fig.update_layout(title=f"{ticker} Stock Price", xaxis_title="Date", yaxis_title="Price") | |
st.plotly_chart(fig) | |
with tab2: | |
st.subheader(f"{ticker} Price Prediction") | |
prediction_data = predict_stock_price(ticker) | |
if "error" not in prediction_data: | |
hist_df = pd.DataFrame(prediction_data["historical_data"]) | |
fig = go.Figure() | |
# Plot historical data | |
fig.add_trace(go.Scatter( | |
x=hist_df.index, | |
y=hist_df['Close'], | |
name='Historical', | |
line=dict(color='blue') | |
)) | |
# Plot predicted data | |
fig.add_trace(go.Scatter( | |
x=prediction_data["forecast_dates"], | |
y=prediction_data["forecast_values"], | |
name='Predicted', | |
line=dict(color='red', dash='dash') | |
)) | |
fig.update_layout(title=f"{ticker} Price Prediction", xaxis_title="Date", yaxis_title="Price") | |
st.plotly_chart(fig) | |
with tab3: | |
st.subheader(f"{ticker} Technical Indicators") | |
# Calculate and plot moving averages | |
hist['MA5'] = hist['Close'].rolling(window=5).mean() | |
hist['MA20'] = hist['Close'].rolling(window=20).mean() | |
hist['MA50'] = hist['Close'].rolling(window=50).mean() | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=hist.index, y=hist['Close'], name='Price')) | |
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA5'], name='5-day MA')) | |
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA20'], name='20-day MA')) | |
fig.add_trace(go.Scatter(x=hist.index, y=hist['MA50'], name='50-day MA')) | |
fig.update_layout(title=f"{ticker} Technical Indicators", xaxis_title="Date", yaxis_title="Price") | |
st.plotly_chart(fig) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
if hasattr(e, "output"): | |
st.write("Raw Output:", e.output) | |
else: | |
st.warning("Please enter a query.") | |