Spaces:
Sleeping
Sleeping
import argparse | |
from SmolLm3 import LlamaModel | |
import yaml | |
import torch | |
from transformers import AutoTokenizer | |
from train import generate | |
def get_config(config_path): | |
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) | |
return config | |
def load_model_from_checkpoint(config_path, checkpoint_path, device): | |
config = get_config(config_path) | |
model = LlamaModel(config['model']) | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) | |
state_dict = checkpoint['model_state_dict'] | |
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
return model | |
def get_tokenizer(config): | |
tokenizer_path = config['tokenizer']['tokenizer_name_or_path'] | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
vocab_size = tokenizer.vocab_size | |
return tokenizer, vocab_size | |
def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device): | |
encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
generated_text = generate(model, | |
idx=encoded_text, | |
max_new_tokens=max_new_tokens, | |
context_length=context_length, | |
temperature=temperature, | |
top_k=top_k, | |
eos_token=eos_token, | |
device=device) | |
return tokenizer.decode(generated_text.squeeze(0)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Generate text using the SmolLM model') | |
parser.add_argument('--config_path', type=str, default="config_smollm2_135M.yaml", | |
help='Path to the config file') | |
parser.add_argument('--checkpoint_path', type=str, required=True, | |
help='Path to the model checkpoint') | |
parser.add_argument('--input_text', type=str, default="Bernuli principle", | |
help='Input text prompt for generation') | |
parser.add_argument('--max_new_tokens', type=int, default=256, | |
help='Maximum number of new tokens to generate') | |
parser.add_argument('--context_length', type=int, default=256, | |
help='Context length for generation') | |
parser.add_argument('--temperature', type=float, default=0.7, | |
help='Temperature for sampling') | |
parser.add_argument('--top_k', type=int, default=5, | |
help='Top-k value for sampling') | |
parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", | |
help='Device to run the model on (cuda/cpu)') | |
args = parser.parse_args() | |
config = get_config(args.config_path) | |
model = load_model_from_checkpoint(args.config_path, args.checkpoint_path, args.device) | |
print(model) | |
tokenizer, vocab_size = get_tokenizer(config) | |
print(tokenizer) | |
print(vocab_size) | |
generated_text = generate_text( | |
model, | |
tokenizer, | |
args.input_text, | |
args.max_new_tokens, | |
args.context_length, | |
args.temperature, | |
args.top_k, | |
tokenizer.eos_token_id, | |
args.device | |
) | |
print(generated_text) |