import argparse import json import os import torch from PIL import Image from transformers import AutoTokenizer from .rope import precompute_freqs_cis from .text import lm_head, text_decoder, text_encoder from .vision import encode_image from .weights import load_from_pt, load_from_safetensors if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--image", "-i", type=str, required=True) parser.add_argument("--prompt", "-p", type=str, required=True) parser.add_argument("--model", "-m", type=str, required=True) parser.add_argument("--config", "-c", type=str, default="{}") parser.add_argument("--max-tokens", "-t", type=int, default=200) parser.add_argument("--sampler", "-s", type=str, default="greedy") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") # Load config. config = json.loads(args.config) text_n_heads = config.get("text_n_heads", 32) # Load model. model_path = args.model if not os.path.exists(model_path): raise FileNotFoundError(f"Model not found at {model_path}") if model_path.endswith(".pt"): model = load_from_pt(model_path, **config) elif model_path.endswith(".safetensors"): model = load_from_safetensors(model_path, **config) else: raise ValueError(f"Invalid model format: {model_path}") # Encode image. image_path = args.image if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found at {image_path}") image = Image.open(image_path) image = image.resize((378, 378)) image_tensor = encode_image(image, model.vision) # Encode text, and create inputs_embeds. tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2") prompt = f"\n\nQuestion: {args.prompt}\n\nAnswer:" input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"] input_ids = torch.cat([torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1) inputs_embeds = text_encoder(input_ids, model.text) inputs_embeds = torch.cat( [ inputs_embeds[:, 0:1, :], image_tensor.unsqueeze(0), inputs_embeds[:, 1:, :], ], dim=1, ) kv_cache = torch.empty(24, 2, 1, text_n_heads, 2048, 64, dtype=torch.float16) freqs_cis = precompute_freqs_cis(32, 2048) pos = 0 for _ in range(args.max_tokens): with torch.no_grad(): hidden, kv_cache_update = text_decoder( inputs_embeds, model.text, kv_cache[:, :, :, :, :pos, :], freqs_cis ) logits = lm_head(hidden, model.text) kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = ( kv_cache_update ) pos += kv_cache_update.size(-2) if args.sampler == "multinomial": next_token = torch.multinomial( torch.softmax(logits, dim=-1), num_samples=1 ).squeeze(0) elif args.sampler == "greedy": next_token = torch.argmax(logits, dim=-1) else: raise ValueError(f"Invalid sampler: {args.sampler}") if next_token == tokenizer.eos_token_id: print() break input_ids = next_token.unsqueeze(0) inputs_embeds = text_encoder(input_ids, model.text) output_text = tokenizer.batch_decode(input_ids)[0] print(output_text, end="", flush=True)