import math from contextlib import contextmanager from dataclasses import dataclass from typing import List, Callable import safetensors import torch from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights @dataclass class VisionBlock: ln1: LayerNormWeights attn: AttentionWeights ln2: LayerNormWeights mlp: MLPWeights @dataclass class VisionModel: patch_size: int patch_emb: LinearWeights pos_emb: torch.Tensor blocks: List[VisionBlock] post_ln: LayerNormWeights proj_mlp: MLPWeights @dataclass class TextBlock: ln: LayerNormWeights attn: AttentionWeights mlp: MLPWeights @dataclass class TextModel: wte: torch.Tensor blocks: List[TextBlock] post_ln: LayerNormWeights lm_head: LinearWeights @dataclass class MoondreamModel: vision: VisionModel text: TextModel @contextmanager def safetensors_open(safetensors_file: str): """ Simplify interfacing with safetensors files. Eliminates the need to ignore type errors when using the `safe_open` function. """ with safetensors.safe_open( safetensors_file, framework="pt" ) as st: # pyright: ignore def get_tensor(name: str) -> torch.Tensor: return st.get_tensor(name) yield get_tensor def load_model( get_tensor: Callable[[str], torch.Tensor], vision_blocks: int = 27, text_blocks: int = 24, vision_n_heads: int = 16, text_n_heads: int = 32, ) -> MoondreamModel: ## Vision encoder prefix = "vision_encoder.encoder.model.visual.patch_embed.linear" patch_emb = LinearWeights( weight=get_tensor(f"{prefix}.weight"), bias=get_tensor(f"{prefix}.bias") ) patch_size = int(math.sqrt(patch_emb.weight.shape[1] // 3)) pos_emb = get_tensor("vision_encoder.encoder.model.visual.pos_embed") post_ln = LayerNormWeights( weight=get_tensor("vision_encoder.encoder.model.visual.norm.weight"), bias=get_tensor("vision_encoder.encoder.model.visual.norm.bias"), ) blocks = [] for i in range(vision_blocks): prefix = f"vision_encoder.encoder.model.visual.blocks.{i}" blocks.append( VisionBlock( ln1=LayerNormWeights( weight=get_tensor(f"{prefix}.norm1.weight"), bias=get_tensor(f"{prefix}.norm1.bias"), ), attn=AttentionWeights( qkv=LinearWeights( weight=get_tensor(f"{prefix}.attn.qkv.weight"), bias=get_tensor(f"{prefix}.attn.qkv.bias"), ), proj=LinearWeights( weight=get_tensor(f"{prefix}.attn.proj.weight"), bias=get_tensor(f"{prefix}.attn.proj.bias"), ), n_heads=vision_n_heads, ), ln2=LayerNormWeights( weight=get_tensor(f"{prefix}.norm2.weight"), bias=get_tensor(f"{prefix}.norm2.bias"), ), mlp=MLPWeights( fc1=LinearWeights( weight=get_tensor(f"{prefix}.mlp.fc1.weight"), bias=get_tensor(f"{prefix}.mlp.fc1.bias"), ), fc2=LinearWeights( weight=get_tensor(f"{prefix}.mlp.fc2.weight"), bias=get_tensor(f"{prefix}.mlp.fc2.bias"), ), ), ) ) proj_mlp = MLPWeights( fc1=LinearWeights( weight=get_tensor("vision_encoder.projection.mlp.fc1.weight"), bias=get_tensor("vision_encoder.projection.mlp.fc1.bias"), ), fc2=LinearWeights( weight=get_tensor("vision_encoder.projection.mlp.fc2.weight"), bias=get_tensor("vision_encoder.projection.mlp.fc2.bias"), ), act="gelu_approx", ) vision = VisionModel( patch_size=patch_size, patch_emb=patch_emb, pos_emb=pos_emb, blocks=blocks, post_ln=post_ln, proj_mlp=proj_mlp, ) ## Text decoder model wte = get_tensor("text_model.transformer.embd.wte.weight") post_ln = LayerNormWeights( weight=get_tensor("text_model.lm_head.ln.weight"), bias=get_tensor("text_model.lm_head.ln.bias"), ) lm_head = LinearWeights( weight=get_tensor("text_model.lm_head.linear.weight"), bias=get_tensor("text_model.lm_head.linear.bias"), ) blocks = [] for i in range(text_blocks): prefix = f"text_model.transformer.h.{i}" blocks.append( TextBlock( ln=LayerNormWeights( weight=get_tensor(f"{prefix}.ln.weight"), bias=get_tensor(f"{prefix}.ln.bias"), ), attn=AttentionWeights( qkv=LinearWeights( weight=get_tensor(f"{prefix}.mixer.Wqkv.weight"), bias=get_tensor(f"{prefix}.mixer.Wqkv.bias"), ), proj=LinearWeights( weight=get_tensor(f"{prefix}.mixer.out_proj.weight"), bias=get_tensor(f"{prefix}.mixer.out_proj.bias"), ), n_heads=text_n_heads, ), mlp=MLPWeights( fc1=LinearWeights( weight=get_tensor(f"{prefix}.mlp.fc1.weight"), bias=get_tensor(f"{prefix}.mlp.fc1.bias"), ), fc2=LinearWeights( weight=get_tensor(f"{prefix}.mlp.fc2.weight"), bias=get_tensor(f"{prefix}.mlp.fc2.bias"), ), act="gelu_approx", ), ) ) text = TextModel(wte=wte, blocks=blocks, post_ln=post_ln, lm_head=lm_head) return MoondreamModel(vision=vision, text=text) def load_from_safetensors( safetensors_file: str, vision_blocks: int = 27, text_blocks: int = 24, **kwargs, ) -> MoondreamModel: with safetensors_open(safetensors_file) as get_tensor: return load_model(get_tensor, vision_blocks, text_blocks, **kwargs) def load_from_pt( pt_file: str, vision_blocks: int = 27, text_blocks: int = 24, **kwargs, ) -> MoondreamModel: device = str(torch.empty(0).device) tensors = torch.load(pt_file, map_location=device, weights_only=True) tensors = { k.replace("._orig_mod", ""): v.to(dtype=torch.float16) for k, v in tensors.items() } return load_model(lambda x: tensors[x], vision_blocks, text_blocks, **kwargs) if __name__ == "__main__": weights = load_from_safetensors("model.safetensors") print(weights)