import gradio as gr from llama_cpp import Llama import torch import os from accelerate import Accelerator import tensorflow as tf # Import TensorFlow import numpy as np # For handling input data # Set device for PyTorch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device set to:", device) # Initialize the accelerator accelerator = Accelerator() class LocalInferenceClient: def __init__(self, model_name: str, model_path: str): """ Initialize the inference client with the model. Args: model_name (str): The name of the model. model_path (str): The path to the model file or directory. """ self.model_name = model_name self.model_path = model_path # Initialize the Llama model specifically for gguf self.model = Llama(model_path=model_path, n_ctx=2048, n_threads=8, n_gpu_layers=5) # Move the model to the appropriate device self.model = accelerator.prepare(self.model) # Load the TensorFlow Lite model self.tflite_interpreter = tf.lite.Interpreter(model_path='model.tflite') self.tflite_interpreter.allocate_tensors() # Get input and output tensors self.input_details = self.tflite_interpreter.get_input_details() self.output_details = self.tflite_interpreter.get_output_details() def text_generation(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: """ Generate text based on the provided prompt. Args: prompt (str): The input prompt. max_new_tokens (int): The maximum number of tokens to generate. temperature (float): Sampling temperature. top_p (float): Nucleus sampling probability. Returns: str: The generated text. """ # Use the Llama model for text generation response = self.model.create_chat_completion( messages=[{"role": "user", "content": prompt}], max_tokens=max_new_tokens, temperature=temperature, top_p=top_p ) # Print the response to understand its structure print("Response from model:", response) # Access the content correctly based on the response structure if 'choices' in response and len(response['choices']) > 0: return response['choices'][0]['message']['content'] # Access the content key else: return "โ ๏ธ Error: Unexpected response format." def run_tflite_model(self, input_data: np.ndarray) -> np.ndarray: """ Run inference using the TensorFlow Lite model. Args: input_data (np.ndarray): Input data for the model. Returns: np.ndarray: Output data from the model. """ # Set the input tensor self.tflite_interpreter.set_tensor(self.input_details[0]['index'], input_data) # Run the model self.tflite_interpreter.invoke() # Get the output tensor output_data = self.tflite_interpreter.get_tensor(self.output_details[0]['index']) return output_data # Specify the model paths for gguf models model_configs = { "Test": { "path": r"./test-model.gguf", "specs": """ ## Lake 1 Chat Specifications - **Architecture**: Test - **Parameters**: IDK - **Capabilities**: test - **Intended Use**: test """ } } # Set up a dictionary mapping model names to their clients clients = {name: LocalInferenceClient(name, config['path']) for name, config in model_configs.items()} # Presets for performance/quality tradeoffs presets = { "Test": { "Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9}, "Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95}, "Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90}, } } # A system prompt for the model system_messages = { "Test": "You are Lake 1 Chat, a powerful open-source reasoning model. Think carefully and answer step by step.", } def generate_response(message: str, model_name: str, preset: str) -> str: """ Generate a response based on the user's message. Args: message (str): The user's message. model_name (str): The name of the model to use. preset (str): The performance preset to apply. Returns: str: The generated response. """ client = clients[model_name] params = presets[model_name][preset] system_msg = system_messages[model_name] prompt = f"{system_msg}\n\n:User {message}\nAssistant:" return client.text_generation( prompt, max_new_tokens=params["max_new_tokens"], temperature=params["temperature"], top_p=params["top_p"] ) def handle_chat(message: str, history: list, model: str, preset: str) -> str: """ Handle the chat interaction. Args: message (str): The user's message. history (list): The conversation history. model (str): The model to use. preset (str): The performance preset. Returns: str: The generated response. """ try: return generate_response(message, model, preset) except Exception as e: return f"โ ๏ธ Error: {str(e)}" with gr.Blocks(title="BI CORP AI Assistant", theme="soft") as demo: gr.Markdown("#