Spaces:
Running
Running
# ai_providers.py | |
from abc import ABC, abstractmethod | |
from typing import Dict, Any, Optional | |
import os | |
from dotenv import load_dotenv | |
import time | |
# Corrected import statement | |
from config import get_ai_config # Import the function instead of AI_CONFIG | |
class AIProvider(ABC): | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
pass | |
def get_config(self) -> Dict[str, Any]: | |
pass | |
class GroqProvider(AIProvider): | |
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-math-7b-instruct"): | |
self.api_key = api_key | |
self.model = model | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
from groq import Groq | |
client = Groq(api_key=self.api_key) | |
# Configure default parameters | |
params = { | |
"temperature": kwargs.get("temperature", 0.7), | |
"max_tokens": kwargs.get("max_tokens", 4000), | |
"top_p": kwargs.get("top_p", 1.0), | |
"stop": kwargs.get("stop", None) | |
} | |
# Add retry logic for robustness | |
max_retries = 3 | |
retry_delay = 1 | |
for attempt in range(max_retries): | |
try: | |
completion = client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
**params | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise e | |
time.sleep(retry_delay * (attempt + 1)) | |
def get_config(self) -> Dict[str, Any]: | |
return { | |
"config_list": [{ | |
"model": self.model, | |
"api_key": self.api_key, | |
"temperature": 0.7, | |
"max_tokens": 4000 | |
}] | |
} | |
class OpenAIProvider(AIProvider): | |
def __init__(self, api_key: str, model: str = "gpt-4"): | |
self.api_key = api_key | |
self.model = model | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
import openai | |
openai.api_key = self.api_key | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
**kwargs | |
) | |
return response.choices[0].message.content | |
def get_config(self) -> Dict[str, Any]: | |
return { | |
"config_list": [{ | |
"model": self.model, | |
"api_key": self.api_key, | |
"temperature": 0.7, | |
"max_tokens": 4000 | |
}] | |
} | |
class AnthropicProvider(AIProvider): | |
def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"): | |
self.api_key = api_key | |
self.model = model | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
import anthropic | |
client = anthropic.Client(api_key=self.api_key) | |
response = client.messages.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
**kwargs | |
) | |
return response.content[0].text | |
def get_config(self) -> Dict[str, Any]: | |
return { | |
"config_list": [{ | |
"model": self.model, | |
"api_key": self.api_key, | |
"temperature": 0.7, | |
"max_tokens": 4000 | |
}] | |
} | |
class AIProviderFactory: | |
def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider: | |
if not api_key: | |
load_dotenv() | |
providers = { | |
"openai": (OpenAIProvider, os.getenv("OPENAI_API_KEY"), "gpt-4"), | |
"groq": (GroqProvider, os.getenv("GROQ_API_KEY"), "deepseek-ai/deepseek-r1-distill-llama-70b"), | |
"anthropic": (AnthropicProvider, os.getenv("ANTHROPIC_API_KEY"), "claude-3-opus-20240229") | |
} | |
if provider_type not in providers: | |
raise ValueError(f"Unsupported provider: {provider_type}") | |
ProviderClass, default_api_key, default_model = providers[provider_type] | |
return ProviderClass( | |
api_key=api_key or default_api_key, | |
model=model or default_model | |
) |