import argparse from SmolLm3 import LlamaModel import yaml import torch from transformers import AutoTokenizer from train import generate def get_config(config_path): config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) return config def load_model_from_checkpoint(config_path, checkpoint_path, device): config = get_config(config_path) model = LlamaModel(config['model']) checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) state_dict = checkpoint['model_state_dict'] state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} model.load_state_dict(state_dict) return model def get_tokenizer(config): tokenizer_path = config['tokenizer']['tokenizer_name_or_path'] tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) tokenizer.pad_token = tokenizer.eos_token vocab_size = tokenizer.vocab_size return tokenizer, vocab_size def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device): encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device) generated_text = generate(model, idx=encoded_text, max_new_tokens=max_new_tokens, context_length=context_length, temperature=temperature, top_k=top_k, eos_token=eos_token, device=device) return tokenizer.decode(generated_text.squeeze(0)) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Generate text using the SmolLM model') parser.add_argument('--config_path', type=str, default="config_smollm2_135M.yaml", help='Path to the config file') parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to the model checkpoint') parser.add_argument('--input_text', type=str, default="Bernuli principle", help='Input text prompt for generation') parser.add_argument('--max_new_tokens', type=int, default=256, help='Maximum number of new tokens to generate') parser.add_argument('--context_length', type=int, default=256, help='Context length for generation') parser.add_argument('--temperature', type=float, default=0.7, help='Temperature for sampling') parser.add_argument('--top_k', type=int, default=5, help='Top-k value for sampling') parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", help='Device to run the model on (cuda/cpu)') args = parser.parse_args() config = get_config(args.config_path) model = load_model_from_checkpoint(args.config_path, args.checkpoint_path, args.device) print(model) tokenizer, vocab_size = get_tokenizer(config) print(tokenizer) print(vocab_size) generated_text = generate_text( model, tokenizer, args.input_text, args.max_new_tokens, args.context_length, args.temperature, args.top_k, tokenizer.eos_token_id, args.device ) print(generated_text)