import os import warnings import argilla as rg # Inference MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048)) MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000)) DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5)) # Directory to locally save the generated data SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None) # Models MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None) OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL") HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL") VLLM_BASE_URL = os.getenv("VLLM_BASE_URL") # Just used in case of selecting a different model for completions MODEL_COMPLETION = os.getenv("MODEL_COMPLETION", MODEL) TOKENIZER_ID_COMPLETION = os.getenv("TOKENIZER_ID_COMPLETION", TOKENIZER_ID) OPENAI_BASE_URL_COMPLETION = os.getenv("OPENAI_BASE_URL_COMPLETION", OPENAI_BASE_URL) OLLAMA_BASE_URL_COMPLETION = os.getenv("OLLAMA_BASE_URL_COMPLETION", OLLAMA_BASE_URL) HUGGINGFACE_BASE_URL_COMPLETION = os.getenv( "HUGGINGFACE_BASE_URL_COMPLETION", HUGGINGFACE_BASE_URL ) VLLM_BASE_URL_COMPLETION = os.getenv("VLLM_BASE_URL_COMPLETION", VLLM_BASE_URL) base_urls = [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL] base_urls_completion = [ OPENAI_BASE_URL_COMPLETION, OLLAMA_BASE_URL_COMPLETION, HUGGINGFACE_BASE_URL_COMPLETION, VLLM_BASE_URL_COMPLETION, ] # Validate the configuration of the model and base URLs. def validate_configuration(base_urls, model, env_context=""): huggingface_url = base_urls[2] if huggingface_url and model: raise ValueError( f"`HUGGINGFACE_BASE_URL{env_context}` and `MODEL{env_context}` cannot be set at the same time. " "Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints." ) if not model and any(base_urls): raise ValueError( f"`MODEL{env_context}` is not set. Please provide a model id for inference." ) active_urls = [url for url in base_urls if url] if len(active_urls) > 1: raise ValueError( f"Multiple base URLs are provided: {', '.join(active_urls)}. " "Only one base URL can be set at a time." ) validate_configuration(base_urls, MODEL) validate_configuration(base_urls_completion, MODEL_COMPLETION, "_COMPLETION") BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL BASE_URL_COMPLETION = ( OPENAI_BASE_URL_COMPLETION or OLLAMA_BASE_URL_COMPLETION or HUGGINGFACE_BASE_URL_COMPLETION or VLLM_BASE_URL_COMPLETION ) # API Keys HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError( "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints." ) _API_KEY = os.getenv("API_KEY") API_KEYS = ( [_API_KEY] if _API_KEY else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)] ) API_KEYS = [token for token in API_KEYS if token] # Determine if SFT is available SFT_AVAILABLE = False llama_options = ["llama3", "llama-3", "llama 3"] qwen_options = ["qwen2", "qwen-2", "qwen 2"] if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower(): SFT_AVAILABLE = True if passed_pre_query_template in llama_options: MAGPIE_PRE_QUERY_TEMPLATE = "llama3" elif passed_pre_query_template in qwen_options: MAGPIE_PRE_QUERY_TEMPLATE = "qwen2" else: MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template elif MODEL.lower() in llama_options or any( option in MODEL.lower() for option in llama_options ): SFT_AVAILABLE = True MAGPIE_PRE_QUERY_TEMPLATE = "llama3" elif MODEL.lower() in qwen_options or any( option in MODEL.lower() for option in qwen_options ): SFT_AVAILABLE = True MAGPIE_PRE_QUERY_TEMPLATE = "qwen2" if OPENAI_BASE_URL: SFT_AVAILABLE = False if not SFT_AVAILABLE: warnings.warn( "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm." ) MAGPIE_PRE_QUERY_TEMPLATE = None # Embeddings STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M" # Argilla ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv( "ARGILLA_API_URL_SDG_REVIEWER" ) ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv( "ARGILLA_API_KEY_SDG_REVIEWER" ) if not ARGILLA_API_URL or not ARGILLA_API_KEY: warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty") argilla_client = None else: argilla_client = rg.Argilla( api_url=ARGILLA_API_URL, api_key=ARGILLA_API_KEY, )