Spaces:
Running
Running
Update ai_providers.py
Browse files- ai_providers.py +26 -11
ai_providers.py
CHANGED
@@ -1,9 +1,21 @@
|
|
|
|
|
|
|
|
1 |
from typing import Dict, Any, Optional
|
2 |
import os
|
3 |
from dotenv import load_dotenv
|
4 |
import time
|
5 |
|
6 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"):
|
8 |
self.api_key = api_key
|
9 |
self.model = model
|
@@ -49,13 +61,16 @@ class GroqProvider:
|
|
49 |
}]
|
50 |
}
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
1 |
+
# ai_providers.py
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
from typing import Dict, Any, Optional
|
5 |
import os
|
6 |
from dotenv import load_dotenv
|
7 |
import time
|
8 |
|
9 |
+
class AIProvider(ABC):
|
10 |
+
@abstractmethod
|
11 |
+
def get_completion(self, prompt: str, **kwargs) -> str:
|
12 |
+
pass
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def get_config(self) -> Dict[str, Any]:
|
16 |
+
pass
|
17 |
+
|
18 |
+
class GroqProvider(AIProvider):
|
19 |
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"):
|
20 |
self.api_key = api_key
|
21 |
self.model = model
|
|
|
61 |
}]
|
62 |
}
|
63 |
|
64 |
+
class AIProviderFactory:
|
65 |
+
@staticmethod
|
66 |
+
def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider:
|
67 |
+
if not api_key:
|
68 |
+
load_dotenv()
|
69 |
+
|
70 |
+
if provider_type != "groq":
|
71 |
+
raise ValueError("Only Groq provider is supported.")
|
72 |
+
|
73 |
+
return GroqProvider(
|
74 |
+
api_key=api_key or os.getenv("GROQ_API_KEY"),
|
75 |
+
model=model or "deepseek-ai/deepseek-r1-distill-llama-70b"
|
76 |
+
)
|