Spaces:
Runtime error
Runtime error
# りんなGPT-2-medium ファインチューニングやってみた | |
# %%time | |
# ファインチューニングの実行 | |
# python ./transformers/examples/pytorch/language-modeling/run_clm.py \ | |
# --model_name_or_path=rinna/japanese-gpt2-medium \ | |
# --train_file=natsumesouseki.txt \ | |
# --validation_file=natsumesouseki.txt \ | |
# --do_train \ | |
# --do_eval \ | |
# --num_train_epochs=3 \ | |
# --save_steps=5000 \ | |
# --save_total_limit=3 \ | |
# --per_device_train_batch_size=1 \ | |
# --per_device_eval_batch_size=1 \ | |
# --output_dir=output/ | |
from transformers import T5Tokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import torch | |
# トークナイザーとモデルの準備 | |
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium") | |
# model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium") | |
model = AutoModelForCausalLM.from_pretrained("output/") | |
# 平均/分散の値を正規化 | |
model.eval() | |
# 推論の実行 | |
#def Chat(prompt): | |
# input = tokenizer.encode(prompt, return_tensors="pt") | |
# output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=5) | |
# return tokenizer.batch_decode(output) | |
def Chat(prompt): | |
num = 3 | |
input_ids = tokenizer.encode(prompt, return_tensors="pt",add_special_tokens=False) | |
#with torch.no_grad(): | |
output = model.generate( | |
input_ids, | |
max_length=300, # 最長の文章長 | |
min_length=100, # 最短の文章長 | |
do_sample=True, | |
top_k=500, # 上位{top_k}個の文章を保持 | |
top_p=0.95, # 上位{top_p}%の単語から選択する。例)上位95%の単語から選んでくる | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
#bad_word_ids=[[tokenizer.unk_token_id]], | |
num_return_sequences=num # 生成する文章の数 | |
) | |
decoded = tokenizer.decode(output.tolist()[0]) | |
return decoded | |
app = gr.Interface(fn=Chat, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="text" , title="夏目漱石GPT") | |
app.launch() |