import streamlit as st import yfinance as yf import requests import pandas as pd from langchain.agents import initialize_agent, AgentType from langchain.tools import Tool from langchain_huggingface import HuggingFacePipeline import os from dotenv import load_dotenv from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch from langchain.prompts import PromptTemplate # Load environment variables load_dotenv() NEWSAPI_KEY = os.getenv("NEWSAPI_KEY") access_token = os.getenv("API_KEY") # Initialize model and pipeline tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token) model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b-it", torch_dtype=torch.bfloat16, token=access_token ) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) # Define improved prompt template prompt_template = PromptTemplate( input_variables=["input"], template="""Answer the following question as best you can. You have access to the following tools: Stock Data Fetcher(ticker) - Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple). Stock News Fetcher(ticker) - Fetch recent news articles about a stock ticker. Moving Average Calculator(ticker, window=5) - Calculate the moving average of a stock over a 5-day window. Use the following format: Question: the input question you must answer Thought: you should always think about what to do Action: the action to take, should be one of [Stock Data Fetcher, Stock News Fetcher, Moving Average Calculator] Action Input: the input to the action Observation: the result of the action ... (this Thought/Action/Action Input/Observation can repeat N times) Thought: I now know the final answer Final Answer: the final answer to the original input question Strictly follow this format. Do not provide a Final Answer until all Observations are collected. Begin! Question: {input} """ ) # Helper functions def fetch_stock_data(ticker): try: stock = yf.Ticker(ticker) hist = stock.history(period="1mo") if hist.empty: return {"error": f"No data found for ticker {ticker}"} return hist.tail(5).to_dict() except Exception as e: return {"error": str(e)} def fetch_stock_news(ticker, NEWSAPI_KEY): api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}" response = requests.get(api_url) if response.status_code == 200: articles = response.json().get('articles', []) return [{"title": article['title'], "description": article['description']} for article in articles[:5]] else: return [{"error": "Unable to fetch news."}] def calculate_moving_average(ticker, window=5): stock = yf.Ticker(ticker) hist = stock.history(period="1mo") hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean() return hist[["Close", f"{window}-day MA"]].tail(5) # Tools stock_data_tool = Tool( name="Stock Data Fetcher", func=fetch_stock_data, description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)." ) stock_news_tool = Tool( name="Stock News Fetcher", func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY), description="Fetch recent news articles about a stock ticker." ) moving_average_tool = Tool( name="Moving Average Calculator", func=calculate_moving_average, description="Calculate the moving average of a stock over a 5-day window." ) # Initialize agent tools = [stock_data_tool, stock_news_tool, moving_average_tool] llm = HuggingFacePipeline(pipeline=pipe) agent = initialize_agent( tools=tools, llm=llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, prompt=prompt_template, verbose=True, handle_parsing_errors=True ) # Streamlit interface st.title("Trading Helper Agent") query = st.text_input("Enter your query:") if st.button("Submit"): if query: with st.spinner("Processing..."): try: response = agent.run(query) st.success("Response:") st.write(response) except Exception as e: st.error(f"An error occurred: {e}") else: st.warning("Please enter a query.")