{"cells":[{"cell_type":"markdown","metadata":{"id":"0-7S1J6Jq7nc"},"source":["# Fine-Tuning BERT as a `ToxicityModel`\n","\n","1. First, intall `transformers`, `tlr`, and `codecarbon`."]},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":35124,"status":"ok","timestamp":1686695247928,"user":{"displayName":"Nicholas Corrêa","userId":"09736120585766268588"},"user_tz":-120},"id":"Fx7pg9eT62-d","outputId":"86700c17-1c4a-4fd2-9edb-bf468b152c06"},"outputs":[{"name":"stdout","output_type":"stream","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting transformers\n"," Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.2/7.2 MB\u001b[0m \u001b[31m84.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n","Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)\n"," Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m31.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n","Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)\n"," Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m107.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting safetensors>=0.3.1 (from transformers)\n"," Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m81.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.4.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n","Installing collected packages: tokenizers, safetensors, huggingface-hub, transformers\n","Successfully installed huggingface-hub-0.15.1 safetensors-0.3.1 tokenizers-0.13.3 transformers-4.30.2\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting trl\n"," Downloading trl-0.4.4-py3-none-any.whl (68 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m68.4/68.4 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from trl) (2.0.1+cu118)\n","Requirement already satisfied: transformers>=4.18.0 in /usr/local/lib/python3.10/dist-packages (from trl) (4.30.2)\n","Requirement already satisfied: numpy>=1.18.2 in /usr/local/lib/python3.10/dist-packages (from trl) (1.22.4)\n","Collecting accelerate (from trl)\n"," Downloading accelerate-0.20.3-py3-none-any.whl (227 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m227.6/227.6 kB\u001b[0m \u001b[31m20.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting datasets (from trl)\n"," Downloading datasets-2.12.0-py3-none-any.whl (474 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m474.6/474.6 kB\u001b[0m \u001b[31m47.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.12.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (1.11.1)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.1.2)\n","Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (2.0.0)\n","Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.4.0->trl) (3.25.2)\n","Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.4.0->trl) (16.0.5)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (0.15.1)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (2.27.1)\n","Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (0.13.3)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (0.3.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.18.0->trl) (4.65.0)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate->trl) (5.9.5)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (9.0.0)\n","Collecting dill<0.3.7,>=0.3.0 (from datasets->trl)\n"," Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m14.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (1.5.3)\n","Collecting xxhash (from datasets->trl)\n"," Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.5/212.5 kB\u001b[0m \u001b[31m27.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multiprocess (from datasets->trl)\n"," Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (2023.4.0)\n","Collecting aiohttp (from datasets->trl)\n"," Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m77.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting responses<0.19 (from datasets->trl)\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (23.1.0)\n","Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (2.0.12)\n","Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->trl)\n"," Downloading multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3 (from aiohttp->datasets->trl)\n"," Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n","Collecting yarl<2.0,>=1.0 (from aiohttp->datasets->trl)\n"," Downloading yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (268 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m37.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting frozenlist>=1.1.1 (from aiohttp->datasets->trl)\n"," Downloading frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (149 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m149.6/149.6 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiosignal>=1.1.2 (from aiohttp->datasets->trl)\n"," Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.18.0->trl) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.18.0->trl) (2022.12.7)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.18.0->trl) (3.4)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4.0->trl) (2.1.2)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl) (2022.7.1)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.4.0->trl) (1.3.0)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets->trl) (1.16.0)\n","Installing collected packages: xxhash, multidict, frozenlist, dill, async-timeout, yarl, responses, multiprocess, aiosignal, aiohttp, datasets, accelerate, trl\n","Successfully installed accelerate-0.20.3 aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.12.0 dill-0.3.6 frozenlist-1.3.3 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 trl-0.4.4 xxhash-3.2.0 yarl-1.9.2\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting codecarbon\n"," Downloading codecarbon-2.2.3-py3-none-any.whl (174 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m174.1/174.1 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting arrow (from codecarbon)\n"," Downloading arrow-1.2.3-py3-none-any.whl (66 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.4/66.4 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from codecarbon) (1.5.3)\n","Collecting pynvml (from codecarbon)\n"," Downloading pynvml-11.5.0-py3-none-any.whl (53 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from codecarbon) (2.27.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from codecarbon) (5.9.5)\n","Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.10/dist-packages (from codecarbon) (9.0.0)\n","Collecting fuzzywuzzy (from codecarbon)\n"," Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)\n","Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from codecarbon) (8.1.3)\n","Requirement already satisfied: python-dateutil>=2.7.0 in /usr/local/lib/python3.10/dist-packages (from arrow->codecarbon) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->codecarbon) (2022.7.1)\n","Requirement already satisfied: numpy>=1.21.0 in /usr/local/lib/python3.10/dist-packages (from pandas->codecarbon) (1.22.4)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->codecarbon) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->codecarbon) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->codecarbon) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->codecarbon) (3.4)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7.0->arrow->codecarbon) (1.16.0)\n","Installing collected packages: fuzzywuzzy, pynvml, arrow, codecarbon\n","Successfully installed arrow-1.2.3 codecarbon-2.2.3 fuzzywuzzy-0.18.0 pynvml-11.5.0\n"]}],"source":["%pip install transformers\n","%pip install trl\n","%pip install codecarbon"]},{"cell_type":"markdown","metadata":{"id":"Y6xzGtxPrMaF"},"source":["2. Downloas the `toxic-aira-dataset` from the Hub."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":485},"executionInfo":{"elapsed":3284,"status":"ok","timestamp":1686695275355,"user":{"displayName":"Nicholas Corrêa","userId":"09736120585766268588"},"user_tz":-120},"id":"DtCgCgEr62C9","outputId":"0f900bf3-e28c-48af-c0a5-3cb59ad7b2ea"},"outputs":[],"source":["from datasets import load_dataset\n","\n","dataset = load_dataset(\"nicholasKluge/toxic-aira-dataset\", split=\"portuguese\")\n","\n","print(\"Dataset loaded.\")"]},{"cell_type":"markdown","metadata":{"id":"kQQ1DkB5rjpS"},"source":["3. Download your base model for fine-tuning. Here he are using `bert-base-cased` for the English toxicity model and `bert-base-portuguese-cased` for the Portuguse version."]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":333,"referenced_widgets":["a488c62affc846caaee5d659b615d523","1ab233a991a946b0a902e226134ed49d","e18719d0122246729d012cedfa215acc","a8994cd6846c485099d7cc089e2424b2","e6fc2f253d5742dc8459c734158f91b9","6498f593222b413c92c5afb62e086be2","1dcf16d4d14e49ec90c4ad74c2fc40fb","f4315a30d94d43ddb69633e1f265195c","a75cec00f3f14d25a1ec08b794ffce0e","569077cc0c15497190efdbb37e139095","9651b711a57040f8b9c25d2b1f5a4c1f","fdd8a8b585e548d8a2831c900d8ac94c","e70dd056eb7e4c329fd95bad44d8549b","9fae941828bf481181e7546385716be5","0d9d0709b8f34fecb7fd3dd46fe999de","11532639794541ff8d205def3545985a","a50eb888cd5649b296b650e900b6ff8e","d136624988e64e1699c0e7b2abdc2548","47f8913a41db40a0a7f161bf6d5a1a00","5900851652e74049b20b6b0ae09f5d26","913b4a4996ef4039b3a8410113efe127","5d45d3a7f5f049ac8d0b21e659de544f","f1e1a329e8bf443482232d9ccacd42d8","08ff24dd2d084b52a225c72b4484dcca","a5c27c7f47de49f0a3c0b2ea3b465393","fe55b6cf090a4d028781ba20d16b9350","f7d0354c240e41d58a7776d8158ad02d","50acfb9c523b44c79391d60cd191709d","423b69c785904925ab1e25b5e917d6c6","7292afdddb1e40d9b32553b5c4578aef","6364072137384f5895561f7d2cf6569c","908d048b2b8240a2bdf7c8b0352a0a06","79be3f2242854be7af94a2c83627b7a5","6a6da07dff5649208c175b04ed67f718","f5777e2ae8204dbfb770149ebb79b0b6","f4b550d692f74cd68511806cb699df35","111095e9d89f4c078c8c9d3d3fc9994d","6bf938b8fd46480da8d857f1fb1056d9","587b151045464da1a662a62d64e4c464","9b42d0a334024033bd7c68e3a5a2d596","d01cf96914404a8a9dced31d860b45ef","8dd5a216f35249beb9a9e2209eeb99f8","284fc8e79e3240119c364df5ce79f4fd","14b52fc4dcde467585a50108b8357a26","648c9937b52847c19b0330b91284acb4","680690aefae348e9a3804e5048a098be","4659d2ce571f4d6bb44a98b5fcc93369","a01412773980400ba84671cd5fd2c90b","82f440e55ecd4339ad713aaa57c7db49","95be32fadac945acb94fef445a919fcd","0926d0a6585e42d3bac64bc990de3ffb","31d1021c5c5f42519fe09bea18dbf861","a9f25d3f315c495aade054923004a54a","a3f4b0a75d314367a57de76c7746d59a","5b5536d2b75740a490880c42db5a644e","d84914413b9d42ef9ae61b3c189d5016","98d262899fec41ee9ee38fdd495cbd9f","c38f5079edc8488d9b3e66b523c2cbcc","bcad9ecd0eba4beca9f54e435fe8dbe1","d49ab184be864a9aa64b5281b38e9b6c","e731b16de2584650b088b191e380c5c5","a7be2a41a88147a4bdceede2396874a3","69f36174d5034891a52882dc47cba515","7daaf64785f546728e67afdf3f3d6b64","da0a9e043f2f48aca6056bbd6c89fad2","3c51287c34334bfe8191def47e66f9c3"]},"executionInfo":{"elapsed":12469,"status":"ok","timestamp":1686695299044,"user":{"displayName":"Nicholas Corrêa","userId":"09736120585766268588"},"user_tz":-120},"id":"W3d0nlgO62DB","outputId":"7b92bb51-c3d8-473b-fe59-1eab38379e95"},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a488c62affc846caaee5d659b615d523","version_major":2,"version_minor":0},"text/plain":["Downloading (…)lve/main/config.json: 0%| | 0.00/647 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"fdd8a8b585e548d8a2831c900d8ac94c","version_major":2,"version_minor":0},"text/plain":["Downloading pytorch_model.bin: 0%| | 0.00/438M [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"name":"stderr","output_type":"stream","text":["Some weights of the model checkpoint at neuralmind/bert-base-portuguese-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n","- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n","- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n","Some weights of BertForSequenceClassification were not initialized from the model checkpoint at neuralmind/bert-base-portuguese-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n","You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f1e1a329e8bf443482232d9ccacd42d8","version_major":2,"version_minor":0},"text/plain":["Downloading (…)okenizer_config.json: 0%| | 0.00/43.0 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"6a6da07dff5649208c175b04ed67f718","version_major":2,"version_minor":0},"text/plain":["Downloading (…)solve/main/vocab.txt: 0%| | 0.00/210k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"648c9937b52847c19b0330b91284acb4","version_major":2,"version_minor":0},"text/plain":["Downloading (…)in/added_tokens.json: 0%| | 0.00/2.00 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"d84914413b9d42ef9ae61b3c189d5016","version_major":2,"version_minor":0},"text/plain":["Downloading (…)cial_tokens_map.json: 0%| | 0.00/112 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Model (neuralmind/bert-base-portuguese-cased) ready.\n"]}],"source":["from transformers import AutoModelForSequenceClassification, AutoTokenizer\n","import torch\n","\n","model_name = \"neuralmind/bert-base-portuguese-cased\" # \"neuralmind/bert-base-portuguese-cased\" bert-base-cased\n","\n","model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)\n","tokenizer = AutoTokenizer.from_pretrained(model_name)\n","\n","if tokenizer.pad_token is None:\n"," tokenizer.pad_token = tokenizer.eos_token\n"," model.config.pad_token_id = model.config.eos_token_id\n","\n","print(f\"Model ({model_name}) ready.\")"]},{"cell_type":"markdown","metadata":{"id":"kvq6hWQ4sAlw"},"source":["4. Preprocess the dataset to be compatible with the `RewardTrainer` from `tlr`."]},{"cell_type":"code","execution_count":5,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":17,"referenced_widgets":["c5ac436aba8843c683e4be85d6515a13","4561ac2bab0c470d9e3fcf55ca599050","363700a107c54ae3a9f575e4756dc846","d49393afdbbd46f18002f5e5d9ba8ff4","150446fd7f4e4edf9374ff361c4933d2","6a47d27e478649a1a3c54355ef47107e","789c96ffda924a7685467fbd97bbec2e","6b45725239984a238ef863bcd6804783","0640f08ed65e41228eb4e30336cbfa6c","39f89c00fdeb4c01b6d1f663253a7991","3c6f9d1a236b4a31a7945d3343077724"]},"executionInfo":{"elapsed":19681,"status":"ok","timestamp":1686695323719,"user":{"displayName":"Nicholas Corrêa","userId":"09736120585766268588"},"user_tz":-120},"id":"JwrPjIbo62DC","outputId":"df482268-7fff-452a-ff16-fe3e2ac77e43"},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c5ac436aba8843c683e4be85d6515a13","version_major":2,"version_minor":0},"text/plain":["Map: 0%| | 0/16730 [00:00, ? examples/s]"]},"metadata":{},"output_type":"display_data"}],"source":["def preprocess(examples):\n"," kwargs = {\"padding\": \"max_length\", \"truncation\": True, \"max_length\": 350, \"return_tensors\": \"pt\"}\n","\n"," non_toxic_response = examples[\"non_toxic_response\"]\n"," toxic_response_response = examples[\"toxic_response\"]\n","\n"," # Then tokenize these modified fields.\n"," tokens_non_toxic = tokenizer.encode_plus(non_toxic_response, **kwargs)\n"," tokens_toxic = tokenizer.encode_plus(toxic_response_response, **kwargs)\n","\n"," return {\n"," \"input_ids_chosen\": tokens_non_toxic[\"input_ids\"][0], \"attention_mask_chosen\": tokens_non_toxic[\"attention_mask\"][0],\n"," \"input_ids_rejected\": tokens_toxic[\"input_ids\"][0], \"attention_mask_rejected\": tokens_toxic[\"attention_mask\"][0]\n"," }\n","\n","formatted_dataset = dataset.map(preprocess)\n","formatted_dataset = formatted_dataset.train_test_split()"]},{"cell_type":"markdown","metadata":{"id":"rpytDiCusMk4"},"source":["5. Train your model while tracking the CO2 emissions. 🌱"]},{"cell_type":"code","execution_count":6,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"elapsed":1220169,"status":"ok","timestamp":1686696553380,"user":{"displayName":"Nicholas Corrêa","userId":"09736120585766268588"},"user_tz":-120},"id":"KwAoxTTF62DD","outputId":"2a24f901-297d-487f-df93-f52a3c0f5881"},"outputs":[{"name":"stderr","output_type":"stream","text":["/usr/local/lib/python3.10/dist-packages/trl/trainer/reward_trainer.py:125: UserWarning: When using RewardDataCollatorWithPadding, you should set `max_length` in the RewardTrainer's init it will be set to `512` by default, but you should do it yourself in the future.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/trl/trainer/reward_trainer.py:136: UserWarning: When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.\n"," warnings.warn(\n","[codecarbon INFO @ 22:29:00] [setup] RAM Tracking...\n","[codecarbon INFO @ 22:29:00] [setup] GPU Tracking...\n","[codecarbon INFO @ 22:29:00] Tracking Nvidia GPU via pynvml\n","[codecarbon INFO @ 22:29:00] [setup] CPU Tracking...\n","[codecarbon WARNING @ 22:29:00] No CPU tracking mode found. Falling back on CPU constant mode.\n","[codecarbon WARNING @ 22:29:02] We saw that you have a Intel(R) Xeon(R) CPU @ 2.20GHz but we don't know it. Please contact us.\n","[codecarbon INFO @ 22:29:02] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU @ 2.20GHz\n","[codecarbon INFO @ 22:29:02] >>> Tracker's metadata:\n","[codecarbon INFO @ 22:29:02] Platform system: Linux-5.15.107+-x86_64-with-glibc2.31\n","[codecarbon INFO @ 22:29:02] Python version: 3.10.12\n","[codecarbon INFO @ 22:29:02] CodeCarbon version: 2.2.3\n","[codecarbon INFO @ 22:29:02] Available RAM : 83.481 GB\n","[codecarbon INFO @ 22:29:02] CPU count: 12\n","[codecarbon INFO @ 22:29:02] CPU model: Intel(R) Xeon(R) CPU @ 2.20GHz\n","[codecarbon INFO @ 22:29:02] GPU count: 1\n","[codecarbon INFO @ 22:29:02] GPU model: 1 x NVIDIA A100-SXM4-40GB\n","/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n"," warnings.warn(\n","You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n","/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2395: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.\n"," warnings.warn(\n","Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"]},{"data":{"text/html":["\n","
Step | \n","Training Loss | \n","Validation Loss | \n","Accuracy | \n","
---|---|---|---|
200 | \n","0.278900 | \n","0.256261 | \n","0.900550 | \n","
400 | \n","0.173800 | \n","0.246119 | \n","0.902940 | \n","
600 | \n","0.119500 | \n","0.240692 | \n","0.908917 | \n","
800 | \n","0.047700 | \n","0.342544 | \n","0.902223 | \n","
"],"text/plain":["