LegalAI-DS / ai_providers.py
Hassankhwileh's picture
Update ai_providers.py
92b1cdc verified
raw
history blame
2.04 kB
from typing import Dict, Any, Optional
import os
from dotenv import load_dotenv
import time
class GroqProvider:
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
from groq import Groq
client = Groq(api_key=self.api_key)
# Configure default parameters
params = {
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 4000),
"top_p": kwargs.get("top_p", 1.0),
"stop": kwargs.get("stop", None)
}
# Add retry logic for robustness
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
completion = client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**params
)
return completion.choices[0].message.content
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(retry_delay * (attempt + 1))
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
def create_provider(api_key: Optional[str] = None, model: Optional[str] = None) -> GroqProvider:
if not api_key:
load_dotenv()
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY must be set in environment variables")
default_model = "deepseek-ai/deepseek-r1-distill-llama-70b"
return GroqProvider(api_key=api_key, model=model or default_model)