SmolLMTextGenerator / model_testing.py
crpatel's picture
gradio app
fb26382
import argparse
from SmolLm3 import LlamaModel
import yaml
import torch
from transformers import AutoTokenizer
from train import generate
def get_config(config_path):
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
return config
def load_model_from_checkpoint(config_path, checkpoint_path, device):
config = get_config(config_path)
model = LlamaModel(config['model'])
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
state_dict = checkpoint['model_state_dict']
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
return model
def get_tokenizer(config):
tokenizer_path = config['tokenizer']['tokenizer_name_or_path']
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size
return tokenizer, vocab_size
def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device):
encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device)
generated_text = generate(model,
idx=encoded_text,
max_new_tokens=max_new_tokens,
context_length=context_length,
temperature=temperature,
top_k=top_k,
eos_token=eos_token,
device=device)
return tokenizer.decode(generated_text.squeeze(0))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate text using the SmolLM model')
parser.add_argument('--config_path', type=str, default="config_smollm2_135M.yaml",
help='Path to the config file')
parser.add_argument('--checkpoint_path', type=str, required=True,
help='Path to the model checkpoint')
parser.add_argument('--input_text', type=str, default="Bernuli principle",
help='Input text prompt for generation')
parser.add_argument('--max_new_tokens', type=int, default=256,
help='Maximum number of new tokens to generate')
parser.add_argument('--context_length', type=int, default=256,
help='Context length for generation')
parser.add_argument('--temperature', type=float, default=0.7,
help='Temperature for sampling')
parser.add_argument('--top_k', type=int, default=5,
help='Top-k value for sampling')
parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help='Device to run the model on (cuda/cpu)')
args = parser.parse_args()
config = get_config(args.config_path)
model = load_model_from_checkpoint(args.config_path, args.checkpoint_path, args.device)
print(model)
tokenizer, vocab_size = get_tokenizer(config)
print(tokenizer)
print(vocab_size)
generated_text = generate_text(
model,
tokenizer,
args.input_text,
args.max_new_tokens,
args.context_length,
args.temperature,
args.top_k,
tokenizer.eos_token_id,
args.device
)
print(generated_text)