JoPmt commited on
Commit
06c0e4e
·
1 Parent(s): 12deb6d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline
5
+ from accelerate import Accelerator
6
+ accelerator = Accelerator(cpu=True)
7
+ cwd = "./models"
8
+
9
+ tokenizer = accelerator.prepare(AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m"))
10
+ model = accelerator.prepare(AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m"))
11
+ train_dataset = TextDataset(
12
+ tokenizer=tokenizer,
13
+ ## file_path='./train_text.txt',
14
+ file_path='./train_text.txt',
15
+ block_size=128
16
+ )
17
+ data_collator = DataCollatorForLanguageModeling(
18
+ tokenizer=tokenizer,
19
+ mlm=False
20
+ )
21
+ training_args = TrainingArguments(
22
+ output_dir=cwd,
23
+ overwrite_output_dir=True,
24
+ num_train_epochs=one,
25
+ per_device_train_batch_size=8,
26
+ save_steps=two,
27
+ save_total_limit=one,
28
+ )
29
+ trainer = Trainer(
30
+ model=model,
31
+ args=training_args,
32
+ data_collator=data_collator,
33
+ train_dataset=train_dataset,
34
+ )
35
+ trainer.train()
36
+ tokenizer.save_pretrained('./models')
37
+ trainer.save_model('./models', 'pytorch_model')
38
+ src = './config.json'
39
+ des = './models/config.json'
40
+ os.rename(src, des)
41
+ tokenizer = accelerator.prepare(AutoTokenizer.from_pretrained("./models"))
42
+ model = accelerator.prepare(AutoModelForCausalLM.from_pretrained("./models"))
43
+ def plex(input_text):
44
+ mnputs = tokenizer(input_text, return_tensors='pt')
45
+ prediction = model.generate(mnputs['input_ids'], min_length=20, max_length=150, num_return_sequences=1)
46
+ lines = tokenizer.decode(prediction[0]).splitlines()
47
+ return lines[0]
48
+
49
+ iface=gr.Interface(
50
+ fn=plex,
51
+ inputs=gr.Textbox(label="Prompt Finetuned Model"),
52
+ outputs=gr.Textbox(label="Generated_Text"),
53
+ title="GPT-Neo-125M fine-tuned on a small set of shortstories with Gradio",
54
+ description="Prompt for a short bedtime story.",
55
+ ##examples=gr.Examples(fn=fine_tune_llm,inputs=['./test.txt',"Once upon a time",2,2000],outputs=[gr.Textbox(),gr.File()],cache_examples=True,)
56
+ )
57
+ iface.launch()