# ai_providers.py from abc import ABC, abstractmethod from typing import Dict, Any, Optional import os from dotenv import load_dotenv import time # Corrected import statement from config import get_ai_config # Import the function instead of AI_CONFIG class AIProvider(ABC): @abstractmethod def get_completion(self, prompt: str, **kwargs) -> str: pass @abstractmethod def get_config(self) -> Dict[str, Any]: pass class GroqProvider(AIProvider): def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-math-7b-instruct"): self.api_key = api_key self.model = model def get_completion(self, prompt: str, **kwargs) -> str: from groq import Groq client = Groq(api_key=self.api_key) # Configure default parameters params = { "temperature": kwargs.get("temperature", 0.7), "max_tokens": kwargs.get("max_tokens", 4000), "top_p": kwargs.get("top_p", 1.0), "stop": kwargs.get("stop", None) } # Add retry logic for robustness max_retries = 3 retry_delay = 1 for attempt in range(max_retries): try: completion = client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": prompt}], **params ) return completion.choices[0].message.content except Exception as e: if attempt == max_retries - 1: raise e time.sleep(retry_delay * (attempt + 1)) def get_config(self) -> Dict[str, Any]: return { "config_list": [{ "model": self.model, "api_key": self.api_key, "temperature": 0.7, "max_tokens": 4000 }] } class OpenAIProvider(AIProvider): def __init__(self, api_key: str, model: str = "gpt-4"): self.api_key = api_key self.model = model def get_completion(self, prompt: str, **kwargs) -> str: import openai openai.api_key = self.api_key response = openai.ChatCompletion.create( model=self.model, messages=[{"role": "user", "content": prompt}], **kwargs ) return response.choices[0].message.content def get_config(self) -> Dict[str, Any]: return { "config_list": [{ "model": self.model, "api_key": self.api_key, "temperature": 0.7, "max_tokens": 4000 }] } class AnthropicProvider(AIProvider): def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"): self.api_key = api_key self.model = model def get_completion(self, prompt: str, **kwargs) -> str: import anthropic client = anthropic.Client(api_key=self.api_key) response = client.messages.create( model=self.model, messages=[{"role": "user", "content": prompt}], **kwargs ) return response.content[0].text def get_config(self) -> Dict[str, Any]: return { "config_list": [{ "model": self.model, "api_key": self.api_key, "temperature": 0.7, "max_tokens": 4000 }] } class AIProviderFactory: @staticmethod def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider: if not api_key: load_dotenv() providers = { "openai": (OpenAIProvider, os.getenv("OPENAI_API_KEY"), "gpt-4"), "groq": (GroqProvider, os.getenv("GROQ_API_KEY"), "deepseek-ai/deepseek-r1-distill-llama-70b"), "anthropic": (AnthropicProvider, os.getenv("ANTHROPIC_API_KEY"), "claude-3-opus-20240229") } if provider_type not in providers: raise ValueError(f"Unsupported provider: {provider_type}") ProviderClass, default_api_key, default_model = providers[provider_type] return ProviderClass( api_key=api_key or default_api_key, model=model or default_model )