Dhahlan2000 commited on
Commit
b701198
·
1 Parent(s): dc4a988

fixed by chatgpt

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -7,27 +7,27 @@ from langchain.tools import Tool
7
  from langchain_huggingface import HuggingFacePipeline
8
  import os
9
  from dotenv import load_dotenv
10
- from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
11
  import torch
12
 
13
-
14
  load_dotenv()
15
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
16
  access_token = os.getenv("API_KEY")
17
 
18
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token = access_token)
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  "google/gemma-2b-it",
21
  torch_dtype=torch.bfloat16,
22
- token = access_token
23
  )
24
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=2048)
25
 
 
26
  def validate_ticker(ticker):
27
- # Ensure ticker is uppercase and length is reasonable (1-5 characters)
28
  return ticker.strip().upper()
29
 
30
-
31
  def fetch_stock_data(ticker):
32
  try:
33
  ticker = ticker.strip().upper()
@@ -39,7 +39,6 @@ def fetch_stock_data(ticker):
39
  except Exception as e:
40
  return {"error": str(e)}
41
 
42
-
43
  def fetch_stock_news(ticker, NEWSAPI_KEY):
44
  api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
45
  response = requests.get(api_url)
@@ -55,15 +54,16 @@ def calculate_moving_average(ticker, window=5):
55
  hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
56
  return hist[["Close", f"{window}-day MA"]].tail(5)
57
 
 
58
  llm = HuggingFacePipeline(pipeline=pipe)
59
 
 
60
  stock_data_tool = Tool(
61
  name="Stock Data Fetcher",
62
  func=fetch_stock_data,
63
  description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)."
64
  )
65
 
66
-
67
  stock_news_tool = Tool(
68
  name="Stock News Fetcher",
69
  func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
@@ -78,6 +78,7 @@ moving_average_tool = Tool(
78
 
79
  tools = [stock_data_tool, stock_news_tool, moving_average_tool]
80
 
 
81
  agent = initialize_agent(
82
  tools=tools,
83
  llm=llm,
@@ -86,12 +87,7 @@ agent = initialize_agent(
86
  handle_parsing_errors=True
87
  )
88
 
89
-
90
- print(fetch_stock_data("AAPL"))
91
- print(fetch_stock_news("AAPL", NEWSAPI_KEY))
92
- print(calculate_moving_average("AAPL"))
93
-
94
-
95
  st.title("Trading Helper Agent")
96
 
97
  query = st.text_input("Enter your query:")
@@ -101,8 +97,8 @@ if st.button("Submit"):
101
  st.write("Debug: User Query ->", query)
102
  with st.spinner("Processing..."):
103
  try:
104
- response = agent.invoke(query)
105
- print(f'Response: {response}')
106
  st.success("Response:")
107
  st.write(response)
108
  except Exception as e:
 
7
  from langchain_huggingface import HuggingFacePipeline
8
  import os
9
  from dotenv import load_dotenv
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
  import torch
12
 
13
+ # Load environment variables from .env
14
  load_dotenv()
15
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
16
  access_token = os.getenv("API_KEY")
17
 
18
+ # Initialize the model and tokenizer for the HuggingFace pipeline
19
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token)
20
  model = AutoModelForCausalLM.from_pretrained(
21
  "google/gemma-2b-it",
22
  torch_dtype=torch.bfloat16,
23
+ token=access_token
24
  )
25
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=2048)
26
 
27
+ # Define functions for fetching stock data, news, and moving averages
28
  def validate_ticker(ticker):
 
29
  return ticker.strip().upper()
30
 
 
31
  def fetch_stock_data(ticker):
32
  try:
33
  ticker = ticker.strip().upper()
 
39
  except Exception as e:
40
  return {"error": str(e)}
41
 
 
42
  def fetch_stock_news(ticker, NEWSAPI_KEY):
43
  api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
44
  response = requests.get(api_url)
 
54
  hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
55
  return hist[["Close", f"{window}-day MA"]].tail(5)
56
 
57
+ # Initialize HuggingFace pipeline
58
  llm = HuggingFacePipeline(pipeline=pipe)
59
 
60
+ # Define LangChain tools
61
  stock_data_tool = Tool(
62
  name="Stock Data Fetcher",
63
  func=fetch_stock_data,
64
  description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)."
65
  )
66
 
 
67
  stock_news_tool = Tool(
68
  name="Stock News Fetcher",
69
  func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
 
78
 
79
  tools = [stock_data_tool, stock_news_tool, moving_average_tool]
80
 
81
+ # Initialize the LangChain agent
82
  agent = initialize_agent(
83
  tools=tools,
84
  llm=llm,
 
87
  handle_parsing_errors=True
88
  )
89
 
90
+ # Streamlit app
 
 
 
 
 
91
  st.title("Trading Helper Agent")
92
 
93
  query = st.text_input("Enter your query:")
 
97
  st.write("Debug: User Query ->", query)
98
  with st.spinner("Processing..."):
99
  try:
100
+ # Run the agent and get the response
101
+ response = agent.run(query) # Correct method is `run()`
102
  st.success("Response:")
103
  st.write(response)
104
  except Exception as e: