LegalAlly / src /embeddings.py
Rohil Bansal
New structure
7a7b50b
raw
history blame
1.51 kB
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
import os
import openai
from src.logger import setup_logger
logger = setup_logger(__name__)
def get_embeddings(key):
if test_openai_key(key):
logger.info("Using OpenAI embeddings")
return OpenAIEmbeddings(model="text-embedding-ada-002", api_key=key)
else:
logger.info("Using Mistral embeddings")
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
def test_openai_key(key):
try:
logger.info("Testing OpenAI API key")
openai.api_key = key
# Check if the key is valid
openai.Model.list()
# Check for available credits
response = openai.Completion.create(
engine="text-davinci-002",
prompt="This is a test.",
max_tokens=1
)
logger.info("OpenAI API key is valid and has available credits")
return True
except (openai.error.AuthenticationError, openai.error.RateLimitError):
logger.error("OpenAI API key is invalid or has no available credits")
return False
except Exception as e:
logger.error(f"An error occurred while testing the OpenAI API key: {str(e)}")
return False
def get_model(key):
if test_openai_key(key):
logger.info("Using OpenAI model")
return "gpt-4o-mini"
else:
logger.info("Using Mistral model")
return "mistralai/Mistral-7B-v0.1"