Spaces:
Running
Running
from typing import Dict, Any, Optional | |
import os | |
from dotenv import load_dotenv | |
import time | |
class GroqProvider: | |
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"): | |
self.api_key = api_key | |
self.model = model | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
from groq import Groq | |
client = Groq(api_key=self.api_key) | |
# Configure default parameters | |
params = { | |
"temperature": kwargs.get("temperature", 0.7), | |
"max_tokens": kwargs.get("max_tokens", 4000), | |
"top_p": kwargs.get("top_p", 1.0), | |
"stop": kwargs.get("stop", None) | |
} | |
# Add retry logic for robustness | |
max_retries = 3 | |
retry_delay = 1 | |
for attempt in range(max_retries): | |
try: | |
completion = client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
**params | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise e | |
time.sleep(retry_delay * (attempt + 1)) | |
def get_config(self) -> Dict[str, Any]: | |
return { | |
"config_list": [{ | |
"model": self.model, | |
"api_key": self.api_key, | |
"temperature": 0.7, | |
"max_tokens": 4000 | |
}] | |
} | |
def create_provider(api_key: Optional[str] = None, model: Optional[str] = None) -> GroqProvider: | |
if not api_key: | |
load_dotenv() | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
raise ValueError("GROQ_API_KEY must be set in environment variables") | |
default_model = "deepseek-ai/deepseek-r1-distill-llama-70b" | |
return GroqProvider(api_key=api_key, model=model or default_model) |