|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2") |
|
model = AutoModelForCausalLM.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2") |
|
device = torch.device("cpu") |
|
|
|
|
|
def generate_text(prompt): |
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
output = model.generate(input_ids, max_length=50, do_sample=True) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
|
|
st.title("RuGPT-3 Demo") |
|
|
|
|
|
prompt = st.text_input("Введите текст:") |
|
|
|
if prompt: |
|
|
|
response = generate_text(prompt) |
|
|
|
|
|
st.write(response) |
|
|