{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "JmYRMwEOYkbU" }, "source": [ "# `Fire Detect - ViT`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c6rbUun0tdC5" }, "outputs": [], "source": [ "!pip install evaluate datasets accelerate\n", "!pip install git+https://github.com/huggingface/transformers.git" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "dhLosa2Utm5M" }, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "import gc\n", "import numpy as np\n", "import pandas as pd\n", "import itertools\n", "from collections import Counter\n", "import matplotlib.pyplot as plt\n", "from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, f1_score\n", "from imblearn.over_sampling import RandomOverSampler\n", "import evaluate\n", "from datasets import Dataset, Image, ClassLabel\n", "from transformers import (\n", " TrainingArguments,\n", " Trainer,\n", " ViTImageProcessor,\n", " ViTForImageClassification,\n", " DefaultDataCollator\n", ")\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from torchvision.transforms import (\n", " CenterCrop,\n", " Compose,\n", " Normalize,\n", " RandomRotation,\n", " RandomResizedCrop,\n", " RandomHorizontalFlip,\n", " RandomAdjustSharpness,\n", " Resize,\n", " ToTensor\n", ")\n", "from PIL import Image as PILImage\n", "from PIL import ImageFile\n", "\n", "# Enable loading truncated images\n", "ImageFile.LOAD_TRUNCATED_IMAGES = True" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "20RJuU8_uY2k" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "dataset = load_dataset(\"--your--dataset-goes--here\", split=\"train\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "o8rgwG0nuc00" }, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "file_names = []\n", "labels = []\n", "\n", "for example in dataset:\n", " file_path = str(example['image']) # Convert the image object to a string or path\n", " label = example['label'] # Get the label\n", "\n", " file_names.append(file_path) # Add the file path to the list\n", " labels.append(label) # Add the label to the list" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yz8qs87tuhjs" }, "outputs": [], "source": [ "# Print the total number of file names and labels\n", "print(len(file_names), len(labels))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "7CW5l8Td_V-4" }, "outputs": [], "source": [ "# Create a pandas dataframe from the collected file names and labels\n", "df = pd.DataFrame.from_dict({\"image\": file_names, \"label\": labels})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9SZz49oNBSHf" }, "outputs": [], "source": [ "print(df.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZubCrfrhBZGo" }, "outputs": [], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oFwJ-2_B_br5" }, "outputs": [], "source": [ "df['label'].unique()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZzX_P-onunr7" }, "outputs": [], "source": [ "y = df[['label']]\n", "df = df.drop(['label'], axis=1)\n", "ros = RandomOverSampler(random_state=83)\n", "df, y_resampled = ros.fit_resample(df, y)\n", "del y\n", "df['label'] = y_resampled\n", "del y_resampled\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "WaJ_C30L_N_L" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "# Create a DataFrame from the collected file names and labels\n", "df = pd.DataFrame.from_dict({\"image\": file_names, \"label\": labels})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ha4Bpgoz7dfu" }, "outputs": [], "source": [ "dataset[10][\"image\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-MfUvn2A-tBc" }, "outputs": [], "source": [ "labels_subset = labels[:5]\n", "\n", "# Printing the subset of labels to inspect the content.\n", "print(labels_subset)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "9u1W0MBhBpMA" }, "outputs": [], "source": [ "# Define the new list of unique labels\n", "labels_list = ['Fire Needed Action', 'Normal Conditions', 'Smoky Environment']\n", "\n", "# Initialize dictionaries to map labels to IDs and vice versa\n", "label2id, id2label = {}, {}\n", "for i, label in enumerate(labels_list):\n", " label2id[label] = i\n", " id2label[i] = label\n", "\n", "# Create ClassLabels object\n", "ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4CfU5GJkByam", "outputId": "6d206be1-ad03-41c3-a6d4-7127f490f037" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mapping of IDs to Labels: {0: 'Fire Needed Action', 1: 'Normal Conditions', 2: 'Smoky Environment'} \n", "\n", "Mapping of Labels to IDs: {'Fire Needed Action': 0, 'Normal Conditions': 1, 'Smoky Environment': 2}\n" ] } ], "source": [ "# Print the resulting dictionaries for reference\n", "print(\"Mapping of IDs to Labels:\", id2label, '\\n')\n", "print(\"Mapping of Labels to IDs:\", label2id)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 81, "referenced_widgets": [ "482c35ac85834319987d7b901227d688", "f1ebc501cb5f4f5e83b29076f5aff39b", "a145e3c6cd854b3aa47cf66e36e4cb04", "4104f20f77494fa1aea8aa06e87697a6", "b8925ceffe9144b88bf53389732f5307", "38050770fce14be7b51db96afc1d6785", "6c69d5012544464e8ec4f40ffb3d89f3", "2eb097c8ba994ce691a6bec25c09bf78", "3fab640eea2347588b7c2d692ca1c2ee", "36aa6bf214f7415f9fbce160e18c8822", "abd15e0e6fb94080862d151f4eba260c", "e821343bf5d54850b0d53fc572cbf7e0", "4f31647cf0f0465380c727a913476014", "3c27508b812f4e8f92d83c61d0ffbcc4", "18dcc402180e4cb6b1cb58710826f42b", "52d346a2ed8b4b9095e55a45c2a50522", "1211ac4560ca4710af89b2efdf22171a", "58709f429a31423aa848f64bb9badfe3", "88b5d89d805a483f9b251be0e54d02f3", "1f5a57aed02c47939241ec2af541df73", "0747a395d3e84e64bd8ea81a824e3061", "e4f04ffcb3e94a67a76d3b4ec0977375" ] }, "id": "M9XI2VNYB35G", "outputId": "b39d946d-841f-493b-f023-11f88fba4a7c" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Map: 0%| | 0/6060 [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "482c35ac85834319987d7b901227d688" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Casting the dataset: 0%| | 0/6060 [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "e821343bf5d54850b0d53fc572cbf7e0" } }, "metadata": {} } ], "source": [ "\n", "# Mapping labels to IDs\n", "def map_label2id(example):\n", " example['label'] = ClassLabels.str2int(example['label'])\n", " return example\n", "\n", "dataset = dataset.map(map_label2id, batched=True)\n", "\n", "# Casting label column to ClassLabel Object\n", "dataset = dataset.cast_column('label', ClassLabels)\n", "\n", "# Splitting the dataset into training and testing sets using an 60-40 split ratio.\n", "dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column=\"label\")\n", "\n", "# Extracting the training data from the split dataset.\n", "train_data = dataset['train']\n", "\n", "# Extracting the testing data from the split dataset.\n", "test_data = dataset['test']" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "57f0f1a79fb543f580b15560b5ca088a", "adf6dae0b8474e09afe808103fd7b89d", "b48a214a8151453c8776d9102f133a6e", "b4f7398a20414319bfdb53061f64a1cc", "24d919b10295493faa9e6064ba291fc1", "2e87625fa6a74137ae9ebf57857a6eb1", "c90266df84db48e184bf3e28e00ca8e6", "511f0b07b20342dbbefbd1c3e5e13b97", "615daaf1d20b47fcaef4462992cb813b", "7ba7ee22989f43b79fcc533096f3fc08", "b232c157ed9a4c4593b6eac0a40796a4" ] }, "id": "FvEEn_iDB8uX", "outputId": "d6cd3d21-1707-425f-e268-f693ea880ef6" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "preprocessor_config.json: 0%| | 0.00/160 [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "57f0f1a79fb543f580b15560b5ca088a" } }, "metadata": {} } ], "source": [ "# Define the pre-trained ViT model string\n", "model_str = \"google/vit-base-patch16-224-in21k\"\n", "\n", "# Create a processor for ViT model input\n", "processor = ViTImageProcessor.from_pretrained(model_str)\n", "\n", "# Retrieve the image mean and standard deviation used for normalization\n", "image_mean, image_std = processor.image_mean, processor.image_std\n", "size = processor.size[\"height\"]\n", "\n", "# Define transformations for training and validation data\n", "_train_transforms = Compose(\n", " [\n", " Resize((size, size)),\n", " RandomRotation(90),\n", " RandomAdjustSharpness(2),\n", " ToTensor(),\n", " Normalize(mean=image_mean, std=image_std)\n", " ]\n", ")\n", "\n", "_val_transforms = Compose(\n", " [\n", " Resize((size, size)),\n", " ToTensor(),\n", " Normalize(mean=image_mean, std=image_std)\n", " ]\n", ")\n", "\n", "# Define functions to apply transformations\n", "def train_transforms(examples):\n", " examples['pixel_values'] = [_train_transforms(image.convert(\"RGB\")) for image in examples['image']]\n", " return examples\n", "\n", "def val_transforms(examples):\n", " examples['pixel_values'] = [_val_transforms(image.convert(\"RGB\")) for image in examples['image']]\n", " return examples\n", "\n", "# Set transforms for training and test data\n", "train_data.set_transform(train_transforms)\n", "test_data.set_transform(val_transforms)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "9nEcswLACNVt" }, "outputs": [], "source": [ "def collate_fn(examples):\n", " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n", " labels = torch.tensor([example['label'] for example in examples])\n", " return {\"pixel_values\": pixel_values, \"labels\": labels}" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 153, "referenced_widgets": [ "59f7229c7f394a18ae09f1e955060cd6", "9c1b11fe20e74b41a9353dfa114e4164", "b32da91ce5434373b857701dba8bbec5", "e7ec2e04c576441f832b2c667d452877", "b2e9913dda2d4a3585288d96a2a6a079", "6ea76cd9ef734777a05f9abeec0579be", "5a79c774a5c84cd2bedb4b67079dd867", "00c22c77a94d4a5594a1a595cc6ca6d0", "5ef50ca91aaa4ed59bd9f7673725c359", "9d5b033d7adc4001a6acfa3a1e3266fe", "14e1f5bf75b444dfafb6a109f286844a", "5fb0b18027d14b9bb362c518c53d56d4", "c8689011aa3048f78fc0dd2c73c757ea", "e91481739fc748fd92c30cb934565f03", "49a92101ab754ca1bf8dc8d042a1d5c5", "37cc330abb9d4b4c946e13e1ef279b83", "224d5bd48e1e4fe385de2af3d7aee0d4", "2d76bac331f04e9689175aa90005a545", "72be780a86414d9a8928496ae3f66b96", "9cd9b6bdbfbd4354b3cd5c4eca14eff1", "52184f59a7554977bea1e5e402958fa9", "ebb6e365b87d42af8952f76f22243390" ] }, "id": "goTy8XdcCS7t", "outputId": "c8199676-0733-4ecf-a135-e2799b4056ea" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "config.json: 0%| | 0.00/502 [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "59f7229c7f394a18ae09f1e955060cd6" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "model.safetensors: 0%| | 0.00/346M [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "5fb0b18027d14b9bb362c518c53d56d4" } }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "85.800963\n" ] } ], "source": [ "# Create a ViTForImageClassification model\n", "model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))\n", "model.config.id2label = id2label\n", "model.config.label2id = label2id\n", "\n", "# Calculate and print the number of trainable parameters in millions for the model.\n", "print(model.num_parameters(only_trainable=True) / 1e6)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "fc63188eae5c41d5933f40e540903742", "5a31be32215a4b28ae20a9df7c6e83dd", "5e7e22b7529347d18b9dad6335bbf43b", "e2b810035c954fa0872ad44e82cab363", "6fd546c2794e40e8841eb698a1a5c762", "d1ae2f2bbee24533ae4d1faae0a09374", "5e3a188468df4a2a984aac5201fa22f3", "0accce4bc641429f8b257aba26dc9625", "c416a98c75ed44a8b7c1c1640808419c", "11ddc555dade435c95b151ef5378584a", "af1a0d847c8a4fd28401cad2130a4846" ] }, "id": "bRleo0I-CqPv", "outputId": "84be13e1-71ec-46cc-d9e2-3b6a29c59c56" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading builder script: 0%| | 0.00/4.20k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "fc63188eae5c41d5933f40e540903742" } }, "metadata": {} } ], "source": [ "# Load the accuracy metric from a module named 'evaluate'\n", "accuracy = evaluate.load(\"accuracy\")\n", "\n", "# Define a function 'compute_metrics' to calculate evaluation metrics\n", "def compute_metrics(eval_pred):\n", " # Extract model predictions from the evaluation prediction object\n", " predictions = eval_pred.predictions\n", "\n", " # Extract true labels from the evaluation prediction object\n", " label_ids = eval_pred.label_ids\n", "\n", " # Calculate accuracy using the loaded accuracy metric\n", " # Convert model predictions to class labels by selecting the class with the highest probability (argmax)\n", " predicted_labels = predictions.argmax(axis=1)\n", "\n", " # Calculate accuracy score by comparing predicted labels to true labels\n", " acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']\n", "\n", " # Return the computed accuracy as a dictionary with the key \"accuracy\"\n", " return {\n", " \"accuracy\": acc_score\n", " }" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "xtOJErdaCcy-" }, "outputs": [], "source": [ "# Define training arguments\n", "args = TrainingArguments(\n", " output_dir=\"Fire-Normal-Smoke\",\n", " logging_dir='./logs',\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=5e-6,\n", " per_device_train_batch_size=64,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=6,\n", " weight_decay=0.02,\n", " warmup_steps=50,\n", " remove_unused_columns=False,\n", " save_strategy='epoch',\n", " load_best_model_at_end=True,\n", " save_total_limit=1,\n", " report_to=\"none\"\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "qX_SWsL3C2DN" }, "outputs": [], "source": [ "# Create a Trainer instance\n", "trainer = Trainer(\n", " model,\n", " args,\n", " train_dataset=train_data,\n", " eval_dataset=test_data,\n", " data_collator=collate_fn,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0CnRoQf5C4D-" }, "outputs": [], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iZUij-VwELXU" }, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b8caa0kpH9uf" }, "outputs": [], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "GQo5AFJ1N1GC" }, "outputs": [], "source": [ "outputs = trainer.predict(test_data)\n", "print(outputs.metrics)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vZE0GZXnIVtZ" }, "outputs": [], "source": [ "# Extract the true labels from the model outputs\n", "y_true = outputs.label_ids\n", "\n", "# Predict the labels by selecting the class with the highest probability\n", "y_pred = outputs.predictions.argmax(1)\n", "\n", "# Define a function to plot a confusion matrix\n", "def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues, figsize=(10, 8)):\n", " \"\"\"\n", " This function plots a confusion matrix.\n", "\n", " Parameters:\n", " cm (array-like): Confusion matrix as returned by sklearn.metrics.confusion_matrix.\n", " classes (list): List of class names, e.g., ['Class 0', 'Class 1'].\n", " title (str): Title for the plot.\n", " cmap (matplotlib colormap): Colormap for the plot.\n", " \"\"\"\n", " # Create a figure with a specified size\n", " plt.figure(figsize=figsize)\n", "\n", " # Display the confusion matrix as an image with a colormap\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", " plt.colorbar()\n", "\n", " # Define tick marks and labels for the classes on the axes\n", " tick_marks = np.arange(len(classes))\n", " plt.xticks(tick_marks, classes, rotation=90)\n", " plt.yticks(tick_marks, classes)\n", "\n", " fmt = '.0f'\n", " # Add text annotations to the plot indicating the values in the cells\n", " thresh = cm.max() / 2.0\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, format(cm[i, j], fmt), horizontalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", " # Label the axes\n", " plt.ylabel('True label')\n", " plt.xlabel('Predicted label')\n", "\n", " # Ensure the plot layout is tight\n", " plt.tight_layout()\n", " # Display the plot\n", " plt.show()\n", "\n", "# Calculate accuracy and F1 score\n", "accuracy = accuracy_score(y_true, y_pred)\n", "f1 = f1_score(y_true, y_pred, average='macro')\n", "\n", "# Display accuracy and F1 score\n", "print(f\"Accuracy: {accuracy:.4f}\")\n", "print(f\"F1 Score: {f1:.4f}\")\n", "\n", "# Get the confusion matrix if there are a small number of labels\n", "if len(labels_list) <= 150:\n", " # Compute the confusion matrix\n", " cm = confusion_matrix(y_true, y_pred)\n", "\n", " # Plot the confusion matrix using the defined function\n", " plot_confusion_matrix(cm, labels_list, figsize=(8, 6))\n", "\n", "# Finally, display classification report\n", "print()\n", "print(\"Classification report:\")\n", "print()\n", "print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "Qj1F9FLgIedG" }, "outputs": [], "source": [ "trainer.save_model()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F-jYKwyWIiue" }, "outputs": [], "source": [ "# Import the 'pipeline' function from the 'transformers' library.\n", "from transformers import pipeline\n", "pipe = pipeline('image-classification', model=model_name, device=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Iy0WsizHIm_m" }, "outputs": [], "source": [ "image = test_data[1][\"image\"]\n", "image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TG7rBUMnIpXl" }, "outputs": [], "source": [ "pipe(image)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CtTHa2_DIq0x" }, "outputs": [], "source": [ "id2label[test_data[1][\"label\"]]" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17, "referenced_widgets": [ "d35991d81d964d1c91f7980206bcde8a", "e20b7b3e6bd84b61bdf797b89491a186", "6595a7dea024482e9739c99557b4b7e1", "23015c067ab24d9ab1eed28a031a167c", "62609b47efc744c28b3c95c7f0837e4f", "bad52f9d97e444f698772960c7ddf05c", "90a4e2b717ec452f868467f31c55052d", "b7e8db77d1e64878a2a198233892d1e0", "ae01737748d84f57ade069c2a66e95d8", "1bdc9c959a5b4a0abc2bafb6647e8645", "7e41ea4d67684b16b2d559b7a660b6d3", "160bbd00f1f44be58a418d005e7dcde5", "107b31ff64ce43febd9928e28953cc33", "804f29bc6f804fad8a6fad1cca73ba47", "2dac1441168345f89d10cfc88863f5eb", "8a2eba49a0c647e78e01752342df9432", "7c81b574f418477fbbb54209c2176904", "457fa3d625c64c128eb8cc6f8b28e8cd", "541b70018449416bbf776894239874a7", "ea8be29437dd4f22a8f21004b9aa6db5" ] }, "id": "BYY0rKBJIsgI", "outputId": "cae1d0df-4220-4b61-fa17-cfabf0f9e9f1" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='