Spaces:
Running
Running
Commit
·
b701198
1
Parent(s):
dc4a988
fixed by chatgpt
Browse files
app.py
CHANGED
@@ -7,27 +7,27 @@ from langchain.tools import Tool
|
|
7 |
from langchain_huggingface import HuggingFacePipeline
|
8 |
import os
|
9 |
from dotenv import load_dotenv
|
10 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
|
11 |
import torch
|
12 |
|
13 |
-
|
14 |
load_dotenv()
|
15 |
NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
|
16 |
access_token = os.getenv("API_KEY")
|
17 |
|
18 |
-
tokenizer
|
|
|
19 |
model = AutoModelForCausalLM.from_pretrained(
|
20 |
"google/gemma-2b-it",
|
21 |
torch_dtype=torch.bfloat16,
|
22 |
-
token
|
23 |
)
|
24 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=2048)
|
25 |
|
|
|
26 |
def validate_ticker(ticker):
|
27 |
-
# Ensure ticker is uppercase and length is reasonable (1-5 characters)
|
28 |
return ticker.strip().upper()
|
29 |
|
30 |
-
|
31 |
def fetch_stock_data(ticker):
|
32 |
try:
|
33 |
ticker = ticker.strip().upper()
|
@@ -39,7 +39,6 @@ def fetch_stock_data(ticker):
|
|
39 |
except Exception as e:
|
40 |
return {"error": str(e)}
|
41 |
|
42 |
-
|
43 |
def fetch_stock_news(ticker, NEWSAPI_KEY):
|
44 |
api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
|
45 |
response = requests.get(api_url)
|
@@ -55,15 +54,16 @@ def calculate_moving_average(ticker, window=5):
|
|
55 |
hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
|
56 |
return hist[["Close", f"{window}-day MA"]].tail(5)
|
57 |
|
|
|
58 |
llm = HuggingFacePipeline(pipeline=pipe)
|
59 |
|
|
|
60 |
stock_data_tool = Tool(
|
61 |
name="Stock Data Fetcher",
|
62 |
func=fetch_stock_data,
|
63 |
description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)."
|
64 |
)
|
65 |
|
66 |
-
|
67 |
stock_news_tool = Tool(
|
68 |
name="Stock News Fetcher",
|
69 |
func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
|
@@ -78,6 +78,7 @@ moving_average_tool = Tool(
|
|
78 |
|
79 |
tools = [stock_data_tool, stock_news_tool, moving_average_tool]
|
80 |
|
|
|
81 |
agent = initialize_agent(
|
82 |
tools=tools,
|
83 |
llm=llm,
|
@@ -86,12 +87,7 @@ agent = initialize_agent(
|
|
86 |
handle_parsing_errors=True
|
87 |
)
|
88 |
|
89 |
-
|
90 |
-
print(fetch_stock_data("AAPL"))
|
91 |
-
print(fetch_stock_news("AAPL", NEWSAPI_KEY))
|
92 |
-
print(calculate_moving_average("AAPL"))
|
93 |
-
|
94 |
-
|
95 |
st.title("Trading Helper Agent")
|
96 |
|
97 |
query = st.text_input("Enter your query:")
|
@@ -101,8 +97,8 @@ if st.button("Submit"):
|
|
101 |
st.write("Debug: User Query ->", query)
|
102 |
with st.spinner("Processing..."):
|
103 |
try:
|
104 |
-
|
105 |
-
|
106 |
st.success("Response:")
|
107 |
st.write(response)
|
108 |
except Exception as e:
|
|
|
7 |
from langchain_huggingface import HuggingFacePipeline
|
8 |
import os
|
9 |
from dotenv import load_dotenv
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
11 |
import torch
|
12 |
|
13 |
+
# Load environment variables from .env
|
14 |
load_dotenv()
|
15 |
NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
|
16 |
access_token = os.getenv("API_KEY")
|
17 |
|
18 |
+
# Initialize the model and tokenizer for the HuggingFace pipeline
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token)
|
20 |
model = AutoModelForCausalLM.from_pretrained(
|
21 |
"google/gemma-2b-it",
|
22 |
torch_dtype=torch.bfloat16,
|
23 |
+
token=access_token
|
24 |
)
|
25 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=2048)
|
26 |
|
27 |
+
# Define functions for fetching stock data, news, and moving averages
|
28 |
def validate_ticker(ticker):
|
|
|
29 |
return ticker.strip().upper()
|
30 |
|
|
|
31 |
def fetch_stock_data(ticker):
|
32 |
try:
|
33 |
ticker = ticker.strip().upper()
|
|
|
39 |
except Exception as e:
|
40 |
return {"error": str(e)}
|
41 |
|
|
|
42 |
def fetch_stock_news(ticker, NEWSAPI_KEY):
|
43 |
api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
|
44 |
response = requests.get(api_url)
|
|
|
54 |
hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
|
55 |
return hist[["Close", f"{window}-day MA"]].tail(5)
|
56 |
|
57 |
+
# Initialize HuggingFace pipeline
|
58 |
llm = HuggingFacePipeline(pipeline=pipe)
|
59 |
|
60 |
+
# Define LangChain tools
|
61 |
stock_data_tool = Tool(
|
62 |
name="Stock Data Fetcher",
|
63 |
func=fetch_stock_data,
|
64 |
description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)."
|
65 |
)
|
66 |
|
|
|
67 |
stock_news_tool = Tool(
|
68 |
name="Stock News Fetcher",
|
69 |
func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
|
|
|
78 |
|
79 |
tools = [stock_data_tool, stock_news_tool, moving_average_tool]
|
80 |
|
81 |
+
# Initialize the LangChain agent
|
82 |
agent = initialize_agent(
|
83 |
tools=tools,
|
84 |
llm=llm,
|
|
|
87 |
handle_parsing_errors=True
|
88 |
)
|
89 |
|
90 |
+
# Streamlit app
|
|
|
|
|
|
|
|
|
|
|
91 |
st.title("Trading Helper Agent")
|
92 |
|
93 |
query = st.text_input("Enter your query:")
|
|
|
97 |
st.write("Debug: User Query ->", query)
|
98 |
with st.spinner("Processing..."):
|
99 |
try:
|
100 |
+
# Run the agent and get the response
|
101 |
+
response = agent.run(query) # Correct method is `run()`
|
102 |
st.success("Response:")
|
103 |
st.write(response)
|
104 |
except Exception as e:
|