Spaces:
Running
Running
Commit
·
a7916e8
1
Parent(s):
12e25ff
Add Alpaca trading strategy and backtesting functionality to app.py
Browse files- Introduced MLTrader class for executing trading strategies based on sentiment analysis.
- Integrated Alpaca API for trading operations, including position sizing and order management.
- Added execute_alpaca_trading function to run and backtest the trading strategy using YahooDataBacktesting.
- Updated tools list to include Alpaca Trading Executor for enhanced trading capabilities.
app.py
CHANGED
@@ -15,6 +15,13 @@ from statsmodels.tsa.arima.model import ARIMA
|
|
15 |
import torch
|
16 |
import re
|
17 |
from typing import List, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Load environment variables from .env
|
20 |
load_dotenv()
|
@@ -26,6 +33,17 @@ access_token = os.getenv("API_KEY")
|
|
26 |
if not NEWSAPI_KEY or not access_token:
|
27 |
raise ValueError("NEWSAPI_KEY or API_KEY not found in .env file.")
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# Initialize the model and tokenizer for the HuggingFace pipeline
|
30 |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token)
|
31 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -78,7 +96,6 @@ def predict_stock_price(ticker, days=5):
|
|
78 |
if hist.empty:
|
79 |
return {"error": f"No data found for ticker {ticker}"}
|
80 |
|
81 |
-
|
82 |
model = ARIMA(hist["Close"], order=(5, 1, 0))
|
83 |
model_fit = model.fit()
|
84 |
forecast = model_fit.forecast(steps=days)
|
@@ -95,6 +112,80 @@ def compare_stocks(ticker1, ticker2):
|
|
95 |
}
|
96 |
return comparison
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# Define LangChain tools
|
100 |
stock_data_tool = Tool(
|
@@ -133,13 +224,20 @@ stock_comparator_tool = Tool(
|
|
133 |
description="Compare the recent performance of two stocks given their tickers, e.g., 'AAPL,MSFT'."
|
134 |
)
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
tools = [
|
137 |
stock_data_tool,
|
138 |
stock_news_tool,
|
139 |
moving_average_tool,
|
140 |
sentiment_tool,
|
141 |
stock_prediction_tool,
|
142 |
-
stock_comparator_tool
|
|
|
143 |
]
|
144 |
|
145 |
# Set up a prompt template with history
|
@@ -170,20 +268,20 @@ New question: {input}
|
|
170 |
class CustomPromptTemplate(BaseChatPromptTemplate):
|
171 |
template: str
|
172 |
tools: List[Tool]
|
173 |
-
|
174 |
def format_messages(self, **kwargs) -> str:
|
175 |
intermediate_steps = kwargs.pop("intermediate_steps")
|
176 |
thoughts = ""
|
177 |
for action, observation in intermediate_steps:
|
178 |
thoughts += action.log
|
179 |
thoughts += f"\nObservation: {observation}\nThought: "
|
180 |
-
|
181 |
kwargs["agent_scratchpad"] = thoughts
|
182 |
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
183 |
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
184 |
formatted = self.template.format(**kwargs)
|
185 |
return [HumanMessage(content=formatted)]
|
186 |
-
|
187 |
prompt_with_history = CustomPromptTemplate(
|
188 |
template=template_with_history,
|
189 |
tools=tools,
|
@@ -205,7 +303,7 @@ class CustomOutputParser(AgentOutputParser):
|
|
205 |
action = match.group(1).strip()
|
206 |
action_input = match.group(2)
|
207 |
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
|
208 |
-
|
209 |
output_parser = CustomOutputParser()
|
210 |
|
211 |
# Initialize HuggingFace pipeline
|
@@ -240,8 +338,4 @@ if st.button("Submit"):
|
|
240 |
st.write(response)
|
241 |
except Exception as e:
|
242 |
st.error(f"An error occurred: {e}")
|
243 |
-
# Log the full LLM
|
244 |
-
if hasattr(e, "output"):
|
245 |
-
st.write("Raw Output:", e.output)
|
246 |
-
else:
|
247 |
-
st.warning("Please enter a query.")
|
|
|
15 |
import torch
|
16 |
import re
|
17 |
from typing import List, Union
|
18 |
+
from datetime import datetime
|
19 |
+
from lumibot.brokers import Alpaca
|
20 |
+
from lumibot.backtesting import YahooDataBacktesting
|
21 |
+
from lumibot.strategies.strategy import Strategy
|
22 |
+
from alpaca_trade_api import REST
|
23 |
+
from timedelta import Timedelta
|
24 |
+
from finbert_utils import estimate_sentiment
|
25 |
|
26 |
# Load environment variables from .env
|
27 |
load_dotenv()
|
|
|
33 |
if not NEWSAPI_KEY or not access_token:
|
34 |
raise ValueError("NEWSAPI_KEY or API_KEY not found in .env file.")
|
35 |
|
36 |
+
# Alpaca credentials
|
37 |
+
API_KEY = "PKWJW14IWRJMLJ4CSZ6V"
|
38 |
+
API_SECRET = "zJOGwUvhYBfYJQRz6jc309PLNfTQ4VcxuygFxxfh"
|
39 |
+
BASE_URL = "https://paper-api.alpaca.markets/v2"
|
40 |
+
|
41 |
+
ALPACA_CREDS = {
|
42 |
+
"API_KEY": API_KEY,
|
43 |
+
"API_SECRET": API_SECRET,
|
44 |
+
"PAPER": True
|
45 |
+
}
|
46 |
+
|
47 |
# Initialize the model and tokenizer for the HuggingFace pipeline
|
48 |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token)
|
49 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
96 |
if hist.empty:
|
97 |
return {"error": f"No data found for ticker {ticker}"}
|
98 |
|
|
|
99 |
model = ARIMA(hist["Close"], order=(5, 1, 0))
|
100 |
model_fit = model.fit()
|
101 |
forecast = model_fit.forecast(steps=days)
|
|
|
112 |
}
|
113 |
return comparison
|
114 |
|
115 |
+
def execute_alpaca_trading():
|
116 |
+
class MLTrader(Strategy):
|
117 |
+
def initialize(self, symbol: str = "SPY", cash_at_risk: float = .5):
|
118 |
+
self.symbol = symbol
|
119 |
+
self.sleeptime = "24H"
|
120 |
+
self.last_trade = None
|
121 |
+
self.cash_at_risk = cash_at_risk
|
122 |
+
self.api = REST(base_url=BASE_URL, key_id=API_KEY, secret_key=API_SECRET)
|
123 |
+
|
124 |
+
def position_sizing(self):
|
125 |
+
cash = self.get_cash()
|
126 |
+
last_price = self.get_last_price(self.symbol)
|
127 |
+
quantity = round(cash * self.cash_at_risk / last_price, 0)
|
128 |
+
return cash, last_price, quantity
|
129 |
+
|
130 |
+
def get_dates(self):
|
131 |
+
today = self.get_datetime()
|
132 |
+
three_days_prior = today - Timedelta(days=3)
|
133 |
+
return today.strftime('%Y-%m-%d'), three_days_prior.strftime('%Y-%m-%d')
|
134 |
+
|
135 |
+
def get_sentiment(self):
|
136 |
+
today, three_days_prior = self.get_dates()
|
137 |
+
news = self.api.get_news(symbol=self.symbol,
|
138 |
+
start=three_days_prior,
|
139 |
+
end=today)
|
140 |
+
news = [ev.__dict__["_raw"]["headline"] for ev in news]
|
141 |
+
probability, sentiment = estimate_sentiment(news)
|
142 |
+
return probability, sentiment
|
143 |
+
|
144 |
+
def on_trading_iteration(self):
|
145 |
+
cash, last_price, quantity = self.position_sizing()
|
146 |
+
probability, sentiment = self.get_sentiment()
|
147 |
+
|
148 |
+
if cash > last_price:
|
149 |
+
if sentiment == "positive" and probability > .999:
|
150 |
+
if self.last_trade == "sell":
|
151 |
+
self.sell_all()
|
152 |
+
order = self.create_order(
|
153 |
+
self.symbol,
|
154 |
+
quantity,
|
155 |
+
"buy",
|
156 |
+
type="bracket",
|
157 |
+
take_profit_price=last_price * 1.20,
|
158 |
+
stop_loss_price=last_price * .95
|
159 |
+
)
|
160 |
+
self.submit_order(order)
|
161 |
+
self.last_trade = "buy"
|
162 |
+
elif sentiment == "negative" and probability > .999:
|
163 |
+
if self.last_trade == "buy":
|
164 |
+
self.sell_all()
|
165 |
+
order = self.create_order(
|
166 |
+
self.symbol,
|
167 |
+
quantity,
|
168 |
+
"sell",
|
169 |
+
type="bracket",
|
170 |
+
take_profit_price=last_price * .8,
|
171 |
+
stop_loss_price=last_price * 1.05
|
172 |
+
)
|
173 |
+
self.submit_order(order)
|
174 |
+
self.last_trade = "sell"
|
175 |
+
|
176 |
+
start_date = datetime(2021, 1, 1)
|
177 |
+
end_date = datetime(2024, 10, 1)
|
178 |
+
broker = Alpaca(ALPACA_CREDS)
|
179 |
+
strategy = MLTrader(name='mlstrat', broker=broker,
|
180 |
+
parameters={"symbol": "SPY",
|
181 |
+
"cash_at_risk": .5})
|
182 |
+
strategy.backtest(
|
183 |
+
YahooDataBacktesting,
|
184 |
+
start_date,
|
185 |
+
end_date,
|
186 |
+
parameters={"symbol": "SPY", "cash_at_risk": .5}
|
187 |
+
)
|
188 |
+
return "Alpaca trading strategy executed and backtested."
|
189 |
|
190 |
# Define LangChain tools
|
191 |
stock_data_tool = Tool(
|
|
|
224 |
description="Compare the recent performance of two stocks given their tickers, e.g., 'AAPL,MSFT'."
|
225 |
)
|
226 |
|
227 |
+
alpaca_trading_tool = Tool(
|
228 |
+
name="Alpaca Trading Executor",
|
229 |
+
func=execute_alpaca_trading,
|
230 |
+
description="Run a trading strategy using Alpaca API and backtest results."
|
231 |
+
)
|
232 |
+
|
233 |
tools = [
|
234 |
stock_data_tool,
|
235 |
stock_news_tool,
|
236 |
moving_average_tool,
|
237 |
sentiment_tool,
|
238 |
stock_prediction_tool,
|
239 |
+
stock_comparator_tool,
|
240 |
+
alpaca_trading_tool
|
241 |
]
|
242 |
|
243 |
# Set up a prompt template with history
|
|
|
268 |
class CustomPromptTemplate(BaseChatPromptTemplate):
|
269 |
template: str
|
270 |
tools: List[Tool]
|
271 |
+
|
272 |
def format_messages(self, **kwargs) -> str:
|
273 |
intermediate_steps = kwargs.pop("intermediate_steps")
|
274 |
thoughts = ""
|
275 |
for action, observation in intermediate_steps:
|
276 |
thoughts += action.log
|
277 |
thoughts += f"\nObservation: {observation}\nThought: "
|
278 |
+
|
279 |
kwargs["agent_scratchpad"] = thoughts
|
280 |
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
281 |
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
282 |
formatted = self.template.format(**kwargs)
|
283 |
return [HumanMessage(content=formatted)]
|
284 |
+
|
285 |
prompt_with_history = CustomPromptTemplate(
|
286 |
template=template_with_history,
|
287 |
tools=tools,
|
|
|
303 |
action = match.group(1).strip()
|
304 |
action_input = match.group(2)
|
305 |
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
|
306 |
+
|
307 |
output_parser = CustomOutputParser()
|
308 |
|
309 |
# Initialize HuggingFace pipeline
|
|
|
338 |
st.write(response)
|
339 |
except Exception as e:
|
340 |
st.error(f"An error occurred: {e}")
|
341 |
+
# Log the full LLM
|
|
|
|
|
|
|
|