LegalAI-DS / ai_providers.py
Hassankhwileh's picture
Update ai_providers.py
8200031 verified
raw
history blame
1.91 kB
# ai_providers.py
from typing import Dict, Any
import time
from groq import Groq
from config import GROQ_API_KEY, AI_CONFIG
class GroqProvider:
def __init__(self):
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY must be set in environment variables")
self.client = Groq(api_key=GROQ_API_KEY)
self.model = AI_CONFIG["model"]
def get_completion(self, prompt: str, **kwargs) -> str:
"""Get completion from Groq API with retry logic"""
params = {
"temperature": kwargs.get("temperature", AI_CONFIG["temperature"]),
"max_tokens": kwargs.get("max_tokens", AI_CONFIG["max_tokens"]),
"top_p": kwargs.get("top_p", AI_CONFIG["top_p"]),
"stop": kwargs.get("stop", None)
}
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
completion = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**params
)
return completion.choices[0].message.content
except Exception as e:
if attempt == max_retries - 1:
raise Exception(f"Failed to get completion from Groq after {max_retries} attempts: {str(e)}")
time.sleep(retry_delay * (attempt + 1))
def get_config(self) -> Dict[str, Any]:
"""Get Groq configuration"""
return {
"config_list": [{
"model": self.model,
"api_key": GROQ_API_KEY,
"temperature": AI_CONFIG["temperature"],
"max_tokens": AI_CONFIG["max_tokens"],
"top_p": AI_CONFIG["top_p"]
}]
}
# Global instance
groq_provider = GroqProvider()