Hassankhwileh commited on
Commit
892edcf
·
verified ·
1 Parent(s): 5f69491

Update ai_providers.py

Browse files
Files changed (1) hide show
  1. ai_providers.py +77 -36
ai_providers.py CHANGED
@@ -1,74 +1,115 @@
1
  # ai_providers.py
2
 
3
- from abc import ABC, abstractmethod
4
- from typing import Dict, Any, Optional
5
  import os
6
  from dotenv import load_dotenv
7
  import time
 
 
 
 
8
 
9
- class AIProvider(ABC):
10
- @abstractmethod
11
- def get_completion(self, prompt: str, **kwargs) -> str:
12
- pass
13
 
14
- @abstractmethod
15
- def get_config(self) -> Dict[str, Any]:
16
- pass
 
 
 
 
17
 
18
- class GroqProvider(AIProvider):
19
- def __init__(self, api_key: str):
20
- self.api_key = api_key
21
- self.model = "deepseek-ai/deepseek-r1-distill-llama-70b" # Fixed model for Deepseek R1
22
-
23
- def get_completion(self, prompt: str, **kwargs) -> str:
24
- from groq import Groq
25
-
26
- client = Groq(api_key=self.api_key)
27
-
28
- # Configure default parameters
29
- params = {
30
- "temperature": kwargs.get("temperature", 0.7),
31
- "max_tokens": kwargs.get("max_tokens", 4000),
32
- "top_p": kwargs.get("top_p", 1.0),
33
- "stop": kwargs.get("stop", None)
34
- }
35
-
36
- # Add retry logic for robustness
37
  max_retries = 3
38
  retry_delay = 1
39
-
40
  for attempt in range(max_retries):
41
  try:
42
- completion = client.chat.completions.create(
43
  model=self.model,
44
  messages=[{"role": "user", "content": prompt}],
45
- **params
 
 
 
46
  )
47
  return completion.choices[0].message.content
48
-
49
  except Exception as e:
50
  if attempt == max_retries - 1:
51
  raise e
52
  time.sleep(retry_delay * (attempt + 1))
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def get_config(self) -> Dict[str, Any]:
55
  return {
 
56
  "config_list": [{
57
  "model": self.model,
58
  "api_key": self.api_key,
59
  "temperature": 0.7,
60
- "max_tokens": 4000
 
61
  }]
62
  }
63
 
64
  class AIProviderFactory:
65
  @staticmethod
66
  def create_provider(api_key: Optional[str] = None) -> AIProvider:
 
67
  if not api_key:
 
68
  load_dotenv()
69
  api_key = os.getenv("GROQ_API_KEY")
70
 
71
- if not api_key:
 
72
  raise ValueError("GROQ_API_KEY must be provided either directly or through environment variables")
73
 
74
- return GroqProvider(api_key=api_key)
 
 
 
 
 
 
 
1
  # ai_providers.py
2
 
3
+ from typing import Dict, Any, Optional, List
 
4
  import os
5
  from dotenv import load_dotenv
6
  import time
7
+ from groq import Groq
8
+ from groq.types.chat import ChatCompletion
9
+ from langchain.llms.base import LLM
10
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
11
 
12
+ from pydantic import Field, BaseModel
 
 
 
13
 
14
+ class GroqLLM(LLM, BaseModel):
15
+ api_key: str = Field(..., description="Groq API key")
16
+ model: str = Field(default="deepseek-r1-distill-llama-70b", description="Model name")
17
+ temperature: float = Field(default=0.7, description="Sampling temperature")
18
+ max_tokens: int = Field(default=4000, description="Maximum number of tokens to generate")
19
+ top_p: float = Field(default=1.0, description="Top p sampling parameter")
20
+ client: Any = Field(default=None, description="Groq client instance")
21
 
22
+ def __init__(self, **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.client = Groq(api_key=self.api_key)
25
+
26
+ @property
27
+ def _llm_type(self) -> str:
28
+ return "groq"
29
+
30
+ def _call(
31
+ self,
32
+ prompt: str,
33
+ stop: Optional[List[str]] = None,
34
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
35
+ **kwargs,
36
+ ) -> str:
 
 
 
 
37
  max_retries = 3
38
  retry_delay = 1
39
+
40
  for attempt in range(max_retries):
41
  try:
42
+ completion: ChatCompletion = self.client.chat.completions.create(
43
  model=self.model,
44
  messages=[{"role": "user", "content": prompt}],
45
+ temperature=self.temperature,
46
+ max_tokens=self.max_tokens,
47
+ top_p=self.top_p,
48
+ stop=stop,
49
  )
50
  return completion.choices[0].message.content
51
+
52
  except Exception as e:
53
  if attempt == max_retries - 1:
54
  raise e
55
  time.sleep(retry_delay * (attempt + 1))
56
+
57
+ class AIProvider:
58
+ def get_completion(self, prompt: str, **kwargs) -> str:
59
+ pass
60
+
61
+ def get_config(self) -> Dict[str, Any]:
62
+ pass
63
+
64
+ class GroqProvider(AIProvider):
65
+ def __init__(self, api_key: str):
66
+ if not api_key:
67
+ raise ValueError("API key cannot be None or empty")
68
+ self.api_key = api_key
69
+ self.model = "deepseek-r1-distill-llama-70b"
70
+ self.llm = GroqLLM(
71
+ api_key=api_key,
72
+ model=self.model,
73
+ temperature=0.7,
74
+ max_tokens=4000,
75
+ top_p=1.0
76
+ )
77
+
78
+ def get_completion(self, prompt: str, **kwargs) -> str:
79
+ return self.llm(prompt)
80
+
81
+ def get_llm(self) -> GroqLLM:
82
+ return self.llm
83
+
84
  def get_config(self) -> Dict[str, Any]:
85
  return {
86
+ "llm": self.llm,
87
  "config_list": [{
88
  "model": self.model,
89
  "api_key": self.api_key,
90
  "temperature": 0.7,
91
+ "max_tokens": 4000,
92
+ "api_base": "https://api.groq.com/openai/v1"
93
  }]
94
  }
95
 
96
  class AIProviderFactory:
97
  @staticmethod
98
  def create_provider(api_key: Optional[str] = None) -> AIProvider:
99
+ # Try to get API key from parameter first
100
  if not api_key:
101
+ # If not provided, try to load from environment
102
  load_dotenv()
103
  api_key = os.getenv("GROQ_API_KEY")
104
 
105
+ # Validate API key
106
+ if not api_key or not isinstance(api_key, str) or not api_key.strip():
107
  raise ValueError("GROQ_API_KEY must be provided either directly or through environment variables")
108
 
109
+ # Clean up API key
110
+ api_key = api_key.strip()
111
+
112
+ try:
113
+ return GroqProvider(api_key=api_key)
114
+ except Exception as e:
115
+ raise ValueError(f"Failed to create Groq provider: {str(e)}")