LegalAI-DS / ai_providers.py
Hassankhwileh's picture
Update ai_providers.py
2367c1f verified
raw
history blame
4.46 kB
# ai_providers.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import os
from dotenv import load_dotenv
import time
# Corrected import statement
from config import get_ai_config # Import the function instead of AI_CONFIG
class AIProvider(ABC):
@abstractmethod
def get_completion(self, prompt: str, **kwargs) -> str:
pass
@abstractmethod
def get_config(self) -> Dict[str, Any]:
pass
class GroqProvider(AIProvider):
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-math-7b-instruct"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
from groq import Groq
client = Groq(api_key=self.api_key)
# Configure default parameters
params = {
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 4000),
"top_p": kwargs.get("top_p", 1.0),
"stop": kwargs.get("stop", None)
}
# Add retry logic for robustness
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
completion = client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**params
)
return completion.choices[0].message.content
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(retry_delay * (attempt + 1))
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
class OpenAIProvider(AIProvider):
def __init__(self, api_key: str, model: str = "gpt-4"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
import openai
openai.api_key = self.api_key
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**kwargs
)
return response.choices[0].message.content
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
class AnthropicProvider(AIProvider):
def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
import anthropic
client = anthropic.Client(api_key=self.api_key)
response = client.messages.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**kwargs
)
return response.content[0].text
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
class AIProviderFactory:
@staticmethod
def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider:
if not api_key:
load_dotenv()
providers = {
"openai": (OpenAIProvider, os.getenv("OPENAI_API_KEY"), "gpt-4"),
"groq": (GroqProvider, os.getenv("GROQ_API_KEY"), "deepseek-ai/deepseek-r1-distill-llama-70b"),
"anthropic": (AnthropicProvider, os.getenv("ANTHROPIC_API_KEY"), "claude-3-opus-20240229")
}
if provider_type not in providers:
raise ValueError(f"Unsupported provider: {provider_type}")
ProviderClass, default_api_key, default_model = providers[provider_type]
return ProviderClass(
api_key=api_key or default_api_key,
model=model or default_model
)