|
|
|
|
|
|
|
|
|
|
|
import coremltools as ct |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer |
|
import os |
|
import time, sys |
|
import signal |
|
import traceback |
|
import torch.nn.functional as F |
|
import queue |
|
import threading |
|
import re |
|
|
|
|
|
CONTEXT_LENGTH = 1024 |
|
PREFILL_BATCH_SIZE = 64 |
|
MODEL_PATH = os.path.expanduser("../DeepSeekR1-8B") |
|
ENABLE_VACAB_SPLIT8 = True |
|
ENABLE_LOGITS2 = False |
|
ENABLE_DEBUG = bool(0) |
|
ENABLE_ARGMAX = bool(0) |
|
ENABLE_PREFILL_BATCH = bool(1) |
|
ENABLE_CHAT_DEBUG = bool(0) |
|
|
|
|
|
LIGHT_BLUE = "\033[94m" |
|
DARK_BLUE = "\033[34m" |
|
LIGHT_GREEN = "\033[92m" |
|
RESET_COLOR = "\033[0m" |
|
|
|
if ENABLE_LOGITS2: |
|
assert not ENABLE_ARGMAX, "ENABLE_ARGMAX must be False when ENABLE_LOGITS2 is True" |
|
|
|
|
|
def load_model(path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name=None): |
|
"""Load either compiled or uncompiled CoreML model. |
|
|
|
Args: |
|
path: Path to the model file (.mlmodelc or .mlpackage) |
|
compute_unit: CoreML compute unit to use |
|
function_name: Optional function name to select from multi-function models |
|
""" |
|
DebugLog(f"Attempting to load model: {path}") |
|
DebugLog(f"File exists: {os.path.exists(path)}") |
|
DebugLog(f"Is directory (for mlmodelc): {os.path.isdir(path)}") |
|
|
|
try: |
|
if path.endswith('.mlmodelc'): |
|
DebugLog(f"Loading compiled model: {path}") |
|
if function_name is None: |
|
DebugLog("Loading without function name") |
|
model = ct.models.CompiledMLModel(path, compute_unit) |
|
else: |
|
DebugLog(f"Loading with function name: {function_name}") |
|
model = ct.models.CompiledMLModel(path, compute_unit, function_name=function_name) |
|
else: |
|
DebugLog(f"Loading uncompiled model: {path}") |
|
if function_name is None: |
|
DebugLog("Loading without function name") |
|
model = ct.models.MLModel(model=path, compute_units=compute_unit, is_temp_package=False) |
|
else: |
|
DebugLog(f"Loading with function name: {function_name}") |
|
model = ct.models.MLModel(model=path, compute_units=compute_unit, is_temp_package=False, function_name=function_name) |
|
DebugLog("Model loaded successfully") |
|
|
|
return model |
|
|
|
except Exception as e: |
|
DebugLog(f"Error loading model: {str(e)}") |
|
DebugLog(f"Error type: {type(e)}") |
|
raise |
|
|
|
class SplitModelInference: |
|
def __init__(self, model_parts, model_dir="."): |
|
"""Initialize split model inference. |
|
|
|
Args: |
|
model_parts (list): List of model part numbers to load |
|
Special cases: |
|
- 'C123' for combined part2 with prefill/infer functions |
|
- 'S123' for split model with prefill/infer functions |
|
- 'Q123' for quad split (2Q1-2Q4) |
|
- 'Q123S' for quad split with combined prefill/infer (2Q1S-2Q4S) |
|
- '123D' for dual split without prefill/infer (2D1-2D2) |
|
model_dir (str): Directory containing the model files (default: current directory) |
|
""" |
|
self.context_size = CONTEXT_LENGTH |
|
self.model_dir = model_dir |
|
DebugLog(f"Loading models from directory: {self.model_dir}") |
|
|
|
|
|
self.quant_configs = {} |
|
global_lut = None |
|
if model_parts and model_parts[-1].startswith('lut'): |
|
global_lut = model_parts[-1] |
|
model_parts = model_parts[:-1] |
|
|
|
|
|
if len(model_parts) == 1: |
|
if model_parts[0] == '123D': |
|
self.use_combined_part2 = False |
|
self.use_split_model = True |
|
self.use_split_functions = False |
|
self.use_quad_split = False |
|
self.use_quad_split_combined = False |
|
self.model_parts = ['1', '2D1', '2D2', '3'] |
|
if global_lut: |
|
self.quant_configs = {part: global_lut for part in self.model_parts} |
|
DebugLog(f"Using dual split model with parts: {self.model_parts}") |
|
elif model_parts[0].startswith('C123'): |
|
self.use_combined_part2 = True |
|
self.use_split_model = False |
|
self.use_split_functions = False |
|
self.use_quad_split = False |
|
self.use_quad_split_combined = False |
|
self.model_parts = ['1', '2', '3'] |
|
if global_lut: |
|
self.quant_configs = {part: global_lut for part in self.model_parts} |
|
DebugLog(f"Using combined part2 model with parts: {self.model_parts}") |
|
elif model_parts[0].startswith('S123'): |
|
self.use_combined_part2 = False |
|
self.use_split_model = True |
|
self.use_split_functions = True |
|
self.use_quad_split = False |
|
self.use_quad_split_combined = False |
|
self.model_parts = ['1', '2D1S', '2D2S', '3'] |
|
elif model_parts[0].startswith('Q123S'): |
|
self.use_combined_part2 = False |
|
self.use_split_model = True |
|
self.use_split_functions = False |
|
self.use_quad_split = False |
|
self.use_quad_split_combined = True |
|
self.model_parts = ['1', '2Q1S', '2Q2S', '2Q3S', '2Q4S', '3'] |
|
elif model_parts[0].startswith('Q123'): |
|
self.use_combined_part2 = False |
|
self.use_split_model = True |
|
self.use_split_functions = False |
|
self.use_quad_split = True |
|
self.use_quad_split_combined = False |
|
self.model_parts = ['1', '2Q1', '2Q2', '2Q3', '2Q4', '3'] |
|
else: |
|
self.use_combined_part2 = False |
|
self.use_split_model = False |
|
self.use_split_functions = False |
|
self.use_quad_split = False |
|
self.use_quad_split_combined = False |
|
self.model_parts = model_parts |
|
else: |
|
self.use_combined_part2 = False |
|
self.use_split_model = False |
|
self.use_split_functions = False |
|
self.use_quad_split = False |
|
self.use_quad_split_combined = False |
|
self.model_parts = model_parts |
|
|
|
|
|
if global_lut and not self.use_combined_part2: |
|
self.quant_configs = {part: global_lut for part in self.model_parts} |
|
|
|
DebugLog(f"Using model parts: {self.model_parts}") |
|
if global_lut: |
|
DebugLog(f"With global quantization: {global_lut}") |
|
if self.use_combined_part2: |
|
DebugLog("Using combined part2 model with prefill/infer functions") |
|
elif self.use_split_functions: |
|
DebugLog("Using split model with prefill/infer functions") |
|
elif self.use_quad_split: |
|
DebugLog("Using quad split transformer model (2Q1-2Q4)") |
|
elif self.use_quad_split_combined: |
|
DebugLog("Using combined quad split transformer model (2Q1S-2Q4S)") |
|
|
|
self.models = {} |
|
self.states = {} |
|
self.load_models() |
|
|
|
def find_model_path(self, base_name, description="model"): |
|
"""Find model path, checking mlmodelc first then mlpackage. |
|
Also tries both with and without lut suffix. |
|
|
|
Args: |
|
base_name: Base name of the model without extension |
|
description: Description for error message (e.g., "Split model part 2D1S") |
|
|
|
Returns: |
|
str: Path to the found model file |
|
|
|
Raises: |
|
FileNotFoundError: If neither mlmodelc nor mlpackage exists |
|
""" |
|
|
|
if any(part in base_name for part in ['2Q1S', '2Q2S', '2Q3S', '2Q4S', '2Q1', '2Q2', '2Q3', '2Q4']): |
|
model_path = os.path.join(self.model_dir, f"{base_name}.mlmodelc") |
|
if os.path.exists(model_path): |
|
return model_path |
|
|
|
if '_lut' in base_name: |
|
base_without_lut = base_name.split('_lut')[0] |
|
model_path = os.path.join(self.model_dir, f"{base_without_lut}.mlmodelc") |
|
if os.path.exists(model_path): |
|
return model_path |
|
|
|
raise FileNotFoundError(f"{description} not found: {base_name}.mlmodelc does not exist" + |
|
(f" (also tried {base_name.split('_lut')[0]}.mlmodelc)" if '_lut' in base_name else "")) |
|
|
|
|
|
for ext in ['.mlmodelc', '.mlpackage']: |
|
model_path = os.path.join(self.model_dir, f"{base_name}{ext}") |
|
if os.path.exists(model_path): |
|
return model_path |
|
|
|
|
|
if '_lut' in base_name: |
|
base_without_lut = base_name.split('_lut')[0] |
|
for ext in ['.mlmodelc', '.mlpackage']: |
|
model_path = os.path.join(self.model_dir, f"{base_without_lut}{ext}") |
|
if os.path.exists(model_path): |
|
return model_path |
|
|
|
|
|
raise FileNotFoundError(f"{description} not found: neither {base_name}.mlmodelc nor {base_name}.mlpackage exist in {self.model_dir}" + |
|
(f" (also tried {base_name.split('_lut')[0]}.mlmodelc/mlpackage)" if '_lut' in base_name else "")) |
|
|
|
def load_models(self): |
|
"""Load each model part.""" |
|
DebugLog("Loading model parts...") |
|
|
|
for part in self.model_parts: |
|
quant_suffix = f"_{self.quant_configs[part]}" if part in self.quant_configs else "" |
|
model_key = f"{part}{quant_suffix}" |
|
|
|
try: |
|
if part == '2' and self.use_combined_part2: |
|
|
|
base_name = f"llama32_part2_combined{quant_suffix}" |
|
model_path = self.find_model_path(base_name, "Combined part2 model") |
|
|
|
DebugLog(f"Loading combined part2 model: {model_path}") |
|
|
|
self.models['2_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') |
|
|
|
self.models['2_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') |
|
|
|
self.states['transformer'] = self.models['2_prefill'].make_state() |
|
DebugLog("Combined part2 model loaded successfully") |
|
elif part == '2' and not self.use_combined_part2: |
|
|
|
base_name = f"llama32_part2{quant_suffix}" |
|
model_path = self.find_model_path(base_name, "Regular part2 model") |
|
|
|
DebugLog(f"Loading regular part2 model: {model_path}") |
|
self.models[model_key] = load_model(model_path) |
|
self.states['transformer'] = self.models[model_key].make_state() |
|
DebugLog("Regular part2 model loaded successfully") |
|
elif part in ['2D1S', '2D2S'] and self.use_split_functions: |
|
|
|
base_name = f"llama32_part{part}{quant_suffix}" |
|
model_path = self.find_model_path(base_name, f"Split model part {part}") |
|
|
|
DebugLog(f"Loading split model part {part}: {model_path}") |
|
|
|
self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') |
|
|
|
self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') |
|
|
|
if part == '2D1S': |
|
self.states['transformer'] = self.models[f'{part}_infer'].make_state() |
|
DebugLog(f"Split model part {part} loaded successfully") |
|
elif part.endswith('S') and self.use_quad_split_combined: |
|
|
|
base_name = f"llama32_part{part}{quant_suffix}" |
|
model_path = self.find_model_path(base_name, f"Combined quad split part {part}") |
|
|
|
DebugLog(f"Loading combined quad split part {part}: {model_path}") |
|
|
|
self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') |
|
|
|
self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') |
|
|
|
if part == '2Q1S': |
|
self.states['transformer'] = self.models[f'{part}_infer'].make_state() |
|
DebugLog(f"Created shared transformer state for all quad split parts") |
|
DebugLog(f"Combined quad split part {part} loaded successfully") |
|
elif part.startswith('2Q') and self.use_quad_split: |
|
|
|
|
|
base_name = f"llama32_part{part}S{quant_suffix}" |
|
model_path = self.find_model_path(base_name, f"Quad split part {part}") |
|
|
|
DebugLog(f"Loading quad split part {part}: {model_path}") |
|
|
|
self.models[f'{part}_prefill'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='prefill') |
|
|
|
self.models[f'{part}_infer'] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE, function_name='infer') |
|
|
|
if part == '2Q1': |
|
self.states['transformer'] = self.models[f'{part}_infer'].make_state() |
|
DebugLog(f"Created shared transformer state for all quad split parts") |
|
print(f"Created shared transformer state for all quad split parts") |
|
print(f"Quad split part {part} loaded successfully") |
|
else: |
|
|
|
base_name = f"llama32_part{part}{quant_suffix}" |
|
model_path = self.find_model_path(base_name, f"Regular part {part}") |
|
|
|
print(f"[MODEL LOAD] Regular part {part}:") |
|
print(f" - File: {model_path}") |
|
print(f" - Loading as: '{model_key}'") |
|
|
|
|
|
try: |
|
self.models[model_key] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU_AND_NE) |
|
print(f" - Loaded with CPU_AND_NE compute unit") |
|
except Exception as cpu_error: |
|
print(f" - CPU load failed, trying CPU_AND_NE: {str(cpu_error)}") |
|
self.models[model_key] = load_model(model_path, compute_unit=ct.ComputeUnit.CPU) |
|
print(f" - Loaded with CPU compute unit") |
|
|
|
print(f"[MODEL LOAD] Current model_parts keys: {list(self.models.keys())}") |
|
|
|
except Exception as e: |
|
print(f"Error loading model part {part}: {str(e)}") |
|
raise |
|
|
|
def run_transformer_prefill(self, hidden_states, update_mask, position_ids, causal_mask, current_pos): |
|
"""Run the transformer model in prefill mode.""" |
|
if self.use_split_functions: |
|
|
|
for part in ['2D1S', '2D2S']: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': causal_mask.numpy(), |
|
'start_pos': current_pos.numpy() |
|
} |
|
output = self.models[f'{part}_prefill'].predict(inputs, self.states['transformer']) |
|
hidden_states = torch.from_numpy(output['dummy_output']) |
|
return hidden_states |
|
else: |
|
|
|
return super().run_transformer_prefill(hidden_states, update_mask, position_ids, causal_mask, current_pos) |
|
|
|
def run_transformer_infer(self, hidden_states, update_mask, position_ids, causal_mask, current_pos): |
|
"""Run the transformer model in infer mode.""" |
|
if self.use_split_functions: |
|
|
|
for part in ['2D1S', '2D2S']: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': causal_mask.numpy(), |
|
'current_pos': current_pos.numpy() |
|
} |
|
output = self.models[f'{part}_infer'].predict(inputs, self.states['transformer']) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
return hidden_states |
|
else: |
|
|
|
return super().run_transformer_infer(hidden_states, update_mask, position_ids, causal_mask, current_pos) |
|
|
|
def get_state(self, part): |
|
"""Get the appropriate state for a model part.""" |
|
return self.states['transformer'] |
|
|
|
def run_embeddings(self, input_ids): |
|
"""Run the embeddings model (part 1).""" |
|
if '1' not in self.models: |
|
raise ValueError("Embeddings model (part 1) not loaded") |
|
|
|
output_dict = self.models['1'].predict({ |
|
'input_ids': input_ids.numpy() |
|
}) |
|
return torch.from_numpy(output_dict['hidden_states']) |
|
|
|
def run_transformer(self, hidden_states, update_mask, position_ids, causal_mask, current_pos, part='2'): |
|
"""Run the transformer model.""" |
|
if part not in self.models: |
|
raise ValueError(f"Transformer model (part {part}) not loaded") |
|
|
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': causal_mask.numpy(), |
|
'current_pos': current_pos.numpy() |
|
} |
|
|
|
output_dict = self.models[part].predict(inputs, self.get_state(part)) |
|
return torch.from_numpy(output_dict['transformer_output']) |
|
|
|
def run_transformer_splits(self, hidden_states, update_mask, position_ids, causal_mask, current_pos): |
|
"""Run through transformer splits based on model configuration.""" |
|
if not self.use_split_model: |
|
return self.run_transformer(hidden_states, update_mask, position_ids, causal_mask, current_pos) |
|
|
|
|
|
if any(part.startswith('2Q') for part in self.model_parts): |
|
for i in range(1, 5): |
|
part = f'2Q{i}' |
|
hidden_states = self.run_transformer( |
|
hidden_states, update_mask, position_ids, causal_mask, current_pos, part=part |
|
) |
|
elif any(part.startswith('2O') for part in self.model_parts): |
|
for i in range(1, 9): |
|
part = f'2O{i}' |
|
hidden_states = self.run_transformer( |
|
hidden_states, update_mask, position_ids, causal_mask, current_pos, part=part |
|
) |
|
elif any(part.startswith('2D') for part in self.model_parts): |
|
|
|
for base_part in ['2D1', '2D2']: |
|
|
|
part_key = next(key for key in self.models.keys() if key.startswith(f'{base_part}_') or key == base_part) |
|
|
|
|
|
if 'transformer' not in self.states: |
|
raise ValueError("Transformer state not initialized. Make sure 2D1 is loaded first.") |
|
|
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': causal_mask.numpy(), |
|
'current_pos': current_pos.numpy() |
|
} |
|
output_dict = self.models[part_key].predict(inputs, self.states['transformer']) |
|
hidden_states = torch.from_numpy(output_dict['transformer_output']) |
|
|
|
return hidden_states |
|
|
|
def run_lm_head(self, hidden_states): |
|
"""Run the LM head model (part 3).""" |
|
if '3' not in self.models: |
|
raise ValueError("LM head model (part 3) not loaded") |
|
|
|
output_dict = self.models['3'].predict({ |
|
'hidden_states': hidden_states.numpy() |
|
}) |
|
|
|
|
|
logits_parts = [] |
|
for i in range(1, 9): |
|
logits_key = f'logits{i}' |
|
if logits_key in output_dict: |
|
logits_part = torch.from_numpy(output_dict[logits_key]) |
|
logits_parts.append(logits_part) |
|
|
|
|
|
return torch.cat(logits_parts, dim=-1) |
|
|
|
def run_full_model(self, input_ids, update_mask, position_ids, causal_mask, current_pos): |
|
"""Run the full model.""" |
|
if 'full' not in self.models: |
|
raise ValueError("Full model not loaded") |
|
|
|
|
|
self.context_size = CONTEXT_LENGTH |
|
|
|
|
|
inputs = { |
|
'input_ids': input_ids.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': causal_mask.numpy(), |
|
'current_pos': current_pos.numpy() |
|
} |
|
|
|
|
|
if False: |
|
print("[DEBUG] Input shapes:") |
|
for key, value in inputs.items(): |
|
print(f" {key}: {value.shape}") |
|
|
|
output_dict = self.models['full'].predict(inputs, self.states['transformer']) |
|
|
|
|
|
if ENABLE_VACAB_SPLIT8: |
|
logits_parts = [] |
|
for i in range(1, 9): |
|
logits_parts.append(output_dict[f'logits{i}']) |
|
logits = np.concatenate(logits_parts, axis=-1) |
|
else: |
|
logits = output_dict['logits'] |
|
|
|
return torch.from_numpy(logits) |
|
|
|
def make_causal_mask(length, start): |
|
|
|
|
|
mask = np.full((1, 1, length, length), -np.inf, dtype=np.float16) |
|
|
|
|
|
row_indices = np.arange(length).reshape(length, 1) |
|
col_indices = np.arange(length).reshape(1, length) |
|
|
|
|
|
mask[:, :, col_indices <= (row_indices + start)] = 0 |
|
return mask |
|
|
|
def initialize_tokenizer(model_path): |
|
"""Initialize and configure the tokenizer.""" |
|
try: |
|
print(f"[DEBUG] Loading tokenizer from model path: {model_path}") |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
|
|
|
print("\n[DEBUG] Tokenizer Configuration:") |
|
print(f"Tokenizer type: {type(tokenizer)}") |
|
print(f"Tokenizer name: {tokenizer.__class__.__name__}") |
|
print(f"Vocabulary size: {len(tokenizer)}") |
|
print(f"Model max length: {tokenizer.model_max_length}") |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
print("[DEBUG] Set PAD token to EOS token") |
|
|
|
print(f"\n[DEBUG] Special Tokens:") |
|
print(f"PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})") |
|
print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})") |
|
print(f"BOS token: '{tokenizer.bos_token}' (ID: {tokenizer.bos_token_id})") |
|
print(f"UNK token: '{tokenizer.unk_token}' (ID: {tokenizer.unk_token_id})") |
|
|
|
return tokenizer |
|
|
|
except Exception as e: |
|
print(f"[ERROR] Failed to load tokenizer from {model_path}") |
|
return None |
|
|
|
class TokenPrinter: |
|
"""Handles background printing of generated tokens.""" |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
self.token_queue = queue.Queue() |
|
self.stop_event = threading.Event() |
|
self.thread = None |
|
self.buffer = "" |
|
self.lock = threading.Lock() |
|
self.thinking = True |
|
self.decoding_buffer = [] |
|
self.start() |
|
|
|
def start(self): |
|
"""Start the printer thread.""" |
|
if self.thread is None: |
|
self.thread = threading.Thread(target=self._print_worker) |
|
self.thread.daemon = True |
|
self.thread.start() |
|
|
|
def add_token(self, token_id): |
|
"""Add a token to the print queue.""" |
|
if not self.stop_event.is_set(): |
|
self.token_queue.put(token_id) |
|
|
|
def drain_buffer(self): |
|
""" |
|
Decode token IDs from self.decoding_buffer in the main thread, |
|
then print them with the correct color logic. |
|
""" |
|
if not self.decoding_buffer: |
|
return |
|
|
|
|
|
token_str = self.tokenizer.decode(self.decoding_buffer) |
|
self.decoding_buffer.clear() |
|
|
|
|
|
if self.thinking and "</think>" in token_str: |
|
self.thinking = False |
|
parts = token_str.split("</think>") |
|
if len(parts) > 0: |
|
print(parts[0] + "</think>", end='', flush=True) |
|
if len(parts) > 1: |
|
print(LIGHT_BLUE + parts[1], end='', flush=True) |
|
else: |
|
if not self.thinking: |
|
print(LIGHT_BLUE + token_str, end='', flush=True) |
|
else: |
|
print(token_str, end='', flush=True) |
|
|
|
def _print_worker(self): |
|
"""Worker thread that takes token_ids from the queue but doesn't decode.""" |
|
while not self.stop_event.is_set(): |
|
try: |
|
token_id = self.token_queue.get(timeout=0.01) |
|
with self.lock: |
|
|
|
self.decoding_buffer.append(token_id) |
|
self.token_queue.task_done() |
|
except queue.Empty: |
|
continue |
|
except Exception as e: |
|
print(f"\n[ERROR] Token printer error: {str(e)}") |
|
break |
|
|
|
def stop(self): |
|
"""Stop the printer thread.""" |
|
if self.thread and self.thread.is_alive(): |
|
self.stop_event.set() |
|
try: |
|
self.thread.join(timeout=1.0) |
|
except Exception: |
|
pass |
|
print(RESET_COLOR) |
|
return self.buffer |
|
|
|
def parse_coreml_error(error_str): |
|
"""Parse CoreML error message to extract shape information. |
|
|
|
Args: |
|
error_str: The error message string from CoreML |
|
|
|
Returns: |
|
tuple: (got_shape, expected_shape) or None if parsing fails |
|
""" |
|
try: |
|
|
|
pattern = r"shape \(([\d\s x]+)\) does not match the shape \(([\d\s x]+)\)" |
|
match = re.search(pattern, str(error_str)) |
|
if match: |
|
got_shape = tuple(int(x) for x in match.group(1).split('x')) |
|
expected_shape = tuple(int(x) for x in match.group(2).split('x')) |
|
return got_shape, expected_shape |
|
return None |
|
except Exception as e: |
|
print(f"Error parsing CoreML error message: {e}") |
|
return None |
|
|
|
def handle_coreml_shape_error(e, model_name=""): |
|
"""Handle CoreML shape mismatch errors with detailed information. |
|
|
|
Args: |
|
e: The exception object |
|
model_name: Name of the model for better error reporting |
|
""" |
|
error_str = str(e) |
|
if "MultiArray shape" in error_str: |
|
shape_info = parse_coreml_error(error_str) |
|
if shape_info: |
|
got_shape, expected_shape = shape_info |
|
print(f"\n[ERROR] Shape mismatch in {model_name}:") |
|
print(f" Got shape: {' x '.join(str(x) for x in got_shape)}") |
|
print(f" Expected shape: {' x '.join(str(x) for x in expected_shape)}") |
|
print("This usually indicates a mismatch between the model's expected context length") |
|
print("and the actual input being provided.") |
|
else: |
|
print(f"\n[ERROR] Shape mismatch error in {model_name}:") |
|
print(f" {error_str}") |
|
else: |
|
print(f"\n[ERROR] CoreML error in {model_name}:") |
|
print(f" {error_str}") |
|
|
|
def PreFillChunk(model_parts, input_ids, current_pos, context_size, causal_mask, batch_size=64): |
|
tokens_to_process = current_pos |
|
batch_pos = 0 |
|
|
|
while batch_pos < tokens_to_process: |
|
batch_end = min(batch_pos + batch_size, tokens_to_process) |
|
current_batch_size = batch_end - batch_pos |
|
|
|
try: |
|
|
|
batch_input = input_ids[:, batch_pos:batch_end] |
|
|
|
|
|
if current_batch_size < batch_size: |
|
batch_input = F.pad( |
|
batch_input, |
|
(0, batch_size - current_batch_size), |
|
value=0 |
|
) |
|
|
|
|
|
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32) |
|
|
|
|
|
multiple_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :] |
|
|
|
|
|
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') |
|
|
|
try: |
|
|
|
hidden_states = model_parts[part1_key].predict({'input_ids': batch_input.numpy()})['hidden_states'] |
|
hidden_states = torch.from_numpy(hidden_states) |
|
except Exception as e: |
|
handle_coreml_shape_error(e, f"embeddings model (part {part1_key})") |
|
raise |
|
|
|
|
|
shared_state = model_parts['states']['transformer'] |
|
|
|
|
|
if any(f'{part}_prefill' in model_parts for part in ['2D1S', '2D2S']): |
|
|
|
for part in ['2D1S', '2D2S']: |
|
try: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': multiple_causal_mask.numpy(), |
|
'start_pos': np.array([batch_pos], dtype=np.int32) |
|
} |
|
output = model_parts[f'{part}_prefill'].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['dummy_output']) |
|
except Exception as e: |
|
handle_coreml_shape_error(e, f"transformer model (part {part})") |
|
raise |
|
elif any(part.endswith('S') for part in model_parts if part.startswith('2Q')): |
|
|
|
for i in range(1, 5): |
|
part = f'2Q{i}S' |
|
try: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': multiple_causal_mask.numpy(), |
|
'start_pos': np.array([batch_pos], dtype=np.int32) |
|
} |
|
output = model_parts[f'{part}_prefill'].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['dummy_output']) |
|
except Exception as e: |
|
handle_coreml_shape_error(e, f"transformer model (part {part})") |
|
raise |
|
elif any(part.startswith('2Q') for part in model_parts): |
|
|
|
for i in range(1, 5): |
|
part = f'2Q{i}' |
|
if f'{part}_prefill' in model_parts: |
|
|
|
try: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': multiple_causal_mask.numpy(), |
|
'start_pos': np.array([batch_pos], dtype=np.int32) |
|
} |
|
output = model_parts[f'{part}_prefill'].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['dummy_output']) |
|
except Exception as e: |
|
handle_coreml_shape_error(e, f"transformer model (part {part})") |
|
raise |
|
else: |
|
|
|
try: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': torch.zeros((1, 1, context_size, 1), dtype=torch.float16).numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': multiple_causal_mask.numpy(), |
|
'current_pos': position_ids[0].numpy() |
|
} |
|
output = model_parts[part].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
except Exception as e: |
|
handle_coreml_shape_error(e, f"transformer model (part {part})") |
|
raise |
|
elif any(key.startswith('2D') for key in model_parts.keys()): |
|
|
|
for base_part in ['2D1', '2D2']: |
|
|
|
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) |
|
try: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': torch.zeros((1, 1, context_size, 1), dtype=torch.float16).numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': multiple_causal_mask.numpy(), |
|
'current_pos': position_ids[0].numpy() |
|
} |
|
output = model_parts[part_key].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
except Exception as e: |
|
handle_coreml_shape_error(e, f"transformer model (part {part_key})") |
|
raise |
|
|
|
batch_pos = batch_end |
|
|
|
except Exception as e: |
|
print(f"\n[ERROR] Failed processing batch {batch_pos}-{batch_end}:") |
|
print(f" {str(e)}") |
|
raise |
|
|
|
return torch.tensor([current_pos], dtype=torch.int32) |
|
|
|
def PreFillChunkOneByOne(model_parts, input_ids, current_pos, context_size, causal_mask): |
|
"""Process prefill tokens one at a time using infer function.""" |
|
|
|
|
|
for pos in range(current_pos): |
|
|
|
current_token = input_ids[:, pos:pos+1] |
|
single_causal_mask = causal_mask[:, :, pos:pos+1, :] |
|
current_pos_tensor = torch.tensor([pos], dtype=torch.int32) |
|
|
|
|
|
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') |
|
|
|
|
|
hidden_states = torch.from_numpy(model_parts[part1_key].predict({ |
|
'input_ids': current_token.numpy() |
|
})['hidden_states']) |
|
|
|
|
|
|
|
|
|
shared_state = model_parts['states']['transformer'] |
|
|
|
|
|
if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']): |
|
|
|
for part in ['2D1S', '2D2S']: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': np.zeros((1, 1, context_size, 1), dtype=np.float16), |
|
'position_ids': current_pos_tensor.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': current_pos_tensor.numpy() |
|
} |
|
output = model_parts[f'{part}_infer'].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
elif any(key.startswith('2D') for key in model_parts.keys()): |
|
|
|
for base_part in ['2D1', '2D2']: |
|
|
|
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': np.zeros((1, 1, context_size, 1), dtype=np.float16), |
|
'position_ids': current_pos_tensor.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': current_pos_tensor.numpy() |
|
} |
|
output = model_parts[part_key].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
|
|
return torch.tensor([current_pos], dtype=torch.int32) |
|
|
|
def run_inference(model_parts, tokenizer, prompt, context_size=CONTEXT_LENGTH, num_iterations=5, temperature=0.0): |
|
"""Run inference using model parts.""" |
|
DebugLog(f"\nPrompt: {prompt}") |
|
if temperature > 0: |
|
DebugLog(f"Using temperature: {temperature}") |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
formatted_input = tokenizer.apply_chat_template( |
|
messages, |
|
return_tensors="pt", |
|
add_generation_prompt=False |
|
) |
|
decoded_input = tokenizer.decode(formatted_input[0]) |
|
DebugLog(f"Decoded input: {decoded_input}") |
|
DebugLog(f"prompt: {prompt}") |
|
DebugLog(f"formatted_input size: {formatted_input.size()}") |
|
DebugLog(f"formatted_input: {formatted_input}") |
|
|
|
base_input_ids = formatted_input.to(torch.int32) |
|
context_pos = base_input_ids.size(1) |
|
prompt_tokens = context_pos - 1 |
|
|
|
|
|
input_ids = F.pad( |
|
base_input_ids, |
|
(0, context_size - context_pos), |
|
value=0 |
|
) |
|
|
|
DebugLog(f"context_pos (prompt length) = {context_pos}") |
|
|
|
|
|
causal_mask = make_causal_mask(context_size, 0) |
|
causal_mask = torch.tensor(causal_mask, dtype=torch.float16) |
|
|
|
|
|
DebugLog("\nStarting prefill...") |
|
start_time = time.time() |
|
|
|
|
|
use_single_token = any(key.contains('2D') for key in model_parts.keys()) or any(part.contains('2D') for part in model_parts) |
|
|
|
if False: |
|
print("\nRunning ST prefill...") |
|
current_pos = PreFillChunkOneByOne( |
|
model_parts, |
|
input_ids, |
|
context_pos - 1, |
|
context_size, |
|
causal_mask |
|
) |
|
sequential_prefill_time = time.time() - start_time |
|
batch_prefll_time = 0.0 |
|
else: |
|
print("\nRunning batch prefill...") |
|
current_pos = PreFillChunk( |
|
model_parts, |
|
input_ids, |
|
context_pos - 1, |
|
context_size, |
|
causal_mask, |
|
batch_size=PREFILL_BATCH_SIZE |
|
) |
|
batch_prefill_time = time.time() - start_time |
|
sequential_prefill_time = 0.0 |
|
|
|
|
|
token_printer = TokenPrinter(tokenizer) |
|
print("\nGenerated response:", end=' ', flush=True) |
|
|
|
|
|
start_gen_time = time.time() |
|
pos = context_pos - 1 |
|
|
|
tokens_generated = 0 |
|
try: |
|
DebugLog(f"\nStarting inference... context_pos: {context_pos}") |
|
pos = context_pos |
|
for step in range(num_iterations): |
|
with torch.no_grad(): |
|
|
|
if pos >= context_size - 2: |
|
shift_size = context_size // 4 |
|
new_size = context_size - shift_size |
|
|
|
|
|
|
|
tmp = torch.zeros((1, context_size), dtype=torch.int32) |
|
tmp[:,0:new_size] = input_ids[:,shift_size:context_size] |
|
input_ids = tmp |
|
|
|
|
|
pos = new_size |
|
|
|
|
|
update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16) |
|
update_mask[0, 0, pos-1, 0] = 1.0 |
|
|
|
|
|
|
|
|
|
if any(part.startswith('2Q') for part in model_parts): |
|
|
|
|
|
current_pos = PreFillChunk( |
|
model_parts, |
|
input_ids, |
|
pos-1, |
|
context_size, |
|
causal_mask, |
|
batch_size=PREFILL_BATCH_SIZE |
|
) |
|
|
|
|
|
pos = current_pos |
|
|
|
|
|
current_token = input_ids[:, pos-1:pos] |
|
|
|
|
|
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') |
|
|
|
|
|
hidden_states = model_parts[part1_key].predict({ |
|
'input_ids': current_token.numpy() |
|
})['hidden_states'] |
|
hidden_states = torch.from_numpy(hidden_states) |
|
|
|
|
|
shared_state = model_parts['states']['transformer'] |
|
|
|
|
|
update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16) |
|
update_mask[0, 0, pos-1, 0] = 1.0 |
|
|
|
|
|
position_ids = torch.tensor([pos-1], dtype=torch.int32) |
|
|
|
|
|
single_causal_mask = causal_mask[:, :, pos-1:pos, :] |
|
|
|
|
|
if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']): |
|
|
|
for part in ['2D1S', '2D2S']: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[f'{part}_infer'].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
elif any(part.startswith('2Q') for part in model_parts.keys()): |
|
|
|
for i in range(1, 5): |
|
part = f'2Q{i}S' |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[f'{part}_infer'].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
elif any(part.startswith('2Q') for part in model_parts): |
|
|
|
|
|
for i in range(1, 5): |
|
part = f'2Q{i}' |
|
if f'{part}_infer' in model_parts: |
|
|
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[f'{part}_infer'].predict(inputs, shared_state) |
|
else: |
|
|
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[part].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
elif any(key.startswith('2D') for key in model_parts.keys()): |
|
|
|
for base_part in ['2D1', '2D2']: |
|
|
|
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[part_key].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
else: |
|
print("\n[ERROR] No transformer model parts found!") |
|
break |
|
|
|
try: |
|
|
|
|
|
part3_key = next(key for key in model_parts.keys() if key.startswith('3_') or key == '3') |
|
output_dict = model_parts[part3_key].predict({ |
|
'hidden_states': hidden_states.numpy() |
|
}) |
|
|
|
if ENABLE_VACAB_SPLIT8: |
|
|
|
logits_parts = [] |
|
for i in range(1, 9): |
|
logits_parts.append(output_dict[f'logits{i}']) |
|
logits = np.concatenate(logits_parts, axis=-1) |
|
elif ENABLE_LOGITS2: |
|
|
|
logits = np.concatenate([ |
|
output_dict['logits1'], |
|
output_dict['logits2'] |
|
], axis=-1) |
|
else: |
|
logits = output_dict['logits'] |
|
|
|
|
|
logits = torch.from_numpy(logits) |
|
|
|
|
|
if temperature > 0: |
|
|
|
logits = logits / temperature |
|
|
|
probs = F.softmax(logits[0, -1, :], dim=-1) |
|
|
|
next_token = torch.multinomial(probs, num_samples=1).item() |
|
else: |
|
|
|
next_token = torch.argmax(logits[0, -1, :]).item() |
|
|
|
|
|
input_ids[0, pos] = next_token |
|
token_printer.add_token(next_token) |
|
|
|
|
|
token_printer.drain_buffer() |
|
|
|
|
|
pos += 1 |
|
tokens_generated += 1 |
|
|
|
if next_token == tokenizer.eos_token_id: |
|
print("\n[DEBUG] Generated EOS token, stopping...") |
|
break |
|
except Exception as e: |
|
print(f"\n[ERROR] Error in final layer or token generation: {str(e)}") |
|
break |
|
|
|
except KeyboardInterrupt: |
|
print("\n[DEBUG] Interrupted by user") |
|
except Exception as e: |
|
print(f"\n[ERROR] Exception during inference: {str(e)}") |
|
print(traceback.format_exc()) |
|
|
|
|
|
end_time = time.time() |
|
total_time = end_time - start_gen_time |
|
|
|
print(f"\n\nTotal time: {total_time:.2f} seconds") |
|
print(f"Generation tokens: {tokens_generated}") |
|
print(f"Prefill tokens: {prompt_tokens}") |
|
print(f"Total tokens (prefill + generation): {prompt_tokens + tokens_generated}") |
|
|
|
if prompt_tokens > 0: |
|
if batch_prefill_time > 0: |
|
prefill_tokens_per_second = prompt_tokens / batch_prefill_time |
|
effective_prefill_tokens_per_second = prompt_tokens / batch_prefill_time |
|
print(f"Actual prefill tokens per second: {prefill_tokens_per_second:.2f}") |
|
print(f"Effective prefill tokens per second (batch={PREFILL_BATCH_SIZE}): {effective_prefill_tokens_per_second:.2f}") |
|
elif sequential_prefill_time > 0: |
|
prefill_tokens_per_second = prompt_tokens / sequential_prefill_time |
|
print(f"Sequential prefill tokens per second: {prefill_tokens_per_second:.2f}") |
|
|
|
if tokens_generated > 0: |
|
total_processing_time = total_time + (batch_prefill_time if batch_prefill_time > 0 else sequential_prefill_time) |
|
overall_tokens_per_second = (prompt_tokens + tokens_generated) / total_processing_time |
|
generation_tokens_per_second = tokens_generated / total_time |
|
print(f"Overall tokens processed per second (including prefill): {overall_tokens_per_second:.2f}") |
|
print(f"Generation-only tokens per second: {generation_tokens_per_second:.2f}") |
|
|
|
return token_printer.stop(), { |
|
'total_time': total_time, |
|
'batch_prefill_time': batch_prefill_time, |
|
'sequential_prefill_time': sequential_prefill_time, |
|
'tokens_generated': tokens_generated, |
|
'prompt_tokens': prompt_tokens |
|
} |
|
|
|
def DebugLog(message, always_print=False): |
|
"""Print debug message if ENABLE_CHAT_DEBUG is True or always_print is True. |
|
|
|
Args: |
|
message: Message to print |
|
always_print: If True, print regardless of ENABLE_CHAT_DEBUG setting |
|
""" |
|
if ENABLE_CHAT_DEBUG or always_print: |
|
print(f"[DEBUG] {message}") |
|
|
|
def chat_loop(model_parts, tokenizer, context_size=CONTEXT_LENGTH, temperature=0.0): |
|
"""Interactive chat loop that maintains conversation history.""" |
|
print("\nStarting chat session. Press Ctrl+D to exit.") |
|
print("Type your message and press Enter to chat.") |
|
|
|
DebugLog(f"Using context size: {context_size}") |
|
DebugLog(f"Temperature: {temperature}") |
|
DebugLog(f"Model parts loaded: {list(model_parts.keys())}") |
|
|
|
|
|
conversation = [] |
|
input_ids = None |
|
current_pos = 0 |
|
|
|
try: |
|
while True: |
|
try: |
|
print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True) |
|
user_input = input().strip() |
|
except EOFError: |
|
print("\nExiting chat...") |
|
break |
|
|
|
if not user_input: |
|
continue |
|
|
|
|
|
conversation.append({"role": "user", "content": user_input}) |
|
|
|
DebugLog("\nFormatting conversation:") |
|
for msg in conversation: |
|
DebugLog(f" {msg['role']}: {msg['content'][:50]}...") |
|
|
|
|
|
formatted_input = tokenizer.apply_chat_template( |
|
conversation, |
|
return_tensors="pt", |
|
add_generation_prompt=True |
|
) |
|
|
|
DebugLog("\nTokenization:") |
|
DebugLog(f"Input token IDs: {formatted_input[0][:50]}...") |
|
DebugLog(f"Decoded tokens: {tokenizer.decode(formatted_input[0][:50])}...") |
|
DebugLog(f"Total tokens: {formatted_input.size(1)}") |
|
|
|
|
|
base_input_ids = formatted_input.to(torch.int32) |
|
context_pos = base_input_ids.size(1) |
|
|
|
DebugLog(f"Context position: {context_pos}") |
|
|
|
|
|
if context_pos >= context_size - 100: |
|
DebugLog(f"\nNeed to truncate: {context_pos} tokens > {context_size-100} limit") |
|
while context_pos >= context_size - 100 and len(conversation) > 2: |
|
removed = conversation.pop(0) |
|
DebugLog(f"Removed message: {removed['role']}: {removed['content'][:30]}...") |
|
formatted_input = tokenizer.apply_chat_template( |
|
conversation, |
|
return_tensors="pt", |
|
add_generation_prompt=True |
|
) |
|
base_input_ids = formatted_input.to(torch.int32) |
|
context_pos = base_input_ids.size(1) |
|
DebugLog(f"New context size: {context_pos}") |
|
|
|
|
|
input_ids = F.pad( |
|
base_input_ids, |
|
(0, context_size - context_pos), |
|
value=0 |
|
) |
|
|
|
|
|
causal_mask = make_causal_mask(context_size, 0) |
|
causal_mask = torch.tensor(causal_mask, dtype=torch.float16) |
|
DebugLog(f"Created causal mask with shape: {causal_mask.shape}") |
|
|
|
print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True) |
|
|
|
|
|
if False: |
|
DebugLog("Using sequential prefill") |
|
current_pos = PreFillChunkOneByOne( |
|
model_parts, |
|
input_ids, |
|
context_pos, |
|
context_size, |
|
causal_mask |
|
) |
|
elif any(part.startswith('2Q') for part in model_parts.keys()): |
|
DebugLog(f"Using quad split prefill (size={PREFILL_BATCH_SIZE})") |
|
current_pos = PreFillChunk( |
|
model_parts, |
|
input_ids, |
|
context_pos, |
|
context_size, |
|
causal_mask, |
|
batch_size=PREFILL_BATCH_SIZE |
|
) |
|
else: |
|
DebugLog(f"Using standard batch prefill (size={PREFILL_BATCH_SIZE})") |
|
current_pos = PreFillChunk( |
|
model_parts, |
|
input_ids, |
|
context_pos, |
|
context_size, |
|
causal_mask, |
|
batch_size=PREFILL_BATCH_SIZE |
|
) |
|
|
|
|
|
token_printer = TokenPrinter(tokenizer) |
|
|
|
|
|
pos = context_pos |
|
response_tokens = [] |
|
generation_start_time = time.time() |
|
|
|
try: |
|
while True: |
|
|
|
if pos >= context_size - 2: |
|
DebugLog("\nShifting context window...") |
|
|
|
shift_size = context_size // 4 |
|
new_size = context_size - shift_size |
|
|
|
|
|
tmp = torch.zeros((1, context_size), dtype=torch.int32) |
|
tmp[:,0:new_size] = input_ids[:,shift_size:context_size] |
|
input_ids = tmp |
|
|
|
|
|
pos = new_size |
|
|
|
DebugLog(f"Shifted window by {shift_size} tokens, new position: {pos}") |
|
|
|
|
|
if False: |
|
DebugLog("Running sequential prefill after shift") |
|
current_pos = PreFillChunkOneByOne( |
|
model_parts, |
|
input_ids, |
|
pos, |
|
context_size, |
|
causal_mask |
|
) |
|
else: |
|
DebugLog("Running batch prefill after shift (size={PREFILL_BATCH_SIZE})") |
|
current_pos = PreFillChunk( |
|
model_parts, |
|
input_ids, |
|
pos, |
|
context_size, |
|
causal_mask, |
|
batch_size=PREFILL_BATCH_SIZE |
|
) |
|
|
|
|
|
current_token = input_ids[:, pos-1:pos] |
|
|
|
|
|
part1_key = next(key for key in model_parts.keys() if key.startswith('1_') or key == '1') |
|
|
|
|
|
hidden_states = model_parts[part1_key].predict({ |
|
'input_ids': current_token.numpy() |
|
})['hidden_states'] |
|
hidden_states = torch.from_numpy(hidden_states) |
|
|
|
|
|
shared_state = model_parts['states']['transformer'] |
|
|
|
|
|
update_mask = torch.zeros((1, 1, context_size, 1), dtype=torch.float16) |
|
update_mask[0, 0, pos-1, 0] = 1.0 |
|
|
|
|
|
position_ids = torch.tensor([pos-1], dtype=torch.int32) |
|
|
|
|
|
single_causal_mask = causal_mask[:, :, pos-1:pos, :] |
|
|
|
|
|
if any(f'{part}_infer' in model_parts for part in ['2D1S', '2D2S']): |
|
for part in ['2D1S', '2D2S']: |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[f'{part}_infer'].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
elif any(part.startswith('2Q') for part in model_parts.keys()): |
|
DebugLog(f"Running quad split inference at position {pos}") |
|
for i in range(1, 5): |
|
part = f'2Q{i}' |
|
if f'{part}_infer' in model_parts: |
|
|
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[f'{part}_infer'].predict(inputs, shared_state) |
|
else: |
|
|
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[part].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
elif any(key.startswith('2D') for key in model_parts.keys()): |
|
for base_part in ['2D1', '2D2']: |
|
part_key = next(key for key in model_parts.keys() if key.startswith(f'{base_part}_') or key == base_part) |
|
inputs = { |
|
'hidden_states': hidden_states.numpy(), |
|
'update_mask': update_mask.numpy(), |
|
'position_ids': position_ids.numpy(), |
|
'causal_mask': single_causal_mask.numpy(), |
|
'current_pos': position_ids.numpy() |
|
} |
|
output = model_parts[part_key].predict(inputs, shared_state) |
|
hidden_states = torch.from_numpy(output['transformer_output']) |
|
|
|
|
|
part3_key = next(key for key in model_parts.keys() if key.startswith('3_') or key == '3') |
|
output_dict = model_parts[part3_key].predict({ |
|
'hidden_states': hidden_states.numpy() |
|
}) |
|
|
|
if ENABLE_VACAB_SPLIT8: |
|
logits_parts = [] |
|
for i in range(1, 9): |
|
logits_parts.append(output_dict[f'logits{i}']) |
|
logits = np.concatenate(logits_parts, axis=-1) |
|
else: |
|
logits = output_dict['logits'] |
|
|
|
|
|
logits = torch.from_numpy(logits) |
|
|
|
|
|
if temperature > 0: |
|
logits = logits / temperature |
|
probs = F.softmax(logits[0, -1, :], dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1).item() |
|
else: |
|
next_token = torch.argmax(logits[0, -1, :]).item() |
|
|
|
|
|
input_ids[0, pos] = next_token |
|
response_tokens.append(next_token) |
|
token_printer.add_token(next_token) |
|
|
|
|
|
token_printer.drain_buffer() |
|
|
|
pos += 1 |
|
|
|
|
|
if ENABLE_CHAT_DEBUG and len(response_tokens) > 0 and len(response_tokens) % 10 == 0: |
|
DebugLog(f"\nGenerated {len(response_tokens)} tokens") |
|
DebugLog(f"Last token: {next_token} -> '{tokenizer.decode([next_token])}'") |
|
|
|
if next_token == tokenizer.eos_token_id: |
|
DebugLog("\nGenerated EOS token") |
|
break |
|
|
|
|
|
response_text = token_printer.stop() |
|
generation_time = time.time() - generation_start_time |
|
tokens_per_second = len(response_tokens) / generation_time if generation_time > 0 else 0 |
|
|
|
DebugLog(f"\nFinal response length: {len(response_tokens)} tokens") |
|
|
|
|
|
print(f"\n{DARK_BLUE}[{len(response_tokens)} tokens, {tokens_per_second:.1f} tokens/s]{RESET_COLOR}") |
|
|
|
except KeyboardInterrupt: |
|
DebugLog("\nGeneration interrupted by user") |
|
response_text = token_printer.stop() |
|
generation_time = time.time() - generation_start_time |
|
tokens_per_second = len(response_tokens) / generation_time if generation_time > 0 else 0 |
|
print(f"\n{DARK_BLUE}[{len(response_tokens)} tokens, {tokens_per_second:.1f} tokens/s]{RESET_COLOR}") |
|
|
|
|
|
conversation.append({"role": "assistant", "content": response_text}) |
|
|
|
except Exception as e: |
|
print(f"\n[ERROR] Chat loop error: {str(e)}") |
|
print(traceback.format_exc()) |
|
|
|
def main(): |
|
global CONTEXT_LENGTH, PREFILL_BATCH_SIZE, MODEL_PATH |
|
|
|
print("ANEMLL Chat. Pre-relase alpha version, 2025-01-31") |
|
print("Copyright (c) 2025, Anemll All rights reserved.") |
|
|
|
model_type = "Q123" |
|
lut_suffix = "lut4" |
|
temperature = 0.0 |
|
model_parts = {} |
|
model_path = "." |
|
|
|
if len(sys.argv) < 2: |
|
print("Usage: python chat.py [model_parts] [options]") |
|
print("Usage: python chat.py [model_parts] [options]") |
|
print("\nOptions:") |
|
print(" -d PATH # Model directory path (for both tokenizer and CoreML models)") |
|
print(" S123 # Combined split model (2D1S+2D2S)") |
|
print(" C123 # Combined part2 model with prefill/infer") |
|
print(" Q123 # Quad split model (2Q1-2Q4) [default]") |
|
print(" Q123S # Combined quad split model (2Q1S-2Q4S)") |
|
print(" 1 2D1 2D2 3 # Individual split parts") |
|
print(" pfN # Prefill batch size (e.g., pf128)") |
|
print(" ctx=N # Context length (e.g., ctx=2048) [default: 1024]") |
|
print(" temp=X # Temperature for sampling (e.g., temp=0.01)") |
|
print(" lut4 # LUT suffix [default]") |
|
print("\nDefault configuration: Q123 lut4 ctx=1024") |
|
print(" python chat.py Q123 -d ../anemll-DeepSeek-8B-ctx1024") |
|
|
|
print("\nUsing default configuration...") |
|
else: |
|
|
|
i = 1 |
|
while i < len(sys.argv): |
|
if sys.argv[i] == '-d' and i + 1 < len(sys.argv): |
|
model_path = sys.argv[i + 1] |
|
i += 2 |
|
|
|
ctx_match = re.search(r'ctx(\d+)', model_path) |
|
if ctx_match: |
|
ctx_value = int(ctx_match.group(1)) |
|
if 512 <= ctx_value <= 4096*2: |
|
CONTEXT_LENGTH = ctx_value |
|
print(f"Setting context length to {CONTEXT_LENGTH} from model path") |
|
continue |
|
elif sys.argv[i].startswith('lut'): |
|
lut_suffix = sys.argv[i] |
|
elif sys.argv[i] in ['S123', 'Q123', 'Q123S', 'C123', '123D']: |
|
model_type = sys.argv[i] |
|
i += 1 |
|
|
|
|
|
tokenizer = initialize_tokenizer(model_path) |
|
if tokenizer is None: |
|
print("[ERROR] Failed to initialize tokenizer. Exiting.") |
|
return |
|
|
|
|
|
parts = [model_type] |
|
if lut_suffix: |
|
parts.append(lut_suffix) |
|
|
|
try: |
|
split_model = SplitModelInference(parts, model_dir=model_path) |
|
model_parts.update(split_model.models) |
|
model_parts['states'] = {'transformer': split_model.states['transformer']} |
|
except Exception as e: |
|
print(f"Error loading model parts: {str(e)}") |
|
return |
|
|
|
|
|
i = 1 |
|
while i < len(sys.argv): |
|
arg = sys.argv[i] |
|
if arg.startswith('pf') and arg[2:].isdigit(): |
|
PREFILL_BATCH_SIZE = int(arg[2:]) |
|
elif arg.startswith('ctx='): |
|
try: |
|
CONTEXT_LENGTH = int(arg.split('=')[1]) |
|
except (IndexError, ValueError): |
|
print(f"[WARNING] Invalid context length format. Using default: {CONTEXT_LENGTH}") |
|
elif arg.startswith('temp='): |
|
try: |
|
temperature = float(arg.split('=')[1]) |
|
if temperature < 0: |
|
print(f"[WARNING] Temperature must be non-negative. Using default: 0.0") |
|
temperature = 0.0 |
|
except (IndexError, ValueError): |
|
print(f"[WARNING] Invalid temperature format. Using default: 0.0") |
|
i += 1 |
|
|
|
try: |
|
|
|
chat_loop(model_parts, tokenizer, context_size=CONTEXT_LENGTH, temperature=temperature) |
|
except Exception as e: |
|
print("An error occurred:") |
|
print(traceback.format_exc()) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|