Hassankhwileh commited on
Commit
2367c1f
·
verified ·
1 Parent(s): 8fdbdd0

Update ai_providers.py

Browse files
Files changed (1) hide show
  1. ai_providers.py +106 -22
ai_providers.py CHANGED
@@ -1,32 +1,48 @@
1
  # ai_providers.py
2
 
3
- from typing import Dict, Any
 
 
 
4
  import time
5
- from groq import Groq
6
- from config import GROQ_API_KEY, AI_CONFIG
7
 
8
- class GroqProvider:
9
- def __init__(self):
10
- if not GROQ_API_KEY:
11
- raise ValueError("GROQ_API_KEY must be set in environment variables")
12
- self.client = Groq(api_key=GROQ_API_KEY)
13
- self.model = AI_CONFIG["model"]
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_completion(self, prompt: str, **kwargs) -> str:
16
- """Get completion from Groq API with retry logic"""
 
 
 
 
17
  params = {
18
- "temperature": kwargs.get("temperature", AI_CONFIG["temperature"]),
19
- "max_tokens": kwargs.get("max_tokens", AI_CONFIG["max_tokens"]),
20
- "top_p": kwargs.get("top_p", AI_CONFIG["top_p"]),
21
  "stop": kwargs.get("stop", None)
22
  }
23
 
 
24
  max_retries = 3
25
  retry_delay = 1
26
 
27
  for attempt in range(max_retries):
28
  try:
29
- completion = self.client.chat.completions.create(
30
  model=self.model,
31
  messages=[{"role": "user", "content": prompt}],
32
  **params
@@ -35,20 +51,88 @@ class GroqProvider:
35
 
36
  except Exception as e:
37
  if attempt == max_retries - 1:
38
- raise Exception(f"Failed to get completion from Groq after {max_retries} attempts: {str(e)}")
39
  time.sleep(retry_delay * (attempt + 1))
40
 
41
  def get_config(self) -> Dict[str, Any]:
42
- """Get Groq configuration"""
43
  return {
44
  "config_list": [{
45
  "model": self.model,
46
- "api_key": GROQ_API_KEY,
47
- "temperature": AI_CONFIG["temperature"],
48
- "max_tokens": AI_CONFIG["max_tokens"],
49
- "top_p": AI_CONFIG["top_p"]
50
  }]
51
  }
52
 
53
- # Global instance
54
- groq_provider = GroqProvider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ai_providers.py
2
 
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, Any, Optional
5
+ import os
6
+ from dotenv import load_dotenv
7
  import time
 
 
8
 
9
+ # Corrected import statement
10
+ from config import get_ai_config # Import the function instead of AI_CONFIG
11
+
12
+ class AIProvider(ABC):
13
+ @abstractmethod
14
+ def get_completion(self, prompt: str, **kwargs) -> str:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_config(self) -> Dict[str, Any]:
19
+ pass
20
+
21
+ class GroqProvider(AIProvider):
22
+ def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-math-7b-instruct"):
23
+ self.api_key = api_key
24
+ self.model = model
25
 
26
  def get_completion(self, prompt: str, **kwargs) -> str:
27
+ from groq import Groq
28
+
29
+ client = Groq(api_key=self.api_key)
30
+
31
+ # Configure default parameters
32
  params = {
33
+ "temperature": kwargs.get("temperature", 0.7),
34
+ "max_tokens": kwargs.get("max_tokens", 4000),
35
+ "top_p": kwargs.get("top_p", 1.0),
36
  "stop": kwargs.get("stop", None)
37
  }
38
 
39
+ # Add retry logic for robustness
40
  max_retries = 3
41
  retry_delay = 1
42
 
43
  for attempt in range(max_retries):
44
  try:
45
+ completion = client.chat.completions.create(
46
  model=self.model,
47
  messages=[{"role": "user", "content": prompt}],
48
  **params
 
51
 
52
  except Exception as e:
53
  if attempt == max_retries - 1:
54
+ raise e
55
  time.sleep(retry_delay * (attempt + 1))
56
 
57
  def get_config(self) -> Dict[str, Any]:
 
58
  return {
59
  "config_list": [{
60
  "model": self.model,
61
+ "api_key": self.api_key,
62
+ "temperature": 0.7,
63
+ "max_tokens": 4000
 
64
  }]
65
  }
66
 
67
+ class OpenAIProvider(AIProvider):
68
+ def __init__(self, api_key: str, model: str = "gpt-4"):
69
+ self.api_key = api_key
70
+ self.model = model
71
+
72
+ def get_completion(self, prompt: str, **kwargs) -> str:
73
+ import openai
74
+ openai.api_key = self.api_key
75
+
76
+ response = openai.ChatCompletion.create(
77
+ model=self.model,
78
+ messages=[{"role": "user", "content": prompt}],
79
+ **kwargs
80
+ )
81
+ return response.choices[0].message.content
82
+
83
+ def get_config(self) -> Dict[str, Any]:
84
+ return {
85
+ "config_list": [{
86
+ "model": self.model,
87
+ "api_key": self.api_key,
88
+ "temperature": 0.7,
89
+ "max_tokens": 4000
90
+ }]
91
+ }
92
+
93
+ class AnthropicProvider(AIProvider):
94
+ def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
95
+ self.api_key = api_key
96
+ self.model = model
97
+
98
+ def get_completion(self, prompt: str, **kwargs) -> str:
99
+ import anthropic
100
+
101
+ client = anthropic.Client(api_key=self.api_key)
102
+ response = client.messages.create(
103
+ model=self.model,
104
+ messages=[{"role": "user", "content": prompt}],
105
+ **kwargs
106
+ )
107
+ return response.content[0].text
108
+
109
+ def get_config(self) -> Dict[str, Any]:
110
+ return {
111
+ "config_list": [{
112
+ "model": self.model,
113
+ "api_key": self.api_key,
114
+ "temperature": 0.7,
115
+ "max_tokens": 4000
116
+ }]
117
+ }
118
+
119
+ class AIProviderFactory:
120
+ @staticmethod
121
+ def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider:
122
+ if not api_key:
123
+ load_dotenv()
124
+
125
+ providers = {
126
+ "openai": (OpenAIProvider, os.getenv("OPENAI_API_KEY"), "gpt-4"),
127
+ "groq": (GroqProvider, os.getenv("GROQ_API_KEY"), "deepseek-ai/deepseek-r1-distill-llama-70b"),
128
+ "anthropic": (AnthropicProvider, os.getenv("ANTHROPIC_API_KEY"), "claude-3-opus-20240229")
129
+ }
130
+
131
+ if provider_type not in providers:
132
+ raise ValueError(f"Unsupported provider: {provider_type}")
133
+
134
+ ProviderClass, default_api_key, default_model = providers[provider_type]
135
+ return ProviderClass(
136
+ api_key=api_key or default_api_key,
137
+ model=model or default_model
138
+ )