anemll's picture
Upload folder using huggingface_hub
a283ad3 verified
# Copyright (c) 2025, Anemll All rights reserved.
#
# Use of this source code is governed by a MIT license that can be
# found in the LICENSE.txt file or at https://opensource.org/license/mit
import coremltools as ct
import numpy as np
import torch
from transformers import AutoTokenizer
import os
import time, sys
import signal
import traceback
import torch.nn.functional as F
import queue
import threading
import re
# Configuration
CONTEXT_LENGTH = 1024 # Changed default from 512 to 1024
PREFILL_BATCH_SIZE = 64
MODEL_PATH = os.path.expanduser("../DeepSeekR1-8B")
ENABLE_VACAB_SPLIT8 = True # Enable 8-way vocab split
ENABLE_LOGITS2 = False # Enable 2-way vocab split
ENABLE_DEBUG = bool(0)
ENABLE_ARGMAX = bool(0)
ENABLE_PREFILL_BATCH = bool(1)
ENABLE_CHAT_DEBUG = bool(0) # Debug flag for chat loop
# ANSI color codes
LIGHT_BLUE = "\033[94m"
DARK_BLUE = "\033[34m"
LIGHT_GREEN = "\033[92m"
RESET_COLOR = "\033[0m"
if ENABLE_LOGITS2:
assert not ENABLE_ARGMAX, "ENABLE_ARGMAX must be False when ENABLE_LOGITS2 is True"
def load_model(path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name=None):
"""Load either compiled or uncompiled CoreML model.
Args:
path: Path to the model file (.mlmodelc or .mlpackage)
compute_unit: CoreML compute unit to use
function_name: Optional function name to select from multi-function models
"""
DebugLog(f"Attempting to load model: {path}")
DebugLog(f"File exists: {os.path.exists(path)}")
DebugLog(f"Is directory (for mlmodelc): {os.path.isdir(path)}")
try:
if path.endswith('.mlmodelc'):
DebugLog(f"Loading compiled model: {path}")
if function_name is None:
DebugLog("Loading without function name")
model = ct.models.CompiledMLModel(path, compute_unit)
else:
DebugLog(f"Loading with function name: {function_name}")
model = ct.models.CompiledMLModel(path, compute_unit, function_name=function_name)
else:
DebugLog(f"Loading uncompiled model: {path}")
if function_name is None:
DebugLog("Loading without function name")
model = ct.models.MLModel(model=path, compute_units=compute_unit, is_temp_package=False)
else:
DebugLog(f"Loading with function name: {function_name}")
model = ct.models.MLModel(model=path, compute_units=compute_unit, is_temp_package=False, function_name=function_name)
DebugLog("Model loaded successfully")
return model
except Exception as e:
DebugLog(f"Error loading model: {str(e)}")
DebugLog(f"Error type: {type(e)}")
raise
class SplitModelInference:
def __init__(self, model_parts, model_dir="."):
"""Initialize split model inference.
Args:
model_parts (list): List of model part numbers to load
Special cases:
- 'C123' for combined part2 with prefill/infer functions
- 'S123' for split model with prefill/infer functions
- 'Q123' for quad split (2Q1-2Q4)
- 'Q123S' for quad split with combined prefill/infer (2Q1S-2Q4S)
- '123D' for dual split without prefill/infer (2D1-2D2)
model_dir (str): Directory containing the model files (default: current directory)
"""
self.context_size = CONTEXT_LENGTH
self.model_dir = model_dir
DebugLog(f"Loading models from directory: {self.model_dir}")
# Parse configuration
self.quant_configs = {}
global_lut = None
if model_parts and model_parts[-1].startswith('lut'):
global_lut = model_parts[-1]
model_parts = model_parts[:-1]
# Special handling for different split modes
if len(model_parts) == 1:
if model_parts[0] == '123D': # Dual split without prefill/infer
self.use_combined_part2 = False
self.use_split_model = True
self.use_split_functions = False
self.use_quad_split = False
self.use_quad_split_combined = False
self.model_parts = ['1', '2D1', '2D2', '3']
if global_lut:
self.quant_configs = {part: global_lut for part in self.model_parts}
DebugLog(f"Using dual split model with parts: {self.model_parts}")
elif model_parts[0].startswith('C123'): # Combined part2
self.use_combined_part2 = True
self.use_split_model = False
self.use_split_functions = False
self.use_quad_split = False
self.use_quad_split_combined = False
self.model_parts = ['1', '2', '3']
if global_lut:
self.quant_configs = {part: global_lut for part in self.model_parts}
DebugLog(f"Using combined part2 model with parts: {self.model_parts}")
elif model_parts[0].startswith('S123'): # Split model with prefill/infer functions
self.use_combined_part2 = False
self.use_split_model = True
self.use_split_functions = True
self.use_quad_split = False
self.use_quad_split_combined = False
self.model_parts = ['1', '2D1S', '2D2S', '3']
elif model_parts[0].startswith('Q123S'): # Quad split with combined prefill/infer
self.use_combined_part2 = False
self.use_split_model = True
self.use_split_functions = False
self.use_quad_split = False
self.use_quad_split_combined = True
self.model_parts = ['1', '2Q1S', '2Q2S', '2Q3S', '2Q4S', '3']
elif model_parts[0].startswith('Q123'): # Regular quad split
self.use_combined_part2 = False
self.use_split_model = True
self.use_split_functions = False
self.use_quad_split = True
self.use_quad_split_combined = False
self.model_parts = ['1', '2Q1', '2Q2', '2Q3', '2Q4', '3']
else:
self.use_combined_part2 = False
self.use_split_model = False
self.use_split_functions = False
self.use_quad_split = False
self.use_quad_split_combined = False
self.model_parts = model_parts
else:
self.use_combined_part2 = False
self.use_split_model = False
self.use_split_functions = False
self.use_quad_split = False
self.use_quad_split_combined = False
self.model_parts = model_parts
# Apply global quantization if specified
if global_lut and not self.use_combined_part2: # Skip if already applied for C123
self.quant_configs = {part: global_lut for part in self.model_parts}
DebugLog(f"Using model parts: {self.model_parts}")
if global_lut:
DebugLog(f"With global quantization: {global_lut}")
if self.use_combined_part2:
DebugLog("Using combined part2 model with prefill/infer functions")
elif self.use_split_functions:
DebugLog("Using split model with prefill/infer functions")
elif self.use_quad_split:
DebugLog("Using quad split transformer model (2Q1-2Q4)")
elif self.use_quad_split_combined:
DebugLog("Using combined quad split transformer model (2Q1S-2Q4S)")
self.models = {}
self.states = {}
self.load_models()
def find_model_path(self, base_name, description="model"):
"""Find model path, checking mlmodelc first then mlpackage.
Also tries both with and without lut suffix.
Args:
base_name: Base name of the model without extension
description: Description for error message (e.g., "Split model part 2D1S")
Returns:
str: Path to the found model file
Raises:
FileNotFoundError: If neither mlmodelc nor mlpackage exists
"""
# For quad split parts, only try mlmodelc
if any(part in base_name for part in ['2Q1S', '2Q2S', '2Q3S', '2Q4S', '2Q1', '2Q2', '2Q3', '2Q4']):
model_path = os.path.join(self.model_dir, f"{base_name}.mlmodelc")
if os.path.exists(model_path):
return model_path
# If not found, try without lut suffix
if '_lut' in base_name:
base_without_lut = base_name.split('_lut')[0]
model_path = os.path.join(self.model_dir, f"{base_without_lut}.mlmodelc")
if os.path.exists(model_path):
return model_path
# Neither exists
raise FileNotFoundError(f"{description} not found: {base_name}.mlmodelc does not exist" +
(f" (also tried {base_name.split('_lut')[0]}.mlmodelc)" if '_lut' in base_name else ""))
# For other parts, try both mlmodelc and mlpackage
for ext in ['.mlmodelc', '.mlpackage']:
model_path = os.path.join(self.model_dir, f"{base_name}{ext}")
if os.path.exists(model_path):
return model_path
# If not found, try without lut suffix
if '_lut' in base_name:
base_without_lut = base_name.split('_lut')[0]
for ext in ['.mlmodelc', '.mlpackage']:
model_path = os.path.join(self.model_dir, f"{base_without_lut}{ext}")
if os.path.exists(model_path):
return model_path
# Neither exists
raise FileNotFoundError(f"{description} not found: neither {base_name}.mlmodelc nor {base_name}.mlpackage exist in {self.model_dir}" +
(f" (also tried {base_name.split('_lut')[0]}.mlmodelc/mlpackage)" if '_lut' in base_name else ""))
def load_models(self):
"""Load each model part."""
DebugLog("Loading model parts...")
for part in self.model_parts:
quant_suffix = f"_{self.quant_configs[part]}" if part in self.quant_configs else ""
model_key = f"{part}{quant_suffix}" # Use this as the key in self.models
try:
if part == '2' and self.use_combined_part2:
# Load combined part2 with multiple functions
base_name = f"llama32_part2_combined{quant_suffix}"
model_path = self.find_model_path(base_name, "Combined part2 model")
DebugLog(f"Loading combined part2 model: {model_path}")
# Load prefill function
self.models['2_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill')
# Load infer function
self.models['2_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer')
# Create shared state
self.states['transformer'] = self.models['2_prefill'].make_state()
DebugLog("Combined part2 model loaded successfully")
elif part == '2' and not self.use_combined_part2:
# Load regular part2 model
base_name = f"llama32_part2{quant_suffix}"
model_path = self.find_model_path(base_name, "Regular part2 model")
DebugLog(f"Loading regular part2 model: {model_path}")
self.models[model_key] = load_model(model_path)
self.states['transformer'] = self.models[model_key].make_state()
DebugLog("Regular part2 model loaded successfully")
elif part in ['2D1S', '2D2S'] and self.use_split_functions:
# Load split model with prefill/infer functions
base_name = f"llama32_part{part}{quant_suffix}"
model_path = self.find_model_path(base_name, f"Split model part {part}")
DebugLog(f"Loading split model part {part}: {model_path}")
# Load prefill function
self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill')
# Load infer function
self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer')
# Create shared state for first part only
if part == '2D1S':
self.states['transformer'] = self.models[f'{part}_infer'].make_state()
DebugLog(f"Split model part {part} loaded successfully")
elif part.endswith('S') and self.use_quad_split_combined:
# Load combined quad split model with prefill/infer functions
base_name = f"llama32_part{part}{quant_suffix}"
model_path = self.find_model_path(base_name, f"Combined quad split part {part}")
DebugLog(f"Loading combined quad split part {part}: {model_path}")
# Load prefill function
self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill')
# Load infer function
self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer')
# Create shared state for first part only
if part == '2Q1S':
self.states['transformer'] = self.models[f'{part}_infer'].make_state()
DebugLog(f"Created shared transformer state for all quad split parts")
DebugLog(f"Combined quad split part {part} loaded successfully")
elif part.startswith('2Q') and self.use_quad_split:
# Load quad split model with prefill/infer functions
# Append 'S' to part name for file lookup
base_name = f"llama32_part{part}S{quant_suffix}"
model_path = self.find_model_path(base_name, f"Quad split part {part}")
DebugLog(f"Loading quad split part {part}: {model_path}")
# Load prefill function
self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill')
# Load infer function
self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer')
# Create shared state for first part only
if part == '2Q1':
self.states['transformer'] = self.models[f'{part}_infer'].make_state()
DebugLog(f"Created shared transformer state for all quad split parts")
print(f"Created shared transformer state for all quad split parts")
print(f"Quad split part {part} loaded successfully")
else:
# Load regular models (part 1 and part3)
base_name = f"llama32_part{part}{quant_suffix}"
model_path = self.find_model_path(base_name, f"Regular part {part}")
print(f"[MODEL LOAD] Regular part {part}:")
print(f" - File: {model_path}")
print(f" - Loading as: '{model_key}'")
# Try loading with CPU first, then fall back to CPU_AND_NE if needed
try:
self.models[model_key] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE)
print(f" - Loaded with CPU_AND_NE compute unit")
except Exception as cpu_error:
print(f" - CPU load failed, trying CPU_AND_NE: {str(cpu_error)}")
self.models[model_key] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU)
print(f" - Loaded with CPU compute unit")
print(f"[MODEL LOAD] Current model_parts keys: {list(self.models.keys())}")
except Exception as e:
print(f"Error loading model part {part}: {str(e)}")
raise
def run_transformer_prefill(self, hidden_states, update_mask, position_ids, causal_mask, current_pos):
"""Run the transformer model in prefill mode."""
if self.use_split_functions:
# Use prefill variants for split model
for part in ['2D1S', '2D2S']:
inputs = {
'hidden_states': hidden_states.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': causal_mask.numpy(),
'start_pos': current_pos.numpy()
}
output = self.models[f'{part}_prefill'].predict(inputs, self.states['transformer'])
hidden_states = torch.from_numpy(output['dummy_output'])
return hidden_states
else:
# Use existing prefill implementation
return super().run_transformer_prefill(hidden_states, update_mask, position_ids, causal_mask, current_pos)
def run_transformer_infer(self, hidden_states, update_mask, position_ids, causal_mask, current_pos):
"""Run the transformer model in infer mode."""
if self.use_split_functions:
# Use infer variants for split model
for part in ['2D1S', '2D2S']:
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': causal_mask.numpy(),
'current_pos': current_pos.numpy()
}
output = self.models[f'{part}_infer'].predict(inputs, self.states['transformer'])
hidden_states = torch.from_numpy(output['transformer_output'])
return hidden_states
else:
# Use existing infer implementation
return super().run_transformer_infer(hidden_states, update_mask, position_ids, causal_mask, current_pos)
def get_state(self, part):
"""Get the appropriate state for a model part."""
return self.states['transformer']
def run_embeddings(self, input_ids):
"""Run the embeddings model (part 1)."""
if '1' not in self.models:
raise ValueError("Embeddings model (part 1) not loaded")
output_dict = self.models['1'].predict({
'input_ids': input_ids.numpy()
})
return torch.from_numpy(output_dict['hidden_states'])
def run_transformer(self, hidden_states, update_mask, position_ids, causal_mask, current_pos, part='2'):
"""Run the transformer model."""
if part not in self.models:
raise ValueError(f"Transformer model (part {part}) not loaded")
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': causal_mask.numpy(),
'current_pos': current_pos.numpy()
}
output_dict = self.models[part].predict(inputs, self.get_state(part))
return torch.from_numpy(output_dict['transformer_output'])
def run_transformer_splits(self, hidden_states, update_mask, position_ids, causal_mask, current_pos):
"""Run through transformer splits based on model configuration."""
if not self.use_split_model:
return self.run_transformer(hidden_states, update_mask, position_ids, causal_mask, current_pos)
# Handle different split configurations
if any(part.startswith('2Q') for part in self.model_parts): # Quad split
for i in range(1, 5):
part = f'2Q{i}'
hidden_states = self.run_transformer(
hidden_states, update_mask, position_ids, causal_mask, current_pos, part=part
)
elif any(part.startswith('2O') for part in self.model_parts): # Octa split
for i in range(1, 9):
part = f'2O{i}'
hidden_states = self.run_transformer(
hidden_states, update_mask, position_ids, causal_mask, current_pos, part=part
)
elif any(part.startswith('2D') for part in self.model_parts): # Dual split
# Run through both parts of the dual split
for base_part in ['2D1', '2D2']:
# Find the correct model key (with lut suffix if present)
part_key = next(key for key in self.models.keys() if key.startswith(f'{base_part}_') or key == base_part)
# Use the shared transformer state
if 'transformer' not in self.states:
raise ValueError("Transformer state not initialized. Make sure 2D1 is loaded first.")
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': causal_mask.numpy(),
'current_pos': current_pos.numpy()
}
output_dict = self.models[part_key].predict(inputs, self.states['transformer'])
hidden_states = torch.from_numpy(output_dict['transformer_output'])
return hidden_states
def run_lm_head(self, hidden_states):
"""Run the LM head model (part 3)."""
if '3' not in self.models:
raise ValueError("LM head model (part 3) not loaded")
output_dict = self.models['3'].predict({
'hidden_states': hidden_states.numpy()
})
# Handle split logits
logits_parts = []
for i in range(1, 9): # logits1 through logits8
logits_key = f'logits{i}'
if logits_key in output_dict:
logits_part = torch.from_numpy(output_dict[logits_key])
logits_parts.append(logits_part)
# Concatenate along the vocabulary dimension
return torch.cat(logits_parts, dim=-1)
def run_full_model(self, input_ids, update_mask, position_ids, causal_mask, current_pos):
"""Run the full model."""
if 'full' not in self.models:
raise ValueError("Full model not loaded")
# Update context size from global
self.context_size = CONTEXT_LENGTH
#kv_ was removed from the input names
inputs = {
'input_ids': input_ids.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': causal_mask.numpy(),
'current_pos': current_pos.numpy()
}
# Print shapes of all inputs
if False:
print("[DEBUG] Input shapes:")
for key, value in inputs.items():
print(f" {key}: {value.shape}")
output_dict = self.models['full'].predict(inputs, self.states['transformer'])
# Handle split logits if necessary
if ENABLE_VACAB_SPLIT8:
logits_parts = []
for i in range(1, 9):
logits_parts.append(output_dict[f'logits{i}'])
logits = np.concatenate(logits_parts, axis=-1)
else:
logits = output_dict['logits']
return torch.from_numpy(logits)
def make_causal_mask(length, start):
# Initialize the mask with -inf
mask = np.full((1, 1, length, length), -np.inf, dtype=np.float16)
# Create row and column indices
row_indices = np.arange(length).reshape(length, 1) # Column vector
col_indices = np.arange(length).reshape(1, length) # Row vector
# Set allowed positions to 0 where col_index is within the allowed range of row_index
mask[:, :, col_indices <= (row_indices + start)] = 0
return mask
def initialize_tokenizer(model_path):
"""Initialize and configure the tokenizer."""
try:
print(f"[DEBUG] Loading tokenizer from model path: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
print("\n[DEBUG] Tokenizer Configuration:")
print(f"Tokenizer type: {type(tokenizer)}")
print(f"Tokenizer name: {tokenizer.__class__.__name__}")
print(f"Vocabulary size: {len(tokenizer)}")
print(f"Model max length: {tokenizer.model_max_length}")
#print(f"Chat template: {tokenizer.chat_template if hasattr(tokenizer, 'chat_template') else 'None'}")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print("[DEBUG] Set PAD token to EOS token")
print(f"\n[DEBUG] Special Tokens:")
print(f"PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
print(f"BOS token: '{tokenizer.bos_token}' (ID: {tokenizer.bos_token_id})")
print(f"UNK token: '{tokenizer.unk_token}' (ID: {tokenizer.unk_token_id})")
return tokenizer
except Exception as e:
print(f"[ERROR] Failed to load tokenizer from {model_path}")
return None
class TokenPrinter:
"""Handles background printing of generated tokens."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.token_queue = queue.Queue()
self.stop_event = threading.Event()
self.thread = None
self.buffer = ""
self.lock = threading.Lock()
self.thinking = True # Track if we're still in thinking mode
self.decoding_buffer = [] # <-- Buffer for token IDs
self.start()
def start(self):
"""Start the printer thread."""
if self.thread is None:
self.thread = threading.Thread(target=self._print_worker)
self.thread.daemon = True
self.thread.start()
def add_token(self, token_id):
"""Add a token to the print queue."""
if not self.stop_event.is_set():
self.token_queue.put(token_id)
def drain_buffer(self):
"""
Decode token IDs from self.decoding_buffer in the main thread,
then print them with the correct color logic.
"""
if not self.decoding_buffer:
return
# Decode all tokens at once in the main thread.
token_str = self.tokenizer.decode(self.decoding_buffer)
self.decoding_buffer.clear()
# Color-handling logic. Check for "</think>" and handle self.thinking.
if self.thinking and "</think>" in token_str:
self.thinking = False
parts = token_str.split("</think>")
if len(parts) > 0:
print(parts[0] + "</think>", end='', flush=True)
if len(parts) > 1:
print(LIGHT_BLUE + parts[1], end='', flush=True)
else:
if not self.thinking:
print(LIGHT_BLUE + token_str, end='', flush=True)
else:
print(token_str, end='', flush=True)
def _print_worker(self):
"""Worker thread that takes token_ids from the queue but doesn't decode."""
while not self.stop_event.is_set():
try:
token_id = self.token_queue.get(timeout=0.01)
with self.lock:
# Just store the token_id, decode later on the main thread
self.decoding_buffer.append(token_id)
self.token_queue.task_done()
except queue.Empty:
continue
except Exception as e:
print(f"\n[ERROR] Token printer error: {str(e)}")
break
def stop(self):
"""Stop the printer thread."""
if self.thread and self.thread.is_alive():
self.stop_event.set()
try:
self.thread.join(timeout=1.0)
except Exception:
pass
print(RESET_COLOR) # Reset color at the end
return self.buffer
def parse_coreml_error(error_str):
"""Parse CoreML error message to extract shape information.
Args:
error_str: The error message string from CoreML
Returns:
tuple: (got_shape, expected_shape) or None if parsing fails
"""
try:
# Extract shapes from error message using regex
pattern = r"shape \(([\d\s x]+)\) does not match the shape \(([\d\s x]+)\)"
match = re.search(pattern, str(error_str))
if match:
got_shape = tuple(int(x) for x in match.group(1).split('x'))
expected_shape = tuple(int(x) for x in match.group(2).split('x'))
return got_shape, expected_shape
return None
except Exception as e:
print(f"Error parsing CoreML error message: {e}")
return None
def handle_coreml_shape_error(e, model_name=""):
"""Handle CoreML shape mismatch errors with detailed information.
Args:
e: The exception object
model_name: Name of the model for better error reporting
"""
error_str = str(e)
if "MultiArray shape" in error_str:
shape_info = parse_coreml_error(error_str)
if shape_info:
got_shape, expected_shape = shape_info
print(f"\n[ERROR] Shape mismatch in {model_name}:")
print(f" Got shape: {' x '.join(str(x) for x in got_shape)}")
print(f" Expected shape: {' x '.join(str(x) for x in expected_shape)}")
print("This usually indicates a mismatch between the model's expected context length")
print("and the actual input being provided.")
else:
print(f"\n[ERROR] Shape mismatch error in {model_name}:")
print(f" {error_str}")
else:
print(f"\n[ERROR] CoreML error in {model_name}:")
print(f" {error_str}")
def PreFillChunk(model_parts, input_ids, current_pos, context_size, causal_mask, batch_size=64):
tokens_to_process = current_pos
batch_pos = 0
while batch_pos < tokens_to_process:
batch_end = min(batch_pos + batch_size, tokens_to_process)
current_batch_size = batch_end - batch_pos
try:
# Get current batch of tokens
batch_input = input_ids[:, batch_pos:batch_end]
# Pad if needed
if current_batch_size < batch_size:
batch_input = F.pad(
batch_input,
(0, batch_size - current_batch_size),
value=0
)
# Generate position IDs for this batch
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
# Prepare causal mask for this batch
multiple_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
# Find the correct model key for part 1 (with lut suffix if present)
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1')
try:
# Run embeddings (part 1)
hidden_states = model_parts[part1_key].predict({'input_ids': batch_input.numpy()})['hidden_states']
hidden_states = torch.from_numpy(hidden_states)
except Exception as e:
handle_coreml_shape_error(e, f"embeddings model (part {part1_key})")
raise
# Get shared transformer state
shared_state = model_parts['states']['transformer']
# Handle different model configurations
if any(f'{part}_prefill' in model_parts for part in ['2D1S', '2D2S']):
# S123 mode with prefill/infer functions
for part in ['2D1S', '2D2S']:
try:
inputs = {
'hidden_states': hidden_states.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': multiple_causal_mask.numpy(),
'start_pos': np.array([batch_pos], dtype=np.int32)
}
output = model_parts[f'{part}_prefill'].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['dummy_output'])
except Exception as e:
handle_coreml_shape_error(e, f"transformer model (part {part})")
raise
elif any(part.endswith('S') for part in model_parts if part.startswith('2Q')):
# Q123S mode with combined quad split
for i in range(1, 5):
part = f'2Q{i}S'
try:
inputs = {
'hidden_states': hidden_states.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': multiple_causal_mask.numpy(),
'start_pos': np.array([batch_pos], dtype=np.int32)
}
output = model_parts[f'{part}_prefill'].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['dummy_output'])
except Exception as e:
handle_coreml_shape_error(e, f"transformer model (part {part})")
raise
elif any(part.startswith('2Q') for part in model_parts):
# Q123 mode with quad split
for i in range(1, 5):
part = f'2Q{i}'
if f'{part}_prefill' in model_parts:
# Use prefill function if available
try:
inputs = {
'hidden_states': hidden_states.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': multiple_causal_mask.numpy(),
'start_pos': np.array([batch_pos], dtype=np.int32)
}
output = model_parts[f'{part}_prefill'].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['dummy_output'])
except Exception as e:
handle_coreml_shape_error(e, f"transformer model (part {part})")
raise
else:
# Use regular predict if no prefill function
try:
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': torch.zeros((1, 1, context_size, 1), dtype=torch.float16).numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': multiple_causal_mask.numpy(),
'current_pos': position_ids[0].numpy()
}
output = model_parts[part].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
except Exception as e:
handle_coreml_shape_error(e, f"transformer model (part {part})")
raise
elif any(key.startswith('2D') for key in model_parts.keys()):
# 123D mode with dual split (no prefill functions)
for base_part in ['2D1', '2D2']:
# Find the correct model key (with lut suffix if present)
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part)
try:
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': torch.zeros((1, 1, context_size, 1), dtype=torch.float16).numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': multiple_causal_mask.numpy(),
'current_pos': position_ids[0].numpy()
}
output = model_parts[part_key].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
except Exception as e:
handle_coreml_shape_error(e, f"transformer model (part {part_key})")
raise
batch_pos = batch_end
except Exception as e:
print(f"\n[ERROR] Failed processing batch {batch_pos}-{batch_end}:")
print(f" {str(e)}")
raise
return torch.tensor([current_pos], dtype=torch.int32)
def PreFillChunkOneByOne(model_parts, input_ids, current_pos, context_size, causal_mask):
"""Process prefill tokens one at a time using infer function."""
#print(f"[DEBUG] Starting one-by-one prefill for {current_pos} tokens")
for pos in range(current_pos):
# Get current token
current_token = input_ids[:, pos:pos+1]
single_causal_mask = causal_mask[:, :, pos:pos+1, :]
current_pos_tensor = torch.tensor([pos], dtype=torch.int32)
# Find the correct model key for part 1 (with lut suffix if present)
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1')
# Run embeddings (part 1)
hidden_states = torch.from_numpy(model_parts[part1_key].predict({
'input_ids': current_token.numpy()
})['hidden_states'])
#print(f"[DEBUG] pos: {pos} token: {current_token.item()} states: {hidden_states.shape}")
# Get shared transformer state
shared_state = model_parts['states']['transformer']
# Handle different model configurations
if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']):
# S123 mode with prefill/infer functions
for part in ['2D1S', '2D2S']:
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': np.zeros((1, 1, context_size, 1), dtype=np.float16),
'position_ids': current_pos_tensor.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': current_pos_tensor.numpy()
}
output = model_parts[f'{part}_infer'].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
elif any(key.startswith('2D') for key in model_parts.keys()):
# 123D mode or individual parts mode
for base_part in ['2D1', '2D2']:
# Find the correct model key (with lut suffix if present)
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part)
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': np.zeros((1, 1, context_size, 1), dtype=np.float16),
'position_ids': current_pos_tensor.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': current_pos_tensor.numpy()
}
output = model_parts[part_key].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
return torch.tensor([current_pos], dtype=torch.int32)
def run_inference(model_parts, tokenizer, prompt, context_size=CONTEXT_LENGTH, num_iterations=5, temperature=0.0):
"""Run inference using model parts."""
DebugLog(f"\nPrompt: {prompt}")
if temperature > 0:
DebugLog(f"Using temperature: {temperature}")
# Prepare the prompt
messages = [{"role": "user", "content": prompt}]
formatted_input = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=False
)
decoded_input = tokenizer.decode(formatted_input[0])
DebugLog(f"Decoded input: {decoded_input}")
DebugLog(f"prompt: {prompt}")
DebugLog(f"formatted_input size: {formatted_input.size()}")
DebugLog(f"formatted_input: {formatted_input}")
base_input_ids = formatted_input.to(torch.int32)
context_pos = base_input_ids.size(1)
prompt_tokens = context_pos - 1
# Pad sequence to context_size
input_ids = F.pad(
base_input_ids,
(0, context_size - context_pos),
value=0
)
DebugLog(f"context_pos (prompt length) = {context_pos}")
# Create causal mask
causal_mask = make_causal_mask(context_size, 0)
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
# Prefill phase
DebugLog("\nStarting prefill...")
start_time = time.time()
# Check if we're using 123D mode or individual parts
use_single_token = any(key.contains('2D') for key in model_parts.keys()) or any(part.contains('2D') for part in model_parts)
if False: #use_single_token:
print("\nRunning ST prefill...")
current_pos = PreFillChunkOneByOne(
model_parts,
input_ids,
context_pos - 1,
context_size,
causal_mask
)
sequential_prefill_time = time.time() - start_time
batch_prefll_time = 0.0
else:
print("\nRunning batch prefill...")
current_pos = PreFillChunk(
model_parts,
input_ids,
context_pos - 1,
context_size,
causal_mask,
batch_size=PREFILL_BATCH_SIZE
)
batch_prefill_time = time.time() - start_time
sequential_prefill_time = 0.0
# Initialize token printer
token_printer = TokenPrinter(tokenizer)
print("\nGenerated response:", end=' ', flush=True)
# Generation loop
start_gen_time = time.time()
pos = context_pos - 1
tokens_generated = 0
try:
DebugLog(f"\nStarting inference... context_pos: {context_pos}")
pos = context_pos
for step in range(num_iterations):
with torch.no_grad():
# Check if we need to shift cache
if pos >= context_size - 2:
shift_size = context_size // 4
new_size = context_size - shift_size
# Create shifted input_ids and preserve the most recent context
# Don't add BOS token since this is a continuation
tmp = torch.zeros((1, context_size), dtype=torch.int32)
tmp[:,0:new_size] = input_ids[:,shift_size:context_size]
input_ids = tmp
# Adjust position after shift
pos = new_size
# Create update mask for current position
update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16)
update_mask[0, 0, pos-1, 0] = 1.0
#print(f"\n[DEBUG] Shifted cache by {shift_size} tokens, maintaining context window of {new_size} tokens, new pos: {pos}")
# For Q123 mode, we need to run prefill on the shifted sequence
if any(part.startswith('2Q') for part in model_parts):
# Run prefill using PreFillChunk with proper batch size
# No need to adjust position since we're not adding BOS
current_pos = PreFillChunk(
model_parts,
input_ids,
pos-1, # how much ob
context_size, # Use full context size
causal_mask,
batch_size=PREFILL_BATCH_SIZE
)
#print(f"[DEBUG] Ran prefill after shift for position {pos} with batch_size={PREFILL_BATCH_SIZE}")
# Position should already be correct since we didn't add BOS
pos = current_pos
# Get current token
current_token = input_ids[:, pos-1:pos]
# Find the correct model key for part 1 (with lut suffix if present)
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1')
# Run embeddings (part 1)
hidden_states = model_parts[part1_key].predict({
'input_ids': current_token.numpy()
})['hidden_states']
hidden_states = torch.from_numpy(hidden_states)
# Get shared transformer state
shared_state = model_parts['states']['transformer']
# Create update mask for current position
update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16)
update_mask[0, 0, pos-1, 0] = 1.0
# Create position IDs tensor
position_ids = torch.tensor([pos-1], dtype=torch.int32)
# Create causal mask for current position
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
# Run transformer layers based on model type
if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']):
# S123 mode with prefill/infer functions
for part in ['2D1S', '2D2S']:
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[f'{part}_infer'].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
elif any(part.startswith('2Q') for part in model_parts.keys()):
# Q123S mode with combined quad split
for i in range(1, 5):
part = f'2Q{i}S'
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[f'{part}_infer'].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
elif any(part.startswith('2Q') for part in model_parts):
# Q123 mode with quad split
#print(f"[DEBUG] Running quad split inference at position {pos}")
for i in range(1, 5):
part = f'2Q{i}'
if f'{part}_infer' in model_parts:
# Use infer function if available
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[f'{part}_infer'].predict(inputs, shared_state)
else:
# Use regular predict if no infer function
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[part].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
elif any(key.startswith('2D') for key in model_parts.keys()):
# 123D mode or individual parts mode
for base_part in ['2D1', '2D2']:
# Find the correct model key (with lut suffix if present)
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part)
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[part_key].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
else:
print("\n[ERROR] No transformer model parts found!")
break
try:
# Run final layer norm and get logits
# Find the correct model key for part 3 (with lut suffix if present)
part3_key = next(key for key in model_parts.keys() if key.startswith('3_') or key == '3')
output_dict = model_parts[part3_key].predict({
'hidden_states': hidden_states.numpy()
})
if ENABLE_VACAB_SPLIT8:
# Get all logits parts in a single call
logits_parts = []
for i in range(1, 9):
logits_parts.append(output_dict[f'logits{i}'])
logits = np.concatenate(logits_parts, axis=-1)
elif ENABLE_LOGITS2:
# Get both logits parts in a single call
logits = np.concatenate([
output_dict['logits1'],
output_dict['logits2']
], axis=-1)
else:
logits = output_dict['logits']
# Convert to tensor and get next token
logits = torch.from_numpy(logits)
# Apply temperature if specified
if temperature > 0:
# Scale logits by temperature
logits = logits / temperature
# Apply softmax to get probabilities
probs = F.softmax(logits[0, -1, :], dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probs, num_samples=1).item()
else:
# Use argmax if no temperature
next_token = torch.argmax(logits[0, -1, :]).item()
# Add token to input sequence
input_ids[0, pos] = next_token
token_printer.add_token(next_token)
# Safely decode tokens in the main thread
token_printer.drain_buffer()
# Update position and count
pos += 1
tokens_generated += 1
if next_token == tokenizer.eos_token_id:
print("\n[DEBUG] Generated EOS token, stopping...")
break
except Exception as e:
print(f"\n[ERROR] Error in final layer or token generation: {str(e)}")
break
except KeyboardInterrupt:
print("\n[DEBUG] Interrupted by user")
except Exception as e:
print(f"\n[ERROR] Exception during inference: {str(e)}")
print(traceback.format_exc())
# Print timing statistics
end_time = time.time()
total_time = end_time - start_gen_time
print(f"\n\nTotal time: {total_time:.2f} seconds")
print(f"Generation tokens: {tokens_generated}")
print(f"Prefill tokens: {prompt_tokens}")
print(f"Total tokens (prefill + generation): {prompt_tokens + tokens_generated}")
if prompt_tokens > 0:
if batch_prefill_time > 0: # If using batch prefill
prefill_tokens_per_second = prompt_tokens / batch_prefill_time
effective_prefill_tokens_per_second = prompt_tokens / batch_prefill_time # Don't multiply by batch size
print(f"Actual prefill tokens per second: {prefill_tokens_per_second:.2f}")
print(f"Effective prefill tokens per second (batch={PREFILL_BATCH_SIZE}): {effective_prefill_tokens_per_second:.2f}")
elif sequential_prefill_time > 0: # If using sequential prefill
prefill_tokens_per_second = prompt_tokens / sequential_prefill_time
print(f"Sequential prefill tokens per second: {prefill_tokens_per_second:.2f}")
if tokens_generated > 0:
total_processing_time = total_time + (batch_prefill_time if batch_prefill_time > 0 else sequential_prefill_time)
overall_tokens_per_second = (prompt_tokens + tokens_generated) / total_processing_time
generation_tokens_per_second = tokens_generated / total_time
print(f"Overall tokens processed per second (including prefill): {overall_tokens_per_second:.2f}")
print(f"Generation-only tokens per second: {generation_tokens_per_second:.2f}")
return token_printer.stop(), {
'total_time': total_time,
'batch_prefill_time': batch_prefill_time,
'sequential_prefill_time': sequential_prefill_time,
'tokens_generated': tokens_generated,
'prompt_tokens': prompt_tokens
}
def DebugLog(message, always_print=False):
"""Print debug message if ENABLE_CHAT_DEBUG is True or always_print is True.
Args:
message: Message to print
always_print: If True, print regardless of ENABLE_CHAT_DEBUG setting
"""
if ENABLE_CHAT_DEBUG or always_print:
print(f"[DEBUG] {message}")
def chat_loop(model_parts, tokenizer, context_size=CONTEXT_LENGTH, temperature=0.0):
"""Interactive chat loop that maintains conversation history."""
print("\nStarting chat session. Press Ctrl+D to exit.")
print("Type your message and press Enter to chat.")
DebugLog(f"Using context size: {context_size}")
DebugLog(f"Temperature: {temperature}")
DebugLog(f"Model parts loaded: {list(model_parts.keys())}")
# Initialize conversation history
conversation = []
input_ids = None
current_pos = 0
try:
while True:
try:
print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
user_input = input().strip()
except EOFError:
print("\nExiting chat...")
break
if not user_input:
continue
# Add user message to conversation
conversation.append({"role": "user", "content": user_input})
DebugLog("\nFormatting conversation:")
for msg in conversation:
DebugLog(f" {msg['role']}: {msg['content'][:50]}...")
# Format entire conversation
formatted_input = tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True
)
DebugLog("\nTokenization:")
DebugLog(f"Input token IDs: {formatted_input[0][:50]}...")
DebugLog(f"Decoded tokens: {tokenizer.decode(formatted_input[0][:50])}...")
DebugLog(f"Total tokens: {formatted_input.size(1)}")
# Convert to int32 tensor
base_input_ids = formatted_input.to(torch.int32)
context_pos = base_input_ids.size(1)
DebugLog(f"Context position: {context_pos}")
# Check if we need to truncate history
if context_pos >= context_size - 100:
DebugLog(f"\nNeed to truncate: {context_pos} tokens > {context_size-100} limit")
while context_pos >= context_size - 100 and len(conversation) > 2:
removed = conversation.pop(0)
DebugLog(f"Removed message: {removed['role']}: {removed['content'][:30]}...")
formatted_input = tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True
)
base_input_ids = formatted_input.to(torch.int32)
context_pos = base_input_ids.size(1)
DebugLog(f"New context size: {context_pos}")
# Pad sequence to context_size
input_ids = F.pad(
base_input_ids,
(0, context_size - context_pos),
value=0
)
# Create causal mask for the entire context
causal_mask = make_causal_mask(context_size, 0)
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
DebugLog(f"Created causal mask with shape: {causal_mask.shape}")
print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
# Run prefill on entire context
if False: #any(key.contains('2D') for key in model_parts.keys()):
DebugLog("Using sequential prefill")
current_pos = PreFillChunkOneByOne(
model_parts,
input_ids,
context_pos,
context_size,
causal_mask
)
elif any(part.startswith('2Q') for part in model_parts.keys()):
DebugLog(f"Using quad split prefill (size={PREFILL_BATCH_SIZE})")
current_pos = PreFillChunk(
model_parts,
input_ids,
context_pos,
context_size,
causal_mask,
batch_size=PREFILL_BATCH_SIZE
)
else:
DebugLog(f"Using standard batch prefill (size={PREFILL_BATCH_SIZE})")
current_pos = PreFillChunk(
model_parts,
input_ids,
context_pos,
context_size,
causal_mask,
batch_size=PREFILL_BATCH_SIZE
)
# Initialize token printer
token_printer = TokenPrinter(tokenizer)
# Generation loop
pos = context_pos
response_tokens = []
generation_start_time = time.time() # Add timing
try:
while True: # Changed from context_size - 1 to True for continuous generation
# Check if we need to shift window
if pos >= context_size - 2:
DebugLog("\nShifting context window...")
shift_size = context_size // 4 # Shift by 1/4 of context
new_size = context_size - shift_size
# Create shifted input_ids and preserve the most recent context
tmp = torch.zeros((1, context_size), dtype=torch.int32)
tmp[:,0:new_size] = input_ids[:,shift_size:context_size]
input_ids = tmp
# Adjust position after shift
pos = new_size
DebugLog(f"Shifted window by {shift_size} tokens, new position: {pos}")
# Run prefill on the shifted sequence
if False: #if any(key.contains('2D') for key in model_parts.keys()):
DebugLog("Running sequential prefill after shift")
current_pos = PreFillChunkOneByOne(
model_parts,
input_ids,
pos,
context_size,
causal_mask
)
else:
DebugLog("Running batch prefill after shift (size={PREFILL_BATCH_SIZE})")
current_pos = PreFillChunk(
model_parts,
input_ids,
pos,
context_size,
causal_mask,
batch_size=PREFILL_BATCH_SIZE
)
# Get current token
current_token = input_ids[:, pos-1:pos]
# Find the correct model key for part 1
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1')
# Run embeddings (part 1)
hidden_states = model_parts[part1_key].predict({
'input_ids': current_token.numpy()
})['hidden_states']
hidden_states = torch.from_numpy(hidden_states)
# Get shared transformer state
shared_state = model_parts['states']['transformer']
# Create update mask for current position
update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16)
update_mask[0, 0, pos-1, 0] = 1.0
# Create position IDs tensor
position_ids = torch.tensor([pos-1], dtype=torch.int32)
# Create causal mask for current position
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
# Run transformer layers based on model type
if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']):
for part in ['2D1S', '2D2S']:
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[f'{part}_infer'].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
elif any(part.startswith('2Q') for part in model_parts.keys()):
DebugLog(f"Running quad split inference at position {pos}")
for i in range(1, 5):
part = f'2Q{i}'
if f'{part}_infer' in model_parts:
# Use infer function if available
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[f'{part}_infer'].predict(inputs, shared_state)
else:
# Use regular predict if no infer function
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[part].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
elif any(key.startswith('2D') for key in model_parts.keys()):
for base_part in ['2D1', '2D2']:
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part)
inputs = {
'hidden_states': hidden_states.numpy(),
'update_mask': update_mask.numpy(),
'position_ids': position_ids.numpy(),
'causal_mask': single_causal_mask.numpy(),
'current_pos': position_ids.numpy()
}
output = model_parts[part_key].predict(inputs, shared_state)
hidden_states = torch.from_numpy(output['transformer_output'])
# Run final layer norm and get logits
part3_key = next(key for key in model_parts.keys() if key.startswith('3_') or key == '3')
output_dict = model_parts[part3_key].predict({
'hidden_states': hidden_states.numpy()
})
if ENABLE_VACAB_SPLIT8:
logits_parts = []
for i in range(1, 9):
logits_parts.append(output_dict[f'logits{i}'])
logits = np.concatenate(logits_parts, axis=-1)
else:
logits = output_dict['logits']
# Convert to tensor and get next token
logits = torch.from_numpy(logits)
# Apply temperature if specified
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits[0, -1, :], dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
else:
next_token = torch.argmax(logits[0, -1, :]).item()
# Add token to input sequence and response
input_ids[0, pos] = next_token
response_tokens.append(next_token)
token_printer.add_token(next_token)
# Safely decode tokens in the main thread
token_printer.drain_buffer()
pos += 1
# Add debug output for generated tokens
if ENABLE_CHAT_DEBUG and len(response_tokens) > 0 and len(response_tokens) % 10 == 0:
DebugLog(f"\nGenerated {len(response_tokens)} tokens")
DebugLog(f"Last token: {next_token} -> '{tokenizer.decode([next_token])}'")
if next_token == tokenizer.eos_token_id:
DebugLog("\nGenerated EOS token")
break
# Get the complete response text and calculate stats
response_text = token_printer.stop()
generation_time = time.time() - generation_start_time
tokens_per_second = len(response_tokens) / generation_time if generation_time > 0 else 0
DebugLog(f"\nFinal response length: {len(response_tokens)} tokens")
# Print generation stats in dark blue
print(f"\n{DARK_BLUE}[{len(response_tokens)} tokens, {tokens_per_second:.1f} tokens/s]{RESET_COLOR}")
except KeyboardInterrupt:
DebugLog("\nGeneration interrupted by user")
response_text = token_printer.stop()
generation_time = time.time() - generation_start_time
tokens_per_second = len(response_tokens) / generation_time if generation_time > 0 else 0
print(f"\n{DARK_BLUE}[{len(response_tokens)} tokens, {tokens_per_second:.1f} tokens/s]{RESET_COLOR}")
# Add assistant's response to conversation history
conversation.append({"role": "assistant", "content": response_text})
except Exception as e:
print(f"\n[ERROR] Chat loop error: {str(e)}")
print(traceback.format_exc())
def main():
global CONTEXT_LENGTH, PREFILL_BATCH_SIZE, MODEL_PATH
print("ANEMLL Chat. Pre-relase alpha version, 2025-01-31")
print("Copyright (c) 2025, Anemll All rights reserved.")
# Set default parameters
model_type = "Q123" # Default model type
lut_suffix = "lut4" # Default LUT suffix
temperature = 0.0
model_parts = {}
model_path = "." # Default to current directory
if len(sys.argv) < 2:
print("Usage: python chat.py [model_parts] [options]")
print("Usage: python chat.py [model_parts] [options]")
print("\nOptions:")
print(" -d PATH # Model directory path (for both tokenizer and CoreML models)")
print(" S123 # Combined split model (2D1S+2D2S)")
print(" C123 # Combined part2 model with prefill/infer")
print(" Q123 # Quad split model (2Q1-2Q4) [default]")
print(" Q123S # Combined quad split model (2Q1S-2Q4S)")
print(" 1 2D1 2D2 3 # Individual split parts")
print(" pfN # Prefill batch size (e.g., pf128)")
print(" ctx=N # Context length (e.g., ctx=2048) [default: 1024]")
print(" temp=X # Temperature for sampling (e.g., temp=0.01)")
print(" lut4 # LUT suffix [default]")
print("\nDefault configuration: Q123 lut4 ctx=1024")
print(" python chat.py Q123 -d ../anemll-DeepSeek-8B-ctx1024")
# Use defaults instead of exiting
print("\nUsing default configuration...")
else:
# Process command line arguments
i = 1
while i < len(sys.argv):
if sys.argv[i] == '-d' and i + 1 < len(sys.argv):
model_path = sys.argv[i + 1]
i += 2
# Extract context length from model path if present
ctx_match = re.search(r'ctx(\d+)', model_path)
if ctx_match:
ctx_value = int(ctx_match.group(1))
if 512 <= ctx_value <= 4096*2:
CONTEXT_LENGTH = ctx_value
print(f"Setting context length to {CONTEXT_LENGTH} from model path")
continue
elif sys.argv[i].startswith('lut'):
lut_suffix = sys.argv[i]
elif sys.argv[i] in ['S123', 'Q123', 'Q123S', 'C123', '123D']:
model_type = sys.argv[i]
i += 1
# Initialize tokenizer using the same path
tokenizer = initialize_tokenizer(model_path)
if tokenizer is None:
print("[ERROR] Failed to initialize tokenizer. Exiting.")
return
# Process model parts
parts = [model_type]
if lut_suffix:
parts.append(lut_suffix)
try:
split_model = SplitModelInference(parts, model_dir=model_path)
model_parts.update(split_model.models)
model_parts['states'] = {'transformer': split_model.states['transformer']}
except Exception as e:
print(f"Error loading model parts: {str(e)}")
return
# Process remaining arguments
i = 1
while i < len(sys.argv):
arg = sys.argv[i]
if arg.startswith('pf') and arg[2:].isdigit():
PREFILL_BATCH_SIZE = int(arg[2:])
elif arg.startswith('ctx='):
try:
CONTEXT_LENGTH = int(arg.split('=')[1])
except (IndexError, ValueError):
print(f"[WARNING] Invalid context length format. Using default: {CONTEXT_LENGTH}")
elif arg.startswith('temp='):
try:
temperature = float(arg.split('=')[1])
if temperature < 0:
print(f"[WARNING] Temperature must be non-negative. Using default: 0.0")
temperature = 0.0
except (IndexError, ValueError):
print(f"[WARNING] Invalid temperature format. Using default: 0.0")
i += 1
try:
# Start interactive chat loop
chat_loop(model_parts, tokenizer, context_size=CONTEXT_LENGTH, temperature=temperature)
except Exception as e:
print("An error occurred:")
print(traceback.format_exc())
if __name__ == "__main__":
main()