Spaces:
Running
Running
File size: 4,457 Bytes
0d24772 2367c1f 0d24772 2367c1f 0d24772 2367c1f 0d24772 2367c1f 0d24772 2367c1f 0d24772 2367c1f 0d24772 2367c1f 0d24772 2367c1f 0d24772 2367c1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# ai_providers.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import os
from dotenv import load_dotenv
import time
# Corrected import statement
from config import get_ai_config # Import the function instead of AI_CONFIG
class AIProvider(ABC):
@abstractmethod
def get_completion(self, prompt: str, **kwargs) -> str:
pass
@abstractmethod
def get_config(self) -> Dict[str, Any]:
pass
class GroqProvider(AIProvider):
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-math-7b-instruct"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
from groq import Groq
client = Groq(api_key=self.api_key)
# Configure default parameters
params = {
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 4000),
"top_p": kwargs.get("top_p", 1.0),
"stop": kwargs.get("stop", None)
}
# Add retry logic for robustness
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
completion = client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**params
)
return completion.choices[0].message.content
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(retry_delay * (attempt + 1))
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
class OpenAIProvider(AIProvider):
def __init__(self, api_key: str, model: str = "gpt-4"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
import openai
openai.api_key = self.api_key
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**kwargs
)
return response.choices[0].message.content
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
class AnthropicProvider(AIProvider):
def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
import anthropic
client = anthropic.Client(api_key=self.api_key)
response = client.messages.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**kwargs
)
return response.content[0].text
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
class AIProviderFactory:
@staticmethod
def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider:
if not api_key:
load_dotenv()
providers = {
"openai": (OpenAIProvider, os.getenv("OPENAI_API_KEY"), "gpt-4"),
"groq": (GroqProvider, os.getenv("GROQ_API_KEY"), "deepseek-ai/deepseek-r1-distill-llama-70b"),
"anthropic": (AnthropicProvider, os.getenv("ANTHROPIC_API_KEY"), "claude-3-opus-20240229")
}
if provider_type not in providers:
raise ValueError(f"Unsupported provider: {provider_type}")
ProviderClass, default_api_key, default_model = providers[provider_type]
return ProviderClass(
api_key=api_key or default_api_key,
model=model or default_model
) |