{ "cells": [ { "cell_type": "markdown", "id": "dabed352-d637-47d9-b1c0-b99c407f2f8d", "metadata": {}, "source": [ "# Data Exploration and Preparation" ] }, { "cell_type": "code", "execution_count": 1, "id": "bc1b7bf2-84da-495d-939a-358c94859631", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-19 20:47:16.372082: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2025-01-19 20:47:16.396318: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2025-01-19 20:47:16.402421: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-01-19 20:47:16.420317: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from keras.datasets import mnist\n", "from keras import Sequential\n", "from keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout" ] }, { "cell_type": "code", "execution_count": 2, "id": "f19a0fc3-be77-44a4-a013-a5c46e967c55", "metadata": {}, "outputs": [], "source": [ "mnist = mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()" ] }, { "cell_type": "markdown", "id": "4044a79e-b763-422c-a555-c7fa258b1f66", "metadata": {}, "source": [ "## Inspection of the dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "41ba16c1-b75f-4cd9-adf5-41506b733598", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train set dimensions (60000, 28, 28)\n", "train set labels [5 0 4 ... 5 6 8]\n", "-------------------\n", "test set dimensions (10000, 28, 28)\n" ] } ], "source": [ "print('train set dimensions ', train_images.shape)\n", "print('train set labels', train_labels)\n", "print('-------------------')\n", "print('test set dimensions',test_images.shape)" ] }, { "cell_type": "code", "execution_count": 4, "id": "a9100043-9679-4ad1-9e7b-14e34d68c4fa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unique classes in the MNIST train set [0 1 2 3 4 5 6 7 8 9]\n", "Frequency of unique classes in MNIST train set (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8), array([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]))\n" ] } ], "source": [ "print ('Unique classes in the MNIST train set', np. unique(train_labels))\n", "print ('Frequency of unique classes in MNIST train set', np. unique(train_labels, return_counts=True))" ] }, { "cell_type": "code", "execution_count": 5, "id": "8936b73d-00c0-40dd-bd36-e0f19ee52992", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unique classes in the MNIST test set [0 1 2 3 4 5 6 7 8 9]\n", "Frequency of unique classes in MNIST test set (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8), array([ 980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009]))\n" ] } ], "source": [ "print ('Unique classes in the MNIST test set', np. unique(test_labels))\n", "print ('Frequency of unique classes in MNIST test set', np. unique(test_labels, return_counts=True))" ] }, { "cell_type": "code", "execution_count": 6, "id": "4c550e46-f950-425d-9ae9-debdaa8971f2", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(20, 4))\n", "\n", "# Loop through the first 10 images to get an idea of the dataset\n", "for i in range(10):\n", " plt.subplot(1, 10, i + 1)\n", " plt.imshow(train_images[i], cmap='gray')\n", " plt.axis('off') \n", " plt.title(f\"Label: {train_labels[i]}\", fontsize=8)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "id": "88206a02-ddf5-4f3b-94ba-3157733920fb", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.imshow(train_images[0])\n", "plt.colorbar()\n", "plt.grid(False)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "9e5d5a9d-326c-410a-ad78-02eb3c235d09", "metadata": {}, "source": [ "## Data preprocessing" ] }, { "cell_type": "code", "execution_count": 8, "id": "9715b5e8-6223-4114-8416-476834e99133", "metadata": {}, "outputs": [], "source": [ "def preprocess_data(train_images, test_images):\n", " train_images = train_images / 255.0\n", " test_images = test_images / 255.0\n", " train_images = np.expand_dims(train_images, axis=-1)\n", " test_images = np.expand_dims(test_images, axis=-1)\n", " return train_images, test_images" ] }, { "cell_type": "code", "execution_count": 9, "id": "838d3810-00a7-4c69-935c-c4c59fb30306", "metadata": {}, "outputs": [], "source": [ "train_images, test_images = preprocess_data(train_images, test_images)" ] }, { "cell_type": "code", "execution_count": 10, "id": "c707e0a3-6bf5-49f3-b7e6-1b765f3c22b4", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.imshow(train_images[0])\n", "plt.colorbar()\n", "plt.grid(False)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "3fc16921-a9e1-4cfe-aee8-a28c3b980678", "metadata": {}, "source": [ "# A simple CNN, trained with the original data" ] }, { "cell_type": "code", "execution_count": 11, "id": "53f60557-8a54-47f1-a031-8ffcdf99fee8", "metadata": {}, "outputs": [], "source": [ "model = Sequential([\n", " Flatten(),\n", " Dense(128, activation=tf.nn.relu),\n", " Dense(10, activation=tf.nn.softmax)\n", "])" ] }, { "cell_type": "code", "execution_count": 12, "id": "f7a5bf92-cbe4-47f2-8c6f-62a96adf570d", "metadata": {}, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 13, "id": "c5c5e756-0735-4e6d-86a5-33009a44e199", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 8ms/step - accuracy: 0.8804 - loss: 0.4226\n", "Epoch 2/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 9ms/step - accuracy: 0.9634 - loss: 0.1228\n", "Epoch 3/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 9ms/step - accuracy: 0.9766 - loss: 0.0772\n", "Epoch 4/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 8ms/step - accuracy: 0.9813 - loss: 0.0627\n", "Epoch 5/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 9ms/step - accuracy: 0.9869 - loss: 0.0447\n", "Epoch 6/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 8ms/step - accuracy: 0.9893 - loss: 0.0332\n", "Epoch 7/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 9ms/step - accuracy: 0.9926 - loss: 0.0250\n", "Epoch 8/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 8ms/step - accuracy: 0.9936 - loss: 0.0224\n", "Epoch 9/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 8ms/step - accuracy: 0.9942 - loss: 0.0187\n", "Epoch 10/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 8ms/step - accuracy: 0.9962 - loss: 0.0132\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(train_images, train_labels, epochs=10)" ] }, { "cell_type": "code", "execution_count": 14, "id": "3d879a06-4779-450b-a730-2fd403409392", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9739 - loss: 0.0968\n", "Test accuracy: 0.9768999814987183\n" ] } ], "source": [ "test_loss, test_acc = model.evaluate(test_images, test_labels)\n", "print('Test accuracy:', test_acc)" ] }, { "cell_type": "markdown", "id": "55770481-81cd-4d9c-855f-2a4ad026d5c5", "metadata": {}, "source": [ "# Dataset augmentation" ] }, { "cell_type": "code", "execution_count": 15, "id": "5cc13e61-aebf-46e1-94b1-a20d369e0f7d", "metadata": {}, "outputs": [], "source": [ "from keras.src.legacy.preprocessing.image import ImageDataGenerator\n", "\n", "datagen = ImageDataGenerator(\n", " rotation_range=40,\n", " width_shift_range=0.2,\n", " height_shift_range=0.2,\n", " shear_range=0.1,\n", " zoom_range=0.2,\n", " horizontal_flip=False, #because of 6 and 9\n", " fill_mode='nearest'\n", ")\n", "\n", "datagen.fit(train_images)" ] }, { "cell_type": "markdown", "id": "e9acbf8f-be21-4c07-8904-2b97262e7922", "metadata": {}, "source": [ "### Explore some of the augmented images" ] }, { "cell_type": "code", "execution_count": 16, "id": "51fdf7c3-0016-4bd6-a99f-9ffebfffd6cf", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Filter the images to get only those with label '2' for better understanding and comparison\n", "label_of_interest = 2\n", "filtered_images = train_images[train_labels == label_of_interest]\n", "filtered_labels = train_labels[train_labels == label_of_interest]\n", "\n", "augmented_images = datagen.flow(filtered_images, batch_size=5)\n", "images = next(augmented_images) \n", "\n", "# Plot 5 augmented images\n", "fig, axes = plt.subplots(1, 5, figsize=(15, 3)) # 1 row, 5 images\n", "axes = axes.flatten()\n", "\n", "for i in range(5):\n", " axes[i].imshow(images[i].squeeze(), cmap='gray')\n", " axes[i].axis('off') \n", " axes[i].set_title(f\"Label: {label_of_interest}\")\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "a49fee01-6270-4afa-8568-525f43978467", "metadata": {}, "source": [ "# Simple CNN with augmented dataset" ] }, { "cell_type": "code", "execution_count": 17, "id": "0e468ad2-5c10-461d-b39f-38ef3a56b826", "metadata": {}, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 19, "id": "bb11d7d9-d7a9-495d-ae64-f4609fdbdfbd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m58s\u001b[0m 31ms/step - accuracy: 0.6399 - loss: 1.1305\n", "Epoch 2/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m57s\u001b[0m 30ms/step - accuracy: 0.7456 - loss: 0.8298\n", "Epoch 3/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m59s\u001b[0m 31ms/step - accuracy: 0.7830 - loss: 0.6953\n", "Epoch 4/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m58s\u001b[0m 31ms/step - accuracy: 0.8088 - loss: 0.6210\n", "Epoch 5/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m57s\u001b[0m 30ms/step - accuracy: 0.8225 - loss: 0.5732\n", "Epoch 6/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m56s\u001b[0m 30ms/step - accuracy: 0.8357 - loss: 0.5396\n", "Epoch 7/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m59s\u001b[0m 31ms/step - accuracy: 0.8414 - loss: 0.5202\n", "Epoch 8/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m53s\u001b[0m 28ms/step - accuracy: 0.8482 - loss: 0.5012\n", "Epoch 9/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m54s\u001b[0m 29ms/step - accuracy: 0.8566 - loss: 0.4772\n", "Epoch 10/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m57s\u001b[0m 30ms/step - accuracy: 0.8574 - loss: 0.4629\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Train the model using the augmented data generator\n", "model.fit(datagen.flow(train_images, train_labels, batch_size=32), epochs=10)" ] }, { "cell_type": "code", "execution_count": 20, "id": "5d4bbd2a-9e18-49dd-a555-d1841d5d68f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9229 - loss: 0.2424\n", "Augmented test accuracy: 0.932200014591217\n" ] } ], "source": [ "# Evaluate the model on the test set\n", "augmented_test_loss, augmented_test_acc = model.evaluate(test_images, test_labels)\n", "print('Augmented test accuracy:', augmented_test_acc)" ] }, { "cell_type": "code", "execution_count": 21, "id": "9a788df7-c878-49f6-86b5-e385be7bf842", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Simple test accuracy: 0.9768999814987183\n", "Simple augmented test accuracy: 0.932200014591217\n" ] } ], "source": [ "# Compare the model accuracy on the original set vs the augmented set\n", "print('Simple test accuracy:', test_acc)\n", "print('Simple augmented test accuracy:', augmented_test_acc)" ] }, { "cell_type": "markdown", "id": "213269a1-95be-4f45-ae13-76ceaac35828", "metadata": {}, "source": [ "# More advanced CNN with augmented dataset" ] }, { "cell_type": "code", "execution_count": 22, "id": "89cca745-bda4-4a8a-92f0-8d31d508cc8e", "metadata": {}, "outputs": [], "source": [ "# more layers are added\n", "def build_advanced_model():\n", " model = Sequential([\n", " Input(shape=(28, 28, 1)),\n", " Conv2D(32, (3, 3), activation='relu'),\n", " MaxPooling2D((2, 2)),\n", " Conv2D(64, (3, 3), activation='relu'),\n", " MaxPooling2D((2, 2)),\n", " Flatten(),\n", " Dense(128, activation='relu'),\n", " Dropout(0.5),\n", " Dense(10, activation='softmax')\n", " ])\n", " model.compile(optimizer='adam',\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])\n", " return model" ] }, { "cell_type": "code", "execution_count": 23, "id": "c1dfe2cf-86f9-4f36-8184-1dd566ece339", "metadata": {}, "outputs": [], "source": [ "advanced_model = build_advanced_model()" ] }, { "cell_type": "code", "execution_count": 24, "id": "93ba283c-6a0f-49c8-adf6-505ac51cbece", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m251s\u001b[0m 133ms/step - accuracy: 0.5385 - loss: 1.3305 - val_accuracy: 0.9559 - val_loss: 0.1485\n", "Epoch 2/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m246s\u001b[0m 131ms/step - accuracy: 0.8474 - loss: 0.4887 - val_accuracy: 0.9647 - val_loss: 0.1102\n", "Epoch 3/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m241s\u001b[0m 129ms/step - accuracy: 0.8866 - loss: 0.3713 - val_accuracy: 0.9685 - val_loss: 0.0905\n", "Epoch 4/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m241s\u001b[0m 128ms/step - accuracy: 0.9064 - loss: 0.3077 - val_accuracy: 0.9735 - val_loss: 0.0727\n", "Epoch 5/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m245s\u001b[0m 130ms/step - accuracy: 0.9166 - loss: 0.2790 - val_accuracy: 0.9762 - val_loss: 0.0756\n", "Epoch 6/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m243s\u001b[0m 129ms/step - accuracy: 0.9195 - loss: 0.2690 - val_accuracy: 0.9792 - val_loss: 0.0639\n", "Epoch 7/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m245s\u001b[0m 130ms/step - accuracy: 0.9270 - loss: 0.2395 - val_accuracy: 0.9804 - val_loss: 0.0569\n", "Epoch 8/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m240s\u001b[0m 128ms/step - accuracy: 0.9340 - loss: 0.2231 - val_accuracy: 0.9774 - val_loss: 0.0704\n", "Epoch 9/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m243s\u001b[0m 130ms/step - accuracy: 0.9356 - loss: 0.2144 - val_accuracy: 0.9844 - val_loss: 0.0496\n", "Epoch 10/10\n", "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m240s\u001b[0m 128ms/step - accuracy: 0.9395 - loss: 0.2053 - val_accuracy: 0.9827 - val_loss: 0.0538\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "advanced_model.fit(datagen.flow(train_images, train_labels, batch_size=32),\n", " validation_data=(test_images, test_labels),\n", " epochs=10)" ] }, { "cell_type": "code", "execution_count": 25, "id": "ba5a04bc-f912-45ee-8c38-f860ae59687a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m9s\u001b[0m 28ms/step - accuracy: 0.9814 - loss: 0.0564\n", "Simple test accuracy: 0.9768999814987183\n", "Simple augmented test accuracy: 0.932200014591217\n", "Advanced test accuracy: 0.982699990272522\n" ] } ], "source": [ "advanced_test_loss, advanced_test_acc = advanced_model.evaluate(test_images, test_labels)\n", "print('Simple test accuracy:', test_acc)\n", "print('Simple augmented test accuracy:', augmented_test_acc)\n", "print('Advanced test accuracy:', advanced_test_acc)" ] }, { "cell_type": "code", "execution_count": 26, "id": "2178138b-5388-45a0-9e01-5cbae10e5bb9", "metadata": {}, "outputs": [], "source": [ "# Export functions to make them importable\n", "if __name__ == \"__main__\":\n", " pass # Prevent unintended execution during import" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel) *", "language": "python", "name": "conda-base-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }