File size: 1,912 Bytes
0d24772
 
8200031
0d24772
8200031
 
0d24772
8200031
 
 
 
 
 
0d24772
 
8200031
0d24772
8200031
 
 
0d24772
 
 
 
 
 
 
 
8200031
0d24772
 
 
 
 
 
 
 
8200031
0d24772
 
 
8200031
0d24772
 
 
8200031
 
 
 
0d24772
 
 
8200031
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# ai_providers.py

from typing import Dict, Any
import time
from groq import Groq
from config import GROQ_API_KEY, AI_CONFIG

class GroqProvider:
    def __init__(self):
        if not GROQ_API_KEY:
            raise ValueError("GROQ_API_KEY must be set in environment variables")
        self.client = Groq(api_key=GROQ_API_KEY)
        self.model = AI_CONFIG["model"]
        
    def get_completion(self, prompt: str, **kwargs) -> str:
        """Get completion from Groq API with retry logic"""
        params = {
            "temperature": kwargs.get("temperature", AI_CONFIG["temperature"]),
            "max_tokens": kwargs.get("max_tokens", AI_CONFIG["max_tokens"]),
            "top_p": kwargs.get("top_p", AI_CONFIG["top_p"]),
            "stop": kwargs.get("stop", None)
        }
        
        max_retries = 3
        retry_delay = 1
        
        for attempt in range(max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[{"role": "user", "content": prompt}],
                    **params
                )
                return completion.choices[0].message.content
                
            except Exception as e:
                if attempt == max_retries - 1:
                    raise Exception(f"Failed to get completion from Groq after {max_retries} attempts: {str(e)}")
                time.sleep(retry_delay * (attempt + 1))
        
    def get_config(self) -> Dict[str, Any]:
        """Get Groq configuration"""
        return {
            "config_list": [{
                "model": self.model,
                "api_key": GROQ_API_KEY,
                "temperature": AI_CONFIG["temperature"],
                "max_tokens": AI_CONFIG["max_tokens"],
                "top_p": AI_CONFIG["top_p"]
            }]
        }

# Global instance
groq_provider = GroqProvider()