Hassankhwileh commited on
Commit
b6737fb
·
verified ·
1 Parent(s): 98b9b80

Update ai_providers.py

Browse files
Files changed (1) hide show
  1. ai_providers.py +26 -11
ai_providers.py CHANGED
@@ -1,9 +1,21 @@
 
 
 
1
  from typing import Dict, Any, Optional
2
  import os
3
  from dotenv import load_dotenv
4
  import time
5
 
6
- class GroqProvider:
 
 
 
 
 
 
 
 
 
7
  def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"):
8
  self.api_key = api_key
9
  self.model = model
@@ -49,13 +61,16 @@ class GroqProvider:
49
  }]
50
  }
51
 
52
- def create_provider(api_key: Optional[str] = None, model: Optional[str] = None) -> GroqProvider:
53
- if not api_key:
54
- load_dotenv()
55
- api_key = os.getenv("GROQ_API_KEY")
56
-
57
- if not api_key:
58
- raise ValueError("GROQ_API_KEY must be set in environment variables")
59
-
60
- default_model = "deepseek-ai/deepseek-r1-distill-llama-70b"
61
- return GroqProvider(api_key=api_key, model=model or default_model)
 
 
 
 
1
+ # ai_providers.py
2
+
3
+ from abc import ABC, abstractmethod
4
  from typing import Dict, Any, Optional
5
  import os
6
  from dotenv import load_dotenv
7
  import time
8
 
9
+ class AIProvider(ABC):
10
+ @abstractmethod
11
+ def get_completion(self, prompt: str, **kwargs) -> str:
12
+ pass
13
+
14
+ @abstractmethod
15
+ def get_config(self) -> Dict[str, Any]:
16
+ pass
17
+
18
+ class GroqProvider(AIProvider):
19
  def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"):
20
  self.api_key = api_key
21
  self.model = model
 
61
  }]
62
  }
63
 
64
+ class AIProviderFactory:
65
+ @staticmethod
66
+ def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider:
67
+ if not api_key:
68
+ load_dotenv()
69
+
70
+ if provider_type != "groq":
71
+ raise ValueError("Only Groq provider is supported.")
72
+
73
+ return GroqProvider(
74
+ api_key=api_key or os.getenv("GROQ_API_KEY"),
75
+ model=model or "deepseek-ai/deepseek-r1-distill-llama-70b"
76
+ )