Dhahlan2000 commited on
Commit
ee1c031
·
1 Parent(s): 585716d

added gemma

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -7,9 +7,21 @@ from langchain.tools import Tool
7
  from langchain_huggingface import HuggingFacePipeline
8
  import os
9
  from dotenv import load_dotenv
 
 
10
 
11
  load_dotenv()
12
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
 
 
 
 
 
 
 
 
 
 
13
 
14
  def fetch_stock_data(ticker):
15
  print("fetching stock data")
@@ -34,15 +46,7 @@ def calculate_moving_average(ticker, window=5):
34
  hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
35
  return hist[["Close", f"{window}-day MA"]].tail(5)
36
 
37
- llm = HuggingFacePipeline.from_model_id(
38
- model_id="microsoft/Phi-3-mini-4k-instruct",
39
- task="text-generation",
40
- pipeline_kwargs={
41
- "max_new_tokens": 100,
42
- "top_k": 50,
43
- "temperature": 0.1,
44
- },
45
- )
46
 
47
  stock_data_tool = Tool(
48
  name="Stock Data Fetcher",
@@ -80,6 +84,7 @@ if st.button("Submit"):
80
  with st.spinner("Processing..."):
81
  try:
82
  response = agent.run(query)
 
83
  st.success("Response:")
84
  st.write(response)
85
  except Exception as e:
 
7
  from langchain_huggingface import HuggingFacePipeline
8
  import os
9
  from dotenv import load_dotenv
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
11
+ import torch
12
 
13
  load_dotenv()
14
  NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
15
+ access_token = os.getenv("API_KEY")
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token = access_token)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ "google/gemma-2b-it",
20
+ torch_dtype=torch.bfloat16,
21
+ token = access_token
22
+ )
23
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=100, top_k=50, temperature=0.1)
24
+
25
 
26
  def fetch_stock_data(ticker):
27
  print("fetching stock data")
 
46
  hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
47
  return hist[["Close", f"{window}-day MA"]].tail(5)
48
 
49
+ llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
 
 
50
 
51
  stock_data_tool = Tool(
52
  name="Stock Data Fetcher",
 
84
  with st.spinner("Processing..."):
85
  try:
86
  response = agent.run(query)
87
+ print(f"Response: {response}")
88
  st.success("Response:")
89
  st.write(response)
90
  except Exception as e: