{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"\n",
"\n",
"[View Source Code](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/02_pytorch_classification.ipynb) | [View Slides](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/slides/02_pytorch_classification.pdf) | [Watch Video Walkthrough](https://youtu.be/Z_ikDlimN6A?t=30691) "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r8C1WSzsHC7x"
},
"source": [
"# 02. PyTorch Neural Network Classification\n",
"\n",
"## What is a classification problem?\n",
"\n",
"A [classification problem](https://en.wikipedia.org/wiki/Statistical_classification) involves predicting whether something is one thing or another.\n",
"\n",
"For example, you might want to:\n",
"\n",
"| Problem type | What is it? | Example |\n",
"| ----- | ----- | ----- |\n",
"| **Binary classification** | Target can be one of two options, e.g. yes or no | Predict whether or not someone has heart disease based on their health parameters. |\n",
"| **Multi-class classification** | Target can be one of more than two options | Decide whether a photo of is of food, a person or a dog. |\n",
"| **Multi-label classification** | Target can be assigned more than one option | Predict what categories should be assigned to a Wikipedia article (e.g. mathematics, science & philosohpy). |\n",
"\n",
"
\n", " | X1 | \n", "X2 | \n", "label | \n", "
---|---|---|---|
0 | \n", "0.754246 | \n", "0.231481 | \n", "1 | \n", "
1 | \n", "-0.756159 | \n", "0.153259 | \n", "1 | \n", "
2 | \n", "-0.815392 | \n", "0.173282 | \n", "1 | \n", "
3 | \n", "-0.393731 | \n", "0.692883 | \n", "1 | \n", "
4 | \n", "0.442208 | \n", "-0.896723 | \n", "0 | \n", "
5 | \n", "-0.479646 | \n", "0.676435 | \n", "1 | \n", "
6 | \n", "-0.013648 | \n", "0.803349 | \n", "1 | \n", "
7 | \n", "0.771513 | \n", "0.147760 | \n", "1 | \n", "
8 | \n", "-0.169322 | \n", "-0.793456 | \n", "1 | \n", "
9 | \n", "-0.121486 | \n", "1.021509 | \n", "0 | \n", "
forward()
function\n",
" calculations (model(x_train)
).\n",
" loss = loss_fn(y_pred, y_train
).optimizer.zero_grad()
).requires_grad=True
). This is known as backpropagation, hence \"backwards\"\n",
" (loss.backward()
).requires_grad=True
\n",
" with respect to the loss\n",
" gradients in order to improve them (optimizer.step()
).