File size: 3,369 Bytes
fb26382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
from SmolLm3 import LlamaModel
import yaml
import torch
from transformers import AutoTokenizer
from train import generate

def get_config(config_path):
    config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
    return config

def load_model_from_checkpoint(config_path, checkpoint_path, device):
    config = get_config(config_path)
    model = LlamaModel(config['model'])
    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
    state_dict = checkpoint['model_state_dict']
    state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    return model

def get_tokenizer(config):
    tokenizer_path = config['tokenizer']['tokenizer_name_or_path']
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    tokenizer.pad_token = tokenizer.eos_token
    vocab_size = tokenizer.vocab_size
    return tokenizer, vocab_size

def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device):
    encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device)
    generated_text = generate(model, 
                            idx=encoded_text,
                            max_new_tokens=max_new_tokens,
                            context_length=context_length, 
                            temperature=temperature, 
                            top_k=top_k, 
                            eos_token=eos_token, 
                            device=device)
    return tokenizer.decode(generated_text.squeeze(0))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate text using the SmolLM model')
    parser.add_argument('--config_path', type=str, default="config_smollm2_135M.yaml",
                        help='Path to the config file')
    parser.add_argument('--checkpoint_path', type=str, required=True,
                        help='Path to the model checkpoint')
    parser.add_argument('--input_text', type=str, default="Bernuli principle",
                        help='Input text prompt for generation')
    parser.add_argument('--max_new_tokens', type=int, default=256,
                        help='Maximum number of new tokens to generate')
    parser.add_argument('--context_length', type=int, default=256,
                        help='Context length for generation')
    parser.add_argument('--temperature', type=float, default=0.7,
                        help='Temperature for sampling')
    parser.add_argument('--top_k', type=int, default=5,
                        help='Top-k value for sampling')
    parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help='Device to run the model on (cuda/cpu)')
    
    args = parser.parse_args()

    config = get_config(args.config_path)
    model = load_model_from_checkpoint(args.config_path, args.checkpoint_path, args.device)
    print(model)
    tokenizer, vocab_size = get_tokenizer(config)
    print(tokenizer)
    print(vocab_size)

    generated_text = generate_text(
        model, 
        tokenizer, 
        args.input_text,
        args.max_new_tokens,
        args.context_length,
        args.temperature,
        args.top_k,
        tokenizer.eos_token_id,
        args.device
    )
    print(generated_text)