File size: 1,179 Bytes
b233041
 
 
3bdf427
 
 
b233041
8b2cc38
b233041
 
4d7a2e1
 
 
 
 
 
 
 
 
 
 
 
b233041
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# りんなGPT-2-medium ファインチューニングやってみた

# パッケージのインストール
# pip install transformers==4.23.1
# pip install evaluate==0.3.0
# pip install sentencepiece==0.1.97

# %%time

# ファインチューニングの実行
# python ./transformers/examples/pytorch/language-modeling/run_clm.py \
#    --model_name_or_path=rinna/japanese-gpt2-medium \
#    --train_file=natsumesouseki.txt \
#   --validation_file=natsumesouseki.txt \
#    --do_train \
#    --do_eval \
#    --num_train_epochs=3 \
#    --save_steps=5000 \
#    --save_total_limit=3 \
#    --per_device_train_batch_size=1 \
#    --per_device_eval_batch_size=1 \
#    --output_dir=output/

from transformers import T5Tokenizer, AutoModelForCausalLM

# トークナイザーとモデルの準備
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")

# 推論の実行
def Chat(prompt, ):
    input = tokenizer.encode(prompt, return_tensors="pt")
    output = model.generate(input, do_sample=True, max_length=300, num_return_sequences=5)
    return print(tokenizer.batch_decode(output))