# Copyright (c) 2025, Anemll All rights reserved. # # Use of this source code is governed by a MIT license that can be # found in the LICENSE.txt file or at https://opensource.org/license/mit import coremltools as ct import numpy as np import torch from transformers import AutoTokenizer import os import time, sys import signal import traceback import torch.nn.functional as F import queue import threading import re # Configuration CONTEXT_LENGTH = 1024 # Changed default from 512 to 1024 PREFILL_BATCH_SIZE = 64 MODEL_PATH = os.path.expanduser("../DeepSeekR1-8B") ENABLE_VACAB_SPLIT8 = True # Enable 8-way vocab split ENABLE_LOGITS2 = False # Enable 2-way vocab split ENABLE_DEBUG = bool(0) ENABLE_ARGMAX = bool(0) ENABLE_PREFILL_BATCH = bool(1) ENABLE_CHAT_DEBUG = bool(0) # Debug flag for chat loop # ANSI color codes LIGHT_BLUE = "\033[94m" DARK_BLUE = "\033[34m" LIGHT_GREEN = "\033[92m" RESET_COLOR = "\033[0m" if ENABLE_LOGITS2: assert not ENABLE_ARGMAX, "ENABLE_ARGMAX must be False when ENABLE_LOGITS2 is True" def load_model(path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name=None): """Load either compiled or uncompiled CoreML model. Args: path: Path to the model file (.mlmodelc or .mlpackage) compute_unit: CoreML compute unit to use function_name: Optional function name to select from multi-function models """ DebugLog(f"Attempting to load model: {path}") DebugLog(f"File exists: {os.path.exists(path)}") DebugLog(f"Is directory (for mlmodelc): {os.path.isdir(path)}") try: if path.endswith('.mlmodelc'): DebugLog(f"Loading compiled model: {path}") if function_name is None: DebugLog("Loading without function name") model = ct.models.CompiledMLModel(path, compute_unit) else: DebugLog(f"Loading with function name: {function_name}") model = ct.models.CompiledMLModel(path, compute_unit, function_name=function_name) else: DebugLog(f"Loading uncompiled model: {path}") if function_name is None: DebugLog("Loading without function name") model = ct.models.MLModel(model=path, compute_units=compute_unit, is_temp_package=False) else: DebugLog(f"Loading with function name: {function_name}") model = ct.models.MLModel(model=path, compute_units=compute_unit, is_temp_package=False, function_name=function_name) DebugLog("Model loaded successfully") return model except Exception as e: DebugLog(f"Error loading model: {str(e)}") DebugLog(f"Error type: {type(e)}") raise class SplitModelInference: def __init__(self, model_parts, model_dir="."): """Initialize split model inference. Args: model_parts (list): List of model part numbers to load Special cases: - 'C123' for combined part2 with prefill/infer functions - 'S123' for split model with prefill/infer functions - 'Q123' for quad split (2Q1-2Q4) - 'Q123S' for quad split with combined prefill/infer (2Q1S-2Q4S) - '123D' for dual split without prefill/infer (2D1-2D2) model_dir (str): Directory containing the model files (default: current directory) """ self.context_size = CONTEXT_LENGTH self.model_dir = model_dir DebugLog(f"Loading models from directory: {self.model_dir}") # Parse configuration self.quant_configs = {} global_lut = None if model_parts and model_parts[-1].startswith('lut'): global_lut = model_parts[-1] model_parts = model_parts[:-1] # Special handling for different split modes if len(model_parts) == 1: if model_parts[0] == '123D': # Dual split without prefill/infer self.use_combined_part2 = False self.use_split_model = True self.use_split_functions = False self.use_quad_split = False self.use_quad_split_combined = False self.model_parts = ['1', '2D1', '2D2', '3'] if global_lut: self.quant_configs = {part: global_lut for part in self.model_parts} DebugLog(f"Using dual split model with parts: {self.model_parts}") elif model_parts[0].startswith('C123'): # Combined part2 self.use_combined_part2 = True self.use_split_model = False self.use_split_functions = False self.use_quad_split = False self.use_quad_split_combined = False self.model_parts = ['1', '2', '3'] if global_lut: self.quant_configs = {part: global_lut for part in self.model_parts} DebugLog(f"Using combined part2 model with parts: {self.model_parts}") elif model_parts[0].startswith('S123'): # Split model with prefill/infer functions self.use_combined_part2 = False self.use_split_model = True self.use_split_functions = True self.use_quad_split = False self.use_quad_split_combined = False self.model_parts = ['1', '2D1S', '2D2S', '3'] elif model_parts[0].startswith('Q123S'): # Quad split with combined prefill/infer self.use_combined_part2 = False self.use_split_model = True self.use_split_functions = False self.use_quad_split = False self.use_quad_split_combined = True self.model_parts = ['1', '2Q1S', '2Q2S', '2Q3S', '2Q4S', '3'] elif model_parts[0].startswith('Q123'): # Regular quad split self.use_combined_part2 = False self.use_split_model = True self.use_split_functions = False self.use_quad_split = True self.use_quad_split_combined = False self.model_parts = ['1', '2Q1', '2Q2', '2Q3', '2Q4', '3'] else: self.use_combined_part2 = False self.use_split_model = False self.use_split_functions = False self.use_quad_split = False self.use_quad_split_combined = False self.model_parts = model_parts else: self.use_combined_part2 = False self.use_split_model = False self.use_split_functions = False self.use_quad_split = False self.use_quad_split_combined = False self.model_parts = model_parts # Apply global quantization if specified if global_lut and not self.use_combined_part2: # Skip if already applied for C123 self.quant_configs = {part: global_lut for part in self.model_parts} DebugLog(f"Using model parts: {self.model_parts}") if global_lut: DebugLog(f"With global quantization: {global_lut}") if self.use_combined_part2: DebugLog("Using combined part2 model with prefill/infer functions") elif self.use_split_functions: DebugLog("Using split model with prefill/infer functions") elif self.use_quad_split: DebugLog("Using quad split transformer model (2Q1-2Q4)") elif self.use_quad_split_combined: DebugLog("Using combined quad split transformer model (2Q1S-2Q4S)") self.models = {} self.states = {} self.load_models() def find_model_path(self, base_name, description="model"): """Find model path, checking mlmodelc first then mlpackage. Also tries both with and without lut suffix. Args: base_name: Base name of the model without extension description: Description for error message (e.g., "Split model part 2D1S") Returns: str: Path to the found model file Raises: FileNotFoundError: If neither mlmodelc nor mlpackage exists """ # For quad split parts, only try mlmodelc if any(part in base_name for part in ['2Q1S', '2Q2S', '2Q3S', '2Q4S', '2Q1', '2Q2', '2Q3', '2Q4']): model_path = os.path.join(self.model_dir, f"{base_name}.mlmodelc") if os.path.exists(model_path): return model_path # If not found, try without lut suffix if '_lut' in base_name: base_without_lut = base_name.split('_lut')[0] model_path = os.path.join(self.model_dir, f"{base_without_lut}.mlmodelc") if os.path.exists(model_path): return model_path # Neither exists raise FileNotFoundError(f"{description} not found: {base_name}.mlmodelc does not exist" + (f" (also tried {base_name.split('_lut')[0]}.mlmodelc)" if '_lut' in base_name else "")) # For other parts, try both mlmodelc and mlpackage for ext in ['.mlmodelc', '.mlpackage']: model_path = os.path.join(self.model_dir, f"{base_name}{ext}") if os.path.exists(model_path): return model_path # If not found, try without lut suffix if '_lut' in base_name: base_without_lut = base_name.split('_lut')[0] for ext in ['.mlmodelc', '.mlpackage']: model_path = os.path.join(self.model_dir, f"{base_without_lut}{ext}") if os.path.exists(model_path): return model_path # Neither exists raise FileNotFoundError(f"{description} not found: neither {base_name}.mlmodelc nor {base_name}.mlpackage exist in {self.model_dir}" + (f" (also tried {base_name.split('_lut')[0]}.mlmodelc/mlpackage)" if '_lut' in base_name else "")) def load_models(self): """Load each model part.""" DebugLog("Loading model parts...") for part in self.model_parts: quant_suffix = f"_{self.quant_configs[part]}" if part in self.quant_configs else "" model_key = f"{part}{quant_suffix}" # Use this as the key in self.models try: if part == '2' and self.use_combined_part2: # Load combined part2 with multiple functions base_name = f"llama32_part2_combined{quant_suffix}" model_path = self.find_model_path(base_name, "Combined part2 model") DebugLog(f"Loading combined part2 model: {model_path}") # Load prefill function self.models['2_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') # Load infer function self.models['2_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') # Create shared state self.states['transformer'] = self.models['2_prefill'].make_state() DebugLog("Combined part2 model loaded successfully") elif part == '2' and not self.use_combined_part2: # Load regular part2 model base_name = f"llama32_part2{quant_suffix}" model_path = self.find_model_path(base_name, "Regular part2 model") DebugLog(f"Loading regular part2 model: {model_path}") self.models[model_key] = load_model(model_path) self.states['transformer'] = self.models[model_key].make_state() DebugLog("Regular part2 model loaded successfully") elif part in ['2D1S', '2D2S'] and self.use_split_functions: # Load split model with prefill/infer functions base_name = f"llama32_part{part}{quant_suffix}" model_path = self.find_model_path(base_name, f"Split model part {part}") DebugLog(f"Loading split model part {part}: {model_path}") # Load prefill function self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') # Load infer function self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') # Create shared state for first part only if part == '2D1S': self.states['transformer'] = self.models[f'{part}_infer'].make_state() DebugLog(f"Split model part {part} loaded successfully") elif part.endswith('S') and self.use_quad_split_combined: # Load combined quad split model with prefill/infer functions base_name = f"llama32_part{part}{quant_suffix}" model_path = self.find_model_path(base_name, f"Combined quad split part {part}") DebugLog(f"Loading combined quad split part {part}: {model_path}") # Load prefill function self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') # Load infer function self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') # Create shared state for first part only if part == '2Q1S': self.states['transformer'] = self.models[f'{part}_infer'].make_state() DebugLog(f"Created shared transformer state for all quad split parts") DebugLog(f"Combined quad split part {part} loaded successfully") elif part.startswith('2Q') and self.use_quad_split: # Load quad split model with prefill/infer functions # Append 'S' to part name for file lookup base_name = f"llama32_part{part}S{quant_suffix}" model_path = self.find_model_path(base_name, f"Quad split part {part}") DebugLog(f"Loading quad split part {part}: {model_path}") # Load prefill function self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') # Load infer function self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') # Create shared state for first part only if part == '2Q1': self.states['transformer'] = self.models[f'{part}_infer'].make_state() DebugLog(f"Created shared transformer state for all quad split parts") print(f"Created shared transformer state for all quad split parts") print(f"Quad split part {part} loaded successfully") else: # Load regular models (part 1 and part3) base_name = f"llama32_part{part}{quant_suffix}" model_path = self.find_model_path(base_name, f"Regular part {part}") print(f"[MODEL LOAD] Regular part {part}:") print(f" - File: {model_path}") print(f" - Loading as: '{model_key}'") # Try loading with CPU first, then fall back to CPU_AND_NE if needed try: self.models[model_key] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE) print(f" - Loaded with CPU_AND_NE compute unit") except Exception as cpu_error: print(f" - CPU load failed, trying CPU_AND_NE: {str(cpu_error)}") self.models[model_key] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU) print(f" - Loaded with CPU compute unit") print(f"[MODEL LOAD] Current model_parts keys: {list(self.models.keys())}") except Exception as e: print(f"Error loading model part {part}: {str(e)}") raise def run_transformer_prefill(self, hidden_states, update_mask, position_ids, causal_mask, current_pos): """Run the transformer model in prefill mode.""" if self.use_split_functions: # Use prefill variants for split model for part in ['2D1S', '2D2S']: inputs = { 'hidden_states': hidden_states.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': causal_mask.numpy(), 'start_pos': current_pos.numpy() } output = self.models[f'{part}_prefill'].predict(inputs, self.states['transformer']) hidden_states = torch.from_numpy(output['dummy_output']) return hidden_states else: # Use existing prefill implementation return super().run_transformer_prefill(hidden_states, update_mask, position_ids, causal_mask, current_pos) def run_transformer_infer(self, hidden_states, update_mask, position_ids, causal_mask, current_pos): """Run the transformer model in infer mode.""" if self.use_split_functions: # Use infer variants for split model for part in ['2D1S', '2D2S']: inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': causal_mask.numpy(), 'current_pos': current_pos.numpy() } output = self.models[f'{part}_infer'].predict(inputs, self.states['transformer']) hidden_states = torch.from_numpy(output['transformer_output']) return hidden_states else: # Use existing infer implementation return super().run_transformer_infer(hidden_states, update_mask, position_ids, causal_mask, current_pos) def get_state(self, part): """Get the appropriate state for a model part.""" return self.states['transformer'] def run_embeddings(self, input_ids): """Run the embeddings model (part 1).""" if '1' not in self.models: raise ValueError("Embeddings model (part 1) not loaded") output_dict = self.models['1'].predict({ 'input_ids': input_ids.numpy() }) return torch.from_numpy(output_dict['hidden_states']) def run_transformer(self, hidden_states, update_mask, position_ids, causal_mask, current_pos, part='2'): """Run the transformer model.""" if part not in self.models: raise ValueError(f"Transformer model (part {part}) not loaded") inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': causal_mask.numpy(), 'current_pos': current_pos.numpy() } output_dict = self.models[part].predict(inputs, self.get_state(part)) return torch.from_numpy(output_dict['transformer_output']) def run_transformer_splits(self, hidden_states, update_mask, position_ids, causal_mask, current_pos): """Run through transformer splits based on model configuration.""" if not self.use_split_model: return self.run_transformer(hidden_states, update_mask, position_ids, causal_mask, current_pos) # Handle different split configurations if any(part.startswith('2Q') for part in self.model_parts): # Quad split for i in range(1, 5): part = f'2Q{i}' hidden_states = self.run_transformer( hidden_states, update_mask, position_ids, causal_mask, current_pos, part=part ) elif any(part.startswith('2O') for part in self.model_parts): # Octa split for i in range(1, 9): part = f'2O{i}' hidden_states = self.run_transformer( hidden_states, update_mask, position_ids, causal_mask, current_pos, part=part ) elif any(part.startswith('2D') for part in self.model_parts): # Dual split # Run through both parts of the dual split for base_part in ['2D1', '2D2']: # Find the correct model key (with lut suffix if present) part_key = next(key for key in self.models.keys() if key.startswith(f'{base_part}_') or key == base_part) # Use the shared transformer state if 'transformer' not in self.states: raise ValueError("Transformer state not initialized. Make sure 2D1 is loaded first.") inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': causal_mask.numpy(), 'current_pos': current_pos.numpy() } output_dict = self.models[part_key].predict(inputs, self.states['transformer']) hidden_states = torch.from_numpy(output_dict['transformer_output']) return hidden_states def run_lm_head(self, hidden_states): """Run the LM head model (part 3).""" if '3' not in self.models: raise ValueError("LM head model (part 3) not loaded") output_dict = self.models['3'].predict({ 'hidden_states': hidden_states.numpy() }) # Handle split logits logits_parts = [] for i in range(1, 9): # logits1 through logits8 logits_key = f'logits{i}' if logits_key in output_dict: logits_part = torch.from_numpy(output_dict[logits_key]) logits_parts.append(logits_part) # Concatenate along the vocabulary dimension return torch.cat(logits_parts, dim=-1) def run_full_model(self, input_ids, update_mask, position_ids, causal_mask, current_pos): """Run the full model.""" if 'full' not in self.models: raise ValueError("Full model not loaded") # Update context size from global self.context_size = CONTEXT_LENGTH #kv_ was removed from the input names inputs = { 'input_ids': input_ids.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': causal_mask.numpy(), 'current_pos': current_pos.numpy() } # Print shapes of all inputs if False: print("[DEBUG] Input shapes:") for key, value in inputs.items(): print(f" {key}: {value.shape}") output_dict = self.models['full'].predict(inputs, self.states['transformer']) # Handle split logits if necessary if ENABLE_VACAB_SPLIT8: logits_parts = [] for i in range(1, 9): logits_parts.append(output_dict[f'logits{i}']) logits = np.concatenate(logits_parts, axis=-1) else: logits = output_dict['logits'] return torch.from_numpy(logits) def make_causal_mask(length, start): # Initialize the mask with -inf mask = np.full((1, 1, length, length), -np.inf, dtype=np.float16) # Create row and column indices row_indices = np.arange(length).reshape(length, 1) # Column vector col_indices = np.arange(length).reshape(1, length) # Row vector # Set allowed positions to 0 where col_index is within the allowed range of row_index mask[:, :, col_indices <= (row_indices + start)] = 0 return mask def initialize_tokenizer(model_path): """Initialize and configure the tokenizer.""" try: print(f"[DEBUG] Loading tokenizer from model path: {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) print("\n[DEBUG] Tokenizer Configuration:") print(f"Tokenizer type: {type(tokenizer)}") print(f"Tokenizer name: {tokenizer.__class__.__name__}") print(f"Vocabulary size: {len(tokenizer)}") print(f"Model max length: {tokenizer.model_max_length}") #print(f"Chat template: {tokenizer.chat_template if hasattr(tokenizer, 'chat_template') else 'None'}") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id print("[DEBUG] Set PAD token to EOS token") print(f"\n[DEBUG] Special Tokens:") print(f"PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})") print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})") print(f"BOS token: '{tokenizer.bos_token}' (ID: {tokenizer.bos_token_id})") print(f"UNK token: '{tokenizer.unk_token}' (ID: {tokenizer.unk_token_id})") return tokenizer except Exception as e: print(f"[ERROR] Failed to load tokenizer from {model_path}") return None class TokenPrinter: """Handles background printing of generated tokens.""" def __init__(self, tokenizer): self.tokenizer = tokenizer self.token_queue = queue.Queue() self.stop_event = threading.Event() self.thread = None self.buffer = "" self.lock = threading.Lock() self.thinking = True # Track if we're still in thinking mode self.decoding_buffer = [] # <-- Buffer for token IDs self.start() def start(self): """Start the printer thread.""" if self.thread is None: self.thread = threading.Thread(target=self._print_worker) self.thread.daemon = True self.thread.start() def add_token(self, token_id): """Add a token to the print queue.""" if not self.stop_event.is_set(): self.token_queue.put(token_id) def drain_buffer(self): """ Decode token IDs from self.decoding_buffer in the main thread, then print them with the correct color logic. """ if not self.decoding_buffer: return # Decode all tokens at once in the main thread. token_str = self.tokenizer.decode(self.decoding_buffer) self.decoding_buffer.clear() # Color-handling logic. Check for "" and handle self.thinking. if self.thinking and "" in token_str: self.thinking = False parts = token_str.split("") if len(parts) > 0: print(parts[0] + "", end='', flush=True) if len(parts) > 1: print(LIGHT_BLUE + parts[1], end='', flush=True) else: if not self.thinking: print(LIGHT_BLUE + token_str, end='', flush=True) else: print(token_str, end='', flush=True) def _print_worker(self): """Worker thread that takes token_ids from the queue but doesn't decode.""" while not self.stop_event.is_set(): try: token_id = self.token_queue.get(timeout=0.01) with self.lock: # Just store the token_id, decode later on the main thread self.decoding_buffer.append(token_id) self.token_queue.task_done() except queue.Empty: continue except Exception as e: print(f"\n[ERROR] Token printer error: {str(e)}") break def stop(self): """Stop the printer thread.""" if self.thread and self.thread.is_alive(): self.stop_event.set() try: self.thread.join(timeout=1.0) except Exception: pass print(RESET_COLOR) # Reset color at the end return self.buffer def parse_coreml_error(error_str): """Parse CoreML error message to extract shape information. Args: error_str: The error message string from CoreML Returns: tuple: (got_shape, expected_shape) or None if parsing fails """ try: # Extract shapes from error message using regex pattern = r"shape \(([\d\s x]+)\) does not match the shape \(([\d\s x]+)\)" match = re.search(pattern, str(error_str)) if match: got_shape = tuple(int(x) for x in match.group(1).split('x')) expected_shape = tuple(int(x) for x in match.group(2).split('x')) return got_shape, expected_shape return None except Exception as e: print(f"Error parsing CoreML error message: {e}") return None def handle_coreml_shape_error(e, model_name=""): """Handle CoreML shape mismatch errors with detailed information. Args: e: The exception object model_name: Name of the model for better error reporting """ error_str = str(e) if "MultiArray shape" in error_str: shape_info = parse_coreml_error(error_str) if shape_info: got_shape, expected_shape = shape_info print(f"\n[ERROR] Shape mismatch in {model_name}:") print(f" Got shape: {' x '.join(str(x) for x in got_shape)}") print(f" Expected shape: {' x '.join(str(x) for x in expected_shape)}") print("This usually indicates a mismatch between the model's expected context length") print("and the actual input being provided.") else: print(f"\n[ERROR] Shape mismatch error in {model_name}:") print(f" {error_str}") else: print(f"\n[ERROR] CoreML error in {model_name}:") print(f" {error_str}") def PreFillChunk(model_parts, input_ids, current_pos, context_size, causal_mask, batch_size=64): tokens_to_process = current_pos batch_pos = 0 while batch_pos < tokens_to_process: batch_end = min(batch_pos + batch_size, tokens_to_process) current_batch_size = batch_end - batch_pos try: # Get current batch of tokens batch_input = input_ids[:, batch_pos:batch_end] # Pad if needed if current_batch_size < batch_size: batch_input = F.pad( batch_input, (0, batch_size - current_batch_size), value=0 ) # Generate position IDs for this batch position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32) # Prepare causal mask for this batch multiple_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :] # Find the correct model key for part 1 (with lut suffix if present) part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') try: # Run embeddings (part 1) hidden_states = model_parts[part1_key].predict({'input_ids': batch_input.numpy()})['hidden_states'] hidden_states = torch.from_numpy(hidden_states) except Exception as e: handle_coreml_shape_error(e, f"embeddings model (part {part1_key})") raise # Get shared transformer state shared_state = model_parts['states']['transformer'] # Handle different model configurations if any(f'{part}_prefill' in model_parts for part in ['2D1S', '2D2S']): # S123 mode with prefill/infer functions for part in ['2D1S', '2D2S']: try: inputs = { 'hidden_states': hidden_states.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': multiple_causal_mask.numpy(), 'start_pos': np.array([batch_pos], dtype=np.int32) } output = model_parts[f'{part}_prefill'].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['dummy_output']) except Exception as e: handle_coreml_shape_error(e, f"transformer model (part {part})") raise elif any(part.endswith('S') for part in model_parts if part.startswith('2Q')): # Q123S mode with combined quad split for i in range(1, 5): part = f'2Q{i}S' try: inputs = { 'hidden_states': hidden_states.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': multiple_causal_mask.numpy(), 'start_pos': np.array([batch_pos], dtype=np.int32) } output = model_parts[f'{part}_prefill'].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['dummy_output']) except Exception as e: handle_coreml_shape_error(e, f"transformer model (part {part})") raise elif any(part.startswith('2Q') for part in model_parts): # Q123 mode with quad split for i in range(1, 5): part = f'2Q{i}' if f'{part}_prefill' in model_parts: # Use prefill function if available try: inputs = { 'hidden_states': hidden_states.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': multiple_causal_mask.numpy(), 'start_pos': np.array([batch_pos], dtype=np.int32) } output = model_parts[f'{part}_prefill'].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['dummy_output']) except Exception as e: handle_coreml_shape_error(e, f"transformer model (part {part})") raise else: # Use regular predict if no prefill function try: inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': torch.zeros((1, 1, context_size, 1), dtype=torch.float16).numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': multiple_causal_mask.numpy(), 'current_pos': position_ids[0].numpy() } output = model_parts[part].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) except Exception as e: handle_coreml_shape_error(e, f"transformer model (part {part})") raise elif any(key.startswith('2D') for key in model_parts.keys()): # 123D mode with dual split (no prefill functions) for base_part in ['2D1', '2D2']: # Find the correct model key (with lut suffix if present) part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) try: inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': torch.zeros((1, 1, context_size, 1), dtype=torch.float16).numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': multiple_causal_mask.numpy(), 'current_pos': position_ids[0].numpy() } output = model_parts[part_key].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) except Exception as e: handle_coreml_shape_error(e, f"transformer model (part {part_key})") raise batch_pos = batch_end except Exception as e: print(f"\n[ERROR] Failed processing batch {batch_pos}-{batch_end}:") print(f" {str(e)}") raise return torch.tensor([current_pos], dtype=torch.int32) def PreFillChunkOneByOne(model_parts, input_ids, current_pos, context_size, causal_mask): """Process prefill tokens one at a time using infer function.""" #print(f"[DEBUG] Starting one-by-one prefill for {current_pos} tokens") for pos in range(current_pos): # Get current token current_token = input_ids[:, pos:pos+1] single_causal_mask = causal_mask[:, :, pos:pos+1, :] current_pos_tensor = torch.tensor([pos], dtype=torch.int32) # Find the correct model key for part 1 (with lut suffix if present) part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') # Run embeddings (part 1) hidden_states = torch.from_numpy(model_parts[part1_key].predict({ 'input_ids': current_token.numpy() })['hidden_states']) #print(f"[DEBUG] pos: {pos} token: {current_token.item()} states: {hidden_states.shape}") # Get shared transformer state shared_state = model_parts['states']['transformer'] # Handle different model configurations if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']): # S123 mode with prefill/infer functions for part in ['2D1S', '2D2S']: inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': np.zeros((1, 1, context_size, 1), dtype=np.float16), 'position_ids': current_pos_tensor.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': current_pos_tensor.numpy() } output = model_parts[f'{part}_infer'].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) elif any(key.startswith('2D') for key in model_parts.keys()): # 123D mode or individual parts mode for base_part in ['2D1', '2D2']: # Find the correct model key (with lut suffix if present) part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': np.zeros((1, 1, context_size, 1), dtype=np.float16), 'position_ids': current_pos_tensor.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': current_pos_tensor.numpy() } output = model_parts[part_key].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) return torch.tensor([current_pos], dtype=torch.int32) def run_inference(model_parts, tokenizer, prompt, context_size=CONTEXT_LENGTH, num_iterations=5, temperature=0.0): """Run inference using model parts.""" DebugLog(f"\nPrompt: {prompt}") if temperature > 0: DebugLog(f"Using temperature: {temperature}") # Prepare the prompt messages = [{"role": "user", "content": prompt}] formatted_input = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=False ) decoded_input = tokenizer.decode(formatted_input[0]) DebugLog(f"Decoded input: {decoded_input}") DebugLog(f"prompt: {prompt}") DebugLog(f"formatted_input size: {formatted_input.size()}") DebugLog(f"formatted_input: {formatted_input}") base_input_ids = formatted_input.to(torch.int32) context_pos = base_input_ids.size(1) prompt_tokens = context_pos - 1 # Pad sequence to context_size input_ids = F.pad( base_input_ids, (0, context_size - context_pos), value=0 ) DebugLog(f"context_pos (prompt length) = {context_pos}") # Create causal mask causal_mask = make_causal_mask(context_size, 0) causal_mask = torch.tensor(causal_mask, dtype=torch.float16) # Prefill phase DebugLog("\nStarting prefill...") start_time = time.time() # Check if we're using 123D mode or individual parts use_single_token = any(key.contains('2D') for key in model_parts.keys()) or any(part.contains('2D') for part in model_parts) if False: #use_single_token: print("\nRunning ST prefill...") current_pos = PreFillChunkOneByOne( model_parts, input_ids, context_pos - 1, context_size, causal_mask ) sequential_prefill_time = time.time() - start_time batch_prefll_time = 0.0 else: print("\nRunning batch prefill...") current_pos = PreFillChunk( model_parts, input_ids, context_pos - 1, context_size, causal_mask, batch_size=PREFILL_BATCH_SIZE ) batch_prefill_time = time.time() - start_time sequential_prefill_time = 0.0 # Initialize token printer token_printer = TokenPrinter(tokenizer) print("\nGenerated response:", end=' ', flush=True) # Generation loop start_gen_time = time.time() pos = context_pos - 1 tokens_generated = 0 try: DebugLog(f"\nStarting inference... context_pos: {context_pos}") pos = context_pos for step in range(num_iterations): with torch.no_grad(): # Check if we need to shift cache if pos >= context_size - 2: shift_size = context_size // 4 new_size = context_size - shift_size # Create shifted input_ids and preserve the most recent context # Don't add BOS token since this is a continuation tmp = torch.zeros((1, context_size), dtype=torch.int32) tmp[:,0:new_size] = input_ids[:,shift_size:context_size] input_ids = tmp # Adjust position after shift pos = new_size # Create update mask for current position update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16) update_mask[0, 0, pos-1, 0] = 1.0 #print(f"\n[DEBUG] Shifted cache by {shift_size} tokens, maintaining context window of {new_size} tokens, new pos: {pos}") # For Q123 mode, we need to run prefill on the shifted sequence if any(part.startswith('2Q') for part in model_parts): # Run prefill using PreFillChunk with proper batch size # No need to adjust position since we're not adding BOS current_pos = PreFillChunk( model_parts, input_ids, pos-1, # how much ob context_size, # Use full context size causal_mask, batch_size=PREFILL_BATCH_SIZE ) #print(f"[DEBUG] Ran prefill after shift for position {pos} with batch_size={PREFILL_BATCH_SIZE}") # Position should already be correct since we didn't add BOS pos = current_pos # Get current token current_token = input_ids[:, pos-1:pos] # Find the correct model key for part 1 (with lut suffix if present) part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') # Run embeddings (part 1) hidden_states = model_parts[part1_key].predict({ 'input_ids': current_token.numpy() })['hidden_states'] hidden_states = torch.from_numpy(hidden_states) # Get shared transformer state shared_state = model_parts['states']['transformer'] # Create update mask for current position update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16) update_mask[0, 0, pos-1, 0] = 1.0 # Create position IDs tensor position_ids = torch.tensor([pos-1], dtype=torch.int32) # Create causal mask for current position single_causal_mask = causal_mask[:, :, pos-1:pos, :] # Run transformer layers based on model type if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']): # S123 mode with prefill/infer functions for part in ['2D1S', '2D2S']: inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[f'{part}_infer'].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) elif any(part.startswith('2Q') for part in model_parts.keys()): # Q123S mode with combined quad split for i in range(1, 5): part = f'2Q{i}S' inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[f'{part}_infer'].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) elif any(part.startswith('2Q') for part in model_parts): # Q123 mode with quad split #print(f"[DEBUG] Running quad split inference at position {pos}") for i in range(1, 5): part = f'2Q{i}' if f'{part}_infer' in model_parts: # Use infer function if available inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[f'{part}_infer'].predict(inputs, shared_state) else: # Use regular predict if no infer function inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[part].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) elif any(key.startswith('2D') for key in model_parts.keys()): # 123D mode or individual parts mode for base_part in ['2D1', '2D2']: # Find the correct model key (with lut suffix if present) part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[part_key].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) else: print("\n[ERROR] No transformer model parts found!") break try: # Run final layer norm and get logits # Find the correct model key for part 3 (with lut suffix if present) part3_key = next(key for key in model_parts.keys() if key.startswith('3_') or key == '3') output_dict = model_parts[part3_key].predict({ 'hidden_states': hidden_states.numpy() }) if ENABLE_VACAB_SPLIT8: # Get all logits parts in a single call logits_parts = [] for i in range(1, 9): logits_parts.append(output_dict[f'logits{i}']) logits = np.concatenate(logits_parts, axis=-1) elif ENABLE_LOGITS2: # Get both logits parts in a single call logits = np.concatenate([ output_dict['logits1'], output_dict['logits2'] ], axis=-1) else: logits = output_dict['logits'] # Convert to tensor and get next token logits = torch.from_numpy(logits) # Apply temperature if specified if temperature > 0: # Scale logits by temperature logits = logits / temperature # Apply softmax to get probabilities probs = F.softmax(logits[0, -1, :], dim=-1) # Sample from the distribution next_token = torch.multinomial(probs, num_samples=1).item() else: # Use argmax if no temperature next_token = torch.argmax(logits[0, -1, :]).item() # Add token to input sequence input_ids[0, pos] = next_token token_printer.add_token(next_token) # Safely decode tokens in the main thread token_printer.drain_buffer() # Update position and count pos += 1 tokens_generated += 1 if next_token == tokenizer.eos_token_id: print("\n[DEBUG] Generated EOS token, stopping...") break except Exception as e: print(f"\n[ERROR] Error in final layer or token generation: {str(e)}") break except KeyboardInterrupt: print("\n[DEBUG] Interrupted by user") except Exception as e: print(f"\n[ERROR] Exception during inference: {str(e)}") print(traceback.format_exc()) # Print timing statistics end_time = time.time() total_time = end_time - start_gen_time print(f"\n\nTotal time: {total_time:.2f} seconds") print(f"Generation tokens: {tokens_generated}") print(f"Prefill tokens: {prompt_tokens}") print(f"Total tokens (prefill + generation): {prompt_tokens + tokens_generated}") if prompt_tokens > 0: if batch_prefill_time > 0: # If using batch prefill prefill_tokens_per_second = prompt_tokens / batch_prefill_time effective_prefill_tokens_per_second = prompt_tokens / batch_prefill_time # Don't multiply by batch size print(f"Actual prefill tokens per second: {prefill_tokens_per_second:.2f}") print(f"Effective prefill tokens per second (batch={PREFILL_BATCH_SIZE}): {effective_prefill_tokens_per_second:.2f}") elif sequential_prefill_time > 0: # If using sequential prefill prefill_tokens_per_second = prompt_tokens / sequential_prefill_time print(f"Sequential prefill tokens per second: {prefill_tokens_per_second:.2f}") if tokens_generated > 0: total_processing_time = total_time + (batch_prefill_time if batch_prefill_time > 0 else sequential_prefill_time) overall_tokens_per_second = (prompt_tokens + tokens_generated) / total_processing_time generation_tokens_per_second = tokens_generated / total_time print(f"Overall tokens processed per second (including prefill): {overall_tokens_per_second:.2f}") print(f"Generation-only tokens per second: {generation_tokens_per_second:.2f}") return token_printer.stop(), { 'total_time': total_time, 'batch_prefill_time': batch_prefill_time, 'sequential_prefill_time': sequential_prefill_time, 'tokens_generated': tokens_generated, 'prompt_tokens': prompt_tokens } def DebugLog(message, always_print=False): """Print debug message if ENABLE_CHAT_DEBUG is True or always_print is True. Args: message: Message to print always_print: If True, print regardless of ENABLE_CHAT_DEBUG setting """ if ENABLE_CHAT_DEBUG or always_print: print(f"[DEBUG] {message}") def chat_loop(model_parts, tokenizer, context_size=CONTEXT_LENGTH, temperature=0.0): """Interactive chat loop that maintains conversation history.""" print("\nStarting chat session. Press Ctrl+D to exit.") print("Type your message and press Enter to chat.") DebugLog(f"Using context size: {context_size}") DebugLog(f"Temperature: {temperature}") DebugLog(f"Model parts loaded: {list(model_parts.keys())}") # Initialize conversation history conversation = [] input_ids = None current_pos = 0 try: while True: try: print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True) user_input = input().strip() except EOFError: print("\nExiting chat...") break if not user_input: continue # Add user message to conversation conversation.append({"role": "user", "content": user_input}) DebugLog("\nFormatting conversation:") for msg in conversation: DebugLog(f" {msg['role']}: {msg['content'][:50]}...") # Format entire conversation formatted_input = tokenizer.apply_chat_template( conversation, return_tensors="pt", add_generation_prompt=True ) DebugLog("\nTokenization:") DebugLog(f"Input token IDs: {formatted_input[0][:50]}...") DebugLog(f"Decoded tokens: {tokenizer.decode(formatted_input[0][:50])}...") DebugLog(f"Total tokens: {formatted_input.size(1)}") # Convert to int32 tensor base_input_ids = formatted_input.to(torch.int32) context_pos = base_input_ids.size(1) DebugLog(f"Context position: {context_pos}") # Check if we need to truncate history if context_pos >= context_size - 100: DebugLog(f"\nNeed to truncate: {context_pos} tokens > {context_size-100} limit") while context_pos >= context_size - 100 and len(conversation) > 2: removed = conversation.pop(0) DebugLog(f"Removed message: {removed['role']}: {removed['content'][:30]}...") formatted_input = tokenizer.apply_chat_template( conversation, return_tensors="pt", add_generation_prompt=True ) base_input_ids = formatted_input.to(torch.int32) context_pos = base_input_ids.size(1) DebugLog(f"New context size: {context_pos}") # Pad sequence to context_size input_ids = F.pad( base_input_ids, (0, context_size - context_pos), value=0 ) # Create causal mask for the entire context causal_mask = make_causal_mask(context_size, 0) causal_mask = torch.tensor(causal_mask, dtype=torch.float16) DebugLog(f"Created causal mask with shape: {causal_mask.shape}") print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True) # Run prefill on entire context if False: #any(key.contains('2D') for key in model_parts.keys()): DebugLog("Using sequential prefill") current_pos = PreFillChunkOneByOne( model_parts, input_ids, context_pos, context_size, causal_mask ) elif any(part.startswith('2Q') for part in model_parts.keys()): DebugLog(f"Using quad split prefill (size={PREFILL_BATCH_SIZE})") current_pos = PreFillChunk( model_parts, input_ids, context_pos, context_size, causal_mask, batch_size=PREFILL_BATCH_SIZE ) else: DebugLog(f"Using standard batch prefill (size={PREFILL_BATCH_SIZE})") current_pos = PreFillChunk( model_parts, input_ids, context_pos, context_size, causal_mask, batch_size=PREFILL_BATCH_SIZE ) # Initialize token printer token_printer = TokenPrinter(tokenizer) # Generation loop pos = context_pos response_tokens = [] generation_start_time = time.time() # Add timing try: while True: # Changed from context_size - 1 to True for continuous generation # Check if we need to shift window if pos >= context_size - 2: DebugLog("\nShifting context window...") shift_size = context_size // 4 # Shift by 1/4 of context new_size = context_size - shift_size # Create shifted input_ids and preserve the most recent context tmp = torch.zeros((1, context_size), dtype=torch.int32) tmp[:,0:new_size] = input_ids[:,shift_size:context_size] input_ids = tmp # Adjust position after shift pos = new_size DebugLog(f"Shifted window by {shift_size} tokens, new position: {pos}") # Run prefill on the shifted sequence if False: #if any(key.contains('2D') for key in model_parts.keys()): DebugLog("Running sequential prefill after shift") current_pos = PreFillChunkOneByOne( model_parts, input_ids, pos, context_size, causal_mask ) else: DebugLog("Running batch prefill after shift (size={PREFILL_BATCH_SIZE})") current_pos = PreFillChunk( model_parts, input_ids, pos, context_size, causal_mask, batch_size=PREFILL_BATCH_SIZE ) # Get current token current_token = input_ids[:, pos-1:pos] # Find the correct model key for part 1 part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') # Run embeddings (part 1) hidden_states = model_parts[part1_key].predict({ 'input_ids': current_token.numpy() })['hidden_states'] hidden_states = torch.from_numpy(hidden_states) # Get shared transformer state shared_state = model_parts['states']['transformer'] # Create update mask for current position update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16) update_mask[0, 0, pos-1, 0] = 1.0 # Create position IDs tensor position_ids = torch.tensor([pos-1], dtype=torch.int32) # Create causal mask for current position single_causal_mask = causal_mask[:, :, pos-1:pos, :] # Run transformer layers based on model type if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']): for part in ['2D1S', '2D2S']: inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[f'{part}_infer'].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) elif any(part.startswith('2Q') for part in model_parts.keys()): DebugLog(f"Running quad split inference at position {pos}") for i in range(1, 5): part = f'2Q{i}' if f'{part}_infer' in model_parts: # Use infer function if available inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[f'{part}_infer'].predict(inputs, shared_state) else: # Use regular predict if no infer function inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[part].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) elif any(key.startswith('2D') for key in model_parts.keys()): for base_part in ['2D1', '2D2']: part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) inputs = { 'hidden_states': hidden_states.numpy(), 'update_mask': update_mask.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': single_causal_mask.numpy(), 'current_pos': position_ids.numpy() } output = model_parts[part_key].predict(inputs, shared_state) hidden_states = torch.from_numpy(output['transformer_output']) # Run final layer norm and get logits part3_key = next(key for key in model_parts.keys() if key.startswith('3_') or key == '3') output_dict = model_parts[part3_key].predict({ 'hidden_states': hidden_states.numpy() }) if ENABLE_VACAB_SPLIT8: logits_parts = [] for i in range(1, 9): logits_parts.append(output_dict[f'logits{i}']) logits = np.concatenate(logits_parts, axis=-1) else: logits = output_dict['logits'] # Convert to tensor and get next token logits = torch.from_numpy(logits) # Apply temperature if specified if temperature > 0: logits = logits / temperature probs = F.softmax(logits[0, -1, :], dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() else: next_token = torch.argmax(logits[0, -1, :]).item() # Add token to input sequence and response input_ids[0, pos] = next_token response_tokens.append(next_token) token_printer.add_token(next_token) # Safely decode tokens in the main thread token_printer.drain_buffer() pos += 1 # Add debug output for generated tokens if ENABLE_CHAT_DEBUG and len(response_tokens) > 0 and len(response_tokens) % 10 == 0: DebugLog(f"\nGenerated {len(response_tokens)} tokens") DebugLog(f"Last token: {next_token} -> '{tokenizer.decode([next_token])}'") if next_token == tokenizer.eos_token_id: DebugLog("\nGenerated EOS token") break # Get the complete response text and calculate stats response_text = token_printer.stop() generation_time = time.time() - generation_start_time tokens_per_second = len(response_tokens) / generation_time if generation_time > 0 else 0 DebugLog(f"\nFinal response length: {len(response_tokens)} tokens") # Print generation stats in dark blue print(f"\n{DARK_BLUE}[{len(response_tokens)} tokens, {tokens_per_second:.1f} tokens/s]{RESET_COLOR}") except KeyboardInterrupt: DebugLog("\nGeneration interrupted by user") response_text = token_printer.stop() generation_time = time.time() - generation_start_time tokens_per_second = len(response_tokens) / generation_time if generation_time > 0 else 0 print(f"\n{DARK_BLUE}[{len(response_tokens)} tokens, {tokens_per_second:.1f} tokens/s]{RESET_COLOR}") # Add assistant's response to conversation history conversation.append({"role": "assistant", "content": response_text}) except Exception as e: print(f"\n[ERROR] Chat loop error: {str(e)}") print(traceback.format_exc()) def main(): global CONTEXT_LENGTH, PREFILL_BATCH_SIZE, MODEL_PATH print("ANEMLL Chat. Pre-relase alpha version, 2025-01-31") print("Copyright (c) 2025, Anemll All rights reserved.") # Set default parameters model_type = "Q123" # Default model type lut_suffix = "lut4" # Default LUT suffix temperature = 0.0 model_parts = {} model_path = "." # Default to current directory if len(sys.argv) < 2: print("Usage: python chat.py [model_parts] [options]") print("Usage: python chat.py [model_parts] [options]") print("\nOptions:") print(" -d PATH # Model directory path (for both tokenizer and CoreML models)") print(" S123 # Combined split model (2D1S+2D2S)") print(" C123 # Combined part2 model with prefill/infer") print(" Q123 # Quad split model (2Q1-2Q4) [default]") print(" Q123S # Combined quad split model (2Q1S-2Q4S)") print(" 1 2D1 2D2 3 # Individual split parts") print(" pfN # Prefill batch size (e.g., pf128)") print(" ctx=N # Context length (e.g., ctx=2048) [default: 1024]") print(" temp=X # Temperature for sampling (e.g., temp=0.01)") print(" lut4 # LUT suffix [default]") print("\nDefault configuration: Q123 lut4 ctx=1024") print(" python chat.py Q123 -d ../anemll-DeepSeek-8B-ctx1024") # Use defaults instead of exiting print("\nUsing default configuration...") else: # Process command line arguments i = 1 while i < len(sys.argv): if sys.argv[i] == '-d' and i + 1 < len(sys.argv): model_path = sys.argv[i + 1] i += 2 # Extract context length from model path if present ctx_match = re.search(r'ctx(\d+)', model_path) if ctx_match: ctx_value = int(ctx_match.group(1)) if 512 <= ctx_value <= 4096*2: CONTEXT_LENGTH = ctx_value print(f"Setting context length to {CONTEXT_LENGTH} from model path") continue elif sys.argv[i].startswith('lut'): lut_suffix = sys.argv[i] elif sys.argv[i] in ['S123', 'Q123', 'Q123S', 'C123', '123D']: model_type = sys.argv[i] i += 1 # Initialize tokenizer using the same path tokenizer = initialize_tokenizer(model_path) if tokenizer is None: print("[ERROR] Failed to initialize tokenizer. Exiting.") return # Process model parts parts = [model_type] if lut_suffix: parts.append(lut_suffix) try: split_model = SplitModelInference(parts, model_dir=model_path) model_parts.update(split_model.models) model_parts['states'] = {'transformer': split_model.states['transformer']} except Exception as e: print(f"Error loading model parts: {str(e)}") return # Process remaining arguments i = 1 while i < len(sys.argv): arg = sys.argv[i] if arg.startswith('pf') and arg[2:].isdigit(): PREFILL_BATCH_SIZE = int(arg[2:]) elif arg.startswith('ctx='): try: CONTEXT_LENGTH = int(arg.split('=')[1]) except (IndexError, ValueError): print(f"[WARNING] Invalid context length format. Using default: {CONTEXT_LENGTH}") elif arg.startswith('temp='): try: temperature = float(arg.split('=')[1]) if temperature < 0: print(f"[WARNING] Temperature must be non-negative. Using default: 0.0") temperature = 0.0 except (IndexError, ValueError): print(f"[WARNING] Invalid temperature format. Using default: 0.0") i += 1 try: # Start interactive chat loop chat_loop(model_parts, tokenizer, context_size=CONTEXT_LENGTH, temperature=temperature) except Exception as e: print("An error occurred:") print(traceback.format_exc()) if __name__ == "__main__": main()