nikhil-kumar commited on
Commit
c07a03d
·
verified ·
1 Parent(s): 92a1845

Upload train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +46 -0
train_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
3
+
4
+ # Load dataset
5
+ dataset = load_dataset('json', data_files='flirty_dataset.json')
6
+
7
+ # Tokenizer and model
8
+ model_name = "gpt2"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
+
12
+ # Tokenize dataset
13
+ def tokenize_function(examples):
14
+ return tokenizer(examples['prompt'], truncation=True, padding="max_length", max_length=128)
15
+
16
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
17
+
18
+ # Training arguments
19
+ training_args = TrainingArguments(
20
+ output_dir="./fine_tuned_gpt2",
21
+ evaluation_strategy="epoch",
22
+ save_strategy="epoch",
23
+ learning_rate=5e-5,
24
+ num_train_epochs=3,
25
+ per_device_train_batch_size=8,
26
+ save_total_limit=2,
27
+ logging_dir="./logs",
28
+ logging_steps=10,
29
+ fp16=True
30
+ )
31
+
32
+ # Trainer
33
+ trainer = Trainer(
34
+ model=model,
35
+ args=training_args,
36
+ train_dataset=tokenized_dataset["train"],
37
+ eval_dataset=tokenized_dataset["validation"],
38
+ tokenizer=tokenizer
39
+ )
40
+
41
+ # Train the model
42
+ trainer.train()
43
+
44
+ # Save model
45
+ trainer.save_model("./fine_tuned_gpt2")
46
+ tokenizer.save_pretrained("./fine_tuned_gpt2")