Spaces:
Sleeping
Sleeping
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings | |
import os | |
import openai | |
from src.logger import setup_logger | |
logger = setup_logger(__name__) | |
def get_embeddings(key): | |
if test_openai_key(key): | |
logger.info("Using OpenAI embeddings") | |
return OpenAIEmbeddings(model="text-embedding-ada-002", api_key=key) | |
else: | |
logger.info("Using Mistral embeddings") | |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def test_openai_key(key): | |
try: | |
logger.info("Testing OpenAI API key") | |
openai.api_key = key | |
# Check if the key is valid | |
openai.Model.list() | |
# Check for available credits | |
response = openai.Completion.create( | |
engine="text-davinci-002", | |
prompt="This is a test.", | |
max_tokens=1 | |
) | |
logger.info("OpenAI API key is valid and has available credits") | |
return True | |
except (openai.error.AuthenticationError, openai.error.RateLimitError): | |
logger.error("OpenAI API key is invalid or has no available credits") | |
return False | |
except Exception as e: | |
logger.error(f"An error occurred while testing the OpenAI API key: {str(e)}") | |
return False | |
def get_model(key): | |
if test_openai_key(key): | |
logger.info("Using OpenAI model") | |
return "gpt-4o-mini" | |
else: | |
logger.info("Using Mistral model") | |
return "mistralai/Mistral-7B-v0.1" | |