vvv-knyazeva commited on
Commit
de6799c
·
1 Parent(s): db91526

Delete pages/gpt_v2.py

Browse files
Files changed (1) hide show
  1. pages/gpt_v2.py +0 -41
pages/gpt_v2.py DELETED
@@ -1,41 +0,0 @@
1
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
- import torch
3
- import streamlit as st
4
-
5
- model = GPT2LMHeadModel.from_pretrained(
6
- 'sberbank-ai/rugpt3small_based_on_gpt2',
7
- output_attentions = False,
8
- output_hidden_states = False,
9
- )
10
-
11
- tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
12
-
13
- # Вешаем сохраненные веса на нашу модель
14
- model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
15
-
16
- prompt = st.text_input('Введите текст prompt:')
17
- length = st.slider('Длина генерируемой последовательности:', 1, 256, 16)
18
- num_samples = st.slider('Число генераций:', 1, 6, 1)
19
- temperature = st.slider('Температура:', 1.0, 6.0, 1.0)
20
- selected_text = st.empty()
21
-
22
- def generate_text(model, tokenizer, prompt, length, num_samples, temperature, selected_text):
23
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
24
- output_sequences = model.generate(
25
- input_ids=input_ids,
26
- max_length=length,
27
- num_return_sequences=num_samples,
28
- temperature=temperature
29
- )
30
-
31
- generated_texts = []
32
- for output_sequence in output_sequences:
33
- generated_text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)
34
- generated_texts.append(generated_text)
35
-
36
- selected_text.slider('Выберите текст:', 1, num_samples, 1)
37
- return generated_texts[selected_text.value-1]
38
-
39
- if st.button('Сгенерировать текст'):
40
- text = generate_text(model, tokenizer, prompt, length, num_samples, temperature, selected_text)
41
- st.write(text)