02. Neural Network Classification with TensorFlow¶
Okay, we've seen how to deal with a regression problem in TensorFlow, let's look at how we can approach a classification problem.
A classification problem involves predicting whether something is one thing or another.
For example, you might want to:
- Predict whether or not someone has heart disease based on their health parameters. This is called binary classification since there are only two options.
- Decide whether a photo of is of food, a person or a dog. This is called multi-class classification since there are more than two options.
- Predict what categories should be assigned to a Wikipedia article. This is called multi-label classification since a single article could have more than one category assigned.
In this notebook, we're going to work through a number of different classification problems with TensorFlow. In other words, taking a set of inputs and predicting what class those set of inputs belong to.
What we're going to cover¶
Specifically, we're going to go through doing the following with TensorFlow:
- Architecture of a classification model
- Input shapes and output shapes
X
: features/data (inputs)y
: labels (outputs)- "What class do the inputs belong to?"
- Creating custom data to view and fit
- Steps in modelling for binary and mutliclass classification
- Creating a model
- Compiling a model
- Defining a loss function
- Setting up an optimizer
- Finding the best learning rate
- Creating evaluation metrics
- Fitting a model (getting it to find patterns in our data)
- Improving a model
- The power of non-linearity
- Evaluating classification models
- Visualizng the model ("visualize, visualize, visualize")
- Looking at training curves
- Compare predictions to ground truth (using our evaluation metrics)
How you can use this notebook¶
You can read through the descriptions and the code (it should all run, except for the cells which error on purpose), but there's a better option.
Write all of the code yourself.
Yes. I'm serious. Create a new notebook, and rewrite each line by yourself. Investigate it, see if you can break it, why does it break?
You don't have to write the text descriptions but writing the code yourself is a great way to get hands-on experience.
Don't worry if you make mistakes, we all do. The way to get better and make less mistakes is to write more code.
Typical architecture of a classification neural network¶
The word typical is on purpose.
Because the architecture of a classification neural network can widely vary depending on the problem you're working on.
However, there are some fundamentals all deep neural networks contain:
- An input layer.
- Some hidden layers.
- An output layer.
Much of the rest is up to the data analyst creating the model.
The following are some standard values you'll often use in your classification neural networks.
Hyperparameter | Binary Classification | Multiclass classification |
---|---|---|
Input layer shape | Same as number of features (e.g. 5 for age, sex, height, weight, smoking status in heart disease prediction) | Same as binary classification |
Hidden layer(s) | Problem specific, minimum = 1, maximum = unlimited | Same as binary classification |
Neurons per hidden layer | Problem specific, generally 10 to 100 | Same as binary classification |
Output layer shape | 1 (one class or the other) | 1 per class (e.g. 3 for food, person or dog photo) |
Hidden activation | Usually ReLU (rectified linear unit) | Same as binary classification |
Output activation | Sigmoid | Softmax |
Loss function | Cross entropy (tf.keras.losses.BinaryCrossentropy in TensorFlow) |
Cross entropy (tf.keras.losses.CategoricalCrossentropy in TensorFlow) |
Optimizer | SGD (stochastic gradient descent), Adam | Same as binary classification |
Table 1: Typical architecture of a classification network. Source: Adapted from page 295 of Hands-On Machine Learning with Scikit-Learn, Keras & TensorFlow Book by Aurélien Géron
Don't worry if not much of the above makes sense right now, we'll get plenty of experience as we go through this notebook.
Let's start by importing TensorFlow as the common alias tf
. For this notebook, make sure you're using version 2.x+.
import tensorflow as tf
print(tf.__version__)
import datetime
print(f"Notebook last run (end-to-end): {datetime.datetime.now()}")
2.12.0 Notebook last run (end-to-end): 2023-05-11 03:26:50.047328
Creating data to view and fit¶
We could start by importing a classification dataset but let's practice making some of our own classification data.
🔑 Note: It's a common practice to get you and model you build working on a toy (or simple) dataset before moving to your actual problem. Treat it as a rehersal experiment before the actual experiment(s).
Since classification is predicting whether something is one thing or another, let's make some data to reflect that.
To do so, we'll use Scikit-Learn's make_circles()
function.
from sklearn.datasets import make_circles
# Make 1000 examples
n_samples = 1000
# Create circles
X, y = make_circles(n_samples,
noise=0.03,
random_state=42)
Wonderful, now we've created some data, let's look at the features (X
) and labels (y
).
# Check out the features
X
array([[ 0.75424625, 0.23148074], [-0.75615888, 0.15325888], [-0.81539193, 0.17328203], ..., [-0.13690036, -0.81001183], [ 0.67036156, -0.76750154], [ 0.28105665, 0.96382443]])
# See the first 10 labels
y[:10]
array([1, 1, 1, 1, 0, 1, 1, 1, 1, 0])
Okay, we've seen some of our data and labels, how about we move towards visualizing?
🔑 Note: One important step of starting any kind of machine learning project is to become one with the data. And one of the best ways to do this is to visualize the data you're working with as much as possible. The data explorer's motto is "visualize, visualize, visualize".
We'll start with a DataFrame.
# Make dataframe of features and labels
import pandas as pd
circles = pd.DataFrame({"X0":X[:, 0], "X1":X[:, 1], "label":y})
circles.head()
X0 | X1 | label | |
---|---|---|---|
0 | 0.754246 | 0.231481 | 1 |
1 | -0.756159 | 0.153259 | 1 |
2 | -0.815392 | 0.173282 | 1 |
3 | -0.393731 | 0.692883 | 1 |
4 | 0.442208 | -0.896723 | 0 |
What kind of labels are we dealing with?
# Check out the different labels
circles.label.value_counts()
1 500 0 500 Name: label, dtype: int64
Alright, looks like we're dealing with a binary classification problem. It's binary because there are only two labels (0 or 1).
If there were more label options (e.g. 0, 1, 2, 3 or 4), it would be called multiclass classification.
Let's take our visualization a step further and plot our data.
# Visualize with a plot
import matplotlib.pyplot as plt
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdYlBu);
Nice! From the plot, can you guess what kind of model we might want to build?
How about we try and build one to classify blue or red dots? As in, a model which is able to distinguish blue from red dots.
🛠 Practice: Before pushing forward, you might want to spend 10 minutes playing around with the TensorFlow Playground. Try adjusting the different hyperparameters you see and click play to see a neural network train. I think you'll find the data very similar to what we've just created.
Input and output shapes¶
One of the most common issues you'll run into when building neural networks is shape mismatches.
More specifically, the shape of the input data and the shape of the output data.
In our case, we want to input X
and get our model to predict y
.
So let's check out the shapes of X
and y
.
# Check the shapes of our features and labels
X.shape, y.shape
((1000, 2), (1000,))
Hmm, where do these numbers come from?
# Check how many samples we have
len(X), len(y)
(1000, 1000)
So we've got as many X
values as we do y
values, that makes sense.
Let's check out one example of each.
# View the first example of features and labels
X[0], y[0]
(array([0.75424625, 0.23148074]), 1)
Alright, so we've got two X
features which lead to one y
value.
This means our neural network input shape will has to accept a tensor with at least one dimension being two and output a tensor with at least one value.
🤔 Note:
y
having a shape of (1000,) can seem confusing. However, this is because ally
values are actually scalars (single values) and therefore don't have a dimension. For now, think of your output shape as being at least the same value as one example ofy
(in our case, the output from our neural network has to be at least one value).
Steps in modelling¶
Now we know what data we have as well as the input and output shapes, let's see how we'd build a neural network to model it.
In TensorFlow, there are typically 3 fundamental steps to creating and training a model.
- Creating a model - piece together the layers of a neural network yourself (using the functional or sequential API) or import a previously built model (known as transfer learning).
- Compiling a model - defining how a model's performance should be measured (loss/metrics) as well as defining how it should improve (optimizer).
- Fitting a model - letting the model try to find patterns in the data (how does
X
get toy
).
Let's see these in action using the Sequential API to build a model for our regression data. And then we'll step through each.
# Set random seed
tf.random.set_seed(42)
# 1. Create the model using the Sequential API
model_1 = tf.keras.Sequential([
tf.keras.layers.Dense(1)
])
# 2. Compile the model
model_1.compile(loss=tf.keras.losses.BinaryCrossentropy(), # binary since we are working with 2 clases (0 & 1)
optimizer=tf.keras.optimizers.SGD(),
metrics=['accuracy'])
# 3. Fit the model
model_1.fit(X, y, epochs=5)
Epoch 1/5 32/32 [==============================] - 5s 5ms/step - loss: 3.9204 - accuracy: 0.4810 Epoch 2/5 32/32 [==============================] - 0s 4ms/step - loss: 0.7762 - accuracy: 0.4910 Epoch 3/5 32/32 [==============================] - 0s 4ms/step - loss: 0.7171 - accuracy: 0.4910 Epoch 4/5 32/32 [==============================] - 0s 4ms/step - loss: 0.7014 - accuracy: 0.4930 Epoch 5/5 32/32 [==============================] - 0s 5ms/step - loss: 0.6961 - accuracy: 0.4900
<keras.callbacks.History at 0x7f5d8c0f9360>
Looking at the accuracy metric, our model performs poorly (50% accuracy on a binary classification problem is the equivalent of guessing), but what if we trained it for longer?
# Train our model for longer (more chances to look at the data)
model_1.fit(X, y, epochs=200, verbose=0) # set verbose=0 to remove training updates
model_1.evaluate(X, y)
32/32 [==============================] - 0s 2ms/step - loss: 0.6935 - accuracy: 0.5000
[0.6934829950332642, 0.5]
Even after 200 passes of the data, it's still performing as if it's guessing.
What if we added an extra layer and trained for a little longer?
# Set random seed
tf.random.set_seed(42)
# 1. Create the model (same as model_1 but with an extra layer)
model_2 = tf.keras.Sequential([
tf.keras.layers.Dense(1), # add an extra layer
tf.keras.layers.Dense(1)
])
# 2. Compile the model
model_2.compile(loss=tf.keras.losses.BinaryCrossentropy(),
optimizer=tf.keras.optimizers.SGD(),
metrics=['accuracy'])
# 3. Fit the model
model_2.fit(X, y, epochs=100, verbose=0) # set verbose=0 to make the output print less
<keras.callbacks.History at 0x7f5d008171f0>
# Evaluate the model
model_2.evaluate(X, y)
32/32 [==============================] - 0s 2ms/step - loss: 0.6933 - accuracy: 0.5000
[0.693259596824646, 0.5]
Still not even as good as guessing (~50% accuracy)... hmm...?
Let's remind ourselves of a couple more ways we can use to improve our models.
Improving a model¶
To improve our model, we can alter almost every part of the 3 steps we went through before.
- Creating a model - here you might want to add more layers, increase the number of hidden units (also called neurons) within each layer, change the activation functions of each layer.
- Compiling a model - you might want to choose a different optimization function (such as the Adam optimizer, which is usually pretty good for many problems) or perhaps change the learning rate of the optimization function.
- Fitting a model - perhaps you could fit a model for more epochs (leave it training for longer).
There are many different ways to potentially improve a neural network. Some of the most common include: increasing the number of layers (making the network deeper), increasing the number of hidden units (making the network wider) and changing the learning rate. Because these values are all human-changeable, they're referred to as hyperparameters) and the practice of trying to find the best hyperparameters is referred to as hyperparameter tuning.
How about we try adding more neurons, an extra layer and our friend the Adam optimizer?
Surely doing this will result in predictions better than guessing...
Note: The following message (below this one) can be ignored if you're running TensorFlow 2.8.0+, the error seems to have been fixed.
Note: If you're using TensorFlow 2.7.0+ (but not 2.8.0+) the original code from the following cells may have caused some errors. They've since been updated to fix those errors. You can see explanations on what happened at the following resources:
# Set random seed
tf.random.set_seed(42)
# 1. Create the model (this time 3 layers)
model_3 = tf.keras.Sequential([
# Before TensorFlow 2.7.0
# tf.keras.layers.Dense(100), # add 100 dense neurons
# With TensorFlow 2.7.0
# tf.keras.layers.Dense(100, input_shape=(None, 1)), # add 100 dense neurons
## After TensorFlow 2.8.0 ##
tf.keras.layers.Dense(100), # add 100 dense neurons
tf.keras.layers.Dense(10), # add another layer with 10 neurons
tf.keras.layers.Dense(1)
])
# 2. Compile the model
model_3.compile(loss=tf.keras.losses.BinaryCrossentropy(),
optimizer=tf.keras.optimizers.Adam(), # use Adam instead of SGD
metrics=['accuracy'])
# 3. Fit the model
model_3.fit(X, y, epochs=100, verbose=1) # fit for 100 passes of the data
Epoch 1/100 32/32 [==============================] - 2s 3ms/step - loss: 3.7522 - accuracy: 0.4440 Epoch 2/100 32/32 [==============================] - 0s 3ms/step - loss: 1.9375 - accuracy: 0.4660 Epoch 3/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7301 - accuracy: 0.5000 Epoch 4/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7154 - accuracy: 0.5000 Epoch 5/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7062 - accuracy: 0.5000 Epoch 6/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7002 - accuracy: 0.5000 Epoch 7/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6967 - accuracy: 0.5000 Epoch 8/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6953 - accuracy: 0.5000 Epoch 9/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6942 - accuracy: 0.5000 Epoch 10/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6939 - accuracy: 0.5000 Epoch 11/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6939 - accuracy: 0.5000 Epoch 12/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6938 - accuracy: 0.4820 Epoch 13/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6939 - accuracy: 0.4660 Epoch 14/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6940 - accuracy: 0.4920 Epoch 15/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6938 - accuracy: 0.4740 Epoch 16/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6937 - accuracy: 0.4880 Epoch 17/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6947 - accuracy: 0.4920 Epoch 18/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6937 - accuracy: 0.4840 Epoch 19/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6939 - accuracy: 0.4810 Epoch 20/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6940 - accuracy: 0.4730 Epoch 21/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6939 - accuracy: 0.4710 Epoch 22/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6937 - accuracy: 0.4390 Epoch 23/100 32/32 [==============================] - 0s 4ms/step - loss: 0.6937 - accuracy: 0.4650 Epoch 24/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6948 - accuracy: 0.4910 Epoch 25/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6950 - accuracy: 0.4890 Epoch 26/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6940 - accuracy: 0.5030 Epoch 27/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6947 - accuracy: 0.5170 Epoch 28/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6936 - accuracy: 0.5170 Epoch 29/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6951 - accuracy: 0.4610 Epoch 30/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6949 - accuracy: 0.4840 Epoch 31/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6947 - accuracy: 0.4880 Epoch 32/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6945 - accuracy: 0.4910 Epoch 33/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6940 - accuracy: 0.4700 Epoch 34/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6947 - accuracy: 0.4810 Epoch 35/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6947 - accuracy: 0.5090 Epoch 36/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6937 - accuracy: 0.4800 Epoch 37/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6940 - accuracy: 0.4920 Epoch 38/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6951 - accuracy: 0.4730 Epoch 39/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6938 - accuracy: 0.4810 Epoch 40/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6951 - accuracy: 0.4910 Epoch 41/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6957 - accuracy: 0.4910 Epoch 42/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6953 - accuracy: 0.4690 Epoch 43/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6944 - accuracy: 0.5110 Epoch 44/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6950 - accuracy: 0.4740 Epoch 45/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6939 - accuracy: 0.4940 Epoch 46/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6939 - accuracy: 0.4890 Epoch 47/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6954 - accuracy: 0.4950 Epoch 48/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6947 - accuracy: 0.4960 Epoch 49/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6953 - accuracy: 0.4670 Epoch 50/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6942 - accuracy: 0.4640 Epoch 51/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6950 - accuracy: 0.4910 Epoch 52/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6947 - accuracy: 0.5220 Epoch 53/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6968 - accuracy: 0.4900 Epoch 54/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6953 - accuracy: 0.5130 Epoch 55/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6953 - accuracy: 0.5120 Epoch 56/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6958 - accuracy: 0.4780 Epoch 57/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6956 - accuracy: 0.4780 Epoch 58/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6952 - accuracy: 0.5060 Epoch 59/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6974 - accuracy: 0.5180 Epoch 60/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6965 - accuracy: 0.4860 Epoch 61/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6965 - accuracy: 0.4280 Epoch 62/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6952 - accuracy: 0.4800 Epoch 63/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6944 - accuracy: 0.4780 Epoch 64/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6955 - accuracy: 0.5020 Epoch 65/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6964 - accuracy: 0.4720 Epoch 66/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6950 - accuracy: 0.5040 Epoch 67/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6960 - accuracy: 0.4550 Epoch 68/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6951 - accuracy: 0.4810 Epoch 69/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6945 - accuracy: 0.5330 Epoch 70/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6967 - accuracy: 0.4660 Epoch 71/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6949 - accuracy: 0.4620 Epoch 72/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6954 - accuracy: 0.5060 Epoch 73/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6958 - accuracy: 0.5010 Epoch 74/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6953 - accuracy: 0.4980 Epoch 75/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6957 - accuracy: 0.5110 Epoch 76/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6966 - accuracy: 0.4610 Epoch 77/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6954 - accuracy: 0.4960 Epoch 78/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6961 - accuracy: 0.4590 Epoch 79/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6977 - accuracy: 0.5000 Epoch 80/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6951 - accuracy: 0.5290 Epoch 81/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6969 - accuracy: 0.4970 Epoch 82/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6957 - accuracy: 0.4960 Epoch 83/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6978 - accuracy: 0.4510 Epoch 84/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6949 - accuracy: 0.4980 Epoch 85/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6961 - accuracy: 0.4780 Epoch 86/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6981 - accuracy: 0.4790 Epoch 87/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6956 - accuracy: 0.4580 Epoch 88/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6969 - accuracy: 0.4590 Epoch 89/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6969 - accuracy: 0.4760 Epoch 90/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6967 - accuracy: 0.4410 Epoch 91/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6951 - accuracy: 0.4990 Epoch 92/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6973 - accuracy: 0.4870 Epoch 93/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6981 - accuracy: 0.4580 Epoch 94/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6970 - accuracy: 0.4990 Epoch 95/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6966 - accuracy: 0.4680 Epoch 96/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6962 - accuracy: 0.4980 Epoch 97/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6963 - accuracy: 0.4620 Epoch 98/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6966 - accuracy: 0.4900 Epoch 99/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6968 - accuracy: 0.5280 Epoch 100/100 32/32 [==============================] - 0s 3ms/step - loss: 0.6968 - accuracy: 0.4710
<keras.callbacks.History at 0x7f5d00280220>
Still!
We've pulled out a few tricks but our model isn't even doing better than guessing.
Let's make some visualizations to see what's happening.
🔑 Note: Whenever your model is performing strangely or there's something going on with your data you're not quite sure of, remember these three words: visualize, visualize, visualize. Inspect your data, inspect your model, inpsect your model's predictions.
To visualize our model's predictions we're going to create a function plot_decision_boundary()
which:
- Takes in a trained model, features (
X
) and labels (y
). - Creates a meshgrid of the different
X
values. - Makes predictions across the meshgrid.
- Plots the predictions as well as a line between the different zones (where each unique class falls).
If this sounds confusing, let's see it in code and then see the output.
🔑 Note: If you're ever unsure of what a function does, try unraveling it and writing it line by line for yourself to see what it does. Break it into small parts and see what each part outputs.
import numpy as np
def plot_decision_boundary(model, X, y):
"""
Plots the decision boundary created by a model predicting on X.
This function has been adapted from two phenomenal resources:
1. CS231n - https://cs231n.github.io/neural-networks-case-study/
2. Made with ML basics - https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/08_Neural_Networks.ipynb
"""
# Define the axis boundaries of the plot and create a meshgrid
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
np.linspace(y_min, y_max, 100))
# Create X values (we're going to predict on all of these)
x_in = np.c_[xx.ravel(), yy.ravel()] # stack 2D arrays together: https://numpy.org/devdocs/reference/generated/numpy.c_.html
# Make predictions using the trained model
y_pred = model.predict(x_in)
# Check for multi-class
if model.output_shape[-1] > 1: # checks the final dimension of the model's output shape, if this is > (greater than) 1, it's multi-class
print("doing multiclass classification...")
# We have to reshape our predictions to get them ready for plotting
y_pred = np.argmax(y_pred, axis=1).reshape(xx.shape)
else:
print("doing binary classifcation...")
y_pred = np.round(np.max(y_pred, axis=1)).reshape(xx.shape)
# Plot decision boundary
plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
Now we've got a function to plot our model's decision boundary (the cut off point its making between red and blue dots), let's try it out.
# Check out the predictions our model is making
plot_decision_boundary(model_3, X, y)
313/313 [==============================] - 0s 1ms/step doing binary classifcation...
Looks like our model is trying to draw a straight line through the data.
What's wrong with doing this?
The main issue is our data isn't separable by a straight line.
In a regression problem, our model might work. In fact, let's try it.
# Set random seed
tf.random.set_seed(42)
# Create some regression data
X_regression = np.arange(0, 1000, 5)
y_regression = np.arange(100, 1100, 5)
# Split it into training and test sets
X_reg_train = X_regression[:150]
X_reg_test = X_regression[150:]
y_reg_train = y_regression[:150]
y_reg_test = y_regression[150:]
# Fit our model to the data
# Note: Before TensorFlow 2.7.0, this line would work
# model_3.fit(X_reg_train, y_reg_train, epochs=100)
# After TensorFlow 2.7.0, see here for more: https://github.com/mrdbourke/tensorflow-deep-learning/discussions/278
model_3.fit(tf.expand_dims(X_reg_train, axis=-1),
y_reg_train,
epochs=100)
Epoch 1/100
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-19-5dd5867236b4> in <cell line: 19>() 17 18 # After TensorFlow 2.7.0, see here for more: https://github.com/mrdbourke/tensorflow-deep-learning/discussions/278 ---> 19 model_3.fit(tf.expand_dims(X_reg_train, axis=-1), 20 y_reg_train, 21 epochs=100) /usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs) 68 # To get the full stack trace, call: 69 # `tf.debugging.disable_traceback_filtering()` ---> 70 raise e.with_traceback(filtered_tb) from None 71 finally: 72 del filtered_tb /usr/local/lib/python3.10/dist-packages/keras/engine/training.py in tf__train_function(iterator) 13 try: 14 do_return = True ---> 15 retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope) 16 except: 17 do_return = False ValueError: in user code: File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 1284, in train_function * return step_function(self, iterator) File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 1268, in step_function ** outputs = model.distribute_strategy.run(run_step, args=(data,)) File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 1249, in run_step ** outputs = model.train_step(data) File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 1050, in train_step y_pred = self(x, training=True) File "/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler raise e.with_traceback(filtered_tb) from None File "/usr/local/lib/python3.10/dist-packages/keras/engine/input_spec.py", line 280, in assert_input_compatibility raise ValueError( ValueError: Exception encountered when calling layer 'sequential_2' (type Sequential). Input 0 of layer "dense_3" is incompatible with the layer: expected axis -1 of input shape to have value 2, but received input with shape (None, 1) Call arguments received by layer 'sequential_2' (type Sequential): • inputs=tf.Tensor(shape=(None, 1), dtype=int64) • training=True • mask=None
model_3.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_3 (Dense) (None, 100) 300 dense_4 (Dense) (None, 10) 1010 dense_5 (Dense) (None, 1) 11 ================================================================= Total params: 1,321 Trainable params: 1,321 Non-trainable params: 0 _________________________________________________________________
Oh wait... we compiled our model for a binary classification problem.
No trouble, we can recreate it for a regression problem.
# Setup random seed
tf.random.set_seed(42)
# Recreate the model
model_3 = tf.keras.Sequential([
tf.keras.layers.Dense(100),
tf.keras.layers.Dense(10),
tf.keras.layers.Dense(1)
])
# Change the loss and metrics of our compiled model
model_3.compile(loss=tf.keras.losses.mae, # change the loss function to be regression-specific
optimizer=tf.keras.optimizers.Adam(),
metrics=['mae']) # change the metric to be regression-specific
# Fit the recompiled model
model_3.fit(tf.expand_dims(X_reg_train, axis=-1),
y_reg_train,
epochs=100)
Epoch 1/100 5/5 [==============================] - 1s 6ms/step - loss: 351.1467 - mae: 351.1467 Epoch 2/100 5/5 [==============================] - 0s 4ms/step - loss: 229.6038 - mae: 229.6038 Epoch 3/100 5/5 [==============================] - 0s 4ms/step - loss: 112.1480 - mae: 112.1480 Epoch 4/100 5/5 [==============================] - 0s 4ms/step - loss: 51.5032 - mae: 51.5032 Epoch 5/100 5/5 [==============================] - 0s 4ms/step - loss: 79.7073 - mae: 79.7073 Epoch 6/100 5/5 [==============================] - 0s 4ms/step - loss: 64.2647 - mae: 64.2647 Epoch 7/100 5/5 [==============================] - 0s 4ms/step - loss: 43.0651 - mae: 43.0651 Epoch 8/100 5/5 [==============================] - 0s 4ms/step - loss: 50.1065 - mae: 50.1065 Epoch 9/100 5/5 [==============================] - 0s 4ms/step - loss: 41.9964 - mae: 41.9964 Epoch 10/100 5/5 [==============================] - 0s 4ms/step - loss: 45.2621 - mae: 45.2621 Epoch 11/100 5/5 [==============================] - 0s 4ms/step - loss: 42.6393 - mae: 42.6393 Epoch 12/100 5/5 [==============================] - 0s 4ms/step - loss: 42.2500 - mae: 42.2500 Epoch 13/100 5/5 [==============================] - 0s 3ms/step - loss: 41.2409 - mae: 41.2409 Epoch 14/100 5/5 [==============================] - 0s 3ms/step - loss: 41.9115 - mae: 41.9115 Epoch 15/100 5/5 [==============================] - 0s 3ms/step - loss: 41.1004 - mae: 41.1004 Epoch 16/100 5/5 [==============================] - 0s 4ms/step - loss: 41.4659 - mae: 41.4659 Epoch 17/100 5/5 [==============================] - 0s 4ms/step - loss: 41.3447 - mae: 41.3447 Epoch 18/100 5/5 [==============================] - 0s 4ms/step - loss: 41.2736 - mae: 41.2736 Epoch 19/100 5/5 [==============================] - 0s 4ms/step - loss: 41.0783 - mae: 41.0783 Epoch 20/100 5/5 [==============================] - 0s 4ms/step - loss: 41.1026 - mae: 41.1026 Epoch 21/100 5/5 [==============================] - 0s 4ms/step - loss: 41.0839 - mae: 41.0839 Epoch 22/100 5/5 [==============================] - 0s 4ms/step - loss: 41.1524 - mae: 41.1524 Epoch 23/100 5/5 [==============================] - 0s 4ms/step - loss: 41.0649 - mae: 41.0649 Epoch 24/100 5/5 [==============================] - 0s 4ms/step - loss: 41.0189 - mae: 41.0189 Epoch 25/100 5/5 [==============================] - 0s 4ms/step - loss: 40.9399 - mae: 40.9399 Epoch 26/100 5/5 [==============================] - 0s 4ms/step - loss: 40.8702 - mae: 40.8702 Epoch 27/100 5/5 [==============================] - 0s 4ms/step - loss: 40.8792 - mae: 40.8792 Epoch 28/100 5/5 [==============================] - 0s 4ms/step - loss: 40.9051 - mae: 40.9051 Epoch 29/100 5/5 [==============================] - 0s 4ms/step - loss: 41.0040 - mae: 41.0040 Epoch 30/100 5/5 [==============================] - 0s 4ms/step - loss: 40.7552 - mae: 40.7552 Epoch 31/100 5/5 [==============================] - 0s 4ms/step - loss: 41.4596 - mae: 41.4596 Epoch 32/100 5/5 [==============================] - 0s 4ms/step - loss: 41.1033 - mae: 41.1033 Epoch 33/100 5/5 [==============================] - 0s 4ms/step - loss: 41.3175 - mae: 41.3175 Epoch 34/100 5/5 [==============================] - 0s 4ms/step - loss: 41.1769 - mae: 41.1769 Epoch 35/100 5/5 [==============================] - 0s 4ms/step - loss: 40.5617 - mae: 40.5617 Epoch 36/100 5/5 [==============================] - 0s 3ms/step - loss: 41.1154 - mae: 41.1154 Epoch 37/100 5/5 [==============================] - 0s 4ms/step - loss: 40.7481 - mae: 40.7481 Epoch 38/100 5/5 [==============================] - 0s 4ms/step - loss: 40.2204 - mae: 40.2204 Epoch 39/100 5/5 [==============================] - 0s 4ms/step - loss: 40.7754 - mae: 40.7754 Epoch 40/100 5/5 [==============================] - 0s 4ms/step - loss: 40.4291 - mae: 40.4291 Epoch 41/100 5/5 [==============================] - 0s 4ms/step - loss: 40.3978 - mae: 40.3978 Epoch 42/100 5/5 [==============================] - 0s 4ms/step - loss: 40.2725 - mae: 40.2725 Epoch 43/100 5/5 [==============================] - 0s 4ms/step - loss: 40.4337 - mae: 40.4337 Epoch 44/100 5/5 [==============================] - 0s 4ms/step - loss: 40.1374 - mae: 40.1374 Epoch 45/100 5/5 [==============================] - 0s 4ms/step - loss: 40.4168 - mae: 40.4168 Epoch 46/100 5/5 [==============================] - 0s 4ms/step - loss: 40.2554 - mae: 40.2554 Epoch 47/100 5/5 [==============================] - 0s 4ms/step - loss: 40.3564 - mae: 40.3564 Epoch 48/100 5/5 [==============================] - 0s 4ms/step - loss: 40.0230 - mae: 40.0230 Epoch 49/100 5/5 [==============================] - 0s 4ms/step - loss: 40.6025 - mae: 40.6025 Epoch 50/100 5/5 [==============================] - 0s 4ms/step - loss: 40.0258 - mae: 40.0258 Epoch 51/100 5/5 [==============================] - 0s 4ms/step - loss: 40.1666 - mae: 40.1666 Epoch 52/100 5/5 [==============================] - 0s 4ms/step - loss: 40.5203 - mae: 40.5203 Epoch 53/100 5/5 [==============================] - 0s 4ms/step - loss: 40.6270 - mae: 40.6270 Epoch 54/100 5/5 [==============================] - 0s 3ms/step - loss: 40.6488 - mae: 40.6488 Epoch 55/100 5/5 [==============================] - 0s 3ms/step - loss: 41.1507 - mae: 41.1507 Epoch 56/100 5/5 [==============================] - 0s 3ms/step - loss: 41.7925 - mae: 41.7925 Epoch 57/100 5/5 [==============================] - 0s 4ms/step - loss: 40.8292 - mae: 40.8292 Epoch 58/100 5/5 [==============================] - 0s 4ms/step - loss: 40.2886 - mae: 40.2886 Epoch 59/100 5/5 [==============================] - 0s 4ms/step - loss: 41.2423 - mae: 41.2423 Epoch 60/100 5/5 [==============================] - 0s 4ms/step - loss: 40.1771 - mae: 40.1771 Epoch 61/100 5/5 [==============================] - 0s 4ms/step - loss: 39.6253 - mae: 39.6253 Epoch 62/100 5/5 [==============================] - 0s 3ms/step - loss: 40.7286 - mae: 40.7286 Epoch 63/100 5/5 [==============================] - 0s 4ms/step - loss: 39.5326 - mae: 39.5326 Epoch 64/100 5/5 [==============================] - 0s 4ms/step - loss: 39.6509 - mae: 39.6509 Epoch 65/100 5/5 [==============================] - 0s 4ms/step - loss: 39.9142 - mae: 39.9142 Epoch 66/100 5/5 [==============================] - 0s 6ms/step - loss: 40.1273 - mae: 40.1273 Epoch 67/100 5/5 [==============================] - 0s 4ms/step - loss: 39.8303 - mae: 39.8303 Epoch 68/100 5/5 [==============================] - 0s 4ms/step - loss: 39.5528 - mae: 39.5528 Epoch 69/100 5/5 [==============================] - 0s 4ms/step - loss: 39.8911 - mae: 39.8911 Epoch 70/100 5/5 [==============================] - 0s 4ms/step - loss: 40.1772 - mae: 40.1772 Epoch 71/100 5/5 [==============================] - 0s 4ms/step - loss: 40.7013 - mae: 40.7013 Epoch 72/100 5/5 [==============================] - 0s 4ms/step - loss: 38.9850 - mae: 38.9850 Epoch 73/100 5/5 [==============================] - 0s 4ms/step - loss: 39.7226 - mae: 39.7226 Epoch 74/100 5/5 [==============================] - 0s 4ms/step - loss: 39.2922 - mae: 39.2922 Epoch 75/100 5/5 [==============================] - 0s 4ms/step - loss: 39.5231 - mae: 39.5231 Epoch 76/100 5/5 [==============================] - 0s 4ms/step - loss: 39.1719 - mae: 39.1719 Epoch 77/100 5/5 [==============================] - 0s 4ms/step - loss: 39.1334 - mae: 39.1334 Epoch 78/100 5/5 [==============================] - 0s 3ms/step - loss: 39.4218 - mae: 39.4218 Epoch 79/100 5/5 [==============================] - 0s 4ms/step - loss: 39.0527 - mae: 39.0527 Epoch 80/100 5/5 [==============================] - 0s 4ms/step - loss: 38.5616 - mae: 38.5616 Epoch 81/100 5/5 [==============================] - 0s 4ms/step - loss: 39.0586 - mae: 39.0586 Epoch 82/100 5/5 [==============================] - 0s 4ms/step - loss: 39.5188 - mae: 39.5188 Epoch 83/100 5/5 [==============================] - 0s 4ms/step - loss: 38.8511 - mae: 38.8511 Epoch 84/100 5/5 [==============================] - 0s 4ms/step - loss: 38.6502 - mae: 38.6502 Epoch 85/100 5/5 [==============================] - 0s 4ms/step - loss: 38.6555 - mae: 38.6555 Epoch 86/100 5/5 [==============================] - 0s 4ms/step - loss: 38.4175 - mae: 38.4175 Epoch 87/100 5/5 [==============================] - 0s 4ms/step - loss: 38.6160 - mae: 38.6160 Epoch 88/100 5/5 [==============================] - 0s 4ms/step - loss: 38.4141 - mae: 38.4141 Epoch 89/100 5/5 [==============================] - 0s 4ms/step - loss: 38.5214 - mae: 38.5214 Epoch 90/100 5/5 [==============================] - 0s 4ms/step - loss: 38.2880 - mae: 38.2880 Epoch 91/100 5/5 [==============================] - 0s 4ms/step - loss: 38.0676 - mae: 38.0676 Epoch 92/100 5/5 [==============================] - 0s 4ms/step - loss: 38.6172 - mae: 38.6172 Epoch 93/100 5/5 [==============================] - 0s 4ms/step - loss: 38.8183 - mae: 38.8183 Epoch 94/100 5/5 [==============================] - 0s 4ms/step - loss: 37.8435 - mae: 37.8435 Epoch 95/100 5/5 [==============================] - 0s 4ms/step - loss: 38.1737 - mae: 38.1737 Epoch 96/100 5/5 [==============================] - 0s 4ms/step - loss: 38.2674 - mae: 38.2674 Epoch 97/100 5/5 [==============================] - 0s 4ms/step - loss: 37.8406 - mae: 37.8406 Epoch 98/100 5/5 [==============================] - 0s 4ms/step - loss: 38.8478 - mae: 38.8478 Epoch 99/100 5/5 [==============================] - 0s 4ms/step - loss: 38.0126 - mae: 38.0126 Epoch 100/100 5/5 [==============================] - 0s 4ms/step - loss: 37.7153 - mae: 37.7153
<keras.callbacks.History at 0x7f5cd57c5d50>
Okay, it seems like our model is learning something (the mae
value trends down with each epoch), let's plot its predictions.
# Make predictions with our trained model
y_reg_preds = model_3.predict(y_reg_test)
# Plot the model's predictions against our regression data
plt.figure(figsize=(10, 7))
plt.scatter(X_reg_train, y_reg_train, c='b', label='Training data')
plt.scatter(X_reg_test, y_reg_test, c='g', label='Testing data')
plt.scatter(X_reg_test, y_reg_preds.squeeze(), c='r', label='Predictions')
plt.legend();
2/2 [==============================] - 0s 4ms/step
Okay, the predictions aren't perfect (if the predictions were perfect, the red would line up with the green), but they look better than complete guessing.
So this means our model must be learning something...
There must be something we're missing out on for our classification problem.
The missing piece: Non-linearity¶
Okay, so we saw our neural network can model straight lines (with ability a little bit better than guessing).
What about non-straight (non-linear) lines?
If we're going to model our classification data (the red and clue circles), we're going to need some non-linear lines.
🔨 Practice: Before we get to the next steps, I'd encourage you to play around with the TensorFlow Playground (check out what the data has in common with our own classification data) for 10-minutes. In particular the tab which says "activation". Once you're done, come back.
Did you try out the activation options? If so, what did you find?
If you didn't, don't worry, let's see it in code.
We're going to replicate the neural network you can see at this link: TensorFlow Playground.
The neural network we're going to recreate with TensorFlow code. See it live at TensorFlow Playground.
The main change we'll add to models we've built before is the use of the activation
keyword.
# Set the random seed
tf.random.set_seed(42)
# Create the model
model_4 = tf.keras.Sequential([
tf.keras.layers.Dense(1, activation=tf.keras.activations.linear), # 1 hidden layer with linear activation
tf.keras.layers.Dense(1) # output layer
])
# Compile the model
model_4.compile(loss=tf.keras.losses.binary_crossentropy,
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), # note: "lr" used to be what was used, now "learning_rate" is favoured
metrics=["accuracy"])
# Fit the model
history = model_4.fit(X, y, epochs=100)
Epoch 1/100 32/32 [==============================] - 1s 3ms/step - loss: 4.2764 - accuracy: 0.4990 Epoch 2/100 32/32 [==============================] - 0s 3ms/step - loss: 4.1347 - accuracy: 0.4950 Epoch 3/100 32/32 [==============================] - 0s 3ms/step - loss: 4.1151 - accuracy: 0.4870 Epoch 4/100 32/32 [==============================] - 0s 3ms/step - loss: 4.0430 - accuracy: 0.4750 Epoch 5/100 32/32 [==============================] - 0s 3ms/step - loss: 3.9497 - accuracy: 0.4680 Epoch 6/100 32/32 [==============================] - 0s 3ms/step - loss: 3.9062 - accuracy: 0.4610 Epoch 7/100 32/32 [==============================] - 0s 3ms/step - loss: 3.8450 - accuracy: 0.4570 Epoch 8/100 32/32 [==============================] - 0s 3ms/step - loss: 3.7964 - accuracy: 0.4490 Epoch 9/100 32/32 [==============================] - 0s 3ms/step - loss: 3.7485 - accuracy: 0.4440 Epoch 10/100 32/32 [==============================] - 0s 3ms/step - loss: 3.7203 - accuracy: 0.4400 Epoch 11/100 32/32 [==============================] - 0s 3ms/step - loss: 3.6246 - accuracy: 0.4360 Epoch 12/100 32/32 [==============================] - 0s 3ms/step - loss: 3.5373 - accuracy: 0.4360 Epoch 13/100 32/32 [==============================] - 0s 3ms/step - loss: 3.4786 - accuracy: 0.4350 Epoch 14/100 32/32 [==============================] - 0s 3ms/step - loss: 3.3771 - accuracy: 0.4350 Epoch 15/100 32/32 [==============================] - 0s 3ms/step - loss: 3.2890 - accuracy: 0.4410 Epoch 16/100 32/32 [==============================] - 0s 3ms/step - loss: 3.1248 - accuracy: 0.4400 Epoch 17/100 32/32 [==============================] - 0s 3ms/step - loss: 2.8467 - accuracy: 0.4370 Epoch 18/100 32/32 [==============================] - 0s 3ms/step - loss: 2.7801 - accuracy: 0.4360 Epoch 19/100 32/32 [==============================] - 0s 3ms/step - loss: 2.6968 - accuracy: 0.4360 Epoch 20/100 32/32 [==============================] - 0s 3ms/step - loss: 2.3279 - accuracy: 0.4330 Epoch 21/100 32/32 [==============================] - 0s 3ms/step - loss: 2.1250 - accuracy: 0.4340 Epoch 22/100 32/32 [==============================] - 0s 3ms/step - loss: 2.0046 - accuracy: 0.4350 Epoch 23/100 32/32 [==============================] - 0s 3ms/step - loss: 1.8705 - accuracy: 0.4350 Epoch 24/100 32/32 [==============================] - 0s 3ms/step - loss: 1.7323 - accuracy: 0.4350 Epoch 25/100 32/32 [==============================] - 0s 3ms/step - loss: 1.5785 - accuracy: 0.4360 Epoch 26/100 32/32 [==============================] - 0s 3ms/step - loss: 1.3957 - accuracy: 0.4370 Epoch 27/100 32/32 [==============================] - 0s 3ms/step - loss: 1.1862 - accuracy: 0.4390 Epoch 28/100 32/32 [==============================] - 0s 3ms/step - loss: 1.0745 - accuracy: 0.4410 Epoch 29/100 32/32 [==============================] - 0s 3ms/step - loss: 1.0452 - accuracy: 0.4420 Epoch 30/100 32/32 [==============================] - 0s 3ms/step - loss: 1.0294 - accuracy: 0.4420 Epoch 31/100 32/32 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.4430 Epoch 32/100 32/32 [==============================] - 0s 3ms/step - loss: 1.0058 - accuracy: 0.4430 Epoch 33/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9966 - accuracy: 0.4430 Epoch 34/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9883 - accuracy: 0.4430 Epoch 35/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9806 - accuracy: 0.4430 Epoch 36/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9735 - accuracy: 0.4420 Epoch 37/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9671 - accuracy: 0.4420 Epoch 38/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9612 - accuracy: 0.4420 Epoch 39/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9555 - accuracy: 0.4420 Epoch 40/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9501 - accuracy: 0.4420 Epoch 41/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9449 - accuracy: 0.4420 Epoch 42/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9400 - accuracy: 0.4400 Epoch 43/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9351 - accuracy: 0.4400 Epoch 44/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.4390 Epoch 45/100 32/32 [==============================] - 0s 4ms/step - loss: 0.9264 - accuracy: 0.4390 Epoch 46/100 32/32 [==============================] - 0s 4ms/step - loss: 0.9222 - accuracy: 0.4380 Epoch 47/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9183 - accuracy: 0.4380 Epoch 48/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9143 - accuracy: 0.4380 Epoch 49/100 32/32 [==============================] - 0s 4ms/step - loss: 0.9105 - accuracy: 0.4370 Epoch 50/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9068 - accuracy: 0.4370 Epoch 51/100 32/32 [==============================] - 0s 3ms/step - loss: 0.9032 - accuracy: 0.4370 Epoch 52/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8998 - accuracy: 0.4350 Epoch 53/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8964 - accuracy: 0.4340 Epoch 54/100 32/32 [==============================] - 0s 4ms/step - loss: 0.8931 - accuracy: 0.4320 Epoch 55/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8899 - accuracy: 0.4320 Epoch 56/100 32/32 [==============================] - 0s 4ms/step - loss: 0.8867 - accuracy: 0.4320 Epoch 57/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8836 - accuracy: 0.4310 Epoch 58/100 32/32 [==============================] - 0s 4ms/step - loss: 0.8805 - accuracy: 0.4310 Epoch 59/100 32/32 [==============================] - 0s 4ms/step - loss: 0.8776 - accuracy: 0.4310 Epoch 60/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8747 - accuracy: 0.4300 Epoch 61/100 32/32 [==============================] - 0s 4ms/step - loss: 0.8719 - accuracy: 0.4300 Epoch 62/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8692 - accuracy: 0.4280 Epoch 63/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8664 - accuracy: 0.4280 Epoch 64/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8638 - accuracy: 0.4260 Epoch 65/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8612 - accuracy: 0.4260 Epoch 66/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8585 - accuracy: 0.4250 Epoch 67/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8560 - accuracy: 0.4260 Epoch 68/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8535 - accuracy: 0.4250 Epoch 69/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8510 - accuracy: 0.4250 Epoch 70/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8486 - accuracy: 0.4240 Epoch 71/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8461 - accuracy: 0.4240 Epoch 72/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8437 - accuracy: 0.4240 Epoch 73/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8414 - accuracy: 0.4220 Epoch 74/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8391 - accuracy: 0.4210 Epoch 75/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8368 - accuracy: 0.4200 Epoch 76/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8346 - accuracy: 0.4170 Epoch 77/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8324 - accuracy: 0.4170 Epoch 78/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8302 - accuracy: 0.4150 Epoch 79/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8280 - accuracy: 0.4150 Epoch 80/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8259 - accuracy: 0.4150 Epoch 81/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8238 - accuracy: 0.4140 Epoch 82/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8218 - accuracy: 0.4140 Epoch 83/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8198 - accuracy: 0.4140 Epoch 84/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8178 - accuracy: 0.4140 Epoch 85/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8159 - accuracy: 0.4120 Epoch 86/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8140 - accuracy: 0.4130 Epoch 87/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8121 - accuracy: 0.4130 Epoch 88/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8102 - accuracy: 0.4140 Epoch 89/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8084 - accuracy: 0.4150 Epoch 90/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8066 - accuracy: 0.4140 Epoch 91/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8048 - accuracy: 0.4150 Epoch 92/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8031 - accuracy: 0.4170 Epoch 93/100 32/32 [==============================] - 0s 3ms/step - loss: 0.8013 - accuracy: 0.4180 Epoch 94/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7996 - accuracy: 0.4200 Epoch 95/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7979 - accuracy: 0.4200 Epoch 96/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7962 - accuracy: 0.4220 Epoch 97/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7946 - accuracy: 0.4220 Epoch 98/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7930 - accuracy: 0.4250 Epoch 99/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7913 - accuracy: 0.4290 Epoch 100/100 32/32 [==============================] - 0s 3ms/step - loss: 0.7896 - accuracy: 0.4300
Okay, our model performs a little worse than guessing.
Let's remind ourselves what our data looks like.
# Check out our data
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdYlBu);
And let's see how our model is making predictions on it.
# Check the deicison boundary (blue is blue class, yellow is the crossover, red is red class)
plot_decision_boundary(model_4, X, y)
313/313 [==============================] - 0s 1ms/step doing binary classifcation...