{ "cells": [ { "cell_type": "markdown", "id": "c2dc16c4-c7f4-4945-ba91-6430a51e6f5a", "metadata": { "id": "c2dc16c4-c7f4-4945-ba91-6430a51e6f5a" }, "source": [ "\"Open\n", "\n", "[View Source Code](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/03_pytorch_computer_vision.ipynb) | [View Slides](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/slides/03_pytorch_computer_vision.pdf) | [Watch Video Walkthrough](https://youtu.be/Z_ikDlimN6A?t=50417) " ] }, { "cell_type": "markdown", "id": "08f47c6a-3318-4e3f-8bb3-c520e00e63dd", "metadata": { "id": "08f47c6a-3318-4e3f-8bb3-c520e00e63dd" }, "source": [ "# 03. PyTorch Computer Vision\n", "\n", "[Computer vision](https://en.wikipedia.org/wiki/Computer_vision) is the art of teaching a computer to see.\n", "\n", "For example, it could involve building a model to classify whether a photo is of a cat or a dog ([binary classification](https://developers.google.com/machine-learning/glossary#binary-classification)).\n", "\n", "Or whether a photo is of a cat, dog or chicken ([multi-class classification](https://developers.google.com/machine-learning/glossary#multi-class-classification)).\n", "\n", "Or identifying where a car appears in a video frame ([object detection](https://en.wikipedia.org/wiki/Object_detection)).\n", "\n", "Or figuring out where different objects in an image can be separated ([panoptic segmentation](https://arxiv.org/abs/1801.00868)).\n", "\n", "![example computer vision problems](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-computer-vision-problems.png)\n", "*Example computer vision problems for binary classification, multiclass classification, object detection and segmentation.*" ] }, { "cell_type": "markdown", "id": "19179a39-0c6c-40f7-9891-09e17d107ecf", "metadata": { "id": "19179a39-0c6c-40f7-9891-09e17d107ecf" }, "source": [ "## Where does computer vision get used?\n", "\n", "If you use a smartphone, you've already used computer vision.\n", "\n", "Camera and photo apps use [computer vision to enhance](https://machinelearning.apple.com/research/panoptic-segmentation) and sort images.\n", "\n", "Modern cars use [computer vision](https://youtu.be/j0z4FweCy4M?t=2989) to avoid other cars and stay within lane lines.\n", "\n", "Manufacturers use computer vision to identify defects in various products.\n", "\n", "Security cameras use computer vision to detect potential intruders.\n", "\n", "In essence, anything that can described in a visual sense can be a potential computer vision problem." ] }, { "cell_type": "markdown", "id": "412e8bd1-0e6b-4ad6-8506-b28a8f669dc1", "metadata": { "id": "412e8bd1-0e6b-4ad6-8506-b28a8f669dc1" }, "source": [ "## What we're going to cover\n", "\n", "We're going to apply the PyTorch Workflow we've been learning in the past couple of sections to computer vision.\n", "\n", "![a PyTorch workflow with a computer vision focus](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-pytorch-computer-vision-workflow.png)\n", "\n", "Specifically, we're going to cover:\n", "\n", "| **Topic** | **Contents** |\n", "| ----- | ----- |\n", "| **0. Computer vision libraries in PyTorch** | PyTorch has a bunch of built-in helpful computer vision libraries, let's check them out. |\n", "| **1. Load data** | To practice computer vision, we'll start with some images of different pieces of clothing from [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist). |\n", "| **2. Prepare data** | We've got some images, let's load them in with a [PyTorch `DataLoader`](https://pytorch.org/docs/stable/data.html) so we can use them with our training loop. |\n", "| **3. Model 0: Building a baseline model** | Here we'll create a multi-class classification model to learn patterns in the data, we'll also choose a **loss function**, **optimizer** and build a **training loop**. | \n", "| **4. Making predictions and evaluting model 0** | Let's make some predictions with our baseline model and evaluate them. |\n", "| **5. Setup device agnostic code for future models** | It's best practice to write device-agnostic code, so let's set it up. |\n", "| **6. Model 1: Adding non-linearity** | Experimenting is a large part of machine learning, let's try and improve upon our baseline model by adding non-linear layers. |\n", "| **7. Model 2: Convolutional Neural Network (CNN)** | Time to get computer vision specific and introduce the powerful convolutional neural network architecture. |\n", "| **8. Comparing our models** | We've built three different models, let's compare them. |\n", "| **9. Evaluating our best model** | Let's make some predictons on random images and evaluate our best model. |\n", "| **10. Making a confusion matrix** | A confusion matrix is a great way to evaluate a classification model, let's see how we can make one. |\n", "| **11. Saving and loading the best performing model** | Since we might want to use our model for later, let's save it and make sure it loads back in correctly. |" ] }, { "cell_type": "markdown", "id": "cddf62c3-f5e5-4f7e-852a-2ad6d38b7399", "metadata": { "id": "cddf62c3-f5e5-4f7e-852a-2ad6d38b7399" }, "source": [ "## Where can can you get help?\n", "\n", "All of the materials for this course [live on GitHub](https://github.com/mrdbourke/pytorch-deep-learning).\n", "\n", "If you run into trouble, you can ask a question on the course [GitHub Discussions page](https://github.com/mrdbourke/pytorch-deep-learning/discussions) there too.\n", "\n", "And of course, there's the [PyTorch documentation](https://pytorch.org/docs/stable/index.html) and [PyTorch developer forums](https://discuss.pytorch.org/), a very helpful place for all things PyTorch. " ] }, { "cell_type": "markdown", "id": "a0bedcfc-e12a-4a81-9913-84c6a888742a", "metadata": { "id": "a0bedcfc-e12a-4a81-9913-84c6a888742a" }, "source": [ "## 0. Computer vision libraries in PyTorch\n", "\n", "Before we get started writing code, let's talk about some PyTorch computer vision libraries you should be aware of.\n", "\n", "| PyTorch module | What does it do? |\n", "| ----- | ----- |\n", "| [`torchvision`](https://pytorch.org/vision/stable/index.html) | Contains datasets, model architectures and image transformations often used for computer vision problems. |\n", "| [`torchvision.datasets`](https://pytorch.org/vision/stable/datasets.html) | Here you'll find many example computer vision datasets for a range of problems from image classification, object detection, image captioning, video classification and more. It also contains [a series of base classes for making custom datasets](https://pytorch.org/vision/stable/datasets.html#base-classes-for-custom-datasets). |\n", "| [`torchvision.models`](https://pytorch.org/vision/stable/models.html) | This module contains well-performing and commonly used computer vision model architectures implemented in PyTorch, you can use these with your own problems. | \n", "| [`torchvision.transforms`](https://pytorch.org/vision/stable/transforms.html) | Often images need to be transformed (turned into numbers/processed/augmented) before being used with a model, common image transformations are found here. | \n", "| [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) | Base dataset class for PyTorch. | \n", "| [`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#module-torch.utils.data) | Creates a Python iterable over a dataset (created with `torch.utils.data.Dataset`). |\n", "\n", "> **Note:** The `torch.utils.data.Dataset` and `torch.utils.data.DataLoader` classes aren't only for computer vision in PyTorch, they are capable of dealing with many different types of data.\n", "\n", "Now we've covered some of the most important PyTorch computer vision libraries, let's import the relevant dependencies.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "c263a60d-d788-482f-b9e7-9cab4f6b1f72", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c263a60d-d788-482f-b9e7-9cab4f6b1f72", "outputId": "20ba933b-6026-475f-a8d9-12cf416aff74" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PyTorch version: 2.0.1+cu118\n", "torchvision version: 0.15.2+cu118\n" ] } ], "source": [ "# Import PyTorch\n", "import torch\n", "from torch import nn\n", "\n", "# Import torchvision \n", "import torchvision\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "\n", "# Import matplotlib for visualization\n", "import matplotlib.pyplot as plt\n", "\n", "# Check versions\n", "# Note: your PyTorch version shouldn't be lower than 1.10.0 and torchvision version shouldn't be lower than 0.11\n", "print(f\"PyTorch version: {torch.__version__}\\ntorchvision version: {torchvision.__version__}\")" ] }, { "cell_type": "markdown", "id": "48d6bfe7-91da-44eb-9ab6-7c41c1e9fa8e", "metadata": { "id": "48d6bfe7-91da-44eb-9ab6-7c41c1e9fa8e" }, "source": [ "## 1. Getting a dataset\n", "\n", "To begin working on a computer vision problem, let's get a computer vision dataset.\n", "\n", "We're going to start with FashionMNIST.\n", "\n", "MNIST stands for Modified National Institute of Standards and Technology.\n", "\n", "The [original MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database) contains thousands of examples of handwritten digits (from 0 to 9) and was used to build computer vision models to identify numbers for postal services.\n", "\n", "[FashionMNIST](https://github.com/zalandoresearch/fashion-mnist), made by Zalando Research, is a similar setup. \n", "\n", "Except it contains grayscale images of 10 different kinds of clothing.\n", "\n", "![example image of FashionMNIST](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-fashion-mnist-slide.png)\n", "*`torchvision.datasets` contains a lot of example datasets you can use to practice writing computer vision code on. FashionMNIST is one of those datasets. And since it has 10 different image classes (different types of clothing), it's a multi-class classification problem.*\n", "\n", "Later, we'll be building a computer vision neural network to identify the different styles of clothing in these images.\n", "\n", "PyTorch has a bunch of common computer vision datasets stored in `torchvision.datasets`.\n", "\n", "Including FashionMNIST in [`torchvision.datasets.FashionMNIST()`](https://pytorch.org/vision/main/generated/torchvision.datasets.FashionMNIST.html).\n", "\n", "To download it, we provide the following parameters:\n", "* `root: str` - which folder do you want to download the data to?\n", "* `train: Bool` - do you want the training or test split?\n", "* `download: Bool` - should the data be downloaded?\n", "* `transform: torchvision.transforms` - what transformations would you like to do on the data?\n", "* `target_transform` - you can transform the targets (labels) if you like too.\n", "\n", "Many other datasets in `torchvision` have these parameter options." ] }, { "cell_type": "code", "execution_count": 2, "id": "486f8377-6810-4367-859d-69dccc7aef95", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "486f8377-6810-4367-859d-69dccc7aef95", "outputId": "877f93b2-12c5-477e-92bf-3ec3f1449282" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 26421880/26421880 [00:01<00:00, 16189161.14it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", "\n", "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 29515/29515 [00:00<00:00, 269809.67it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", "\n", "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4422102/4422102 [00:00<00:00, 4950701.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", "\n", "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5148/5148 [00:00<00:00, 4744512.63it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Setup training data\n", "train_data = datasets.FashionMNIST(\n", " root=\"data\", # where to download data to?\n", " train=True, # get training data\n", " download=True, # download data if it doesn't exist on disk\n", " transform=ToTensor(), # images come as PIL format, we want to turn into Torch tensors\n", " target_transform=None # you can transform labels as well\n", ")\n", "\n", "# Setup testing data\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False, # get test data\n", " download=True,\n", " transform=ToTensor()\n", ")" ] }, { "cell_type": "markdown", "id": "a63246f6-3645-49de-88fe-ec18e78bfbaf", "metadata": { "id": "a63246f6-3645-49de-88fe-ec18e78bfbaf" }, "source": [ "Let's check out the first sample of the training data." ] }, { "cell_type": "code", "execution_count": 3, "id": "43bfd3d9-a132-41e8-8ccd-5ae25a7da59a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "43bfd3d9-a132-41e8-8ccd-5ae25a7da59a", "outputId": "1595e80b-6a3f-4171-a128-ec506b4d8326" }, "outputs": [ { "data": { "text/plain": [ "(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,\n", " 0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0039, 0.0039, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,\n", " 0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,\n", " 0.0157, 0.0000, 0.0000, 0.0118],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,\n", " 0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0471, 0.0392, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,\n", " 0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,\n", " 0.3020, 0.5098, 0.2824, 0.0588],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,\n", " 0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,\n", " 0.5529, 0.3451, 0.6745, 0.2588],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,\n", " 0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,\n", " 0.4824, 0.7686, 0.8980, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,\n", " 0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,\n", " 0.8745, 0.9608, 0.6784, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,\n", " 0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,\n", " 0.8627, 0.9529, 0.7922, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,\n", " 0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,\n", " 0.8863, 0.7725, 0.8196, 0.2039],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,\n", " 0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,\n", " 0.9608, 0.4667, 0.6549, 0.2196],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,\n", " 0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,\n", " 0.8510, 0.8196, 0.3608, 0.0000],\n", " [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,\n", " 0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,\n", " 0.8549, 1.0000, 0.3020, 0.0000],\n", " [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,\n", " 0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,\n", " 0.8784, 0.9569, 0.6235, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,\n", " 0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,\n", " 0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,\n", " 0.9137, 0.9333, 0.8431, 0.0000],\n", " [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,\n", " 0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,\n", " 0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,\n", " 0.8627, 0.9098, 0.9647, 0.0000],\n", " [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,\n", " 0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,\n", " 0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,\n", " 0.8706, 0.8941, 0.8824, 0.0000],\n", " [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,\n", " 0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,\n", " 0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,\n", " 0.8745, 0.8784, 0.8980, 0.1137],\n", " [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,\n", " 0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,\n", " 0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,\n", " 0.8627, 0.8667, 0.9020, 0.2627],\n", " [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,\n", " 0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,\n", " 0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,\n", " 0.7098, 0.8039, 0.8078, 0.4510],\n", " [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,\n", " 0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,\n", " 0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,\n", " 0.6549, 0.6941, 0.8235, 0.3608],\n", " [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,\n", " 0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,\n", " 0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,\n", " 0.7529, 0.8471, 0.6667, 0.0000],\n", " [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,\n", " 0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,\n", " 0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,\n", " 0.3882, 0.2275, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,\n", " 0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000]]]),\n", " 9)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# See first training sample\n", "image, label = train_data[0]\n", "image, label" ] }, { "cell_type": "markdown", "id": "9ad9d782-06cb-4591-ae3c-3a8b2389a1b2", "metadata": { "id": "9ad9d782-06cb-4591-ae3c-3a8b2389a1b2" }, "source": [ "### 1.1 Input and output shapes of a computer vision model\n", "\n", "We've got a big tensor of values (the image) leading to a single value for the target (the label).\n", "\n", "Let's see the image shape." ] }, { "cell_type": "code", "execution_count": 4, "id": "c2997d9f-b574-4d23-aa34-1a4df1751226", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c2997d9f-b574-4d23-aa34-1a4df1751226", "outputId": "d9c4283b-aab8-410f-dd7f-03f08d943366" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 28, 28])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# What's the shape of the image?\n", "image.shape" ] }, { "cell_type": "markdown", "id": "b5326a05-f807-448d-99a3-6d03fc8739f8", "metadata": { "id": "b5326a05-f807-448d-99a3-6d03fc8739f8" }, "source": [ "The shape of the image tensor is `[1, 28, 28]` or more specifically:\n", "\n", "```\n", "[color_channels=1, height=28, width=28]\n", "```\n", "\n", "Having `color_channels=1` means the image is grayscale.\n", "\n", "![example input and output shapes of the fashionMNIST problem](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-computer-vision-input-and-output-shapes.png)\n", "*Various problems will have various input and output shapes. But the premise remains: encode data into numbers, build a model to find patterns in those numbers, convert those patterns into something meaningful.*\n", "\n", "If `color_channels=3`, the image comes in pixel values for red, green and blue (this is also known a the [RGB color model](https://en.wikipedia.org/wiki/RGB_color_model)).\n", "\n", "The order of our current tensor is often referred to as `CHW` (Color Channels, Height, Width).\n", "\n", "There's debate on whether images should be represented as `CHW` (color channels first) or `HWC` (color channels last).\n", "\n", "> **Note:** You'll also see `NCHW` and `NHWC` formats where `N` stands for *number of images*. For example if you have a `batch_size=32`, your tensor shape may be `[32, 1, 28, 28]`. We'll cover batch sizes later.\n", "\n", "PyTorch generally accepts `NCHW` (channels first) as the default for many operators.\n", "\n", "However, PyTorch also explains that `NHWC` (channels last) performs better and is [considered best practice](https://pytorch.org/blog/tensor-memory-format-matters/#pytorch-best-practice). \n", "\n", "For now, since our dataset and models are relatively small, this won't make too much of a difference.\n", "\n", "But keep it in mind for when you're working on larger image datasets and using convolutional neural networks (we'll see these later).\n", "\n", "Let's check out more shapes of our data." ] }, { "cell_type": "code", "execution_count": 5, "id": "fc4f768c-c3f6-454d-a633-673ad1d6eca0", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fc4f768c-c3f6-454d-a633-673ad1d6eca0", "outputId": "fcac1ff4-5b9a-4459-a05e-77482f0e6776" }, "outputs": [ { "data": { "text/plain": [ "(60000, 60000, 10000, 10000)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# How many samples are there? \n", "len(train_data.data), len(train_data.targets), len(test_data.data), len(test_data.targets)" ] }, { "cell_type": "markdown", "id": "6e0267d5-946b-4c53-af69-61acd3527972", "metadata": { "id": "6e0267d5-946b-4c53-af69-61acd3527972" }, "source": [ "So we've got 60,000 training samples and 10,000 testing samples.\n", "\n", "What classes are there?\n", "\n", "We can find these via the `.classes` attribute." ] }, { "cell_type": "code", "execution_count": 6, "id": "e22849c6-d93f-4b38-8403-5ebf0deaf008", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e22849c6-d93f-4b38-8403-5ebf0deaf008", "outputId": "6e18aa0f-b8a0-45ee-9f4e-8931bcdfbec0" }, "outputs": [ { "data": { "text/plain": [ "['T-shirt/top',\n", " 'Trouser',\n", " 'Pullover',\n", " 'Dress',\n", " 'Coat',\n", " 'Sandal',\n", " 'Shirt',\n", " 'Sneaker',\n", " 'Bag',\n", " 'Ankle boot']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# See classes\n", "class_names = train_data.classes\n", "class_names" ] }, { "cell_type": "markdown", "id": "abdd225c-5742-4d9c-8e8d-fb30a9c3cb6e", "metadata": { "id": "abdd225c-5742-4d9c-8e8d-fb30a9c3cb6e" }, "source": [ "Sweet! It looks like we're dealing with 10 different kinds of clothes.\n", "\n", "Because we're working with 10 different classes, it means our problem is **multi-class classification**.\n", "\n", "Let's get visual." ] }, { "cell_type": "markdown", "id": "fb625d80-6a98-471e-a758-4de0ce0f3a64", "metadata": { "id": "fb625d80-6a98-471e-a758-4de0ce0f3a64" }, "source": [ "### 1.2 Visualizing our data" ] }, { "cell_type": "code", "execution_count": 7, "id": "b1df1f2c-28c9-43bf-aaef-cf996c9ae1c5", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 469 }, "id": "b1df1f2c-28c9-43bf-aaef-cf996c9ae1c5", "outputId": "9bbdbb0d-eed3-408a-bd7b-03aa22cb35bb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image shape: torch.Size([1, 28, 28])\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkhklEQVR4nO3de3SU9b3v8c/kNgSYTAghNwkYUEAFYkshplhESYG0xwPK7tHWswo9Li0YXEXarQu3ilq70+La1lOLes7aLdS1xNuqyJZtOVVogrQJyu1QaptCGgUlCRfNTMh1kvmdPzhGI9ffwyS/JLxfa81aZOb58Px4eJJPnszMNz5jjBEAAL0szvUCAAAXJwoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQT0kp07d2ru3LlKSUlRIBDQ7NmztWfPHtfLApzxMQsO6Hm7du3S9OnTlZubq+9///uKRqN6+umn9fHHH+udd97R+PHjXS8R6HUUENALvvnNb6qiokL79+/X8OHDJUm1tbUaN26cZs+erd/+9reOVwj0Pn4EB/SCt99+W0VFRV3lI0nZ2dm67rrrtHHjRp04ccLh6gA3KCCgF7S1tSk5OfmU+wcPHqz29nbt27fPwaoAtyggoBeMHz9elZWV6uzs7Lqvvb1d27dvlyR99NFHrpYGOEMBAb3grrvu0t///nfdfvvteu+997Rv3z5997vfVW1trSSppaXF8QqB3kcBAb1g8eLFuv/++7Vu3TpdddVVmjRpkqqrq3XvvfdKkoYOHep4hUDvo4CAXvKTn/xE9fX1evvtt7V37169++67ikajkqRx48Y5Xh3Q+3gZNuDQtGnTVFtbqw8++EBxcXw/iIsLZzzgyEsvvaR3331Xy5Yto3xwUeIKCOgFW7du1aOPPqrZs2dr+PDhqqys1Jo1a/T1r39dr7/+uhISElwvEeh1nPVAL7jkkksUHx+vxx9/XI2NjcrLy9Njjz2m5cuXUz64aHEFBABwgh88AwCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgRJ97A0I0GtXhw4cVCATk8/lcLwcAYMkYo8bGRuXk5Jx1ykefK6DDhw8rNzfX9TIAABfo0KFDGjly5Bkf73MFFAgEJEnX6htKUKLj1QAAbHUoom16o+vr+Zn0WAGtXr1ajz/+uOrq6pSfn6+nnnpK06ZNO2fu0x+7JShRCT4KCAD6nf8/X+dcT6P0yIsQXnrpJS1fvlwrV67Url27lJ+frzlz5ujIkSM9sTsAQD/UIwX0xBNP6I477tD3vvc9XXnllXr22Wc1ePBg/frXv+6J3QEA+qGYF1B7e7t27typoqKiz3YSF6eioiJVVFScsn1bW5vC4XC3GwBg4It5AR07dkydnZ3KzMzsdn9mZqbq6upO2b60tFTBYLDrxivgAODi4PyNqCtWrFAoFOq6HTp0yPWSAAC9IOavgktPT1d8fLzq6+u73V9fX6+srKxTtvf7/fL7/bFeBgCgj4v5FVBSUpKmTJmizZs3d90XjUa1efNmFRYWxnp3AIB+qkfeB7R8+XItXLhQX/nKVzRt2jQ9+eSTampq0ve+972e2B0AoB/qkQK65ZZbdPToUT300EOqq6vT1VdfrU2bNp3ywgQAwMXLZ4wxrhfxeeFwWMFgUDM1j0kIANAPdZiIyrRBoVBIKSkpZ9zO+avgAAAXJwoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOBEgusFAH2Kz2efMSb26ziN+OFp1plP5ozztK+UdZWectY8HG9fQqJ1xkTarTN9npdz1aseOse5AgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJxhGCnyOLz7eOmM6OqwzcVdfaZ356/eH2u+nxToiSUpsmmadSWiJ2u/n9zusM706WNTLsFQP55B89tcCvXkcfAl2VeEzRjqPTwuugAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACYaRAp9jO3RR8jaM9NCcVOvMbYVvW2f+eHSMdUaSPvBnWWdMsv1+EooKrTPjnv7IOtPx/kHrjCTJGPuIh/PBi/hhw7wFOzvtI+Gw1fbGnN8x4AoIAOAEBQQAcCLmBfTwww/L5/N1u02YMCHWuwEA9HM98hzQVVddpbfeeuuznXj4uToAYGDrkWZISEhQVpb9k5gAgItHjzwHtH//fuXk5GjMmDG67bbbdPDgmV+B0tbWpnA43O0GABj4Yl5ABQUFWrt2rTZt2qRnnnlGNTU1+trXvqbGxsbTbl9aWqpgMNh1y83NjfWSAAB9UMwLqLi4WN/61rc0efJkzZkzR2+88YYaGhr08ssvn3b7FStWKBQKdd0OHToU6yUBAPqgHn91QGpqqsaNG6cDBw6c9nG/3y+/39/TywAA9DE9/j6gEydOqLq6WtnZ2T29KwBAPxLzAvrRj36k8vJyvf/++/rTn/6km266SfHx8fr2t78d610BAPqxmP8I7sMPP9S3v/1tHT9+XCNGjNC1116ryspKjRgxIta7AgD0YzEvoBdffDHWfyXQa6Ktrb2yn/YvnbDO/FNwh3VmUFzEOiNJ5XFR68xHW+xfwdo52f44fPBEwDoT3f1V64wkDd9nP7gzZXetdebYjEusM0en2A9KlaTMSvvMsLeqrbY30Xbp2Lm3YxYcAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADjR47+QDnDC5/OWM/YDHk/8t2usM9+9ssw6Ux2xnyg/Mulj64wkfStnp33ov9tnfll1nXWm6R9B60zcEG+DO+uusf8e/aN59v9PJtJhnRm2y9uX77iF9daZcPsYq+07Iq3ShvNYi/VKAACIAQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJxgGjZ6l9cp1X3YNfe9Y525fuh7PbCSU10ib1Ogm0ySdaahc4h1ZuWV/2mdOTouYJ2JGG9f6v59/1etMyc8TOuO77D/vLjmf+y2zkjSgrR3rTOrfjvJavsOEzmv7bgCAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnGEaK3mW8Dcfsy/afyLDOHE8Zap2p60i1zgyPP2GdkaRAXIt15tLEY9aZo532g0XjE6PWmXYTb52RpEeuet0603pFonUm0ddpnfnqoMPWGUn61nvftc4M0T887etcuAICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcYRgpcoBF++4Gfg3wR60ySr8M6czgyzDojSftbxltn/h62H8o6N/Mv1pmIh8Gi8fI2BNfLkNCcxE+sM63GfoCp/Rl00vRM+8Giezzu61y4AgIAOEEBAQCcsC6grVu36sYbb1ROTo58Pp9ee+21bo8bY/TQQw8pOztbycnJKioq0v79+2O1XgDAAGFdQE1NTcrPz9fq1atP+/iqVav0i1/8Qs8++6y2b9+uIUOGaM6cOWptbb3gxQIABg7rFyEUFxeruLj4tI8ZY/Tkk0/qgQce0Lx58yRJzz33nDIzM/Xaa6/p1ltvvbDVAgAGjJg+B1RTU6O6ujoVFRV13RcMBlVQUKCKiorTZtra2hQOh7vdAAADX0wLqK6uTpKUmZnZ7f7MzMyux76otLRUwWCw65abmxvLJQEA+ijnr4JbsWKFQqFQ1+3QoUOulwQA6AUxLaCsrCxJUn19fbf76+vrux77Ir/fr5SUlG43AMDAF9MCysvLU1ZWljZv3tx1Xzgc1vbt21VYWBjLXQEA+jnrV8GdOHFCBw4c6Pq4pqZGe/bsUVpamkaNGqVly5bpscce0+WXX668vDw9+OCDysnJ0fz582O5bgBAP2ddQDt27ND111/f9fHy5cslSQsXLtTatWt17733qqmpSXfeeacaGhp07bXXatOmTRo0aFDsVg0A6Pd8xhhvU/p6SDgcVjAY1EzNU4LPfkAf+jifzz4Sbz980nTYD+6UpPhh9sM7b634s/1+fPafdkc7AtaZ1Phm64wklTfYDyP9y/HTP897No+O/w/rzK7mS60zOUn2A0Ilb8fv/fZ068zl/tO/SvhsfvdJvnVGknIHfWyd+f2yGVbbd3S0alvZIwqFQmd9Xt/5q+AAABcnCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnLD+dQzABfEwfN2XYH+aep2Gfej2K6wzNwx+3Trzp9ZLrDMjEhqtMxFjP0lckrL9IetMILPVOtPQOdg6k5ZwwjrT2JlsnZGkwXFt1hkv/09fTjpmnbnnrS9bZyQpMPG4dSYl0e5aJXqe1zZcAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAEwwjRa/yJSZZZ6Kt9kMuvUr/c7t15lhnonUmNa7ZOpPk67TOtHscRvrVtBrrzFEPAz93teRZZwLxLdaZEXH2A0IlKTfRfnDnn1tzrTNvNF1mnbn9v7xlnZGkF/73160zSZv+ZLV9nImc33bWKwEAIAYoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4MTFPYzU5/MWS7AfPumL99D1cfaZaGub/X6i9kMuvTIR+2Gfvel//q9fWmcOdaRaZ+oi9pnUePsBpp3ydo5XtgStM4Pizm8A5eeNSAhbZ8JR+6GnXjVGB1lnIh4GwHo5dvcN32+dkaRXQ0Wecj2BKyAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcGLADCP1Jdj/U0xHh6d9eRmoaexnDQ5ILfOmWWcOzbcflnrbl96xzkhSXUfAOrO7+VLrTDC+xTozJM5+0GyrsR+cK0mH24dZZ7wM1ExLOGGdyfAwwLTTePte+6OI/XHwwsug2Q877I+dJDX+10brTOpznnZ1TlwBAQCcoIAAAE5YF9DWrVt14403KicnRz6fT6+99lq3xxctWiSfz9ftNnfu3FitFwAwQFgXUFNTk/Lz87V69eozbjN37lzV1tZ23V544YULWiQAYOCxfua+uLhYxcXFZ93G7/crKyvL86IAAANfjzwHVFZWpoyMDI0fP15LlizR8ePHz7htW1ubwuFwtxsAYOCLeQHNnTtXzz33nDZv3qyf/exnKi8vV3FxsTo7T/9S2tLSUgWDwa5bbm5urJcEAOiDYv4+oFtvvbXrz5MmTdLkyZM1duxYlZWVadasWadsv2LFCi1fvrzr43A4TAkBwEWgx1+GPWbMGKWnp+vAgQOnfdzv9yslJaXbDQAw8PV4AX344Yc6fvy4srOze3pXAIB+xPpHcCdOnOh2NVNTU6M9e/YoLS1NaWlpeuSRR7RgwQJlZWWpurpa9957ry677DLNmTMnpgsHAPRv1gW0Y8cOXX/99V0ff/r8zcKFC/XMM89o7969+s1vfqOGhgbl5ORo9uzZ+vGPfyy/3x+7VQMA+j2fMca4XsTnhcNhBYNBzdQ8Jfi8DVLsixKy7d8XFcnLtM58fMVg60xzls86I0lXf+Ov1plFmdusM0c77Z8XTPR5GzTb2JlsnclKbLDObAldaZ0ZmmA/jNTL0FNJ+nLy+9aZhqj9uZeT8Il15r4D/2SdyRxsP4BTkv599BvWmYiJWmeqIvbfoAfi7IciS9LbzZdZZ9ZfOcJq+w4TUZk2KBQKnfV5fWbBAQCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwImY/0puV9qKp1pnMv7lH572dXXKh9aZK5Ptp0C3Ru2ngQ+Ki1hn3mu5xDojSc3RJOvM/nb7qeChDvspy/E++4nEknSkPWCd+beaIuvM5mnPWmceODzXOhOX7G3Y/fHOodaZBUPDHvZkf45/f9RW68yYpCPWGUna2GT/izQPR4ZZZzITQ9aZSxOPWmck6ebA360z62U3Dft8cQUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE702WGkvoQE+Xznv7yCf33Xeh+zAn+xzkhSs/FbZ7wMFvUy1NCLYEKzp1xbxP70ORJJ8bQvW+P8dZ5yN6Xssc5s/WWBdeba1rutM9U3rLHObG6Jt85I0tEO+/+nW2tusM7sOphrnbnm0hrrzKTAR9YZydsg3EB8q3Um0ddhnWmK2n8dkqTKVvtBsz2FKyAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcKLPDiOtXTJF8f5B5739w8GnrPex7uNrrDOSlDvoY+vM6KRj1pn85A+sM14E4uyHJ0rS+BT7AYobm0ZaZ8oaJlhnshMbrDOS9HbzWOvMiw8/bp1ZdM8PrTOFbyy2zoQv9fY9ZscQY51JyT9unXngS/9pnUnydVpnGjrth4pKUpq/yTqTGu9tuK8tL0ORJSkQ12KdiR9/mdX2prNN2n/u7bgCAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAn+uww0sFHoopPip739hvDV1vvY0zyUeuMJB2LBKwz/+fEJOvMyORPrDPBePtBg5f566wzkrSnNdU6s+noVdaZnOSwdaY+ErTOSNLxyBDrTHPUfijkr37+hHXm3+qLrDM3pe2yzkhSfpL9YNGGqP33s++1Z1lnGqPnP6T4U60m0TojSSEPQ0wDHj4HI8b+S3G8Of+vj5+XGmc/LDU8abjV9h2RVoaRAgD6LgoIAOCEVQGVlpZq6tSpCgQCysjI0Pz581VVVdVtm9bWVpWUlGj48OEaOnSoFixYoPr6+pguGgDQ/1kVUHl5uUpKSlRZWak333xTkUhEs2fPVlPTZ7+06Z577tHrr7+uV155ReXl5Tp8+LBuvvnmmC8cANC/WT3ztWnTpm4fr127VhkZGdq5c6dmzJihUCikX/3qV1q3bp1uuOEGSdKaNWt0xRVXqLKyUtdc4+03kAIABp4Leg4oFApJktLS0iRJO3fuVCQSUVHRZ6/WmTBhgkaNGqWKiorT/h1tbW0Kh8PdbgCAgc9zAUWjUS1btkzTp0/XxIkTJUl1dXVKSkpSampqt20zMzNVV3f6l/qWlpYqGAx23XJzc70uCQDQj3guoJKSEu3bt08vvvjiBS1gxYoVCoVCXbdDhw5d0N8HAOgfPL0RdenSpdq4caO2bt2qkSNHdt2flZWl9vZ2NTQ0dLsKqq+vV1bW6d9w5vf75ffbv5EPANC/WV0BGWO0dOlSrV+/Xlu2bFFeXl63x6dMmaLExERt3ry5676qqiodPHhQhYWFsVkxAGBAsLoCKikp0bp167RhwwYFAoGu53WCwaCSk5MVDAZ1++23a/ny5UpLS1NKSoruvvtuFRYW8go4AEA3VgX0zDPPSJJmzpzZ7f41a9Zo0aJFkqSf//zniouL04IFC9TW1qY5c+bo6aefjsliAQADh88YY1wv4vPC4bCCwaBmXPugEhLOf+jg1Cd3Wu9rXzjHOiNJmYMarTOTh35onalqth/UeLglxTozOCFinZGk5Hj7XIexf91Lht/+eI/y2w/TlKRAnP0gySRfp3Wm08Prf65KOmydOdgxzDojSXUdqdaZ95rtP5+GJdgPxvyzh8/b5o4k64wktXXaP03e2mGfCfpbrTNT0z6wzkhSnOy/5K/7j+usto+2tuofj/2LQqGQUlLO/DWJWXAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwwtNvRO0Ncdv2Ks6XeN7bv/L76db7eHDeK9YZSSpvmGCd2Vg3yToTbrf/TbEjBjdZZ1IS7adNS1Jaov2+gh6mHw/ydVhnPukYYp2RpLa48z/nPtUpn3Wmri1onflj9HLrTCQab52RpDYPOS/T0T9uT7fO5CSHrDONHec/Wf/z3m9Ms84cCw21zrQOtv9SvK1zrHVGkuZm/cU6k3zE7hzvbDu/7bkCAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnfMYY43oRnxcOhxUMBjVT85RgMYzUi9Bt13jKjbmryjozLbXGOrMrPMo6c9DD8MRI1Nv3IYlxUevM4MR268wgD0Muk+I7rTOSFCf7T4eoh2GkQ+Ltj8OQhDbrTEpCq3VGkgLx9rk4n/354EW8h/+jd0KXxn4hZxDw8P/UYew/BwuD1dYZSfp1zVetM8FvHLDavsNEVKYNCoVCSklJOeN2XAEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBN9dxhp3M12w0ij3oZP9pamBQXWmYL737XPBOwHFE5IqrfOSFKi7IdPDvIwsHJInP2wz1aPp7WX78i2teRaZzo97GnLJ1dYZyIehlxKUn3zmQdInkmixwGwtqLG/nxo6fA22DjUMsg6Ex9nf+61lqVbZ4a/Zz+kV5L8b9h/XbHFMFIAQJ9GAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACf67jBSzbMbRgrPfFMnecq1ZCVbZ/zH26wzjaPt95NS3WSdkaS4tg7rTPT//tXTvoCBimGkAIA+jQICADhhVUClpaWaOnWqAoGAMjIyNH/+fFVVVXXbZubMmfL5fN1uixcvjumiAQD9n1UBlZeXq6SkRJWVlXrzzTcViUQ0e/ZsNTV1/3n7HXfcodra2q7bqlWrYrpoAED/l2Cz8aZNm7p9vHbtWmVkZGjnzp2aMWNG1/2DBw9WVlZWbFYIABiQLug5oFAoJElKS0vrdv/zzz+v9PR0TZw4UStWrFBzc/MZ/462tjaFw+FuNwDAwGd1BfR50WhUy5Yt0/Tp0zVx4sSu+7/zne9o9OjRysnJ0d69e3XfffepqqpKr7766mn/ntLSUj3yyCNelwEA6Kc8vw9oyZIl+t3vfqdt27Zp5MiRZ9xuy5YtmjVrlg4cOKCxY8ee8nhbW5va2j57b0g4HFZubi7vA+pFvA/oM7wPCLhw5/s+IE9XQEuXLtXGjRu1devWs5aPJBUUFEjSGQvI7/fL7/d7WQYAoB+zKiBjjO6++26tX79eZWVlysvLO2dmz549kqTs7GxPCwQADExWBVRSUqJ169Zpw4YNCgQCqqurkyQFg0ElJyerurpa69at0ze+8Q0NHz5ce/fu1T333KMZM2Zo8uTJPfIPAAD0T1YF9Mwzz0g6+WbTz1uzZo0WLVqkpKQkvfXWW3ryySfV1NSk3NxcLViwQA888EDMFgwAGBisfwR3Nrm5uSovL7+gBQEALg6eX4aNgcO8+2dPuUExXseZpPypl3YkKdp7uwIuegwjBQA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcCLB9QK+yBgjSepQRDKOFwMAsNahiKTPvp6fSZ8roMbGRknSNr3heCUAgAvR2NioYDB4xsd95lwV1cui0agOHz6sQCAgn8/X7bFwOKzc3FwdOnRIKSkpjlboHsfhJI7DSRyHkzgOJ/WF42CMUWNjo3JychQXd+ZnevrcFVBcXJxGjhx51m1SUlIu6hPsUxyHkzgOJ3EcTuI4nOT6OJztyudTvAgBAOAEBQQAcKJfFZDf79fKlSvl9/tdL8UpjsNJHIeTOA4ncRxO6k/Hoc+9CAEAcHHoV1dAAICBgwICADhBAQEAnKCAAABOUEAAACf6TQGtXr1al156qQYNGqSCggK98847rpfU6x5++GH5fL5utwkTJrheVo/bunWrbrzxRuXk5Mjn8+m1117r9rgxRg899JCys7OVnJysoqIi7d+/381ie9C5jsOiRYtOOT/mzp3rZrE9pLS0VFOnTlUgEFBGRobmz5+vqqqqbtu0traqpKREw4cP19ChQ7VgwQLV19c7WnHPOJ/jMHPmzFPOh8WLFzta8en1iwJ66aWXtHz5cq1cuVK7du1Sfn6+5syZoyNHjrheWq+76qqrVFtb23Xbtm2b6yX1uKamJuXn52v16tWnfXzVqlX6xS9+oWeffVbbt2/XkCFDNGfOHLW2tvbySnvWuY6DJM2dO7fb+fHCCy/04gp7Xnl5uUpKSlRZWak333xTkUhEs2fPVlNTU9c299xzj15//XW98sorKi8v1+HDh3XzzTc7XHXsnc9xkKQ77rij2/mwatUqRys+A9MPTJs2zZSUlHR93NnZaXJyckxpaanDVfW+lStXmvz8fNfLcEqSWb9+fdfH0WjUZGVlmccff7zrvoaGBuP3+80LL7zgYIW944vHwRhjFi5caObNm+dkPa4cOXLESDLl5eXGmJP/94mJieaVV17p2uavf/2rkWQqKipcLbPHffE4GGPMddddZ37wgx+4W9R56PNXQO3t7dq5c6eKioq67ouLi1NRUZEqKiocrsyN/fv3KycnR2PGjNFtt92mgwcPul6SUzU1Naqrq+t2fgSDQRUUFFyU50dZWZkyMjI0fvx4LVmyRMePH3e9pB4VCoUkSWlpaZKknTt3KhKJdDsfJkyYoFGjRg3o8+GLx+FTzz//vNLT0zVx4kStWLFCzc3NLpZ3Rn1uGvYXHTt2TJ2dncrMzOx2f2Zmpv72t785WpUbBQUFWrt2rcaPH6/a2lo98sgj+trXvqZ9+/YpEAi4Xp4TdXV1knTa8+PTxy4Wc+fO1c0336y8vDxVV1fr/vvvV3FxsSoqKhQfH+96eTEXjUa1bNkyTZ8+XRMnTpR08nxISkpSampqt20H8vlwuuMgSd/5znc0evRo5eTkaO/evbrvvvtUVVWlV1991eFqu+vzBYTPFBcXd/158uTJKigo0OjRo/Xyyy/r9ttvd7gy9AW33npr158nTZqkyZMna+zYsSorK9OsWbMcrqxnlJSUaN++fRfF86Bnc6bjcOedd3b9edKkScrOztasWbNUXV2tsWPH9vYyT6vP/wguPT1d8fHxp7yKpb6+XllZWY5W1TekpqZq3LhxOnDggOulOPPpOcD5caoxY8YoPT19QJ4fS5cu1caNG/WHP/yh2+8Py8rKUnt7uxoaGrptP1DPhzMdh9MpKCiQpD51PvT5AkpKStKUKVO0efPmrvui0ag2b96swsJChytz78SJE6qurlZ2drbrpTiTl5enrKysbudHOBzW9u3bL/rz48MPP9Tx48cH1PlhjNHSpUu1fv16bdmyRXl5ed0enzJlihITE7udD1VVVTp48OCAOh/OdRxOZ8+ePZLUt84H16+COB8vvvii8fv9Zu3atea9994zd955p0lNTTV1dXWul9arfvjDH5qysjJTU1Nj/vjHP5qioiKTnp5ujhw54nppPaqxsdHs3r3b7N6920gyTzzxhNm9e7f54IMPjDHG/PSnPzWpqalmw4YNZu/evWbevHkmLy/PtLS0OF55bJ3tODQ2Npof/ehHpqKiwtTU1Ji33nrLfPnLXzaXX365aW1tdb30mFmyZIkJBoOmrKzM1NbWdt2am5u7tlm8eLEZNWqU2bJli9mxY4cpLCw0hYWFDlcde+c6DgcOHDCPPvqo2bFjh6mpqTEbNmwwY8aMMTNmzHC88u76RQEZY8xTTz1lRo0aZZKSksy0adNMZWWl6yX1ultuucVkZ2ebpKQkc8kll5hbbrnFHDhwwPWyetwf/vAHI+mU28KFC40xJ1+K/eCDD5rMzEzj9/vNrFmzTFVVldtF94CzHYfm5mYze/ZsM2LECJOYmGhGjx5t7rjjjgH3Tdrp/v2SzJo1a7q2aWlpMXfddZcZNmyYGTx4sLnppptMbW2tu0X3gHMdh4MHD5oZM2aYtLQ04/f7zWWXXWb++Z//2YRCIbcL/wJ+HxAAwIk+/xwQAGBgooAAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJ/4fSFZm765APLcAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "image, label = train_data[0]\n", "print(f\"Image shape: {image.shape}\")\n", "plt.imshow(image.squeeze()) # image shape is [1, 28, 28] (colour channels, height, width)\n", "plt.title(label);" ] }, { "cell_type": "markdown", "id": "adb19c5c-2f2b-4aaf-8300-256f3594e2db", "metadata": { "id": "adb19c5c-2f2b-4aaf-8300-256f3594e2db" }, "source": [ "We can turn the image into grayscale using the `cmap` parameter of `plt.imshow()`." ] }, { "cell_type": "code", "execution_count": 8, "id": "92f09917-88f7-4446-b65f-baae586914c9", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 452 }, "id": "92f09917-88f7-4446-b65f-baae586914c9", "outputId": "c702456b-607c-4214-8e03-4bd0b22b097f" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(image.squeeze(), cmap=\"gray\")\n", "plt.title(class_names[label]);" ] }, { "cell_type": "markdown", "id": "9a09388a-d754-485f-aa26-4e7a0f782967", "metadata": { "id": "9a09388a-d754-485f-aa26-4e7a0f782967" }, "source": [ "Beautiful, well as beautiful as a pixelated grayscale ankle boot can get.\n", "\n", "Let's view a few more." ] }, { "cell_type": "code", "execution_count": 9, "id": "7188ed7a-5959-48c4-ac7f-19129a2adc83", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 752 }, "id": "7188ed7a-5959-48c4-ac7f-19129a2adc83", "outputId": "98d50938-b984-4725-8949-d85bf3143555" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot more images\n", "torch.manual_seed(42)\n", "fig = plt.figure(figsize=(9, 9))\n", "rows, cols = 4, 4\n", "for i in range(1, rows * cols + 1):\n", " random_idx = torch.randint(0, len(train_data), size=[1]).item()\n", " img, label = train_data[random_idx]\n", " fig.add_subplot(rows, cols, i)\n", " plt.imshow(img.squeeze(), cmap=\"gray\")\n", " plt.title(class_names[label])\n", " plt.axis(False);" ] }, { "cell_type": "markdown", "id": "f356fbe9-95b1-4f81-a82d-dc15b3adc06a", "metadata": { "id": "f356fbe9-95b1-4f81-a82d-dc15b3adc06a" }, "source": [ "Hmmm, this dataset doesn't look too aesthetic.\n", "\n", "But the principles we're going to learn on how to build a model for it will be similar across a wide range of computer vision problems.\n", "\n", "In essence, taking pixel values and building a model to find patterns in them to use on future pixel values.\n", "\n", "Plus, even for this small dataset (yes, even 60,000 images in deep learning is considered quite small), could you write a program to classify each one of them?\n", "\n", "You probably could.\n", "\n", "But I think coding a model in PyTorch would be faster.\n", "\n", "> **Question:** Do you think the above data can be model with only straight (linear) lines? Or do you think you'd also need non-straight (non-linear) lines?" ] }, { "cell_type": "markdown", "id": "43cdd23d-bd1f-4e8c-ba20-22d2b6ac14b1", "metadata": { "id": "43cdd23d-bd1f-4e8c-ba20-22d2b6ac14b1" }, "source": [ "## 2. Prepare DataLoader\n", "\n", "Now we've got a dataset ready to go.\n", "\n", "The next step is to prepare it with a [`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) or `DataLoader` for short.\n", "\n", "The `DataLoader` does what you think it might do.\n", "\n", "It helps load data into a model.\n", "\n", "For training and for inference.\n", "\n", "It turns a large `Dataset` into a Python iterable of smaller chunks.\n", "\n", "These smaller chunks are called **batches** or **mini-batches** and can be set by the `batch_size` parameter.\n", "\n", "Why do this?\n", "\n", "Because it's more computationally efficient.\n", "\n", "In an ideal world you could do the forward pass and backward pass across all of your data at once.\n", "\n", "But once you start using really large datasets, unless you've got infinite computing power, it's easier to break them up into batches.\n", "\n", "It also gives your model more opportunities to improve.\n", "\n", "With **mini-batches** (small portions of the data), gradient descent is performed more often per epoch (once per mini-batch rather than once per epoch).\n", "\n", "What's a good batch size?\n", "\n", "[32 is a good place to start](https://twitter.com/ylecun/status/989610208497360896?s=20&t=N96J_jotN--PYuJk2WcjMw) for a fair amount of problems.\n", "\n", "But since this is a value you can set (a **hyperparameter**) you can try all different kinds of values, though generally powers of 2 are used most often (e.g. 32, 64, 128, 256, 512).\n", "\n", "![an example of what a batched dataset looks like](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-batching-fashionmnist.png)\n", "*Batching FashionMNIST with a batch size of 32 and shuffle turned on. A similar batching process will occur for other datasets but will differ depending on the batch size.*\n", "\n", "Let's create `DataLoader`'s for our training and test sets. " ] }, { "cell_type": "code", "execution_count": 10, "id": "bb2dbf90-a326-43cb-b25b-71af142fafeb", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bb2dbf90-a326-43cb-b25b-71af142fafeb", "outputId": "1f563408-3f50-4e8c-a15f-53e2f918b1ac" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataloaders: (, )\n", "Length of train dataloader: 1875 batches of 32\n", "Length of test dataloader: 313 batches of 32\n" ] } ], "source": [ "from torch.utils.data import DataLoader\n", "\n", "# Setup the batch size hyperparameter\n", "BATCH_SIZE = 32\n", "\n", "# Turn datasets into iterables (batches)\n", "train_dataloader = DataLoader(train_data, # dataset to turn into iterable\n", " batch_size=BATCH_SIZE, # how many samples per batch? \n", " shuffle=True # shuffle data every epoch?\n", ")\n", "\n", "test_dataloader = DataLoader(test_data,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False # don't necessarily have to shuffle the testing data\n", ")\n", "\n", "# Let's check out what we've created\n", "print(f\"Dataloaders: {train_dataloader, test_dataloader}\") \n", "print(f\"Length of train dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}\")\n", "print(f\"Length of test dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "7a925ee7-484b-4149-be8f-3ad790172a5f", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7a925ee7-484b-4149-be8f-3ad790172a5f", "outputId": "85815bd7-39e9-44ed-b974-9e30fff5b5ad" }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32, 1, 28, 28]), torch.Size([32]))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check out what's inside the training dataloader\n", "train_features_batch, train_labels_batch = next(iter(train_dataloader))\n", "train_features_batch.shape, train_labels_batch.shape" ] }, { "cell_type": "markdown", "id": "4fee4cf8-ab73-4c81-8e5e-3c81691e799c", "metadata": { "id": "4fee4cf8-ab73-4c81-8e5e-3c81691e799c" }, "source": [ "And we can see that the data remains unchanged by checking a single sample. " ] }, { "cell_type": "code", "execution_count": 12, "id": "c863d66a-49be-43be-84dc-372a5d6fc2c2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 463 }, "id": "c863d66a-49be-43be-84dc-372a5d6fc2c2", "outputId": "1052cbcb-6186-4dfe-b5f0-6968bde9fb21" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image size: torch.Size([1, 28, 28])\n", "Label: 6, label size: torch.Size([])\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQuUlEQVR4nO3dX6gfdP3H8fd355ydv9vOGbZl6raT+QcmNmoqXRitGhJUkC5ICCyCCsu7ugh2mxcSQiRIXim7CDFEulCD6A+EyaJCisniKJktmW7u2DnH8z3/PL+L4E1Df+28P23f7Zw9Hpd6Xn6/+/o9PvfV7W1ndXV1NQAgIjZd7CcAwKVDFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFLgsdDqd+Pa3v33Or3v00Uej0+nE3/72twv/pOASJAqse3/+85/j0KFDsXv37hgaGoqrrroqDh48GD/60Y8u+GPff//98dRTT13wx4Fe6bh9xHr23HPPxYEDB2LXrl1xzz33xPvf//549dVX4/nnn4+XXnoppqamIuLfnxS+9a1vxUMPPfRf/3orKyuxtLQUg4OD0el0zvn4Y2NjcejQoXj00UfPxw8HLrr+i/0E4H/x/e9/P7Zt2xa///3vY3x8/Kw/9/rrr5f/en19fdHX1/dfv2Z1dTW63W4MDw+X//pwqfOvj1jXXnrppdi7d++7ghARsWPHjnf9saeeeipuuummGBwcjL1798azzz571p9/r/+msGfPnvjsZz8bP//5z2P//v0xPDwcP/7xj6PT6cTc3Fw89thj0el0otPpxFe+8pXz/COE3hIF1rXdu3fHH/7wh/jLX/5yzq/97W9/G/fee2986UtfigceeCC63W7cddddcfr06XNujx8/HnfffXccPHgwfvjDH8a+ffviyJEjMTg4GLfffnscOXIkjhw5Et/4xjfOxw8LLhr/+oh17Tvf+U585jOfiX379sWtt94at99+e3zqU5+KAwcOxMDAwFlf++KLL8axY8fi2muvjYiIAwcOxIc//OH4yU9+cs5fmTQ1NRXPPvts3HHHHWf98W9+85vxwQ9+ML785S+f3x8YXCQ+KbCuHTx4MH73u9/F5z//+XjhhRfigQceiDvuuCOuuuqq+NnPfnbW137605/OIERE3HzzzbF169Z4+eWXz/k4k5OT7woCbESiwLp3yy23xJNPPhlnzpyJo0ePxve+972YmZmJQ4cOxbFjx/Lrdu3a9a7txMREnDlz5pyPMTk5eV6fM1yqRIENY/PmzXHLLbfE/fffHw8//HAsLS3FE088kX/+//tVRWv5Vdl+pRGXC1FgQ9q/f39ERLz22msX9HHW8nsZYD0RBda1X/3qV+/5M/2nn346IiJuuOGGC/r4o6OjMT09fUEfA3rJrz5iXbvvvvvi7bffji984Qtx4403xuLiYjz33HPx+OOPx549e+KrX/3qBX38j370o/GLX/wiHnzwwfjABz4Qk5OTcdttt13Qx4QLSRRY137wgx/EE088EU8//XQ88sgjsbi4GLt27Yp77703Dh8+/J6/qe18evDBB+PrX/96HD58OObn5+Oee+4RBdY1t48ASP6bAgBJFABIogBAEgUAkigAkEQBgLTm36fgt/Nzsezevbu8+fjHP17e/PGPfyxv3ve+95U3v/71r8ubVi3ft36V+sa1lr+3PikAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCt+f/R7CBeb7W+3hvxmNnDDz9c3uzdu7e8+elPf1re3HnnneXNQw89VN5EtD2/jciRv3YO4gFQIgoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAKn/Yj8B3tulfsBr586d5c0nP/nJpsc6depUeTMyMlLefPe73y1vpqeny5uPfexj5U1ExOnTp8ub48ePlzf//Oc/y5teutS/N9Y7nxQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYDUWV3jycFOp3Ohnwv/4aabbmra7du3r7z50Ic+1PRYVZOTk027LVu2lDfXXXddedPymrdccH3++efLm4iIbdu2lTfPPPNMedPtdsubf/zjH+XN0aNHy5uIiFdeeaVpx9ouzPqkAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGA5CBeD9x8883lzRe/+MWmxzp27Fh5s7y8XN688cYb5c3+/fvLm4iIO++8s7x57LHHypuvfe1r5U3Lcbarr766vImI+Pvf/17ePPLII+XN+Ph4eXPFFVeUN9u3by9vItp+TKdPn256rI3GQTwASkQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACA5iNcD9913X3nz5ptvNj1Wy4G2sbGx8qa/v7+8ef3118ubiIjZ2dnyZuvWreXN3XffXd6cOHGivPnNb35T3kRErKyslDc7d+4sb7rdbnnT8s+HK6+8sryJiFhcXCxvHn/88abH2mgcxAOgRBQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAFL9qhlle/bsKW/OnDnT9FgTExNNu17YsWNH027Lli3lzTvvvFPeLC8vlzcvvvhieTMwMFDeRETs2rWrvGk5bjc0NFTetBzr27Sp7eek119/fdOOtfFJAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIAyUG8ohtvvLG8WV1dLW+2bdtW3kS0HUBrOQQ3Pz9f3nQ6nfImou3Y2vDwcHnTcoTw5MmT5U3Lgb+Itte8v7/+Ld7yfmh5v27durW8iYhYWFgob2644Yby5vjx4+XNRuCTAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkFxJLfrEJz5R3rRcW9y8eXN5ExExMTFR3szOzpY309PT5U1fX195ExGxtLRU3oyOjpY3r732WnmzaVPvfl41NzdX3uzYsaO8GRwcLG927txZ3pw4caK8iWh7j3/kIx8pb1xJBeCyJwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAMlBvKLrrruuvPnTn/5U3kxNTZU3ERG33XZbeTM+Pl7e9PfX3zqnTp0qbyLajgMODAyUN2+++WZ50/LcxsbGypuIiIWFhfJm69at5U3L+6HlQOIrr7xS3kREXH/99eVNy5G/y5VPCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASJf1QbyWw1qzs7PlTV9fX3mzsrJS3kREdDqd8mZ5ebm8mZiYKG8WFxfLm4iIbrdb3rQcnWt5zbdt21betBypi2g76tZysK/lcVr+3o6MjJQ3ERFvvPFGedPy9/aaa64pb1599dXy5lLjkwIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFANJlfRDvyiuvLG9ajrO1HNZqPR539dVXlzdTU1PlzdzcXHnTquU1bzkE12JhYaG8aTmqGNH2OuzcubO8aTke13KAcGBgoLxp1fI67Nu3r7xxEA+ADUUUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQOqurq6tr+sJO50I/lw1r9+7d5c2WLVuaHutzn/tceTM4OFjenDhxoryZn58vbyIiZmZmypuWK6lr/FY4S6+u5ka0XRV95513ypvt27eXN9dee21588wzz5Q3EREnT54sb44dO9aTx7nUreU97pMCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSg3gbzMTERHlz+PDh8uavf/1refP222+XNxFtR91ajsetrKyUNy3PrWUTETE2NtaTTctr9+STT5Y3U1NT5Q3/GwfxACgRBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGA1H+xn8DF1HLkr1eHAVuPpnW73fJmjTcRz9LfX3/rtGwiIhYXF8ublqNuLcfjTp48Wd4MDQ2VNxERy8vL5U3La9fyOJf6cbuW79uW74uNwCcFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgCky/ogXouWQ3W9OqIXETE/P9+TTcvBuVYtB9pafkwtB9AGBwd78jgREZs3by5vRkdHy5uZmZny5lJ3uR63a+GTAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUA0mV9EK9XR7Iu9WNcCwsL5U1/f/2t09fXV95ERAwPD5c3Q0ND5U3L4cKWTctRxVYjIyPlzenTpy/AM2G98EkBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIl/WVVP5teXm5vGm5XDo7O1veRLRdcW25XtpyWfVf//pXebNpU9vPxXp1xXV6erq8YePwSQGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAMlBPJqOpvX31986fX195U1E26G6FktLS+VNy3Nreb0j2l7zlsOFLQcS2Th8UgAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQHIQj1hZWSlvNm2q/3yi9RBcy2ONjo6WNy3H7ebm5sqbxcXF8qZVyxHClvcDG4dPCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASA7iEUtLS+XNyMhIedPf3/Z263a75c3AwEB5s7y8XN5MT0+XN+Pj4+VNRNtxu9bXnMuXTwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiuZdGk0+n0ZBPRdgjuzJkz5c0VV1xR3rQet+uVoaGhnmzYOHxSACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkiupxNLSUnnT319/6ywvL5c3EREDAwM92QwPD5c3c3Nz5U232y1vIiIGBwebdlUtrx0bh08KACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIDuIR8/Pz5U3L8bi+vr7yJiJidna2vOl0Oj15nJmZmfJmZGSkvImIWFlZ6cmm9XAhG4NPCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASA7i0XQ8bvPmzT3ZRLQd3xsfHy9vhoaGyptut9uTx2nV8linTp26AM/k3VredxERq6ur5/mZ8J98UgAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQHIQj6ajaQsLC+XNli1bypuIiL6+vvLmrbfeKm9anl8vj9u1GBsbK29aXjs2Dp8UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGA5ErqJarT6TTtVldXy5uZmZny5tZbby1vfvnLX5Y3EREDAwPlTct10NHR0fKm2+2WNy2vd0TE8PBweTM+Pl7eTE9PlzdsHD4pAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgdVbXeEGt9UAbG9PevXvLm6WlpabHuuaaa8qbycnJ8mb79u3lzcmTJ8ub1u+lt956q7w5ceJEeXP06NHyhvVhLf+490kBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgCpf61fuMa7eQCsYz4pAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJD+DweYWJOnM3TKAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Show a sample\n", "torch.manual_seed(42)\n", "random_idx = torch.randint(0, len(train_features_batch), size=[1]).item()\n", "img, label = train_features_batch[random_idx], train_labels_batch[random_idx]\n", "plt.imshow(img.squeeze(), cmap=\"gray\")\n", "plt.title(class_names[label])\n", "plt.axis(\"Off\");\n", "print(f\"Image size: {img.shape}\")\n", "print(f\"Label: {label}, label size: {label.shape}\")" ] }, { "cell_type": "markdown", "id": "db1695cf-f53d-4c7c-ad39-dfed76533125", "metadata": { "id": "db1695cf-f53d-4c7c-ad39-dfed76533125" }, "source": [ "## 3. Model 0: Build a baseline model\n", "\n", "Data loaded and prepared!\n", "\n", "Time to build a **baseline model** by subclassing `nn.Module`.\n", "\n", "A **baseline model** is one of the simplest models you can imagine.\n", "\n", "You use the baseline as a starting point and try to improve upon it with subsequent, more complicated models.\n", "\n", "Our baseline will consist of two [`nn.Linear()`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) layers.\n", "\n", "We've done this in a previous section but there's going to one slight difference.\n", "\n", "Because we're working with image data, we're going to use a different layer to start things off.\n", "\n", "And that's the [`nn.Flatten()`](https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html) layer.\n", "\n", "`nn.Flatten()` compresses the dimensions of a tensor into a single vector.\n", "\n", "This is easier to understand when you see it." ] }, { "cell_type": "code", "execution_count": 13, "id": "405319f1-f242-4bd9-90f5-3abdc50782ac", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "405319f1-f242-4bd9-90f5-3abdc50782ac", "outputId": "742cd0fe-c95f-4201-a469-f12733625784" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape before flattening: torch.Size([1, 28, 28]) -> [color_channels, height, width]\n", "Shape after flattening: torch.Size([1, 784]) -> [color_channels, height*width]\n" ] } ], "source": [ "# Create a flatten layer\n", "flatten_model = nn.Flatten() # all nn modules function as a model (can do a forward pass)\n", "\n", "# Get a single sample\n", "x = train_features_batch[0]\n", "\n", "# Flatten the sample\n", "output = flatten_model(x) # perform forward pass\n", "\n", "# Print out what happened\n", "print(f\"Shape before flattening: {x.shape} -> [color_channels, height, width]\")\n", "print(f\"Shape after flattening: {output.shape} -> [color_channels, height*width]\")\n", "\n", "# Try uncommenting below and see what happens\n", "#print(x)\n", "#print(output)" ] }, { "cell_type": "markdown", "id": "86bb7806-fca6-45af-8111-3e00e38f5be9", "metadata": { "id": "86bb7806-fca6-45af-8111-3e00e38f5be9" }, "source": [ "The `nn.Flatten()` layer took our shape from `[color_channels, height, width]` to `[color_channels, height*width]`.\n", "\n", "Why do this?\n", "\n", "Because we've now turned our pixel data from height and width dimensions into one long **feature vector**.\n", "\n", "And `nn.Linear()` layers like their inputs to be in the form of feature vectors.\n", "\n", "Let's create our first model using `nn.Flatten()` as the first layer. " ] }, { "cell_type": "code", "execution_count": 14, "id": "1449f427-6859-41ae-8133-50b58ffbce72", "metadata": { "id": "1449f427-6859-41ae-8133-50b58ffbce72" }, "outputs": [], "source": [ "from torch import nn\n", "class FashionMNISTModelV0(nn.Module):\n", " def __init__(self, input_shape: int, hidden_units: int, output_shape: int):\n", " super().__init__()\n", " self.layer_stack = nn.Sequential(\n", " nn.Flatten(), # neural networks like their inputs in vector form\n", " nn.Linear(in_features=input_shape, out_features=hidden_units), # in_features = number of features in a data sample (784 pixels)\n", " nn.Linear(in_features=hidden_units, out_features=output_shape)\n", " )\n", " \n", " def forward(self, x):\n", " return self.layer_stack(x)" ] }, { "cell_type": "markdown", "id": "4d1b50bf-d00b-485c-be00-b3e4de156fab", "metadata": { "id": "4d1b50bf-d00b-485c-be00-b3e4de156fab" }, "source": [ "Wonderful!\n", "\n", "We've got a baseline model class we can use, now let's instantiate a model.\n", "\n", "We'll need to set the following parameters:\n", "* `input_shape=784` - this is how many features you've got going in the model, in our case, it's one for every pixel in the target image (28 pixels high by 28 pixels wide = 784 features).\n", "* `hidden_units=10` - number of units/neurons in the hidden layer(s), this number could be whatever you want but to keep the model small we'll start with `10`.\n", "* `output_shape=len(class_names)` - since we're working with a multi-class classification problem, we need an output neuron per class in our dataset.\n", "\n", "Let's create an instance of our model and send to the CPU for now (we'll run a small test for running `model_0` on CPU vs. a similar model on GPU soon)." ] }, { "cell_type": "code", "execution_count": 15, "id": "dd18384a-76f9-4b5a-a013-fda077f16865", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dd18384a-76f9-4b5a-a013-fda077f16865", "outputId": "e4b63839-d012-40db-a7f7-967a146fe566" }, "outputs": [ { "data": { "text/plain": [ "FashionMNISTModelV0(\n", " (layer_stack): Sequential(\n", " (0): Flatten(start_dim=1, end_dim=-1)\n", " (1): Linear(in_features=784, out_features=10, bias=True)\n", " (2): Linear(in_features=10, out_features=10, bias=True)\n", " )\n", ")" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(42)\n", "\n", "# Need to setup model with input parameters\n", "model_0 = FashionMNISTModelV0(input_shape=784, # one for every pixel (28x28)\n", " hidden_units=10, # how many units in the hiden layer\n", " output_shape=len(class_names) # one for every class\n", ")\n", "model_0.to(\"cpu\") # keep model on CPU to begin with " ] }, { "cell_type": "markdown", "id": "03243179-1cdc-45d9-8b8c-82538ac02e9c", "metadata": { "id": "03243179-1cdc-45d9-8b8c-82538ac02e9c" }, "source": [ "### 3.1 Setup loss, optimizer and evaluation metrics\n", "\n", "Since we're working on a classification problem, let's bring in our [`helper_functions.py` script](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/helper_functions.py) and subsequently the `accuracy_fn()` we defined in [notebook 02](https://www.learnpytorch.io/02_pytorch_classification/).\n", "\n", "> **Note:** Rather than importing and using our own accuracy function or evaluation metric(s), you could import various evaluation metrics from the [TorchMetrics package](https://torchmetrics.readthedocs.io/en/latest/)." ] }, { "cell_type": "code", "execution_count": 16, "id": "31c91f17-d810-46a4-97c3-c734f93430b1", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "31c91f17-d810-46a4-97c3-c734f93430b1", "outputId": "d2333811-f5fa-426f-90a7-c884fe2493df" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading helper_functions.py\n" ] } ], "source": [ "import requests\n", "from pathlib import Path \n", "\n", "# Download helper functions from Learn PyTorch repo (if not already downloaded)\n", "if Path(\"helper_functions.py\").is_file():\n", " print(\"helper_functions.py already exists, skipping download\")\n", "else:\n", " print(\"Downloading helper_functions.py\")\n", " # Note: you need the \"raw\" GitHub URL for this to work\n", " request = requests.get(\"https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py\")\n", " with open(\"helper_functions.py\", \"wb\") as f:\n", " f.write(request.content)" ] }, { "cell_type": "code", "execution_count": 17, "id": "ce3d13b8-f018-4b44-8bba-375074dc4c5f", "metadata": { "id": "ce3d13b8-f018-4b44-8bba-375074dc4c5f" }, "outputs": [], "source": [ "# Import accuracy metric\n", "from helper_functions import accuracy_fn # Note: could also use torchmetrics.Accuracy(task = 'multiclass', num_classes=len(class_names)).to(device)\n", "\n", "# Setup loss function and optimizer\n", "loss_fn = nn.CrossEntropyLoss() # this is also called \"criterion\"/\"cost function\" in some places\n", "optimizer = torch.optim.SGD(params=model_0.parameters(), lr=0.1)" ] }, { "cell_type": "markdown", "id": "4109f867-83f2-4394-a925-8acdc63ccffe", "metadata": { "id": "4109f867-83f2-4394-a925-8acdc63ccffe" }, "source": [ "### 3.2 Creating a function to time our experiments\n", "\n", "Loss function and optimizer ready!\n", "\n", "It's time to start training a model.\n", "\n", "But how about we do a little experiment while we train.\n", "\n", "I mean, let's make a timing function to measure the time it takes our model to train on CPU versus using a GPU.\n", "\n", "We'll train this model on the CPU but the next one on the GPU and see what happens.\n", "\n", "Our timing function will import the [`timeit.default_timer()` function](https://docs.python.org/3/library/timeit.html#timeit.default_timer) from the Python [`timeit` module](https://docs.python.org/3/library/timeit.html)." ] }, { "cell_type": "code", "execution_count": 18, "id": "31adc3fe-ce90-4b4e-b0d4-3613abae5714", "metadata": { "id": "31adc3fe-ce90-4b4e-b0d4-3613abae5714" }, "outputs": [], "source": [ "from timeit import default_timer as timer \n", "def print_train_time(start: float, end: float, device: torch.device = None):\n", " \"\"\"Prints difference between start and end time.\n", "\n", " Args:\n", " start (float): Start time of computation (preferred in timeit format). \n", " end (float): End time of computation.\n", " device ([type], optional): Device that compute is running on. Defaults to None.\n", "\n", " Returns:\n", " float: time between start and end in seconds (higher is longer).\n", " \"\"\"\n", " total_time = end - start\n", " print(f\"Train time on {device}: {total_time:.3f} seconds\")\n", " return total_time" ] }, { "cell_type": "markdown", "id": "07b9560e-f5dc-45d6-b3b2-ddae17a71b34", "metadata": { "id": "07b9560e-f5dc-45d6-b3b2-ddae17a71b34" }, "source": [ "### 3.3 Creating a training loop and training a model on batches of data\n", "\n", "Beautiful!\n", "\n", "Looks like we've got all of the pieces of the puzzle ready to go, a timer, a loss function, an optimizer, a model and most importantly, some data.\n", "\n", "Let's now create a training loop and a testing loop to train and evaluate our model.\n", "\n", "We'll be using the same steps as the previous notebook(s), though since our data is now in batch form, we'll add another loop to loop through our data batches.\n", "\n", "Our data batches are contained within our `DataLoader`s, `train_dataloader` and `test_dataloader` for the training and test data splits respectively.\n", "\n", "A batch is `BATCH_SIZE` samples of `X` (features) and `y` (labels), since we're using `BATCH_SIZE=32`, our batches have 32 samples of images and targets.\n", "\n", "And since we're computing on batches of data, our loss and evaluation metrics will be calculated **per batch** rather than across the whole dataset.\n", "\n", "This means we'll have to divide our loss and accuracy values by the number of batches in each dataset's respective dataloader. \n", "\n", "Let's step through it: \n", "1. Loop through epochs.\n", "2. Loop through training batches, perform training steps, calculate the train loss *per batch*.\n", "3. Loop through testing batches, perform testing steps, calculate the test loss *per batch*.\n", "4. Print out what's happening.\n", "5. Time it all (for fun).\n", "\n", "A fair few steps but...\n", "\n", "...if in doubt, code it out. " ] }, { "cell_type": "code", "execution_count": 19, "id": "c07bbf10-81e3-47f0-990d-9a4a838276ab", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 587, "referenced_widgets": [ "0bd8f8b5ff4d4b50b03e3a65cc1446f0", "430d171cfd584196ad0fa3e1cd0a286c", "618fb3cf63a94da9ad5f29a3d9a87ac3", "3524e24faad44aa38926b40b2d590f6b", "c01ca4def9224135ad367b6f8dbbae62", "decc1966e6a84973839efc0c65f51790", "39fc424b6cef4e98a80a342f530be99b", "e929063168354b018bbf0bb45fdfef1f", "d62646457b284fcb8aeac382b77eb942", "5c0883aa74f94568850741dad118cb88", "e44697d8dd41492e8619a860b3911e19" ] }, "id": "c07bbf10-81e3-47f0-990d-9a4a838276ab", "outputId": "3fb70da8-1a65-42bb-a684-85f0d1dd11c0" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0bd8f8b5ff4d4b50b03e3a65cc1446f0", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00 pred_prob -> pred_labels)\n", " \n", " # Scale loss and acc to find the average loss/acc per batch\n", " loss /= len(data_loader)\n", " acc /= len(data_loader)\n", " \n", " return {\"model_name\": model.__class__.__name__, # only works when model was created with a class\n", " \"model_loss\": loss.item(),\n", " \"model_acc\": acc}\n", "\n", "# Calculate model 0 results on test dataset\n", "model_0_results = eval_model(model=model_0, data_loader=test_dataloader,\n", " loss_fn=loss_fn, accuracy_fn=accuracy_fn\n", ")\n", "model_0_results" ] }, { "cell_type": "markdown", "id": "a39c3042-1262-4d1f-b33e-c8e2ba6781d3", "metadata": { "id": "a39c3042-1262-4d1f-b33e-c8e2ba6781d3" }, "source": [ "Looking good!\n", "\n", "We can use this dictionary to compare the baseline model results to other models later on." ] }, { "cell_type": "markdown", "id": "3b76784d-4cdb-43d2-a6da-8e4da9a812a9", "metadata": { "id": "3b76784d-4cdb-43d2-a6da-8e4da9a812a9" }, "source": [ "## 5. Setup device agnostic-code (for using a GPU if there is one)\n", "We've seen how long it takes to train ma PyTorch model on 60,000 samples on CPU.\n", "\n", "> **Note:** Model training time is dependent on hardware used. Generally, more processors means faster training and smaller models on smaller datasets will often train faster than large models and large datasets.\n", "\n", "Now let's setup some [device-agnostic code](https://pytorch.org/docs/stable/notes/cuda.html#best-practices) for our models and data to run on GPU if it's available.\n", "\n", "If you're running this notebook on Google Colab, and you don't a GPU turned on yet, it's now time to turn one on via `Runtime -> Change runtime type -> Hardware accelerator -> GPU`. If you do this, your runtime will likely reset and you'll have to run all of the cells above by going `Runtime -> Run before`." ] }, { "cell_type": "code", "execution_count": 21, "id": "17b69fe9-f974-4538-922c-20c5cc8220cc", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "17b69fe9-f974-4538-922c-20c5cc8220cc", "outputId": "10c3b74b-4db7-4a30-8c3a-5a259d1f54b8" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'cuda'" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Setup device agnostic code\n", "import torch\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "device" ] }, { "cell_type": "markdown", "id": "514021a8-d6f2-47f3-ab50-55f844e42310", "metadata": { "id": "514021a8-d6f2-47f3-ab50-55f844e42310" }, "source": [ "Beautiful!\n", "\n", "Let's build another model." ] }, { "cell_type": "markdown", "id": "d7893907-5f82-4c5e-8fde-fa542a9f25af", "metadata": { "id": "d7893907-5f82-4c5e-8fde-fa542a9f25af" }, "source": [ "## 6. Model 1: Building a better model with non-linearity\n", "\n", "We learned about [the power of non-linearity in notebook 02](https://www.learnpytorch.io/02_pytorch_classification/#6-the-missing-piece-non-linearity).\n", "\n", "Seeing the data we've been working with, do you think it needs non-linear functions?\n", "\n", "And remember, linear means straight and non-linear means non-straight.\n", "\n", "Let's find out.\n", "\n", "We'll do so by recreating a similar model to before, except this time we'll put non-linear functions (`nn.ReLU()`) in between each linear layer." ] }, { "cell_type": "code", "execution_count": 22, "id": "2ccce5f2-b1e5-47a6-a7f3-6bc096b35ffb", "metadata": { "id": "2ccce5f2-b1e5-47a6-a7f3-6bc096b35ffb" }, "outputs": [], "source": [ "# Create a model with non-linear and linear layers\n", "class FashionMNISTModelV1(nn.Module):\n", " def __init__(self, input_shape: int, hidden_units: int, output_shape: int):\n", " super().__init__()\n", " self.layer_stack = nn.Sequential(\n", " nn.Flatten(), # flatten inputs into single vector\n", " nn.Linear(in_features=input_shape, out_features=hidden_units),\n", " nn.ReLU(),\n", " nn.Linear(in_features=hidden_units, out_features=output_shape),\n", " nn.ReLU()\n", " )\n", " \n", " def forward(self, x: torch.Tensor):\n", " return self.layer_stack(x)" ] }, { "cell_type": "markdown", "id": "4b4b7a2f-4834-4aa1-a8e2-b6e3e2b49224", "metadata": { "id": "4b4b7a2f-4834-4aa1-a8e2-b6e3e2b49224" }, "source": [ "That looks good.\n", "\n", "Now let's instantiate it with the same settings we used before.\n", "\n", "We'll need `input_shape=784` (equal to the number of features of our image data), `hidden_units=10` (starting small and the same as our baseline model) and `output_shape=len(class_names)` (one output unit per class).\n", "\n", "> **Note:** Notice how we kept most of the settings of our model the same except for one change: adding non-linear layers. This is a standard practice for running a series of machine learning experiments, change one thing and see what happens, then do it again, again, again." ] }, { "cell_type": "code", "execution_count": 23, "id": "907091ec-7e46-470b-a305-788a3009b837", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "907091ec-7e46-470b-a305-788a3009b837", "outputId": "4cecd2df-2918-4368-fa33-7aea8f958d8f" }, "outputs": [ { "data": { "text/plain": [ "device(type='cuda', index=0)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(42)\n", "model_1 = FashionMNISTModelV1(input_shape=784, # number of input features\n", " hidden_units=10,\n", " output_shape=len(class_names) # number of output classes desired\n", ").to(device) # send model to GPU if it's available\n", "next(model_1.parameters()).device # check model device" ] }, { "cell_type": "markdown", "id": "b54a4e9d-a7ad-404c-920f-485fcff18a92", "metadata": { "id": "b54a4e9d-a7ad-404c-920f-485fcff18a92" }, "source": [ "### 6.1 Setup loss, optimizer and evaluation metrics\n", "\n", "As usual, we'll setup a loss function, an optimizer and an evaluation metric (we could do multiple evaluation metrics but we'll stick with accuracy for now)." ] }, { "cell_type": "code", "execution_count": 24, "id": "fe7e463b-d46c-4f00-853c-fdf0a28d74c8", "metadata": { "id": "fe7e463b-d46c-4f00-853c-fdf0a28d74c8" }, "outputs": [], "source": [ "from helper_functions import accuracy_fn\n", "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(params=model_1.parameters(), \n", " lr=0.1)" ] }, { "cell_type": "markdown", "id": "1eb30af6-a355-49a2-a59f-25169fd27a6e", "metadata": { "id": "1eb30af6-a355-49a2-a59f-25169fd27a6e" }, "source": [ "### 6.2 Functionizing training and test loops\n", "\n", "So far we've been writing train and test loops over and over. \n", "\n", "Let's write them again but this time we'll put them in functions so they can be called again and again.\n", "\n", "And because we're using device-agnostic code now, we'll be sure to call `.to(device)` on our feature (`X`) and target (`y`) tensors.\n", "\n", "For the training loop we'll create a function called `train_step()` which takes in a model, a `DataLoader` a loss function and an optimizer.\n", "\n", "The testing loop will be similar but it'll be called `test_step()` and it'll take in a model, a `DataLoader`, a loss function and an evaluation function.\n", "\n", "> **Note:** Since these are functions, you can customize them in any way you like. What we're making here can be considered barebones training and testing functions for our specific classification use case." ] }, { "cell_type": "code", "execution_count": 25, "id": "3d239ed2-4028-4603-8db3-ffca2b727819", "metadata": { "id": "3d239ed2-4028-4603-8db3-ffca2b727819" }, "outputs": [], "source": [ "def train_step(model: torch.nn.Module,\n", " data_loader: torch.utils.data.DataLoader,\n", " loss_fn: torch.nn.Module,\n", " optimizer: torch.optim.Optimizer,\n", " accuracy_fn,\n", " device: torch.device = device):\n", " train_loss, train_acc = 0, 0\n", " model.to(device)\n", " for batch, (X, y) in enumerate(data_loader):\n", " # Send data to GPU\n", " X, y = X.to(device), y.to(device)\n", "\n", " # 1. Forward pass\n", " y_pred = model(X)\n", "\n", " # 2. Calculate loss\n", " loss = loss_fn(y_pred, y)\n", " train_loss += loss\n", " train_acc += accuracy_fn(y_true=y,\n", " y_pred=y_pred.argmax(dim=1)) # Go from logits -> pred labels\n", "\n", " # 3. Optimizer zero grad\n", " optimizer.zero_grad()\n", "\n", " # 4. Loss backward\n", " loss.backward()\n", "\n", " # 5. Optimizer step\n", " optimizer.step()\n", "\n", " # Calculate loss and accuracy per epoch and print out what's happening\n", " train_loss /= len(data_loader)\n", " train_acc /= len(data_loader)\n", " print(f\"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%\")\n", "\n", "def test_step(data_loader: torch.utils.data.DataLoader,\n", " model: torch.nn.Module,\n", " loss_fn: torch.nn.Module,\n", " accuracy_fn,\n", " device: torch.device = device):\n", " test_loss, test_acc = 0, 0\n", " model.to(device)\n", " model.eval() # put model in eval mode\n", " # Turn on inference context manager\n", " with torch.inference_mode(): \n", " for X, y in data_loader:\n", " # Send data to GPU\n", " X, y = X.to(device), y.to(device)\n", " \n", " # 1. Forward pass\n", " test_pred = model(X)\n", " \n", " # 2. Calculate loss and accuracy\n", " test_loss += loss_fn(test_pred, y)\n", " test_acc += accuracy_fn(y_true=y,\n", " y_pred=test_pred.argmax(dim=1) # Go from logits -> pred labels\n", " )\n", " \n", " # Adjust metrics and print out\n", " test_loss /= len(data_loader)\n", " test_acc /= len(data_loader)\n", " print(f\"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\\n\")" ] }, { "cell_type": "markdown", "id": "e44121b6-c4be-4909-9175-dc9bd8dc6273", "metadata": { "id": "e44121b6-c4be-4909-9175-dc9bd8dc6273" }, "source": [ "Woohoo!\n", "\n", "Now we've got some functions for training and testing our model, let's run them.\n", "\n", "We'll do so inside another loop for each epoch.\n", "\n", "That way for each epoch we're going a training and a testing step.\n", "\n", "> **Note:** You can customize how often you do a testing step. Sometimes people do them every five epochs or 10 epochs or in our case, every epoch.\n", "\n", "Let's also time things to see how long our code takes to run on the GPU." ] }, { "cell_type": "code", "execution_count": 26, "id": "2bb8094b-01a0-4b84-9526-ba8888d04901", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 327, "referenced_widgets": [ "3ee8f4a32dae40a2954869aa28d511af", "9bdbfed6aaa64648ac9624541a719134", "a7e31e6725a0417495bb5d8d9bb0eedb", "8a07bf3a83cf44b09ebec23372699dd4", "4da7f6dcecfc44928a784709a2f85c67", "85241944b82749bda4b5b6ff50f484b2", "b139c87d10be44229d2f65d356912c25", "b684374f8a3c41cb887142dd2c4a0c94", "325e5b7b95db4289b3ee1bd6dbfc4a6c", "987db9e4bab746ff9d393aa1409cf628", "dd5dcc8d0c424965ba5a329efbf725cc" ] }, "id": "2bb8094b-01a0-4b84-9526-ba8888d04901", "outputId": "83769d2d-6f3b-4704-e443-cfc4ef52cc81" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3ee8f4a32dae40a2954869aa28d511af", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00 **Note:** The training time on CUDA vs CPU will depend largely on the quality of the CPU/GPU you're using. Read on for a more explained answer.\n", "\n", "> **Question:** \"I used a a GPU but my model didn't train faster, why might that be?\"\n", ">\n", "> **Answer:** Well, one reason could be because your dataset and model are both so small (like the dataset and model we're working with) the benefits of using a GPU are outweighed by the time it actually takes to transfer the data there.\n", "> \n", "> There's a small bottleneck between copying data from the CPU memory (default) to the GPU memory.\n", ">\n", "> So for smaller models and datasets, the CPU might actually be the optimal place to compute on.\n", ">\n", "> But for larger datasets and models, the speed of computing the GPU can offer usually far outweighs the cost of getting the data there.\n", ">\n", "> However, this is largely dependant on the hardware you're using. With practice, you will get used to where the best place to train your models is. \n", "\n", "Let's evaluate our trained `model_1` using our `eval_model()` function and see how it went." ] }, { "cell_type": "code", "execution_count": 27, "id": "32a544e3-9dbe-4aa1-b074-22e28b8f2f2a", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 381 }, "id": "32a544e3-9dbe-4aa1-b074-22e28b8f2f2a", "outputId": "bab29648-1e35-4f01-9efe-fa4d2030cddb" }, "outputs": [ { "ename": "RuntimeError", "evalue": "ignored", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Note: This will error due to `eval_model()` not using device agnostic code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m model_1_results = eval_model(model=model_1, \n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mdata_loader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest_dataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36meval_model\u001b[0;34m(model, data_loader, loss_fn, accuracy_fn)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;31m# Make predictions with the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# Accumulate the loss and accuracy values per batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 217\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 218\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)" ] } ], "source": [ "torch.manual_seed(42)\n", "\n", "# Note: This will error due to `eval_model()` not using device agnostic code \n", "model_1_results = eval_model(model=model_1, \n", " data_loader=test_dataloader,\n", " loss_fn=loss_fn, \n", " accuracy_fn=accuracy_fn) \n", "model_1_results " ] }, { "cell_type": "markdown", "id": "6a3481a5-489d-4db9-ac95-c3ce385978b7", "metadata": { "id": "6a3481a5-489d-4db9-ac95-c3ce385978b7" }, "source": [ "Oh no! \n", "\n", "It looks like our `eval_model()` function errors out with:\n", "\n", "> `RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_addmm)`\n", "\n", "It's because we've setup our data and model to use device-agnostic code but not our evaluation function.\n", "\n", "How about we fix that by passing a target `device` parameter to our `eval_model()` function?\n", "\n", "Then we'll try calculating the results again." ] }, { "cell_type": "code", "execution_count": 28, "id": "f3665d99-1adc-4d9f-bfc6-e5601a80691c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f3665d99-1adc-4d9f-bfc6-e5601a80691c", "outputId": "05312922-d30b-4c09-9989-963a4a579bf8" }, "outputs": [ { "data": { "text/plain": [ "{'model_name': 'FashionMNISTModelV1',\n", " 'model_loss': 0.6850008964538574,\n", " 'model_acc': 75.01996805111821}" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Move values to device\n", "torch.manual_seed(42)\n", "def eval_model(model: torch.nn.Module, \n", " data_loader: torch.utils.data.DataLoader, \n", " loss_fn: torch.nn.Module, \n", " accuracy_fn, \n", " device: torch.device = device):\n", " \"\"\"Evaluates a given model on a given dataset.\n", "\n", " Args:\n", " model (torch.nn.Module): A PyTorch model capable of making predictions on data_loader.\n", " data_loader (torch.utils.data.DataLoader): The target dataset to predict on.\n", " loss_fn (torch.nn.Module): The loss function of model.\n", " accuracy_fn: An accuracy function to compare the models predictions to the truth labels.\n", " device (str, optional): Target device to compute on. Defaults to device.\n", "\n", " Returns:\n", " (dict): Results of model making predictions on data_loader.\n", " \"\"\"\n", " loss, acc = 0, 0\n", " model.eval()\n", " with torch.inference_mode():\n", " for X, y in data_loader:\n", " # Send data to the target device\n", " X, y = X.to(device), y.to(device)\n", " y_pred = model(X)\n", " loss += loss_fn(y_pred, y)\n", " acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))\n", " \n", " # Scale loss and acc\n", " loss /= len(data_loader)\n", " acc /= len(data_loader)\n", " return {\"model_name\": model.__class__.__name__, # only works when model was created with a class\n", " \"model_loss\": loss.item(),\n", " \"model_acc\": acc}\n", "\n", "# Calculate model 1 results with device-agnostic code \n", "model_1_results = eval_model(model=model_1, data_loader=test_dataloader,\n", " loss_fn=loss_fn, accuracy_fn=accuracy_fn,\n", " device=device\n", ")\n", "model_1_results" ] }, { "cell_type": "code", "execution_count": 29, "id": "a9e916cf-f873-4481-a983-bac26ce4cac2", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a9e916cf-f873-4481-a983-bac26ce4cac2", "outputId": "5cdb9f7f-366c-4c14-9afa-f2d1d4e0267d" }, "outputs": [ { "data": { "text/plain": [ "{'model_name': 'FashionMNISTModelV0',\n", " 'model_loss': 0.47663894295692444,\n", " 'model_acc': 83.42651757188499}" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check baseline results\n", "model_0_results" ] }, { "cell_type": "markdown", "id": "340cbf14-e83f-4981-8a93-5fedb6b51418", "metadata": { "id": "340cbf14-e83f-4981-8a93-5fedb6b51418" }, "source": [ "Woah, in this case, it looks like adding non-linearities to our model made it perform worse than the baseline.\n", "\n", "That's a thing to note in machine learning, sometimes the thing you thought should work doesn't. \n", "\n", "And then the thing you thought might not work does.\n", "\n", "It's part science, part art.\n", "\n", "From the looks of things, it seems like our model is **overfitting** on the training data.\n", "\n", "Overfitting means our model is learning the training data well but those patterns aren't generalizing to the testing data.\n", "\n", "Two of the main to fix overfitting include:\n", "1. Using a smaller or different model (some models fit certain kinds of data better than others).\n", "2. Using a larger dataset (the more data, the more chance a model has to learn generalizable patterns).\n", "\n", "There are more, but I'm going to leave that as a challenge for you to explore.\n", "\n", "Try searching online, \"ways to prevent overfitting in machine learning\" and see what comes up.\n", "\n", "In the meantime, let's take a look at number 1: using a different model." ] }, { "cell_type": "markdown", "id": "ac22d685-1b8d-4215-90de-c0476cb0fbdf", "metadata": { "id": "ac22d685-1b8d-4215-90de-c0476cb0fbdf" }, "source": [ "## 7. Model 2: Building a Convolutional Neural Network (CNN)\n", "\n", "Alright, time to step things up a notch.\n", "\n", "It's time to create a [Convolutional Neural Network](https://en.wikipedia.org/wiki/Convolutional_neural_network) (CNN or ConvNet).\n", "\n", "CNN's are known for their capabilities to find patterns in visual data.\n", "\n", "And since we're dealing with visual data, let's see if using a CNN model can improve upon our baseline.\n", "\n", "The CNN model we're going to be using is known as TinyVGG from the [CNN Explainer](https://poloclub.github.io/cnn-explainer/) website.\n", "\n", "It follows the typical structure of a convolutional neural network:\n", "\n", "`Input layer -> [Convolutional layer -> activation layer -> pooling layer] -> Output layer`\n", "\n", "Where the contents of `[Convolutional layer -> activation layer -> pooling layer]` can be upscaled and repeated multiple times, depending on requirements. " ] }, { "cell_type": "markdown", "id": "9c358955-1d20-4903-b872-a239d2753d88", "metadata": { "id": "9c358955-1d20-4903-b872-a239d2753d88" }, "source": [ "### What model should I use?\n", "\n", "> **Question:** Wait, you say CNN's are good for images, are there any other model types I should be aware of?\n", "\n", "Good question.\n", "\n", "This table is a good general guide for which model to use (though there are exceptions).\n", "\n", "| **Problem type** | **Model to use (generally)** | **Code example** |\n", "| ----- | ----- | ----- |\n", "| Structured data (Excel spreadsheets, row and column data) | Gradient boosted models, Random Forests, XGBoost | [`sklearn.ensemble`](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.ensemble), [XGBoost library](https://xgboost.readthedocs.io/en/stable/) |\n", "| Unstructured data (images, audio, language) | Convolutional Neural Networks, Transformers | [`torchvision.models`](https://pytorch.org/vision/stable/models.html), [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) | \n", "\n", "> **Note:** The table above is only for reference, the model you end up using will be highly dependant on the problem you're working on and the constraints you have (amount of data, latency requirements).\n", "\n", "Enough talking about models, let's now build a CNN that replicates the model on the [CNN Explainer website](https://poloclub.github.io/cnn-explainer/).\n", "\n", "![TinyVGG architecture, as setup by CNN explainer website](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-cnn-explainer-model.png)\n", "\n", "To do so, we'll leverage the [`nn.Conv2d()`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) and [`nn.MaxPool2d()`](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html) layers from `torch.nn`.\n" ] }, { "cell_type": "code", "execution_count": 30, "id": "dce60214-63fd-46e2-89ba-125445ac76b7", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dce60214-63fd-46e2-89ba-125445ac76b7", "outputId": "5ae97191-bb41-4e58-e7f1-914b612cbb60" }, "outputs": [ { "data": { "text/plain": [ "FashionMNISTModelV2(\n", " (block_1): Sequential(\n", " (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU()\n", " (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU()\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (block_2): Sequential(\n", " (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU()\n", " (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU()\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (classifier): Sequential(\n", " (0): Flatten(start_dim=1, end_dim=-1)\n", " (1): Linear(in_features=490, out_features=10, bias=True)\n", " )\n", ")" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a convolutional neural network \n", "class FashionMNISTModelV2(nn.Module):\n", " \"\"\"\n", " Model architecture copying TinyVGG from: \n", " https://poloclub.github.io/cnn-explainer/\n", " \"\"\"\n", " def __init__(self, input_shape: int, hidden_units: int, output_shape: int):\n", " super().__init__()\n", " self.block_1 = nn.Sequential(\n", " nn.Conv2d(in_channels=input_shape, \n", " out_channels=hidden_units, \n", " kernel_size=3, # how big is the square that's going over the image?\n", " stride=1, # default\n", " padding=1),# options = \"valid\" (no padding) or \"same\" (output has same shape as input) or int for specific number \n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=hidden_units, \n", " out_channels=hidden_units,\n", " kernel_size=3,\n", " stride=1,\n", " padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2,\n", " stride=2) # default stride value is same as kernel_size\n", " )\n", " self.block_2 = nn.Sequential(\n", " nn.Conv2d(hidden_units, hidden_units, 3, padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(hidden_units, hidden_units, 3, padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2)\n", " )\n", " self.classifier = nn.Sequential(\n", " nn.Flatten(),\n", " # Where did this in_features shape come from? \n", " # It's because each layer of our network compresses and changes the shape of our inputs data.\n", " nn.Linear(in_features=hidden_units*7*7, \n", " out_features=output_shape)\n", " )\n", " \n", " def forward(self, x: torch.Tensor):\n", " x = self.block_1(x)\n", " # print(x.shape)\n", " x = self.block_2(x)\n", " # print(x.shape)\n", " x = self.classifier(x)\n", " # print(x.shape)\n", " return x\n", "\n", "torch.manual_seed(42)\n", "model_2 = FashionMNISTModelV2(input_shape=1, \n", " hidden_units=10, \n", " output_shape=len(class_names)).to(device)\n", "model_2" ] }, { "cell_type": "markdown", "id": "0a20f25e-cc16-4f85-a69b-62008c01d0ed", "metadata": { "id": "0a20f25e-cc16-4f85-a69b-62008c01d0ed" }, "source": [ "Nice!\n", "\n", "Our biggest model yet!\n", "\n", "What we've done is a common practice in machine learning.\n", "\n", "Find a model architecture somewhere and replicate it with code. " ] }, { "cell_type": "markdown", "id": "6478cc5a-7b33-425d-9ab3-6d40168a1aee", "metadata": { "id": "6478cc5a-7b33-425d-9ab3-6d40168a1aee" }, "source": [ "### 7.1 Stepping through `nn.Conv2d()`\n", "\n", "We could start using our model above and see what happens but let's first step through the two new layers we've added:\n", "* [`nn.Conv2d()`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html), also known as a convolutional layer.\n", "* [`nn.MaxPool2d()`](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html), also known as a max pooling layer.\n", "\n", "> **Question:** What does the \"2d\" in `nn.Conv2d()` stand for?\n", ">\n", "> The 2d is for 2-dimensional data. As in, our images have two dimensions: height and width. Yes, there's color channel dimension but each of the color channel dimensions have two dimensions too: height and width.\n", ">\n", "> For other dimensional data (such as 1D for text or 3D for 3D objects) there's also `nn.Conv1d()` and `nn.Conv3d()`. \n", "\n", "To test the layers out, let's create some toy data just like the data used on CNN Explainer." ] }, { "cell_type": "code", "execution_count": 31, "id": "058b01ac-3f6a-4472-bcbf-3377974e3254", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "058b01ac-3f6a-4472-bcbf-3377974e3254", "outputId": "c404a8dd-d804-4993-bc2b-e4fdbf02b62d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch shape: torch.Size([32, 3, 64, 64]) -> [batch_size, color_channels, height, width]\n", "Single image shape: torch.Size([3, 64, 64]) -> [color_channels, height, width]\n", "Single image pixel values:\n", "tensor([[[ 1.9269, 1.4873, 0.9007, ..., 1.8446, -1.1845, 1.3835],\n", " [ 1.4451, 0.8564, 2.2181, ..., 0.3399, 0.7200, 0.4114],\n", " [ 1.9312, 1.0119, -1.4364, ..., -0.5558, 0.7043, 0.7099],\n", " ...,\n", " [-0.5610, -0.4830, 0.4770, ..., -0.2713, -0.9537, -0.6737],\n", " [ 0.3076, -0.1277, 0.0366, ..., -2.0060, 0.2824, -0.8111],\n", " [-1.5486, 0.0485, -0.7712, ..., -0.1403, 0.9416, -0.0118]],\n", "\n", " [[-0.5197, 1.8524, 1.8365, ..., 0.8935, -1.5114, -0.8515],\n", " [ 2.0818, 1.0677, -1.4277, ..., 1.6612, -2.6223, -0.4319],\n", " [-0.1010, -0.4388, -1.9775, ..., 0.2106, 0.2536, -0.7318],\n", " ...,\n", " [ 0.2779, 0.7342, -0.3736, ..., -0.4601, 0.1815, 0.1850],\n", " [ 0.7205, -0.2833, 0.0937, ..., -0.1002, -2.3609, 2.2465],\n", " [-1.3242, -0.1973, 0.2920, ..., 0.5409, 0.6940, 1.8563]],\n", "\n", " [[-0.7978, 1.0261, 1.1465, ..., 1.2134, 0.9354, -0.0780],\n", " [-1.4647, -1.9571, 0.1017, ..., -1.9986, -0.7409, 0.7011],\n", " [-1.3938, 0.8466, -1.7191, ..., -1.1867, 0.1320, 0.3407],\n", " ...,\n", " [ 0.8206, -0.3745, 1.2499, ..., -0.0676, 0.0385, 0.6335],\n", " [-0.5589, -0.3393, 0.2347, ..., 2.1181, 2.4569, 1.3083],\n", " [-0.4092, 1.5199, 0.2401, ..., -0.2558, 0.7870, 0.9924]]])\n" ] } ], "source": [ "torch.manual_seed(42)\n", "\n", "# Create sample batch of random numbers with same size as image batch\n", "images = torch.randn(size=(32, 3, 64, 64)) # [batch_size, color_channels, height, width]\n", "test_image = images[0] # get a single image for testing\n", "print(f\"Image batch shape: {images.shape} -> [batch_size, color_channels, height, width]\")\n", "print(f\"Single image shape: {test_image.shape} -> [color_channels, height, width]\") \n", "print(f\"Single image pixel values:\\n{test_image}\")" ] }, { "cell_type": "markdown", "id": "bd3291c2-854e-4d0c-97b9-8bf46085fc43", "metadata": { "id": "bd3291c2-854e-4d0c-97b9-8bf46085fc43" }, "source": [ "Let's create an example `nn.Conv2d()` with various parameters:\n", "* `in_channels` (int) - Number of channels in the input image.\n", "* `out_channels` (int) - Number of channels produced by the convolution.\n", "* `kernel_size` (int or tuple) - Size of the convolving kernel/filter.\n", "* `stride` (int or tuple, optional) - How big of a step the convolving kernel takes at a time. Default: 1.\n", "* `padding` (int, tuple, str) - Padding added to all four sides of input. Default: 0.\n", "\n", "![example of going through the different parameters of a Conv2d layer](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-conv2d-layer.gif)\n", "\n", "*Example of what happens when you change the hyperparameters of a `nn.Conv2d()` layer.*" ] }, { "cell_type": "code", "execution_count": 32, "id": "ebd39562-1dad-40e3-90f5-750a5dac24e2", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ebd39562-1dad-40e3-90f5-750a5dac24e2", "outputId": "b61154fb-c5f7-4c3f-c619-4bde6edb4d16" }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1.5396, 0.0516, 0.6454, ..., -0.3673, 0.8711, 0.4256],\n", " [ 0.3662, 1.0114, -0.5997, ..., 0.8983, 0.2809, -0.2741],\n", " [ 1.2664, -1.4054, 0.3727, ..., -0.3409, 1.2191, -0.0463],\n", " ...,\n", " [-0.1541, 0.5132, -0.3624, ..., -0.2360, -0.4609, -0.0035],\n", " [ 0.2981, -0.2432, 1.5012, ..., -0.6289, -0.7283, -0.5767],\n", " [-0.0386, -0.0781, -0.0388, ..., 0.2842, 0.4228, -0.1802]],\n", "\n", " [[-0.2840, -0.0319, -0.4455, ..., -0.7956, 1.5599, -1.2449],\n", " [ 0.2753, -0.1262, -0.6541, ..., -0.2211, 0.1999, -0.8856],\n", " [-0.5404, -1.5489, 0.0249, ..., -0.5932, -1.0913, -0.3849],\n", " ...,\n", " [ 0.3870, -0.4064, -0.8236, ..., 0.1734, -0.4330, -0.4951],\n", " [-0.1984, -0.6386, 1.0263, ..., -0.9401, -0.0585, -0.7833],\n", " [-0.6306, -0.2052, -0.3694, ..., -1.3248, 0.2456, -0.7134]],\n", "\n", " [[ 0.4414, 0.5100, 0.4846, ..., -0.8484, 0.2638, 1.1258],\n", " [ 0.8117, 0.3191, -0.0157, ..., 1.2686, 0.2319, 0.5003],\n", " [ 0.3212, 0.0485, -0.2581, ..., 0.2258, 0.2587, -0.8804],\n", " ...,\n", " [-0.1144, -0.1869, 0.0160, ..., -0.8346, 0.0974, 0.8421],\n", " [ 0.2941, 0.4417, 0.5866, ..., -0.1224, 0.4814, -0.4799],\n", " [ 0.6059, -0.0415, -0.2028, ..., 0.1170, 0.2521, -0.4372]],\n", "\n", " ...,\n", "\n", " [[-0.2560, -0.0477, 0.6380, ..., 0.6436, 0.7553, -0.7055],\n", " [ 1.5595, -0.2209, -0.9486, ..., -0.4876, 0.7754, 0.0750],\n", " [-0.0797, 0.2471, 1.1300, ..., 0.1505, 0.2354, 0.9576],\n", " ...,\n", " [ 1.1065, 0.6839, 1.2183, ..., 0.3015, -0.1910, -0.1902],\n", " [-0.3486, -0.7173, -0.3582, ..., 0.4917, 0.7219, 0.1513],\n", " [ 0.0119, 0.1017, 0.7839, ..., -0.3752, -0.8127, -0.1257]],\n", "\n", " [[ 0.3841, 1.1322, 0.1620, ..., 0.7010, 0.0109, 0.6058],\n", " [ 0.1664, 0.1873, 1.5924, ..., 0.3733, 0.9096, -0.5399],\n", " [ 0.4094, -0.0861, -0.7935, ..., -0.1285, -0.9932, -0.3013],\n", " ...,\n", " [ 0.2688, -0.5630, -1.1902, ..., 0.4493, 0.5404, -0.0103],\n", " [ 0.0535, 0.4411, 0.5313, ..., 0.0148, -1.0056, 0.3759],\n", " [ 0.3031, -0.1590, -0.1316, ..., -0.5384, -0.4271, -0.4876]],\n", "\n", " [[-1.1865, -0.7280, -1.2331, ..., -0.9013, -0.0542, -1.5949],\n", " [-0.6345, -0.5920, 0.5326, ..., -1.0395, -0.7963, -0.0647],\n", " [-0.1132, 0.5166, 0.2569, ..., 0.5595, -1.6881, 0.9485],\n", " ...,\n", " [-0.0254, -0.2669, 0.1927, ..., -0.2917, 0.1088, -0.4807],\n", " [-0.2609, -0.2328, 0.1404, ..., -0.1325, -0.8436, -0.7524],\n", " [-1.1399, -0.1751, -0.8705, ..., 0.1589, 0.3377, 0.3493]]],\n", " grad_fn=)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(42)\n", "\n", "# Create a convolutional layer with same dimensions as TinyVGG \n", "# (try changing any of the parameters and see what happens)\n", "conv_layer = nn.Conv2d(in_channels=3,\n", " out_channels=10,\n", " kernel_size=3,\n", " stride=1,\n", " padding=0) # also try using \"valid\" or \"same\" here \n", "\n", "# Pass the data through the convolutional layer\n", "conv_layer(test_image) # Note: If running PyTorch <1.11.0, this will error because of shape issues (nn.Conv.2d() expects a 4d tensor as input) " ] }, { "cell_type": "markdown", "id": "cb0184ad-5c16-4e1c-bcfa-70ecf15377da", "metadata": { "id": "cb0184ad-5c16-4e1c-bcfa-70ecf15377da" }, "source": [ "If we try to pass a single image in, we get a shape mismatch error:\n", "\n", "> `RuntimeError: Expected 4-dimensional input for 4-dimensional weight [10, 3, 3, 3], but got 3-dimensional input of size [3, 64, 64] instead`\n", ">\n", "> **Note:** If you're running PyTorch 1.11.0+, this error won't occur.\n", "\n", "This is because our `nn.Conv2d()` layer expects a 4-dimensional tensor as input with size `(N, C, H, W)` or `[batch_size, color_channels, height, width]`.\n", "\n", "Right now our single image `test_image` only has a shape of `[color_channels, height, width]` or `[3, 64, 64]`.\n", "\n", "We can fix this for a single image using `test_image.unsqueeze(dim=0)` to add an extra dimension for `N`." ] }, { "cell_type": "code", "execution_count": 33, "id": "abba741d-a1ed-44ed-ba53-41d589433a2c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "abba741d-a1ed-44ed-ba53-41d589433a2c", "outputId": "9dd8151d-376c-4342-c379-91fcb6468706" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 3, 64, 64])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Add extra dimension to test image\n", "test_image.unsqueeze(dim=0).shape" ] }, { "cell_type": "code", "execution_count": 34, "id": "c7280a49-4ee0-452b-a514-61115b6a444c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c7280a49-4ee0-452b-a514-61115b6a444c", "outputId": "87bf7e37-c1a7-44a4-eef2-a02eb0489147" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 10, 62, 62])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pass test image with extra dimension through conv_layer\n", "conv_layer(test_image.unsqueeze(dim=0)).shape" ] }, { "cell_type": "markdown", "id": "181df81b-7c5a-46cc-b8d5-a592bf755a13", "metadata": { "id": "181df81b-7c5a-46cc-b8d5-a592bf755a13" }, "source": [ "Hmm, notice what happens to our shape (the same shape as the first layer of TinyVGG on [CNN Explainer](https://poloclub.github.io/cnn-explainer/)), we get different channel sizes as well as different pixel sizes.\n", "\n", "What if we changed the values of `conv_layer`?" ] }, { "cell_type": "code", "execution_count": 35, "id": "04445d45-cf2f-4c1d-b215-bc50865a207a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "04445d45-cf2f-4c1d-b215-bc50865a207a", "outputId": "eaa97fb8-52c0-493d-eac3-f2e23df9b01c" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 10, 30, 30])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(42)\n", "# Create a new conv_layer with different values (try setting these to whatever you like)\n", "conv_layer_2 = nn.Conv2d(in_channels=3, # same number of color channels as our input image\n", " out_channels=10,\n", " kernel_size=(5, 5), # kernel is usually a square so a tuple also works\n", " stride=2,\n", " padding=0)\n", "\n", "# Pass single image through new conv_layer_2 (this calls nn.Conv2d()'s forward() method on the input)\n", "conv_layer_2(test_image.unsqueeze(dim=0)).shape" ] }, { "cell_type": "markdown", "id": "b27dbdbb-3e32-4ffa-803e-cf943d96c72b", "metadata": { "id": "b27dbdbb-3e32-4ffa-803e-cf943d96c72b" }, "source": [ "Woah, we get another shape change.\n", "\n", "Now our image is of shape `[1, 10, 30, 30]` (it will be different if you use different values) or `[batch_size=1, color_channels=10, height=30, width=30]`.\n", "\n", "What's going on here?\n", "\n", "Behind the scenes, our `nn.Conv2d()` is compressing the information stored in the image.\n", "\n", "It does this by performing operations on the input (our test image) against its internal parameters.\n", "\n", "The goal of this is similar to all of the other neural networks we've been building.\n", "\n", "Data goes in and the layers try to update their internal parameters (patterns) to lower the loss function thanks to some help of the optimizer.\n", "\n", "The only difference is *how* the different layers calculate their parameter updates or in PyTorch terms, the operation present in the layer `forward()` method.\n", "\n", "If we check out our `conv_layer_2.state_dict()` we'll find a similar weight and bias setup as we've seen before." ] }, { "cell_type": "code", "execution_count": 36, "id": "46027ed1-c3a7-46bd-bab7-17f8c20e354b", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "46027ed1-c3a7-46bd-bab7-17f8c20e354b", "outputId": "bc493b18-1ef3-41c6-9de9-5dfeaa4c259e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OrderedDict([('weight', tensor([[[[ 0.0883, 0.0958, -0.0271, 0.1061, -0.0253],\n", " [ 0.0233, -0.0562, 0.0678, 0.1018, -0.0847],\n", " [ 0.1004, 0.0216, 0.0853, 0.0156, 0.0557],\n", " [-0.0163, 0.0890, 0.0171, -0.0539, 0.0294],\n", " [-0.0532, -0.0135, -0.0469, 0.0766, -0.0911]],\n", "\n", " [[-0.0532, -0.0326, -0.0694, 0.0109, -0.1140],\n", " [ 0.1043, -0.0981, 0.0891, 0.0192, -0.0375],\n", " [ 0.0714, 0.0180, 0.0933, 0.0126, -0.0364],\n", " [ 0.0310, -0.0313, 0.0486, 0.1031, 0.0667],\n", " [-0.0505, 0.0667, 0.0207, 0.0586, -0.0704]],\n", "\n", " [[-0.1143, -0.0446, -0.0886, 0.0947, 0.0333],\n", " [ 0.0478, 0.0365, -0.0020, 0.0904, -0.0820],\n", " [ 0.0073, -0.0788, 0.0356, -0.0398, 0.0354],\n", " [-0.0241, 0.0958, -0.0684, -0.0689, -0.0689],\n", " [ 0.1039, 0.0385, 0.1111, -0.0953, -0.1145]]],\n", "\n", "\n", " [[[-0.0903, -0.0777, 0.0468, 0.0413, 0.0959],\n", " [-0.0596, -0.0787, 0.0613, -0.0467, 0.0701],\n", " [-0.0274, 0.0661, -0.0897, -0.0583, 0.0352],\n", " [ 0.0244, -0.0294, 0.0688, 0.0785, -0.0837],\n", " [-0.0616, 0.1057, -0.0390, -0.0409, -0.1117]],\n", "\n", " [[-0.0661, 0.0288, -0.0152, -0.0838, 0.0027],\n", " [-0.0789, -0.0980, -0.0636, -0.1011, -0.0735],\n", " [ 0.1154, 0.0218, 0.0356, -0.1077, -0.0758],\n", " [-0.0384, 0.0181, -0.1016, -0.0498, -0.0691],\n", " [ 0.0003, -0.0430, -0.0080, -0.0782, -0.0793]],\n", "\n", " [[-0.0674, -0.0395, -0.0911, 0.0968, -0.0229],\n", " [ 0.0994, 0.0360, -0.0978, 0.0799, -0.0318],\n", " [-0.0443, -0.0958, -0.1148, 0.0330, -0.0252],\n", " [ 0.0450, -0.0948, 0.0857, -0.0848, -0.0199],\n", " [ 0.0241, 0.0596, 0.0932, 0.1052, -0.0916]]],\n", "\n", "\n", " [[[ 0.0291, -0.0497, -0.0127, -0.0864, 0.1052],\n", " [-0.0847, 0.0617, 0.0406, 0.0375, -0.0624],\n", " [ 0.1050, 0.0254, 0.0149, -0.1018, 0.0485],\n", " [-0.0173, -0.0529, 0.0992, 0.0257, -0.0639],\n", " [-0.0584, -0.0055, 0.0645, -0.0295, -0.0659]],\n", "\n", " [[-0.0395, -0.0863, 0.0412, 0.0894, -0.1087],\n", " [ 0.0268, 0.0597, 0.0209, -0.0411, 0.0603],\n", " [ 0.0607, 0.0432, -0.0203, -0.0306, 0.0124],\n", " [-0.0204, -0.0344, 0.0738, 0.0992, -0.0114],\n", " [-0.0259, 0.0017, -0.0069, 0.0278, 0.0324]],\n", "\n", " [[-0.1049, -0.0426, 0.0972, 0.0450, -0.0057],\n", " [-0.0696, -0.0706, -0.1034, -0.0376, 0.0390],\n", " [ 0.0736, 0.0533, -0.1021, -0.0694, -0.0182],\n", " [ 0.1117, 0.0167, -0.0299, 0.0478, -0.0440],\n", " [-0.0747, 0.0843, -0.0525, -0.0231, -0.1149]]],\n", "\n", "\n", " [[[ 0.0773, 0.0875, 0.0421, -0.0805, -0.1140],\n", " [-0.0938, 0.0861, 0.0554, 0.0972, 0.0605],\n", " [ 0.0292, -0.0011, -0.0878, -0.0989, -0.1080],\n", " [ 0.0473, -0.0567, -0.0232, -0.0665, -0.0210],\n", " [-0.0813, -0.0754, 0.0383, -0.0343, 0.0713]],\n", "\n", " [[-0.0370, -0.0847, -0.0204, -0.0560, -0.0353],\n", " [-0.1099, 0.0646, -0.0804, 0.0580, 0.0524],\n", " [ 0.0825, -0.0886, 0.0830, -0.0546, 0.0428],\n", " [ 0.1084, -0.0163, -0.0009, -0.0266, -0.0964],\n", " [ 0.0554, -0.1146, 0.0717, 0.0864, 0.1092]],\n", "\n", " [[-0.0272, -0.0949, 0.0260, 0.0638, -0.1149],\n", " [-0.0262, -0.0692, -0.0101, -0.0568, -0.0472],\n", " [-0.0367, -0.1097, 0.0947, 0.0968, -0.0181],\n", " [-0.0131, -0.0471, -0.1043, -0.1124, 0.0429],\n", " [-0.0634, -0.0742, -0.0090, -0.0385, -0.0374]]],\n", "\n", "\n", " [[[ 0.0037, -0.0245, -0.0398, -0.0553, -0.0940],\n", " [ 0.0968, -0.0462, 0.0306, -0.0401, 0.0094],\n", " [ 0.1077, 0.0532, -0.1001, 0.0458, 0.1096],\n", " [ 0.0304, 0.0774, 0.1138, -0.0177, 0.0240],\n", " [-0.0803, -0.0238, 0.0855, 0.0592, -0.0731]],\n", "\n", " [[-0.0926, -0.0789, -0.1140, -0.0891, -0.0286],\n", " [ 0.0779, 0.0193, -0.0878, -0.0926, 0.0574],\n", " [-0.0859, -0.0142, 0.0554, -0.0534, -0.0126],\n", " [-0.0101, -0.0273, -0.0585, -0.1029, -0.0933],\n", " [-0.0618, 0.1115, -0.0558, -0.0775, 0.0280]],\n", "\n", " [[ 0.0318, 0.0633, 0.0878, 0.0643, -0.1145],\n", " [ 0.0102, 0.0699, -0.0107, -0.0680, 0.1101],\n", " [-0.0432, -0.0657, -0.1041, 0.0052, 0.0512],\n", " [ 0.0256, 0.0228, -0.0876, -0.1078, 0.0020],\n", " [ 0.1053, 0.0666, -0.0672, -0.0150, -0.0851]]],\n", "\n", "\n", " [[[-0.0557, 0.0209, 0.0629, 0.0957, -0.1060],\n", " [ 0.0772, -0.0814, 0.0432, 0.0977, 0.0016],\n", " [ 0.1051, -0.0984, -0.0441, 0.0673, -0.0252],\n", " [-0.0236, -0.0481, 0.0796, 0.0566, 0.0370],\n", " [-0.0649, -0.0937, 0.0125, 0.0342, -0.0533]],\n", "\n", " [[-0.0323, 0.0780, 0.0092, 0.0052, -0.0284],\n", " [-0.1046, -0.1086, -0.0552, -0.0587, 0.0360],\n", " [-0.0336, -0.0452, 0.1101, 0.0402, 0.0823],\n", " [-0.0559, -0.0472, 0.0424, -0.0769, -0.0755],\n", " [-0.0056, -0.0422, -0.0866, 0.0685, 0.0929]],\n", "\n", " [[ 0.0187, -0.0201, -0.1070, -0.0421, 0.0294],\n", " [ 0.0544, -0.0146, -0.0457, 0.0643, -0.0920],\n", " [ 0.0730, -0.0448, 0.0018, -0.0228, 0.0140],\n", " [-0.0349, 0.0840, -0.0030, 0.0901, 0.1110],\n", " [-0.0563, -0.0842, 0.0926, 0.0905, -0.0882]]],\n", "\n", "\n", " [[[-0.0089, -0.1139, -0.0945, 0.0223, 0.0307],\n", " [ 0.0245, -0.0314, 0.1065, 0.0165, -0.0681],\n", " [-0.0065, 0.0277, 0.0404, -0.0816, 0.0433],\n", " [-0.0590, -0.0959, -0.0631, 0.1114, 0.0987],\n", " [ 0.1034, 0.0678, 0.0872, -0.0155, -0.0635]],\n", "\n", " [[ 0.0577, -0.0598, -0.0779, -0.0369, 0.0242],\n", " [ 0.0594, -0.0448, -0.0680, 0.0156, -0.0681],\n", " [-0.0752, 0.0602, -0.0194, 0.1055, 0.1123],\n", " [ 0.0345, 0.0397, 0.0266, 0.0018, -0.0084],\n", " [ 0.0016, 0.0431, 0.1074, -0.0299, -0.0488]],\n", "\n", " [[-0.0280, -0.0558, 0.0196, 0.0862, 0.0903],\n", " [ 0.0530, -0.0850, -0.0620, -0.0254, -0.0213],\n", " [ 0.0095, -0.1060, 0.0359, -0.0881, -0.0731],\n", " [-0.0960, 0.1006, -0.1093, 0.0871, -0.0039],\n", " [-0.0134, 0.0722, -0.0107, 0.0724, 0.0835]]],\n", "\n", "\n", " [[[-0.1003, 0.0444, 0.0218, 0.0248, 0.0169],\n", " [ 0.0316, -0.0555, -0.0148, 0.1097, 0.0776],\n", " [-0.0043, -0.1086, 0.0051, -0.0786, 0.0939],\n", " [-0.0701, -0.0083, -0.0256, 0.0205, 0.1087],\n", " [ 0.0110, 0.0669, 0.0896, 0.0932, -0.0399]],\n", "\n", " [[-0.0258, 0.0556, -0.0315, 0.0541, -0.0252],\n", " [-0.0783, 0.0470, 0.0177, 0.0515, 0.1147],\n", " [ 0.0788, 0.1095, 0.0062, -0.0993, -0.0810],\n", " [-0.0717, -0.1018, -0.0579, -0.1063, -0.1065],\n", " [-0.0690, -0.1138, -0.0709, 0.0440, 0.0963]],\n", "\n", " [[-0.0343, -0.0336, 0.0617, -0.0570, -0.0546],\n", " [ 0.0711, -0.1006, 0.0141, 0.1020, 0.0198],\n", " [ 0.0314, -0.0672, -0.0016, 0.0063, 0.0283],\n", " [ 0.0449, 0.1003, -0.0881, 0.0035, -0.0577],\n", " [-0.0913, -0.0092, -0.1016, 0.0806, 0.0134]]],\n", "\n", "\n", " [[[-0.0622, 0.0603, -0.1093, -0.0447, -0.0225],\n", " [-0.0981, -0.0734, -0.0188, 0.0876, 0.1115],\n", " [ 0.0735, -0.0689, -0.0755, 0.1008, 0.0408],\n", " [ 0.0031, 0.0156, -0.0928, -0.0386, 0.1112],\n", " [-0.0285, -0.0058, -0.0959, -0.0646, -0.0024]],\n", "\n", " [[-0.0717, -0.0143, 0.0470, -0.1130, 0.0343],\n", " [-0.0763, -0.0564, 0.0443, 0.0918, -0.0316],\n", " [-0.0474, -0.1044, -0.0595, -0.1011, -0.0264],\n", " [ 0.0236, -0.1082, 0.1008, 0.0724, -0.1130],\n", " [-0.0552, 0.0377, -0.0237, -0.0126, -0.0521]],\n", "\n", " [[ 0.0927, -0.0645, 0.0958, 0.0075, 0.0232],\n", " [ 0.0901, -0.0190, -0.0657, -0.0187, 0.0937],\n", " [-0.0857, 0.0262, -0.1135, 0.0605, 0.0427],\n", " [ 0.0049, 0.0496, 0.0001, 0.0639, -0.0914],\n", " [-0.0170, 0.0512, 0.1150, 0.0588, -0.0840]]],\n", "\n", "\n", " [[[ 0.0888, -0.0257, -0.0247, -0.1050, -0.0182],\n", " [ 0.0817, 0.0161, -0.0673, 0.0355, -0.0370],\n", " [ 0.1054, -0.1002, -0.0365, -0.1115, -0.0455],\n", " [ 0.0364, 0.1112, 0.0194, 0.1132, 0.0226],\n", " [ 0.0667, 0.0926, 0.0965, -0.0646, 0.1062]],\n", "\n", " [[ 0.0699, -0.0540, -0.0551, -0.0969, 0.0290],\n", " [-0.0936, 0.0488, 0.0365, -0.1003, 0.0315],\n", " [-0.0094, 0.0527, 0.0663, -0.1148, 0.1059],\n", " [ 0.0968, 0.0459, -0.1055, -0.0412, -0.0335],\n", " [-0.0297, 0.0651, 0.0420, 0.0915, -0.0432]],\n", "\n", " [[ 0.0389, 0.0411, -0.0961, -0.1120, -0.0599],\n", " [ 0.0790, -0.1087, -0.1005, 0.0647, 0.0623],\n", " [ 0.0950, -0.0872, -0.0845, 0.0592, 0.1004],\n", " [ 0.0691, 0.0181, 0.0381, 0.1096, -0.0745],\n", " [-0.0524, 0.0808, -0.0790, -0.0637, 0.0843]]]])), ('bias', tensor([ 0.0364, 0.0373, -0.0489, -0.0016, 0.1057, -0.0693, 0.0009, 0.0549,\n", " -0.0797, 0.1121]))])\n" ] } ], "source": [ "# Check out the conv_layer_2 internal parameters\n", "print(conv_layer_2.state_dict())" ] }, { "cell_type": "markdown", "id": "8b708eb6-ae46-4d8c-a8a4-1392827d3e37", "metadata": { "id": "8b708eb6-ae46-4d8c-a8a4-1392827d3e37" }, "source": [ "Look at that! A bunch of random numbers for a weight and bias tensor.\n", "\n", "The shapes of these are manipulated by the inputs we passed to `nn.Conv2d()` when we set it up.\n", "\n", "Let's check them out." ] }, { "cell_type": "code", "execution_count": 37, "id": "e5518d61-c0b7-4351-b5ea-4d6b6144291a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e5518d61-c0b7-4351-b5ea-4d6b6144291a", "outputId": "14ef701e-c82d-4fae-9dd6-e18352817cf2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "conv_layer_2 weight shape: \n", "torch.Size([10, 3, 5, 5]) -> [out_channels=10, in_channels=3, kernel_size=5, kernel_size=5]\n", "\n", "conv_layer_2 bias shape: \n", "torch.Size([10]) -> [out_channels=10]\n" ] } ], "source": [ "# Get shapes of weight and bias tensors within conv_layer_2\n", "print(f\"conv_layer_2 weight shape: \\n{conv_layer_2.weight.shape} -> [out_channels=10, in_channels=3, kernel_size=5, kernel_size=5]\")\n", "print(f\"\\nconv_layer_2 bias shape: \\n{conv_layer_2.bias.shape} -> [out_channels=10]\")" ] }, { "cell_type": "markdown", "id": "f0de23c7-4501-4156-80a4-ac889a636a42", "metadata": { "id": "f0de23c7-4501-4156-80a4-ac889a636a42" }, "source": [ "> **Question:** What should we set the parameters of our `nn.Conv2d()` layers?\n", ">\n", "> That's a good one. But similar to many other things in machine learning, the values of these aren't set in stone (and recall, because these values are ones we can set ourselves, they're referred to as \"**hyperparameters**\"). \n", ">\n", "> The best way to find out is to try out different values and see how they effect your model's performance.\n", ">\n", "> Or even better, find a working example on a problem similar to yours (like we've done with TinyVGG) and copy it. \n", "\n", "We're working with a different of layer here to what we've seen before.\n", "\n", "But the premise remains the same: start with random numbers and update them to better represent the data." ] }, { "cell_type": "markdown", "id": "6370d45d-ca44-4fa0-a2d7-efaf0a207b91", "metadata": { "id": "6370d45d-ca44-4fa0-a2d7-efaf0a207b91" }, "source": [ "### 7.2 Stepping through `nn.MaxPool2d()`\n", "Now let's check out what happens when we move data through `nn.MaxPool2d()`." ] }, { "cell_type": "code", "execution_count": 38, "id": "1164c753-19d9-43b7-a04f-017d0f7188c3", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1164c753-19d9-43b7-a04f-017d0f7188c3", "outputId": "9c46f08e-928d-4ee4-e43c-b402113fc2b4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test image original shape: torch.Size([3, 64, 64])\n", "Test image with unsqueezed dimension: torch.Size([1, 3, 64, 64])\n", "Shape after going through conv_layer(): torch.Size([1, 10, 62, 62])\n", "Shape after going through conv_layer() and max_pool_layer(): torch.Size([1, 10, 31, 31])\n" ] } ], "source": [ "# Print out original image shape without and with unsqueezed dimension\n", "print(f\"Test image original shape: {test_image.shape}\")\n", "print(f\"Test image with unsqueezed dimension: {test_image.unsqueeze(dim=0).shape}\")\n", "\n", "# Create a sample nn.MaxPoo2d() layer\n", "max_pool_layer = nn.MaxPool2d(kernel_size=2)\n", "\n", "# Pass data through just the conv_layer\n", "test_image_through_conv = conv_layer(test_image.unsqueeze(dim=0))\n", "print(f\"Shape after going through conv_layer(): {test_image_through_conv.shape}\")\n", "\n", "# Pass data through the max pool layer\n", "test_image_through_conv_and_max_pool = max_pool_layer(test_image_through_conv)\n", "print(f\"Shape after going through conv_layer() and max_pool_layer(): {test_image_through_conv_and_max_pool.shape}\")" ] }, { "cell_type": "markdown", "id": "de029abd-6674-4bfa-99ab-322f339f89f4", "metadata": { "id": "de029abd-6674-4bfa-99ab-322f339f89f4" }, "source": [ "Notice the change in the shapes of what's happening in and out of a `nn.MaxPool2d()` layer.\n", "\n", "The `kernel_size` of the `nn.MaxPool2d()` layer will effects the size of the output shape.\n", "\n", "In our case, the shape halves from a `62x62` image to `31x31` image.\n", "\n", "Let's see this work with a smaller tensor." ] }, { "cell_type": "code", "execution_count": 39, "id": "e6a2b196-4845-4b40-9212-e75406e88875", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e6a2b196-4845-4b40-9212-e75406e88875", "outputId": "5a5e5df1-8e25-4061-d223-398ecfc7ef4c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random tensor:\n", "tensor([[[[0.3367, 0.1288],\n", " [0.2345, 0.2303]]]])\n", "Random tensor shape: torch.Size([1, 1, 2, 2])\n", "\n", "Max pool tensor:\n", "tensor([[[[0.3367]]]]) <- this is the maximum value from random_tensor\n", "Max pool tensor shape: torch.Size([1, 1, 1, 1])\n" ] } ], "source": [ "torch.manual_seed(42)\n", "# Create a random tensor with a similiar number of dimensions to our images\n", "random_tensor = torch.randn(size=(1, 1, 2, 2))\n", "print(f\"Random tensor:\\n{random_tensor}\")\n", "print(f\"Random tensor shape: {random_tensor.shape}\")\n", "\n", "# Create a max pool layer\n", "max_pool_layer = nn.MaxPool2d(kernel_size=2) # see what happens when you change the kernel_size value \n", "\n", "# Pass the random tensor through the max pool layer\n", "max_pool_tensor = max_pool_layer(random_tensor)\n", "print(f\"\\nMax pool tensor:\\n{max_pool_tensor} <- this is the maximum value from random_tensor\")\n", "print(f\"Max pool tensor shape: {max_pool_tensor.shape}\")" ] }, { "cell_type": "markdown", "id": "002e586e-dcb3-40fe-a7dd-a1c18a3b8da0", "metadata": { "id": "002e586e-dcb3-40fe-a7dd-a1c18a3b8da0" }, "source": [ "Notice the final two dimensions between `random_tensor` and `max_pool_tensor`, they go from `[2, 2]` to `[1, 1]`.\n", "\n", "In essence, they get halved.\n", "\n", "And the change would be different for different values of `kernel_size` for `nn.MaxPool2d()`.\n", "\n", "Also notice the value leftover in `max_pool_tensor` is the **maximum** value from `random_tensor`.\n", "\n", "What's happening here?\n", "\n", "This is another important piece of the puzzle of neural networks.\n", "\n", "Essentially, **every layer in a neural network is trying to compress data from higher dimensional space to lower dimensional space**. \n", "\n", "In other words, take a lot of numbers (raw data) and learn patterns in those numbers, patterns that are predictive whilst also being *smaller* in size than the original values.\n", "\n", "From an artificial intelligence perspective, you could consider the whole goal of a neural network to *compress* information.\n", "\n", "![each layer of a neural network compresses the original input data into a smaller representation that is (hopefully) capable of making predictions on future input data](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/03-conv-net-as-compression.png)\n", "\n", "This means, that from the point of view of a neural network, intelligence is compression.\n", "\n", "This is the idea of the use of a `nn.MaxPool2d()` layer: take the maximum value from a portion of a tensor and disregard the rest.\n", "\n", "In essence, lowering the dimensionality of a tensor whilst still retaining a (hopefully) significant portion of the information.\n", "\n", "It is the same story for a `nn.Conv2d()` layer.\n", "\n", "Except instead of just taking the maximum, the `nn.Conv2d()` performs a convolutional operation on the data (see this in action on the [CNN Explainer webpage](https://poloclub.github.io/cnn-explainer/)).\n", "\n", "> **Exercise:** What do you think the [`nn.AvgPool2d()`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) layer does? Try making a random tensor like we did above and passing it through. Check the input and output shapes as well as the input and output values.\n", "\n", "> **Extra-curriculum:** Lookup \"most common convolutional neural networks\", what architectures do you find? Are any of them contained within the [`torchvision.models`](https://pytorch.org/vision/stable/models.html) library? What do you think you could do with these?" ] }, { "cell_type": "markdown", "id": "39a3c646-52f0-4f4b-8527-2fc33d0dfb13", "metadata": { "id": "39a3c646-52f0-4f4b-8527-2fc33d0dfb13" }, "source": [ "### 7.3 Setup a loss function and optimizer for `model_2`\n", "\n", "We've stepped through the layers in our first CNN enough.\n", "\n", "But remember, if something still isn't clear, try starting small.\n", "\n", "Pick a single layer of a model, pass some data through it and see what happens.\n", "\n", "Now it's time to move forward and get to training!\n", "\n", "Let's setup a loss function and an optimizer.\n", "\n", "We'll use the functions as before, `nn.CrossEntropyLoss()` as the loss function (since we're working with multi-class classification data).\n", "\n", "And `torch.optim.SGD()` as the optimizer to optimize `model_2.parameters()` with a learning rate of `0.1`." ] }, { "cell_type": "code", "execution_count": 40, "id": "06a76a1b-5f6f-4018-bf7b-8388b385476f", "metadata": { "id": "06a76a1b-5f6f-4018-bf7b-8388b385476f" }, "outputs": [], "source": [ "# Setup loss and optimizer\n", "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(params=model_2.parameters(), \n", " lr=0.1)" ] }, { "cell_type": "markdown", "id": "758bc223-a244-4604-a07a-e2fc2f96c2f6", "metadata": { "id": "758bc223-a244-4604-a07a-e2fc2f96c2f6" }, "source": [ "### 7.4 Training and testing `model_2` using our training and test functions\n", "\n", "Loss and optimizer ready!\n", "\n", "Time to train and test.\n", "\n", "We'll use our `train_step()` and `test_step()` functions we created before.\n", "\n", "We'll also measure the time to compare it to our other models." ] }, { "cell_type": "code", "execution_count": 41, "id": "861d126e-d876-40b3-9b7a-66cfc2f1bf05", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 327, "referenced_widgets": [ "2b9c90ceb8554eaaaaf33acacecbcc11", "9051de473b1b456592576115260f0c48", "a315b7b535e4461ca11a6d96ca74411c", "4cd01ed6c2534d4e80a0af2e9da02052", "eb3ca30526e24fff9194d4e82436df99", "6117713ba6c9490fab837cfbad7d442b", "d426a4e9fff447ea95c7256372e272b7", "fdcf4db7208d42b5acb35bf56c72fd82", "606df8221adf48308f48c86ac54475e8", "ebc6218d07f54005b0e46553b83464e1", "03435ac81db84ae6bb6cb99bb78167c2" ] }, "id": "861d126e-d876-40b3-9b7a-66cfc2f1bf05", "outputId": "77ae1601-bf72-4239-b3f6-9fb99a1c62c0" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2b9c90ceb8554eaaaaf33acacecbcc11", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
model_namemodel_lossmodel_acc
0FashionMNISTModelV00.47663983.426518
1FashionMNISTModelV10.68500175.019968
2FashionMNISTModelV20.32857088.378594
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", " \n", " " ], "text/plain": [ " model_name model_loss model_acc\n", "0 FashionMNISTModelV0 0.476639 83.426518\n", "1 FashionMNISTModelV1 0.685001 75.019968\n", "2 FashionMNISTModelV2 0.328570 88.378594" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "compare_results = pd.DataFrame([model_0_results, model_1_results, model_2_results])\n", "compare_results" ] }, { "cell_type": "markdown", "id": "c67f3fb5-ce7b-40b8-86a0-2797492de0ef", "metadata": { "id": "c67f3fb5-ce7b-40b8-86a0-2797492de0ef" }, "source": [ "Nice!\n", "\n", "We can add the training time values too." ] }, { "cell_type": "code", "execution_count": 44, "id": "297af38f-e69f-4c6f-9027-fcaf0482a55c", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 143 }, "id": "297af38f-e69f-4c6f-9027-fcaf0482a55c", "outputId": "67c01781-78c2-47e8-a1ee-452869bea5e2" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
model_namemodel_lossmodel_acctraining_time
0FashionMNISTModelV00.47663983.42651832.348722
1FashionMNISTModelV10.68500175.01996836.877976
2FashionMNISTModelV20.32857088.37859444.249765
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " model_name model_loss model_acc training_time\n", "0 FashionMNISTModelV0 0.476639 83.426518 32.348722\n", "1 FashionMNISTModelV1 0.685001 75.019968 36.877976\n", "2 FashionMNISTModelV2 0.328570 88.378594 44.249765" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Add training times to results comparison\n", "compare_results[\"training_time\"] = [total_train_time_model_0,\n", " total_train_time_model_1,\n", " total_train_time_model_2]\n", "compare_results" ] }, { "cell_type": "markdown", "id": "fbbe5832-1081-4c76-8d5b-06c7a06da7b9", "metadata": { "id": "fbbe5832-1081-4c76-8d5b-06c7a06da7b9" }, "source": [ "It looks like our CNN (`FashionMNISTModelV2`) model performed the best (lowest loss, highest accuracy) but had the longest training time.\n", "\n", "And our baseline model (`FashionMNISTModelV0`) performed better than `model_1` (`FashionMNISTModelV1`).\n", "\n", "### Performance-speed tradeoff\n", "\n", "Something to be aware of in machine learning is the **performance-speed** tradeoff.\n", "\n", "Generally, you get better performance out of a larger, more complex model (like we did with `model_2`).\n", "\n", "However, this performance increase often comes at a sacrifice of training speed and inference speed.\n", "\n", "> **Note:** The training times you get will be very dependant on the hardware you use. \n", ">\n", "> Generally, the more CPU cores you have, the faster your models will train on CPU. And similar for GPUs.\n", "> \n", "> Newer hardware (in terms of age) will also often train models faster due to incorporating technology advances.\n", "\n", "How about we get visual?" ] }, { "cell_type": "code", "execution_count": 45, "id": "5eb0df60-9318-47d0-adce-f8788ed3999e", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 449 }, "id": "5eb0df60-9318-47d0-adce-f8788ed3999e", "outputId": "d904e039-1544-49f5-bed7-b17555d03b5a" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize our model results\n", "compare_results.set_index(\"model_name\")[\"model_acc\"].plot(kind=\"barh\")\n", "plt.xlabel(\"accuracy (%)\")\n", "plt.ylabel(\"model\");" ] }, { "cell_type": "markdown", "id": "0ba50d51-adb3-4e49-9b9a-85173e747352", "metadata": { "id": "0ba50d51-adb3-4e49-9b9a-85173e747352" }, "source": [ "## 9. Make and evaluate random predictions with best model\n", "\n", "Alright, we've compared our models to each other, let's further evaluate our best performing model, `model_2`.\n", "\n", "To do so, let's create a function `make_predictions()` where we can pass the model and some data for it to predict on." ] }, { "cell_type": "code", "execution_count": 46, "id": "d1d5d3e7-9601-4141-8bd7-9abbd016bf6c", "metadata": { "id": "d1d5d3e7-9601-4141-8bd7-9abbd016bf6c" }, "outputs": [], "source": [ "def make_predictions(model: torch.nn.Module, data: list, device: torch.device = device):\n", " pred_probs = []\n", " model.eval()\n", " with torch.inference_mode():\n", " for sample in data:\n", " # Prepare sample\n", " sample = torch.unsqueeze(sample, dim=0).to(device) # Add an extra dimension and send sample to device\n", "\n", " # Forward pass (model outputs raw logit)\n", " pred_logit = model(sample)\n", "\n", " # Get prediction probability (logit -> prediction probability)\n", " pred_prob = torch.softmax(pred_logit.squeeze(), dim=0) # note: perform softmax on the \"logits\" dimension, not \"batch\" dimension (in this case we have a batch size of 1, so can perform on dim=0)\n", "\n", " # Get pred_prob off GPU for further calculations\n", " pred_probs.append(pred_prob.cpu())\n", " \n", " # Stack the pred_probs to turn list into a tensor\n", " return torch.stack(pred_probs)" ] }, { "cell_type": "code", "execution_count": 47, "id": "420c7461-eaa9-4459-9e68-53574c758765", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "420c7461-eaa9-4459-9e68-53574c758765", "outputId": "f3dd6437-4f0f-4bc2-f9e6-d0969df63a52" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test sample image shape: torch.Size([1, 28, 28])\n", "Test sample label: 5 (Sandal)\n" ] } ], "source": [ "import random\n", "random.seed(42)\n", "test_samples = []\n", "test_labels = []\n", "for sample, label in random.sample(list(test_data), k=9):\n", " test_samples.append(sample)\n", " test_labels.append(label)\n", "\n", "# View the first test sample shape and label\n", "print(f\"Test sample image shape: {test_samples[0].shape}\\nTest sample label: {test_labels[0]} ({class_names[test_labels[0]]})\")" ] }, { "cell_type": "code", "execution_count": 48, "id": "1DYqA0r4SkrV", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1DYqA0r4SkrV", "outputId": "97bc573d-b39b-4eb2-caad-0379257b555e" }, "outputs": [ { "data": { "text/plain": [ "tensor([[2.4012e-07, 6.5406e-08, 4.8069e-08, 2.1070e-07, 1.4175e-07, 9.9992e-01,\n", " 2.1711e-07, 1.6177e-05, 3.7849e-05, 2.7548e-05],\n", " [1.5646e-02, 8.9752e-01, 3.6928e-04, 6.7402e-02, 1.2920e-02, 4.9539e-05,\n", " 5.6485e-03, 1.9456e-04, 2.0808e-04, 3.7861e-05]])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Make predictions on test samples with model 2\n", "pred_probs= make_predictions(model=model_2, \n", " data=test_samples)\n", "\n", "# View first two prediction probabilities list\n", "pred_probs[:2]" ] }, { "cell_type": "markdown", "id": "e9f40dd9-7987-42a9-84cc-65dc912a6345", "metadata": { "id": "e9f40dd9-7987-42a9-84cc-65dc912a6345" }, "source": [ "And now we can use our `make_predictions()` function to predict on `test_samples`." ] }, { "cell_type": "code", "execution_count": 49, "id": "79de2ac1-7d4b-4f81-ae8a-90099bca2a3d", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "79de2ac1-7d4b-4f81-ae8a-90099bca2a3d", "outputId": "918b07bc-4545-4401-84d5-8796ff5acf4c" }, "outputs": [ { "data": { "text/plain": [ "tensor([[2.4012e-07, 6.5406e-08, 4.8069e-08, 2.1070e-07, 1.4175e-07, 9.9992e-01,\n", " 2.1711e-07, 1.6177e-05, 3.7849e-05, 2.7548e-05],\n", " [1.5646e-02, 8.9752e-01, 3.6928e-04, 6.7402e-02, 1.2920e-02, 4.9539e-05,\n", " 5.6485e-03, 1.9456e-04, 2.0808e-04, 3.7861e-05]])" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Make predictions on test samples with model 2\n", "pred_probs= make_predictions(model=model_2, \n", " data=test_samples)\n", "\n", "# View first two prediction probabilities list\n", "pred_probs[:2]" ] }, { "cell_type": "markdown", "id": "22d3c080-4eb6-4b5d-a5c4-2319e78228af", "metadata": { "id": "22d3c080-4eb6-4b5d-a5c4-2319e78228af" }, "source": [ "Excellent!\n", "\n", "And now we can go from prediction probabilities to prediction labels by taking the `torch.argmax()` of the output of the `torch.softmax()` activation function." ] }, { "cell_type": "code", "execution_count": 50, "id": "f9d97bcc-4310-4851-a1f8-6bcd757e9b26", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f9d97bcc-4310-4851-a1f8-6bcd757e9b26", "outputId": "9d0f0bf9-a641-45e7-af77-6621fd1cfcc4" }, "outputs": [ { "data": { "text/plain": [ "tensor([5, 1, 7, 4, 3, 0, 4, 7, 1])" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Turn the prediction probabilities into prediction labels by taking the argmax()\n", "pred_classes = pred_probs.argmax(dim=1)\n", "pred_classes" ] }, { "cell_type": "code", "execution_count": 51, "id": "1141af97-0990-4920-83d4-c13cca3f9abc", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1141af97-0990-4920-83d4-c13cca3f9abc", "outputId": "c69cddd4-bbe9-495e-d477-6ea0a6c7d8de" }, "outputs": [ { "data": { "text/plain": [ "([5, 1, 7, 4, 3, 0, 4, 7, 1], tensor([5, 1, 7, 4, 3, 0, 4, 7, 1]))" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Are our predictions in the same form as our test labels? \n", "test_labels, pred_classes" ] }, { "cell_type": "markdown", "id": "4ea04387-c9ad-424f-8297-defd7b685683", "metadata": { "id": "4ea04387-c9ad-424f-8297-defd7b685683" }, "source": [ "Now our predicted classes are in the same format as our test labels, we can compare.\n", "\n", "Since we're dealing with image data, let's stay true to the data explorer's motto. \n", "\n", "\"Visualize, visualize, visualize!\"" ] }, { "cell_type": "code", "execution_count": 52, "id": "679cb5f7-bb66-42dd-a4d6-400b27b7c019", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 749 }, "id": "679cb5f7-bb66-42dd-a4d6-400b27b7c019", "outputId": "3aae0abe-9c19-4054-d8db-7e00403666aa" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot predictions\n", "plt.figure(figsize=(9, 9))\n", "nrows = 3\n", "ncols = 3\n", "for i, sample in enumerate(test_samples):\n", " # Create a subplot\n", " plt.subplot(nrows, ncols, i+1)\n", "\n", " # Plot the target image\n", " plt.imshow(sample.squeeze(), cmap=\"gray\")\n", "\n", " # Find the prediction label (in text form, e.g. \"Sandal\")\n", " pred_label = class_names[pred_classes[i]]\n", "\n", " # Get the truth label (in text form, e.g. \"T-shirt\")\n", " truth_label = class_names[test_labels[i]] \n", "\n", " # Create the title text of the plot\n", " title_text = f\"Pred: {pred_label} | Truth: {truth_label}\"\n", " \n", " # Check for equality and change title colour accordingly\n", " if pred_label == truth_label:\n", " plt.title(title_text, fontsize=10, c=\"g\") # green text if correct\n", " else:\n", " plt.title(title_text, fontsize=10, c=\"r\") # red text if wrong\n", " plt.axis(False);" ] }, { "cell_type": "markdown", "id": "5ce6dc44-90a5-48c3-91a5-810fa084d98b", "metadata": { "id": "5ce6dc44-90a5-48c3-91a5-810fa084d98b" }, "source": [ "Well, well, well, doesn't that look good!\n", "\n", "Not bad for a couple dozen lines of PyTorch code!" ] }, { "cell_type": "markdown", "id": "ab108078-6770-4cb9-ac62-a761ff159aba", "metadata": { "id": "ab108078-6770-4cb9-ac62-a761ff159aba" }, "source": [ "## 10. Making a confusion matrix for further prediction evaluation\n", "\n", "There are many [different evaluation metrics](https://www.learnpytorch.io/02_pytorch_classification/#9-more-classification-evaluation-metrics) we can use for classification problems. \n", "\n", "One of the most visual is a [confusion matrix](https://www.dataschool.io/simple-guide-to-confusion-matrix-terminology/).\n", "\n", "A confusion matrix shows you where your classification model got confused between predicitons and true labels.\n", "\n", "To make a confusion matrix, we'll go through three steps:\n", "1. Make predictions with our trained model, `model_2` (a confusion matrix compares predictions to true labels).\n", "2. Make a confusion matrix using [`torchmetrics.ConfusionMatrix`](https://torchmetrics.readthedocs.io/en/latest/references/modules.html?highlight=confusion#confusionmatrix).\n", "3. Plot the confusion matrix using [`mlxtend.plotting.plot_confusion_matrix()`](http://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/).\n", "\n", "Let's start by making predictions with our trained model." ] }, { "cell_type": "code", "execution_count": 53, "id": "065b8090-c9c5-43df-b5c1-b45ba33af1be", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "d3ab200da5f940d5b45396f83bd835e2", "f35a13b3e55342aeb24b188c1d81a9e5", "4a282c1974524bd3a7eba45fd3112129", "44d4196e99a4412f893ba8ac4672915d", "12d1a54d4107428eae2e64ff0a255c50", "4d6eb654b2794b0a95f31ac94b52a4ca", "fe5cff037f714657996f0541baee39f3", "0670e3e758e6486b9cf4e2797b4b619a", "3c590fc27b624584ba564e18bc42a2e4", "629ca5b704b84a958d4ee477907f64a1", "4d7c25dcdde8414382be4cf63a9cacf9" ] }, "id": "065b8090-c9c5-43df-b5c1-b45ba33af1be", "outputId": "92a8bee2-71f5-4504-d534-cc63138c413d" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d3ab200da5f940d5b45396f83bd835e2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Making predictions: 0%| | 0/313 [00:00 prediction probabilities -> predictions labels\n", " y_pred = torch.softmax(y_logit, dim=1).argmax(dim=1) # note: perform softmax on the \"logits\" dimension, not \"batch\" dimension (in this case we have a batch size of 32, so can perform on dim=1)\n", " # Put predictions on CPU for evaluation\n", " y_preds.append(y_pred.cpu())\n", "# Concatenate list of predictions into a tensor\n", "y_pred_tensor = torch.cat(y_preds)" ] }, { "cell_type": "markdown", "id": "362002d9-ec41-4c74-a210-b5d4f53410c4", "metadata": { "id": "362002d9-ec41-4c74-a210-b5d4f53410c4" }, "source": [ "Wonderful!\n", "\n", "Now we've got predictions, let's go through steps 2 & 3:\n", "2. Make a confusion matrix using [`torchmetrics.ConfusionMatrix`](https://torchmetrics.readthedocs.io/en/latest/references/modules.html?highlight=confusion#confusionmatrix).\n", "3. Plot the confusion matrix using [`mlxtend.plotting.plot_confusion_matrix()`](http://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/).\n", "\n", "First we'll need to make sure we've got `torchmetrics` and `mlxtend` installed (these two libraries will help us make and visual a confusion matrix).\n", "\n", "> **Note:** If you're using Google Colab, the default version of `mlxtend` installed is 0.14.0 (as of March 2022), however, for the parameters of the `plot_confusion_matrix()` function we'd like use, we need 0.19.0 or higher. " ] }, { "cell_type": "code", "execution_count": 54, "id": "e6c0a05d-d3e0-4b86-9ef7-ee6ea5629b07", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e6c0a05d-d3e0-4b86-9ef7-ee6ea5629b07", "outputId": "b37df16c-c292-4347-807c-91c97bf81f20" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.2/519.2 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m54.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hmlxtend version: 0.22.0\n" ] } ], "source": [ "# See if torchmetrics exists, if not, install it\n", "try:\n", " import torchmetrics, mlxtend\n", " print(f\"mlxtend version: {mlxtend.__version__}\")\n", " assert int(mlxtend.__version__.split(\".\")[1]) >= 19, \"mlxtend verison should be 0.19.0 or higher\"\n", "except:\n", " !pip install -q torchmetrics -U mlxtend # <- Note: If you're using Google Colab, this may require restarting the runtime\n", " import torchmetrics, mlxtend\n", " print(f\"mlxtend version: {mlxtend.__version__}\")" ] }, { "cell_type": "markdown", "id": "5245ede6-fd7f-40ad-a0b3-ae678544b84a", "metadata": { "id": "5245ede6-fd7f-40ad-a0b3-ae678544b84a" }, "source": [ "To plot the confusion matrix, we need to make sure we've got and [`mlxtend`](http://rasbt.github.io/mlxtend/) version of 0.19.0 or higher." ] }, { "cell_type": "code", "execution_count": 55, "id": "21383f88-a2dd-4678-94c6-479c592da0ab", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "21383f88-a2dd-4678-94c6-479c592da0ab", "outputId": "faffbe4c-9c86-4a20-cbd6-c7e8e48e81a5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.22.0\n" ] } ], "source": [ "# Import mlxtend upgraded version\n", "import mlxtend \n", "print(mlxtend.__version__)\n", "assert int(mlxtend.__version__.split(\".\")[1]) >= 19 # should be version 0.19.0 or higher" ] }, { "cell_type": "markdown", "id": "c91b9346-e25f-48ab-967e-425649331dc6", "metadata": { "id": "c91b9346-e25f-48ab-967e-425649331dc6" }, "source": [ "`torchmetrics` and `mlxtend` installed, let's make a confusion matrix!\n", "\n", "First we'll create a `torchmetrics.ConfusionMatrix` instance telling it how many classes we're dealing with by setting `num_classes=len(class_names)`.\n", "\n", "Then we'll create a confusion matrix (in tensor format) by passing our instance our model's predictions (`preds=y_pred_tensor`) and targets (`target=test_data.targets`).\n", "\n", "Finally we can plot our confision matrix using the `plot_confusion_matrix()` function from `mlxtend.plotting`." ] }, { "cell_type": "code", "execution_count": 56, "id": "7aed6d76-ad1c-429e-b8e0-c80572e3ebf4", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 667 }, "id": "7aed6d76-ad1c-429e-b8e0-c80572e3ebf4", "outputId": "ae34ae74-2038-4037-f01d-77a807e4de9b" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from torchmetrics import ConfusionMatrix\n", "from mlxtend.plotting import plot_confusion_matrix\n", "\n", "# 2. Setup confusion matrix instance and compare predictions to targets\n", "confmat = ConfusionMatrix(num_classes=len(class_names), task='multiclass')\n", "confmat_tensor = confmat(preds=y_pred_tensor,\n", " target=test_data.targets)\n", "\n", "# 3. Plot the confusion matrix\n", "fig, ax = plot_confusion_matrix(\n", " conf_mat=confmat_tensor.numpy(), # matplotlib likes working with NumPy \n", " class_names=class_names, # turn the row and column labels into class names\n", " figsize=(10, 7)\n", ");" ] }, { "cell_type": "markdown", "id": "381c1c93-df30-451c-b65e-5d4c1680dc30", "metadata": { "id": "381c1c93-df30-451c-b65e-5d4c1680dc30" }, "source": [ "Woah! Doesn't that look good?\n", "\n", "We can see our model does fairly well since most of the dark squares are down the diagonal from top left to bottom right (and ideal model will have only values in these squares and 0 everywhere else).\n", "\n", "The model gets most \"confused\" on classes that are similar, for example predicting \"Pullover\" for images that are actually labelled \"Shirt\".\n", "\n", "And the same for predicting \"Shirt\" for classes that are actually labelled \"T-shirt/top\".\n", "\n", "This kind of information is often more helpful than a single accuracy metric because it tells use *where* a model is getting things wrong.\n", "\n", "It also hints at *why* the model may be getting certain things wrong.\n", "\n", "It's understandable the model sometimes predicts \"Shirt\" for images labelled \"T-shirt/top\".\n", "\n", "We can use this kind of information to further inspect our models and data to see how it could be improved.\n", "\n", "> **Exercise:** Use the trained `model_2` to make predictions on the test FashionMNIST dataset. Then plot some predictions where the model was wrong alongside what the label of the image should've been. After visualing these predictions do you think it's more of a modelling error or a data error? As in, could the model do better or are the labels of the data too close to each other (e.g. a \"Shirt\" label is too close to \"T-shirt/top\")?" ] }, { "cell_type": "markdown", "id": "25818e83-89de-496d-8b56-af4fc9f2acc5", "metadata": { "id": "25818e83-89de-496d-8b56-af4fc9f2acc5" }, "source": [ "## 11. Save and load best performing model\n", "\n", "Let's finish this section off by saving and loading in our best performing model.\n", "\n", "Recall from [notebook 01](https://www.learnpytorch.io/01_pytorch_workflow/#5-saving-and-loading-a-pytorch-model) we can save and load a PyTorch model using a combination of:\n", "* `torch.save` - a function to save a whole PyTorch model or a model's `state_dict()`. \n", "* `torch.load` - a function to load in a saved PyTorch object.\n", "* `torch.nn.Module.load_state_dict()` - a function to load a saved `state_dict()` into an existing model instance.\n", "\n", "You can see more of these three in the [PyTorch saving and loading models documentation](https://pytorch.org/tutorials/beginner/saving_loading_models.html).\n", "\n", "For now, let's save our `model_2`'s `state_dict()` then load it back in and evaluate it to make sure the save and load went correctly. " ] }, { "cell_type": "code", "execution_count": 57, "id": "d058e8fa-560f-4350-a154-49593ff403c9", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "d058e8fa-560f-4350-a154-49593ff403c9", "outputId": "0156a518-dae2-4b25-999a-c0a77ef7ef7c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving model to: models/03_pytorch_computer_vision_model_2.pth\n" ] } ], "source": [ "from pathlib import Path\n", "\n", "# Create models directory (if it doesn't already exist), see: https://docs.python.org/3/library/pathlib.html#pathlib.Path.mkdir\n", "MODEL_PATH = Path(\"models\")\n", "MODEL_PATH.mkdir(parents=True, # create parent directories if needed\n", " exist_ok=True # if models directory already exists, don't error\n", ")\n", "\n", "# Create model save path\n", "MODEL_NAME = \"03_pytorch_computer_vision_model_2.pth\"\n", "MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME\n", "\n", "# Save the model state dict\n", "print(f\"Saving model to: {MODEL_SAVE_PATH}\")\n", "torch.save(obj=model_2.state_dict(), # only saving the state_dict() only saves the learned parameters\n", " f=MODEL_SAVE_PATH)" ] }, { "cell_type": "markdown", "id": "a1542284-8132-42ba-b00d-57e9b9037e4e", "metadata": { "id": "a1542284-8132-42ba-b00d-57e9b9037e4e" }, "source": [ "Now we've got a saved model `state_dict()` we can load it back in using a combination of `load_state_dict()` and `torch.load()`.\n", "\n", "Since we're using `load_state_dict()`, we'll need to create a new instance of `FashionMNISTModelV2()` with the same input parameters as our saved model `state_dict()`." ] }, { "cell_type": "code", "execution_count": 58, "id": "634a8f7a-3013-4b45-b365-49b286d3c478", "metadata": { "id": "634a8f7a-3013-4b45-b365-49b286d3c478" }, "outputs": [], "source": [ "# Create a new instance of FashionMNISTModelV2 (the same class as our saved state_dict())\n", "# Note: loading model will error if the shapes here aren't the same as the saved version\n", "loaded_model_2 = FashionMNISTModelV2(input_shape=1, \n", " hidden_units=10, # try changing this to 128 and seeing what happens \n", " output_shape=10) \n", "\n", "# Load in the saved state_dict()\n", "loaded_model_2.load_state_dict(torch.load(f=MODEL_SAVE_PATH))\n", "\n", "# Send model to GPU\n", "loaded_model_2 = loaded_model_2.to(device)" ] }, { "cell_type": "markdown", "id": "feeaebf4-6040-4fa5-852d-5eb8d2bbb94c", "metadata": { "id": "feeaebf4-6040-4fa5-852d-5eb8d2bbb94c" }, "source": [ "And now we've got a loaded model we can evaluate it with `eval_model()` to make sure its parameters work similarly to `model_2` prior to saving. " ] }, { "cell_type": "code", "execution_count": 59, "id": "3e3bcd06-d99b-47bc-8828-9e3903285599", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3e3bcd06-d99b-47bc-8828-9e3903285599", "outputId": "c0ee1d5f-9573-4e1a-8430-ee09fb4d72cd" }, "outputs": [ { "data": { "text/plain": [ "{'model_name': 'FashionMNISTModelV2',\n", " 'model_loss': 0.3285697102546692,\n", " 'model_acc': 88.37859424920129}" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate loaded model\n", "torch.manual_seed(42)\n", "\n", "loaded_model_2_results = eval_model(\n", " model=loaded_model_2,\n", " data_loader=test_dataloader,\n", " loss_fn=loss_fn, \n", " accuracy_fn=accuracy_fn\n", ")\n", "\n", "loaded_model_2_results" ] }, { "cell_type": "markdown", "id": "c2b37855-c0da-4834-a2d4-a0faa8410b65", "metadata": { "id": "c2b37855-c0da-4834-a2d4-a0faa8410b65" }, "source": [ "Do these results look the same as `model_2_results`?" ] }, { "cell_type": "code", "execution_count": 60, "id": "68544254-c99a-47ec-a32f-9816c21a993e", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "68544254-c99a-47ec-a32f-9816c21a993e", "outputId": "74b8d4ca-d35a-4f70-e8b9-ed54f034358e" }, "outputs": [ { "data": { "text/plain": [ "{'model_name': 'FashionMNISTModelV2',\n", " 'model_loss': 0.3285697102546692,\n", " 'model_acc': 88.37859424920129}" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_2_results" ] }, { "cell_type": "markdown", "id": "0ee07f93-4344-4c7a-8b1d-92a56034e7b2", "metadata": { "id": "0ee07f93-4344-4c7a-8b1d-92a56034e7b2" }, "source": [ "We can find out if two tensors are close to each other using `torch.isclose()` and passing in a tolerance level of closeness via the parameters `atol` (absolute tolerance) and `rtol` (relative tolerance).\n", "\n", "If our model's results are close, the output of `torch.isclose()` should be true." ] }, { "cell_type": "code", "execution_count": 61, "id": "48dcf0ba-7e00-4406-8aaa-41918856361a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "48dcf0ba-7e00-4406-8aaa-41918856361a", "outputId": "47324300-0d00-46de-d130-1283ad044ef8" }, "outputs": [ { "data": { "text/plain": [ "tensor(True)" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check to see if results are close to each other (if they are very far away, there may be an error)\n", "torch.isclose(torch.tensor(model_2_results[\"model_loss\"]), \n", " torch.tensor(loaded_model_2_results[\"model_loss\"]),\n", " atol=1e-08, # absolute tolerance\n", " rtol=0.0001) # relative tolerance" ] }, { "cell_type": "markdown", "id": "c3969b7d-9955-4b6f-abf8-fe8eedf233a9", "metadata": { "id": "c3969b7d-9955-4b6f-abf8-fe8eedf233a9" }, "source": [ "## Exercises\n", "\n", "All of the exercises are focused on practicing the code in the sections above.\n", "\n", "You should be able to complete them by referencing each section or by following the resource(s) linked.\n", "\n", "All exercises should be completed using [device-agnostic code](https://pytorch.org/docs/stable/notes/cuda.html#device-agnostic-code).\n", "\n", "**Resources:**\n", "* [Exercise template notebook for 03](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/exercises/03_pytorch_computer_vision_exercises.ipynb)\n", "* [Example solutions notebook for 03](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/solutions/03_pytorch_computer_vision_exercise_solutions.ipynb) (try the exercises *before* looking at this)\n", "\n", "1. What are 3 areas in industry where computer vision is currently being used?\n", "2. Search \"what is overfitting in machine learning\" and write down a sentence about what you find. \n", "3. Search \"ways to prevent overfitting in machine learning\", write down 3 of the things you find and a sentence about each. **Note:** there are lots of these, so don't worry too much about all of them, just pick 3 and start with those.\n", "4. Spend 20-minutes reading and clicking through the [CNN Explainer website](https://poloclub.github.io/cnn-explainer/).\n", " * Upload your own example image using the \"upload\" button and see what happens in each layer of a CNN as your image passes through it.\n", "5. Load the [`torchvision.datasets.MNIST()`](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST) train and test datasets.\n", "6. Visualize at least 5 different samples of the MNIST training dataset.\n", "7. Turn the MNIST train and test datasets into dataloaders using `torch.utils.data.DataLoader`, set the `batch_size=32`.\n", "8. Recreate `model_2` used in this notebook (the same model from the [CNN Explainer website](https://poloclub.github.io/cnn-explainer/), also known as TinyVGG) capable of fitting on the MNIST dataset.\n", "9. Train the model you built in exercise 8. on CPU and GPU and see how long it takes on each.\n", "10. Make predictions using your trained model and visualize at least 5 of them comparing the prediciton to the target label.\n", "11. Plot a confusion matrix comparing your model's predictions to the truth labels.\n", "12. Create a random tensor of shape `[1, 3, 64, 64]` and pass it through a `nn.Conv2d()` layer with various hyperparameter settings (these can be any settings you choose), what do you notice if the `kernel_size` parameter goes up and down?\n", "13. Use a model similar to the trained `model_2` from this notebook to make predictions on the test [`torchvision.datasets.FashionMNIST`](https://pytorch.org/vision/main/generated/torchvision.datasets.FashionMNIST.html) dataset. \n", " * Then plot some predictions where the model was wrong alongside what the label of the image should've been. \n", " * After visualing these predictions do you think it's more of a modelling error or a data error? \n", " * As in, could the model do better or are the labels of the data too close to each other (e.g. a \"Shirt\" label is too close to \"T-shirt/top\")?\n", "\n", "## Extra-curriculum\n", "* **Watch:** [MIT's Introduction to Deep Computer Vision](https://www.youtube.com/watch?v=iaSUYvmCekI&list=PLtBw6njQRU-rwp5__7C0oIVt26ZgjG9NI&index=3) lecture. This will give you a great intuition behind convolutional neural networks.\n", "* Spend 10-minutes clicking thorugh the different options of the [PyTorch vision library](https://pytorch.org/vision/stable/index.html), what different modules are available?\n", "* Lookup \"most common convolutional neural networks\", what architectures do you find? Are any of them contained within the [`torchvision.models`](https://pytorch.org/vision/stable/models.html) library? What do you think you could do with these?\n", "* For a large number of pretrained PyTorch computer vision models as well as many different extensions to PyTorch's computer vision functionalities check out the [PyTorch Image Models library `timm`](https://github.com/rwightman/pytorch-image-models/) (Torch Image Models) by Ross Wightman." ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "machine_shape": "hm", "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "3fbe1355223f7b2ffc113ba3ade6a2b520cadace5d5ec3e828c83ce02eb221bf" } } }, "nbformat": 4, "nbformat_minor": 5 }