Dhahlan2000 commited on
Commit
20d39a9
·
1 Parent(s): 2def788

fixed error

Browse files
Files changed (1) hide show
  1. app.py +42 -6
app.py CHANGED
@@ -9,6 +9,33 @@ import os
9
  from dotenv import load_dotenv
10
  from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
11
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  load_dotenv()
14
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
@@ -28,10 +55,16 @@ def validate_ticker(ticker):
28
 
29
 
30
  def fetch_stock_data(ticker):
31
- ticker = validate_ticker(ticker) # Validate and clean input
32
- stock = yf.Ticker(ticker)
33
- hist = stock.history(period="1mo")
34
- return hist.tail(5)
 
 
 
 
 
 
35
 
36
  def fetch_stock_news(ticker, NEWSAPI_KEY):
37
  api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
@@ -53,9 +86,10 @@ llm = HuggingFacePipeline(pipeline=pipe)
53
  stock_data_tool = Tool(
54
  name="Stock Data Fetcher",
55
  func=fetch_stock_data,
56
- description="Fetch recent stock data for a valid stock ticker (e.g., AAPL for Apple)."
57
  )
58
 
 
59
  stock_news_tool = Tool(
60
  name="Stock News Fetcher",
61
  func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
@@ -74,10 +108,12 @@ agent = initialize_agent(
74
  tools=tools,
75
  llm=llm,
76
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
 
77
  verbose=True,
78
- handle_parsing_errors=True # Enables automatic handling of parsing issues
79
  )
80
 
 
81
  print(fetch_stock_data("AAPL"))
82
  print(fetch_stock_news("AAPL", NEWSAPI_KEY))
83
  print(calculate_moving_average("AAPL"))
 
9
  from dotenv import load_dotenv
10
  from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
11
  import torch
12
+ from langchain.prompts import PromptTemplate
13
+
14
+ prompt_template = PromptTemplate(
15
+ input_variables=["input"],
16
+ template="""Answer the following question as best you can. You have access to the following tools:
17
+
18
+ Stock Data Fetcher(ticker) - Fetch recent stock data for a valid stock ticker (e.g., AAPL for Apple).
19
+ Stock News Fetcher(ticker) - Fetch recent news articles about a stock ticker.
20
+ Moving Average Calculator(ticker, window=5) - Calculate the moving average of a stock over a 5-day window.
21
+
22
+ Use the following format:
23
+
24
+ Question: the input question you must answer
25
+ Thought: you should always think about what to do
26
+ Action: the action to take, should be one of [Stock Data Fetcher, Stock News Fetcher, Moving Average Calculator]
27
+ Action Input: the input to the action
28
+ Observation: the result of the action
29
+ ... (this Thought/Action/Action Input/Observation can repeat N times)
30
+ Thought: I now know the final answer
31
+ Final Answer: the final answer to the original input question
32
+
33
+ Begin!
34
+
35
+ Question: {input}
36
+ """
37
+ )
38
+
39
 
40
  load_dotenv()
41
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
 
55
 
56
 
57
  def fetch_stock_data(ticker):
58
+ try:
59
+ ticker = ticker.strip().upper()
60
+ stock = yf.Ticker(ticker)
61
+ hist = stock.history(period="1mo")
62
+ if hist.empty:
63
+ return {"error": f"No data found for ticker {ticker}"}
64
+ return hist.tail(5).to_dict()
65
+ except Exception as e:
66
+ return {"error": str(e)}
67
+
68
 
69
  def fetch_stock_news(ticker, NEWSAPI_KEY):
70
  api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
 
86
  stock_data_tool = Tool(
87
  name="Stock Data Fetcher",
88
  func=fetch_stock_data,
89
+ description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)."
90
  )
91
 
92
+
93
  stock_news_tool = Tool(
94
  name="Stock News Fetcher",
95
  func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
 
108
  tools=tools,
109
  llm=llm,
110
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
111
+ prompt=prompt_template,
112
  verbose=True,
113
+ handle_parsing_errors=True
114
  )
115
 
116
+
117
  print(fetch_stock_data("AAPL"))
118
  print(fetch_stock_news("AAPL", NEWSAPI_KEY))
119
  print(calculate_moving_average("AAPL"))