Dhahlan2000 commited on
Commit
a7f8412
·
1 Parent(s): 96b4247

fixed and improved with github code help

Browse files
Files changed (1) hide show
  1. app.py +94 -20
app.py CHANGED
@@ -1,27 +1,31 @@
1
  import streamlit as st
2
  import yfinance as yf
3
  import requests
4
- import pandas as pd
5
- from langchain.agents import initialize_agent, AgentType
6
- from langchain.tools import Tool
7
- from langchain_huggingface import HuggingFacePipeline
8
  import os
9
  from dotenv import load_dotenv
 
 
 
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
11
  import torch
 
 
12
 
13
  # Load environment variables from .env
14
  load_dotenv()
 
15
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
16
  access_token = os.getenv("API_KEY")
17
 
 
 
 
 
18
  # Initialize the model and tokenizer for the HuggingFace pipeline
19
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- "google/gemma-2b-it",
22
- torch_dtype=torch.bfloat16,
23
- token=access_token
24
- )
25
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
26
 
27
  # Define functions for fetching stock data, news, and moving averages
@@ -54,9 +58,6 @@ def calculate_moving_average(ticker, window=5):
54
  hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
55
  return hist[["Close", f"{window}-day MA"]].tail(5)
56
 
57
- # Initialize HuggingFace pipeline
58
- llm = HuggingFacePipeline(pipeline=pipe)
59
-
60
  # Define LangChain tools
61
  stock_data_tool = Tool(
62
  name="Stock Data Fetcher",
@@ -78,15 +79,88 @@ moving_average_tool = Tool(
78
 
79
  tools = [stock_data_tool, stock_news_tool, moving_average_tool]
80
 
81
- # Initialize the LangChain agent
82
- agent = initialize_agent(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  tools=tools,
84
- llm=llm,
85
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
86
- verbose=True,
87
- handle_parsing_errors=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
 
 
 
 
90
  # Streamlit app
91
  st.title("Trading Helper Agent")
92
 
@@ -98,7 +172,7 @@ if st.button("Submit"):
98
  with st.spinner("Processing..."):
99
  try:
100
  # Run the agent and get the response
101
- response = agent.run(query) # Correct method is `run()`
102
  st.success("Response:")
103
  st.write(response)
104
  except Exception as e:
 
1
  import streamlit as st
2
  import yfinance as yf
3
  import requests
 
 
 
 
4
  import os
5
  from dotenv import load_dotenv
6
+ from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
7
+ from langchain.prompts import BaseChatPromptTemplate
8
+ from langchain.tools import Tool
9
+ from langchain_huggingface import HuggingFacePipeline
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ from langchain.memory import ConversationBufferWindowMemory
12
  import torch
13
+ import re
14
+ from typing import List, Union
15
 
16
  # Load environment variables from .env
17
  load_dotenv()
18
+
19
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
20
  access_token = os.getenv("API_KEY")
21
 
22
+ # Check if the access token and API key are present
23
+ if not NEWSAPI_KEY or not access_token:
24
+ raise ValueError("NEWSAPI_KEY or API_KEY not found in .env file.")
25
+
26
  # Initialize the model and tokenizer for the HuggingFace pipeline
27
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
28
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
 
 
 
 
29
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
30
 
31
  # Define functions for fetching stock data, news, and moving averages
 
58
  hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
59
  return hist[["Close", f"{window}-day MA"]].tail(5)
60
 
 
 
 
61
  # Define LangChain tools
62
  stock_data_tool = Tool(
63
  name="Stock Data Fetcher",
 
79
 
80
  tools = [stock_data_tool, stock_news_tool, moving_average_tool]
81
 
82
+ # Set up a prompt template with history
83
+ template_with_history = """You are SearchGPT, a professional search engine who provides informative answers to users. Answer the following questions as best you can. You have access to the following tools:
84
+
85
+ {tools}
86
+
87
+ Use the following format:
88
+
89
+ Question: the input question you must answer
90
+ Thought: you should always think about what to do
91
+ Action: the action to take, should be one of [{tool_names}]
92
+ Action Input: the input to the action
93
+ Observation: the result of the action
94
+ ... (this Thought/Action/Action Input/Observation can repeat N times)
95
+ Thought: I now know the final answer
96
+ Final Answer: the final answer to the original input question
97
+
98
+ Begin! Remember to give detailed, informative answers
99
+
100
+ Previous conversation history:
101
+ {history}
102
+
103
+ New question: {input}
104
+ {agent_scratchpad}"""
105
+
106
+ # Set up the prompt template
107
+ class CustomPromptTemplate(BaseChatPromptTemplate):
108
+ template: str
109
+ tools: List[Tool]
110
+
111
+ def format_messages(self, **kwargs) -> str:
112
+ intermediate_steps = kwargs.pop("intermediate_steps")
113
+ thoughts = ""
114
+ for action, observation in intermediate_steps:
115
+ thoughts += action.log
116
+ thoughts += f"\nObservation: {observation}\nThought: "
117
+
118
+ kwargs["agent_scratchpad"] = thoughts
119
+ kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
120
+ kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
121
+ formatted = self.template.format(**kwargs)
122
+ return [HumanMessage(content=formatted)]
123
+
124
+ prompt_with_history = CustomPromptTemplate(
125
+ template=template_with_history,
126
  tools=tools,
127
+ input_variables=["input", "intermediate_steps", "history"]
128
+ )
129
+
130
+ # Custom output parser
131
+ class CustomOutputParser(AgentOutputParser):
132
+ def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
133
+ if "Final Answer:" in llm_output:
134
+ return AgentFinish(
135
+ return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
136
+ log=llm_output,
137
+ )
138
+ regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)"
139
+ match = re.search(regex, llm_output, re.DOTALL)
140
+ if not match:
141
+ raise ValueError(f"Could not parse LLM output: `{llm_output}`")
142
+ action = match.group(1).strip()
143
+ action_input = match.group(2)
144
+ return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
145
+
146
+ output_parser = CustomOutputParser()
147
+
148
+ # Initialize HuggingFace pipeline
149
+ llm = HuggingFacePipeline(pipeline=pipe)
150
+
151
+ # LLM chain
152
+ llm_chain = LLMChain(llm=llm, prompt=prompt_with_history)
153
+ tool_names = [tool.name for tool in tools]
154
+ agent = LLMSingleActionAgent(
155
+ llm_chain=llm_chain,
156
+ output_parser=output_parser,
157
+ stop=["\nObservation:"],
158
+ allowed_tools=tool_names
159
  )
160
 
161
+ memory = ConversationBufferWindowMemory(k=2)
162
+ agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)
163
+
164
  # Streamlit app
165
  st.title("Trading Helper Agent")
166
 
 
172
  with st.spinner("Processing..."):
173
  try:
174
  # Run the agent and get the response
175
+ response = agent_executor.run(query) # Correct method is `run()`
176
  st.success("Response:")
177
  st.write(response)
178
  except Exception as e: