import torch import gradio as gr from smollm_training import SmolLMConfig, tokenizer, SmolLM # Load the model def load_model(): config = SmolLMConfig() model = SmolLM(config) # Create base model instead of Lightning model # Load just the model weights state_dict = torch.load("model_weights.pt", map_location="cpu") model.load_state_dict(state_dict) model.eval() return model def generate_text(prompt, max_tokens, temperature=0.8, top_k=40): """Generate text based on the prompt""" try: # Encode the prompt prompt_ids = tokenizer.encode(prompt, return_tensors="pt") # Move to device if needed device = next(model.parameters()).device prompt_ids = prompt_ids.to(device) # Generate text with torch.no_grad(): generated_ids = model.generate( # Call generate directly on base model prompt_ids, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, ) # Decode the generated text generated_text = tokenizer.decode(generated_ids[0].tolist()) return generated_text except Exception as e: return f"An error occurred: {str(e)}" # Load the model globally model = load_model() # Create the Gradio interface demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox( label="Enter your prompt", placeholder="Once upon a time...", lines=3 ), gr.Slider( minimum=50, maximum=500, value=100, step=10, label="Maximum number of tokens", ), ], outputs=gr.Textbox(label="Generated Text", lines=10), title="SmolLM Text Generator", description="Enter a prompt and the model will generate a continuation.", examples=[ ["Once upon a time", 100], ["The future of AI is", 200], ["In a galaxy far far away", 150], ], ) if __name__ == "__main__": demo.launch()