Hassankhwileh commited on
Commit
8200031
·
verified ·
1 Parent(s): 2e1e749

Update ai_providers.py

Browse files
Files changed (1) hide show
  1. ai_providers.py +22 -103
ai_providers.py CHANGED
@@ -1,45 +1,32 @@
1
  # ai_providers.py
2
 
3
- from abc import ABC, abstractmethod
4
- from typing import Dict, Any, Optional
5
- import os
6
- from dotenv import load_dotenv
7
  import time
 
 
8
 
9
- class AIProvider(ABC):
10
- @abstractmethod
11
- def get_completion(self, prompt: str, **kwargs) -> str:
12
- pass
13
-
14
- @abstractmethod
15
- def get_config(self) -> Dict[str, Any]:
16
- pass
17
-
18
- class GroqProvider(AIProvider):
19
- def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-math-7b-instruct"):
20
- self.api_key = api_key
21
- self.model = model
22
 
23
  def get_completion(self, prompt: str, **kwargs) -> str:
24
- from groq import Groq
25
-
26
- client = Groq(api_key=self.api_key)
27
-
28
- # Configure default parameters
29
  params = {
30
- "temperature": kwargs.get("temperature", 0.7),
31
- "max_tokens": kwargs.get("max_tokens", 4000),
32
- "top_p": kwargs.get("top_p", 1.0),
33
  "stop": kwargs.get("stop", None)
34
  }
35
 
36
- # Add retry logic for robustness
37
  max_retries = 3
38
  retry_delay = 1
39
 
40
  for attempt in range(max_retries):
41
  try:
42
- completion = client.chat.completions.create(
43
  model=self.model,
44
  messages=[{"role": "user", "content": prompt}],
45
  **params
@@ -48,88 +35,20 @@ class GroqProvider(AIProvider):
48
 
49
  except Exception as e:
50
  if attempt == max_retries - 1:
51
- raise e
52
  time.sleep(retry_delay * (attempt + 1))
53
 
54
  def get_config(self) -> Dict[str, Any]:
 
55
  return {
56
  "config_list": [{
57
  "model": self.model,
58
- "api_key": self.api_key,
59
- "temperature": 0.7,
60
- "max_tokens": 4000
61
- }]
62
- }
63
-
64
- class OpenAIProvider(AIProvider):
65
- def __init__(self, api_key: str, model: str = "gpt-4"):
66
- self.api_key = api_key
67
- self.model = model
68
-
69
- def get_completion(self, prompt: str, **kwargs) -> str:
70
- import openai
71
- openai.api_key = self.api_key
72
-
73
- response = openai.ChatCompletion.create(
74
- model=self.model,
75
- messages=[{"role": "user", "content": prompt}],
76
- **kwargs
77
- )
78
- return response.choices[0].message.content
79
-
80
- def get_config(self) -> Dict[str, Any]:
81
- return {
82
- "config_list": [{
83
- "model": self.model,
84
- "api_key": self.api_key,
85
- "temperature": 0.7,
86
- "max_tokens": 4000
87
  }]
88
  }
89
 
90
- class AnthropicProvider(AIProvider):
91
- def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
92
- self.api_key = api_key
93
- self.model = model
94
-
95
- def get_completion(self, prompt: str, **kwargs) -> str:
96
- import anthropic
97
-
98
- client = anthropic.Client(api_key=self.api_key)
99
- response = client.messages.create(
100
- model=self.model,
101
- messages=[{"role": "user", "content": prompt}],
102
- **kwargs
103
- )
104
- return response.content[0].text
105
-
106
- def get_config(self) -> Dict[str, Any]:
107
- return {
108
- "config_list": [{
109
- "model": self.model,
110
- "api_key": self.api_key,
111
- "temperature": 0.7,
112
- "max_tokens": 4000
113
- }]
114
- }
115
-
116
- class AIProviderFactory:
117
- @staticmethod
118
- def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider:
119
- if not api_key:
120
- load_dotenv()
121
-
122
- providers = {
123
- "openai": (OpenAIProvider, os.getenv("OPENAI_API_KEY"), "gpt-4"),
124
- "groq": (GroqProvider, os.getenv("GROQ_API_KEY"), "deepseek-ai/deepseek-r1-distill-llama-70b"),
125
- "anthropic": (AnthropicProvider, os.getenv("ANTHROPIC_API_KEY"), "claude-3-opus-20240229")
126
- }
127
-
128
- if provider_type not in providers:
129
- raise ValueError(f"Unsupported provider: {provider_type}")
130
-
131
- ProviderClass, default_api_key, default_model = providers[provider_type]
132
- return ProviderClass(
133
- api_key=api_key or default_api_key,
134
- model=model or default_model
135
- )
 
1
  # ai_providers.py
2
 
3
+ from typing import Dict, Any
 
 
 
4
  import time
5
+ from groq import Groq
6
+ from config import GROQ_API_KEY, AI_CONFIG
7
 
8
+ class GroqProvider:
9
+ def __init__(self):
10
+ if not GROQ_API_KEY:
11
+ raise ValueError("GROQ_API_KEY must be set in environment variables")
12
+ self.client = Groq(api_key=GROQ_API_KEY)
13
+ self.model = AI_CONFIG["model"]
 
 
 
 
 
 
 
14
 
15
  def get_completion(self, prompt: str, **kwargs) -> str:
16
+ """Get completion from Groq API with retry logic"""
 
 
 
 
17
  params = {
18
+ "temperature": kwargs.get("temperature", AI_CONFIG["temperature"]),
19
+ "max_tokens": kwargs.get("max_tokens", AI_CONFIG["max_tokens"]),
20
+ "top_p": kwargs.get("top_p", AI_CONFIG["top_p"]),
21
  "stop": kwargs.get("stop", None)
22
  }
23
 
 
24
  max_retries = 3
25
  retry_delay = 1
26
 
27
  for attempt in range(max_retries):
28
  try:
29
+ completion = self.client.chat.completions.create(
30
  model=self.model,
31
  messages=[{"role": "user", "content": prompt}],
32
  **params
 
35
 
36
  except Exception as e:
37
  if attempt == max_retries - 1:
38
+ raise Exception(f"Failed to get completion from Groq after {max_retries} attempts: {str(e)}")
39
  time.sleep(retry_delay * (attempt + 1))
40
 
41
  def get_config(self) -> Dict[str, Any]:
42
+ """Get Groq configuration"""
43
  return {
44
  "config_list": [{
45
  "model": self.model,
46
+ "api_key": GROQ_API_KEY,
47
+ "temperature": AI_CONFIG["temperature"],
48
+ "max_tokens": AI_CONFIG["max_tokens"],
49
+ "top_p": AI_CONFIG["top_p"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  }]
51
  }
52
 
53
+ # Global instance
54
+ groq_provider = GroqProvider()