Spaces:
Running
Running
import streamlit as st | |
import yfinance as yf | |
import requests | |
import pandas as pd | |
from langchain.agents import initialize_agent, AgentType | |
from langchain.tools import Tool | |
from langchain_huggingface import HuggingFacePipeline | |
import os | |
from dotenv import load_dotenv | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import torch | |
from langchain.prompts import PromptTemplate | |
# Load environment variables | |
load_dotenv() | |
NEWSAPI_KEY = os.getenv("NEWSAPI_KEY") | |
access_token = os.getenv("API_KEY") | |
# Initialize model and pipeline | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b-it", | |
torch_dtype=torch.bfloat16, | |
token=access_token | |
) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# Define improved prompt template | |
prompt_template = PromptTemplate( | |
input_variables=["input"], | |
template="""Answer the following question as best you can. You have access to the following tools: | |
Stock Data Fetcher(ticker) - Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple). | |
Stock News Fetcher(ticker) - Fetch recent news articles about a stock ticker. | |
Moving Average Calculator(ticker, window=5) - Calculate the moving average of a stock over a 5-day window. | |
Use the following format: | |
Question: the input question you must answer | |
Thought: you should always think about what to do | |
Action: the action to take, should be one of [Stock Data Fetcher, Stock News Fetcher, Moving Average Calculator] | |
Action Input: the input to the action | |
Observation: the result of the action | |
(this Thought/Action/Action Input/Observation can repeat N times) | |
Thought: I now know the final answer | |
Final Answer: the final answer to the original input question | |
Strictly follow this format. Do not provide a Final Answer until all Observations are collected. | |
Begin! | |
Question: {input} | |
""" | |
) | |
# Helper functions | |
def fetch_stock_data(ticker): | |
try: | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="1mo") | |
if hist.empty: | |
return {"error": f"No data found for ticker {ticker}"} | |
return hist.tail(5).to_dict() | |
except Exception as e: | |
return {"error": str(e)} | |
def fetch_stock_news(ticker, NEWSAPI_KEY): | |
api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}" | |
response = requests.get(api_url) | |
if response.status_code == 200: | |
articles = response.json().get('articles', []) | |
return [{"title": article['title'], "description": article['description']} for article in articles[:5]] | |
else: | |
return [{"error": "Unable to fetch news."}] | |
def calculate_moving_average(ticker, window=5): | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="1mo") | |
hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean() | |
return hist[["Close", f"{window}-day MA"]].tail(5) | |
# Tools | |
stock_data_tool = Tool( | |
name="Stock Data Fetcher", | |
func=fetch_stock_data, | |
description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)." | |
) | |
stock_news_tool = Tool( | |
name="Stock News Fetcher", | |
func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY), | |
description="Fetch recent news articles about a stock ticker." | |
) | |
moving_average_tool = Tool( | |
name="Moving Average Calculator", | |
func=calculate_moving_average, | |
description="Calculate the moving average of a stock over a 5-day window." | |
) | |
# Initialize agent | |
tools = [stock_data_tool, stock_news_tool, moving_average_tool] | |
llm = HuggingFacePipeline(pipeline=pipe) | |
agent = initialize_agent( | |
tools=tools, | |
llm=llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
prompt=prompt_template, | |
verbose=True, | |
handle_parsing_errors=True | |
) | |
# Streamlit interface | |
st.title("Trading Helper Agent") | |
query = st.text_input("Enter your query:") | |
if st.button("Submit"): | |
if query: | |
with st.spinner("Processing..."): | |
try: | |
response = agent.run(query) | |
st.success("Response:") | |
st.write(response) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
else: | |
st.warning("Please enter a query.") | |