05. Transfer Learning with TensorFlow Part 2: Fine-tuning¶
In the previous section, we saw how we could leverage feature extraction transfer learning to get far better results on our Food Vision project than building our own models (even with less data).
Now we're going to cover another type of transfer learning: fine-tuning.
In fine-tuning transfer learning the pre-trained model weights from another model are unfrozen and tweaked during to better suit your own data.
For feature extraction transfer learning, you may only train the top 1-3 layers of a pre-trained model with your own data, in fine-tuning transfer learning, you might train 1-3+ layers of a pre-trained model (where the '+' indicates that many or all of the layers could be trained).
Feature extraction transfer learning vs. fine-tuning transfer learning. The main difference between the two is that in fine-tuning, more layers of the pre-trained model get unfrozen and tuned on custom data. This fine-tuning usually takes more data than feature extraction to be effective.
What we're going to cover¶
We're going to go through the follow with TensorFlow:
- Introduce fine-tuning, a type of transfer learning to modify a pre-trained model to be more suited to your data
- Using the Keras Functional API (a differnt way to build models in Keras)
- Using a smaller dataset to experiment faster (e.g. 1-10% of training samples of 10 classes of food)
- Data augmentation (how to make your training dataset more diverse without adding more data)
- Running a series of modelling experiments on our Food Vision data
- Model 0: a transfer learning model using the Keras Functional API
- Model 1: a feature extraction transfer learning model on 1% of the data with data augmentation
- Model 2: a feature extraction transfer learning model on 10% of the data with data augmentation
- Model 3: a fine-tuned transfer learning model on 10% of the data
- Model 4: a fine-tuned transfer learning model on 100% of the data
- Introduce the ModelCheckpoint callback to save intermediate training results
- Compare model experiments results using TensorBoard
How you can use this notebook¶
You can read through the descriptions and the code (it should all run, except for the cells which error on purpose), but there's a better option.
Write all of the code yourself.
Yes. I'm serious. Create a new notebook, and rewrite each line by yourself. Investigate it, see if you can break it, why does it break?
You don't have to write the text descriptions but writing the code yourself is a great way to get hands-on experience.
Don't worry if you make mistakes, we all do. The way to get better and make less mistakes is to write more code.
import datetime
print(f"Notebook last run (end-to-end): {datetime.datetime.now()}")
Notebook last run (end-to-end): 2023-05-12 08:01:58.291253
Note: As of TensorFlow 2.10+ there seems to be issues with the
tf.keras.applications.efficient
models (used later on) when loading weights via theload_weights()
methods. To get around this error you can use TensorFlow 2.9.0 (though it may be fixed in future verisons of TensorFlow). See more on the course GitHub issue.
# Note: Issues with TensorFlow 2.10+, however, TensorFlow 2.9 seems to work better
# This may be fixed in the future.
# see: https://www.tensorflow.org/api_docs/python/tf/keras/applications/efficientnet/EfficientNetB0
# and here: https://github.com/keras-team/keras/issues/16983
# Install TensorFlow 2.9.0 ("-U" stands for "update", "-q" stands for "quiet")
!pip install -U -q tensorflow==2.9.0
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 511.7/511.7 MB 3.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 63.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.6/42.6 kB 4.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.8/5.8 MB 98.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 438.7/438.7 kB 32.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 50.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.9/4.9 MB 66.8 MB/s eta 0:00:00 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. tensorflow-datasets 4.9.2 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible. tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.19.6 which is incompatible. TensorFlow version: 2.9.0
# Are we using a GPU? (if not & you're using Google Colab, go to Runtime -> Change Runtime Type -> Harware Accelerator: GPU )
!nvidia-smi
Fri May 12 08:01:58 2023 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | | N/A 42C P8 9W / 70W | 0MiB / 15360MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
Creating helper functions¶
Throughout your machine learning experiments, you'll likely come across snippets of code you want to use over and over again.
For example, a plotting function which plots a model's history
object (see plot_loss_curves()
below).
You could recreate these functions over and over again.
But as you might've guessed, rewritting the same functions becomes tedious.
One of the solutions is to store them in a helper script such as helper_functions.py
. And then import the necesary functionality when you need it.
For example, you might write:
from helper_functions import plot_loss_curves
...
plot_loss_curves(history)
Let's see what this looks like.
# Get helper_functions.py script from course GitHub
!wget https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
# Import helper functions we're going to use
from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, walk_through_dir
--2023-05-12 08:01:58-- https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 10246 (10K) [text/plain] Saving to: ‘helper_functions.py’ helper_functions.py 100%[===================>] 10.01K --.-KB/s in 0s 2023-05-12 08:01:58 (90.4 MB/s) - ‘helper_functions.py’ saved [10246/10246]
Wonderful, now we've got a bunch of helper functions we can use throughout the notebook without having to rewrite them from scratch each time.
🔑 Note: If you're running this notebook in Google Colab, when it times out Colab will delete the
helper_functions.py
file. So to use the functions imported above, you'll have to rerun the cell.
10 Food Classes: Working with less data¶
We saw in the previous notebook that we could get great results with only 10% of the training data using transfer learning with TensorFlow Hub.
In this notebook, we're going to continue to work with smaller subsets of the data, except this time we'll have a look at how we can use the in-built pretrained models within the tf.keras.applications
module as well as how to fine-tune them to our own custom dataset.
We'll also practice using a new but similar dataloader function to what we've used before, image_dataset_from_directory()
which is part of the tf.keras.preprocessing
module.
Finally, we'll also be practicing using the Keras Functional API for building deep learning models. The Functional API is a more flexible way to create models than the tf.keras.Sequential API.
We'll explore each of these in more detail as we go.
Let's start by downloading some data.
# Get 10% of the data of the 10 classes
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_10_percent.zip
unzip_data("10_food_classes_10_percent.zip")
--2023-05-12 08:01:59-- https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_10_percent.zip Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.97.128, 74.125.203.128, 74.125.204.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.97.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 168546183 (161M) [application/zip] Saving to: ‘10_food_classes_10_percent.zip’ 10_food_classes_10_ 100%[===================>] 160.74M 26.9MB/s in 7.0s 2023-05-12 08:02:06 (23.1 MB/s) - ‘10_food_classes_10_percent.zip’ saved [168546183/168546183]
The dataset we're downloading is the 10 food classes dataset (from Food 101) with 10% of the training images we used in the previous notebook.
🔑 Note: You can see how this dataset was created in the image data modification notebook.
# Walk through 10 percent data directory and list number of files
walk_through_dir("10_food_classes_10_percent")
There are 2 directories and 0 images in '10_food_classes_10_percent'. There are 10 directories and 0 images in '10_food_classes_10_percent/train'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/hamburger'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/steak'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/grilled_salmon'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_curry'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/sushi'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/ramen'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/ice_cream'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/pizza'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_wings'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/fried_rice'. There are 10 directories and 0 images in '10_food_classes_10_percent/test'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/hamburger'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/steak'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/grilled_salmon'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_curry'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/sushi'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/ramen'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/ice_cream'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/pizza'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_wings'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/fried_rice'.
We can see that each of the training directories contain 75 images and each of the testing directories contain 250 images.
Let's define our training and test filepaths.
# Create training and test directories
train_dir = "10_food_classes_10_percent/train/"
test_dir = "10_food_classes_10_percent/test/"
Now we've got some image data, we need a way of loading it into a TensorFlow compatible format.
Previously, we've used the ImageDataGenerator
class. And while this works well and is still very commonly used, this time we're going to use the image_data_from_directory
function.
It works much the same way as ImageDataGenerator
's flow_from_directory
method meaning your images need to be in the following file format:
Example of file structure
10_food_classes_10_percent <- top level folder
└───train <- training images
│ └───pizza
│ │ │ 1008104.jpg
│ │ │ 1638227.jpg
│ │ │ ...
│ └───steak
│ │ 1000205.jpg
│ │ 1647351.jpg
│ │ ...
│
└───test <- testing images
│ └───pizza
│ │ │ 1001116.jpg
│ │ │ 1507019.jpg
│ │ │ ...
│ └───steak
│ │ 100274.jpg
│ │ 1653815.jpg
│ │ ...
One of the main benefits of using tf.keras.prepreprocessing.image_dataset_from_directory()
rather than ImageDataGenerator
is that it creates a tf.data.Dataset
object rather than a generator. The main advantage of this is the tf.data.Dataset
API is much more efficient (faster) than the ImageDataGenerator
API which is paramount for larger datasets.
Let's see it in action.
# Create data inputs
import tensorflow as tf
IMG_SIZE = (224, 224) # define image size
train_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(directory=train_dir,
image_size=IMG_SIZE,
label_mode="categorical", # what type are the labels?
batch_size=32) # batch_size is 32 by default, this is generally a good number
test_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(directory=test_dir,
image_size=IMG_SIZE,
label_mode="categorical")
Found 750 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
Wonderful! Looks like our dataloaders have found the correct number of images for each dataset.
For now, the main parameters we're concerned about in the image_dataset_from_directory()
funtion are:
directory
- the filepath of the target directory we're loading images in from.image_size
- the target size of the images we're going to load in (height, width).batch_size
- the batch size of the images we're going to load in. For example if thebatch_size
is 32 (the default), batches of 32 images and labels at a time will be passed to the model.
There are more we could play around with if we needed to in the tf.keras.preprocessing
documentation.
If we check the training data datatype we should see it as a BatchDataset
with shapes relating to our data.
# Check the training data datatype
train_data_10_percent
<BatchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>
In the above output:
(None, 224, 224, 3)
refers to the tensor shape of our images whereNone
is the batch size,224
is the height (and width) and3
is the color channels (red, green, blue).(None, 10)
refers to the tensor shape of the labels whereNone
is the batch size and10
is the number of possible labels (the 10 different food classes).- Both image tensors and labels are of the datatype
tf.float32
.
The batch_size
is None
due to it only being used during model training. You can think of None
as a placeholder waiting to be filled with the batch_size
parameter from image_dataset_from_directory()
.
Another benefit of using the tf.data.Dataset
API are the assosciated methods which come with it.
For example, if we want to find the name of the classes we were working with, we could use the class_names
attribute.
# Check out the class names of our dataset
train_data_10_percent.class_names
['chicken_curry', 'chicken_wings', 'fried_rice', 'grilled_salmon', 'hamburger', 'ice_cream', 'pizza', 'ramen', 'steak', 'sushi']
Or if we wanted to see an example batch of data, we could use the take()
method.
# See an example batch of data
for images, labels in train_data_10_percent.take(1):
print(images, labels)
tf.Tensor( [[[[1.26438774e+02 7.42091827e+01 2.76683655e+01] [1.58372452e+02 1.03438782e+02 4.90357170e+01] [1.59454086e+02 1.01163269e+02 3.66581650e+01] ... [5.59999771e+01 1.42755041e+01 7.85714340e+00] [5.03826332e+01 1.09489832e+01 4.09691763e+00] [4.97550926e+01 1.08725061e+01 3.46937799e+00]] [[7.62091751e+01 2.67091808e+01 9.18363333e-02] [1.17862244e+02 6.69285736e+01 1.99948978e+01] [1.40668365e+02 8.53673477e+01 2.80816307e+01] ... [9.07142944e+01 3.61428604e+01 2.92551079e+01] [8.44897766e+01 3.19336510e+01 2.40611992e+01] [8.15918350e+01 2.90969448e+01 2.23570766e+01]] [[7.56530609e+01 2.63622475e+01 9.79592502e-01] [1.16301025e+02 6.44438858e+01 2.46581688e+01] [1.44556122e+02 9.01938782e+01 3.92908173e+01] ... [1.13260178e+02 4.56173210e+01 3.64030342e+01] [1.09301033e+02 4.44438934e+01 3.54438934e+01] [1.14005219e+02 4.91480713e+01 4.01480713e+01]] ... [[2.14009933e+02 1.74305740e+02 1.49020035e+02] [2.27316193e+02 1.90760010e+02 1.63918167e+02] [2.43117279e+02 2.09234528e+02 1.80974350e+02] ... [1.51877243e+02 7.43057785e+01 5.36629868e+01] [1.57622421e+02 7.66224289e+01 5.96224289e+01] [1.73219696e+02 9.12196884e+01 7.72196884e+01]] [[1.62892624e+02 1.12744644e+02 8.56272736e+01] [1.96586380e+02 1.51091431e+02 1.22290428e+02] [2.24224228e+02 1.80765030e+02 1.51566055e+02] ... [1.47790298e+02 6.97903061e+01 4.93617744e+01] [1.74857300e+02 9.58572998e+01 7.91328583e+01] [1.94315231e+02 1.12315231e+02 1.00315231e+02]] [[1.50694168e+02 9.49798050e+01 6.53369827e+01] [1.48933365e+02 9.51936035e+01 6.60557556e+01] [1.72203308e+02 1.23907364e+02 9.41981964e+01] ... [1.56739639e+02 7.87396469e+01 5.83111191e+01] [1.63668152e+02 8.46681595e+01 6.96681595e+01] [1.69540970e+02 8.75409698e+01 7.55409698e+01]]] [[[8.05612278e+00 8.05612278e+00 4.13265377e-01] [8.64285755e+00 8.64285755e+00 6.42857194e-01] [1.03622456e+01 1.03622456e+01 2.36224508e+00] ... [1.00000000e+00 1.00000000e+00 1.00000000e+00] [1.07144165e+00 1.07144165e+00 1.07144165e+00] [2.00000000e+00 2.00000000e+00 2.00000000e+00]] [[8.07142830e+00 8.07142830e+00 7.14282990e-02] [8.13775444e+00 8.13775444e+00 1.37754560e-01] [9.08673477e+00 9.08673477e+00 1.08673441e+00] ... [1.00000000e+00 1.00000000e+00 0.00000000e+00] [2.00000000e+00 2.00000000e+00 0.00000000e+00] [2.00000000e+00 2.00000000e+00 0.00000000e+00]] [[1.20000000e+01 9.00000000e+00 1.57142830e+00] [1.20000000e+01 9.00000000e+00 1.57142830e+00] [1.20000000e+01 9.00000000e+00 1.57142830e+00] ... [2.00000000e+00 3.00000000e+00 0.00000000e+00] [2.27041864e+00 3.27041864e+00 0.00000000e+00] [3.00000000e+00 4.00000000e+00 0.00000000e+00]] ... [[8.29439163e+01 4.13011284e+01 1.35153913e+01] [7.82704926e+01 3.66277046e+01 8.84196854e+00] [8.28623657e+01 4.07144623e+01 1.19541960e+01] ... [2.43117233e+02 2.02612167e+02 4.68007126e+01] [2.38642975e+02 1.93357269e+02 4.39745407e+01] [2.27362091e+02 1.81989761e+02 4.54284058e+01]] [[8.47142181e+01 4.37142220e+01 1.37142210e+01] [9.19336090e+01 5.09336090e+01 2.08009300e+01] [8.48978882e+01 4.35254517e+01 1.13264341e+01] ... [2.43586838e+02 2.04729660e+02 4.09592323e+01] [2.35561142e+02 1.93505005e+02 3.18469086e+01] [2.25719376e+02 1.85597015e+02 3.15714722e+01]] [[6.70813370e+01 2.80813332e+01 1.23972869e+00] [7.28517151e+01 3.18517189e+01 2.02029824e+00] [8.16430359e+01 4.10001755e+01 5.71953487e+00] ... [2.34586792e+02 1.92464203e+02 4.54591522e+01] [2.30785675e+02 1.90785675e+02 3.30254860e+01] [2.26260101e+02 1.88260101e+02 2.80865364e+01]]] [[[0.00000000e+00 0.00000000e+00 0.00000000e+00] [0.00000000e+00 0.00000000e+00 0.00000000e+00] [2.57142878e+00 2.57142878e+00 2.57142878e+00] ... [4.58674622e+01 4.58674622e+01 4.64287033e+01] [5.46173553e+01 5.46173553e+01 5.66173553e+01] [5.55714645e+01 5.55714645e+01 5.75714645e+01]] [[1.00000000e+00 1.00000000e+00 1.00000000e+00] [1.00000000e+00 1.00000000e+00 1.00000000e+00] [3.19897962e+00 3.19897962e+00 3.19897962e+00] ... [4.95153694e+01 4.95153694e+01 5.11174507e+01] [5.31428566e+01 5.31428566e+01 5.51428566e+01] [5.42397995e+01 5.42397995e+01 5.62397995e+01]] [[1.00000000e+00 1.00000000e+00 1.00000000e+00] [1.00000000e+00 1.00000000e+00 1.00000000e+00] [2.38265324e+00 2.38265324e+00 2.38265324e+00] ... [5.15714722e+01 5.15714722e+01 5.35714722e+01] [5.42704201e+01 5.34847031e+01 5.78418465e+01] [5.73571777e+01 5.65714645e+01 6.09286041e+01]] ... [[1.88494873e+02 1.90494873e+02 1.87494873e+02] [1.88857132e+02 1.90857132e+02 1.87857132e+02] [1.87999985e+02 1.89999985e+02 1.88999985e+02] ... [1.58999969e+02 1.59999969e+02 1.63999969e+02] [1.60000000e+02 1.61000000e+02 1.65000000e+02] [1.57285645e+02 1.58285645e+02 1.62285645e+02]] [[1.85714264e+02 1.87714264e+02 1.84714264e+02] [1.87142838e+02 1.89142838e+02 1.86142838e+02] [1.86683655e+02 1.88683655e+02 1.87683655e+02] ... [1.57285675e+02 1.58285675e+02 1.62285675e+02] [1.59071411e+02 1.60071411e+02 1.64071411e+02] [1.57285645e+02 1.58285645e+02 1.62285645e+02]] [[1.86056122e+02 1.88056122e+02 1.85056122e+02] [1.87071426e+02 1.89071426e+02 1.86071426e+02] [1.88637756e+02 1.90637756e+02 1.89637756e+02] ... [1.56076523e+02 1.57076523e+02 1.61076523e+02] [1.56357147e+02 1.57357147e+02 1.61357147e+02] [1.57000000e+02 1.58000000e+02 1.62000000e+02]]] ... [[[5.76428566e+01 1.16428576e+01 2.26428566e+01] [5.60379448e+01 1.30379467e+01 2.20379467e+01] [5.31004486e+01 1.20000000e+01 1.86294651e+01] ... [1.35033386e+02 9.16136780e+01 5.73749161e+01] [9.58234482e+01 4.86760979e+01 2.28904228e+01] [8.35336380e+01 3.25336380e+01 1.15336380e+01]] [[5.64308014e+01 1.04308033e+01 2.14308033e+01] [5.69575920e+01 1.39575891e+01 2.29575901e+01] [5.30000000e+01 1.20000000e+01 1.84285717e+01] ... [1.32890579e+02 8.56650696e+01 5.24017639e+01] [9.76650085e+01 4.64194641e+01 2.21851196e+01] [8.73862381e+01 3.53862381e+01 1.43862381e+01]] [[5.65200882e+01 1.05200891e+01 2.22343750e+01] [5.53437500e+01 1.23437500e+01 2.01718750e+01] [5.24843750e+01 1.14843750e+01 1.66004467e+01] ... [1.21912888e+02 7.75244293e+01 4.84218597e+01] [9.08257904e+01 4.34039078e+01 1.99061756e+01] [8.74799576e+01 3.91049576e+01 2.03192673e+01]] ... [[6.66116104e+01 5.12544632e+01 5.93258934e+01] [6.95535736e+01 5.45535736e+01 5.77901802e+01] [6.95937500e+01 5.77834816e+01 5.41651764e+01] ... [1.72537857e+02 1.09895065e+02 5.11093330e+01] [1.67843658e+02 1.05843658e+02 4.88436584e+01] [1.55285645e+02 9.32856445e+01 3.58168488e+01]] [[5.40290184e+01 3.90290184e+01 4.39308052e+01] [5.78370552e+01 4.28370552e+01 4.59218750e+01] [6.19464302e+01 4.79464302e+01 4.73526802e+01] ... [1.71616028e+02 1.09187500e+02 4.97589722e+01] [1.66903931e+02 1.04903931e+02 4.79039307e+01] [1.55526779e+02 9.25267792e+01 3.85267792e+01]] [[3.48727684e+01 2.03415184e+01 2.19352684e+01] [3.52544632e+01 2.02544632e+01 2.32544632e+01] [3.74263382e+01 2.30691967e+01 2.67924099e+01] ... [1.71580292e+02 1.09151764e+02 4.96227989e+01] [1.60207535e+02 9.77008286e+01 4.22209358e+01] [1.59042572e+02 9.60425644e+01 4.30425644e+01]]] [[[1.10000000e+02 6.32397957e+01 0.00000000e+00] [1.18780609e+02 7.49744949e+01 7.37755108e+00] [1.06168365e+02 6.54438782e+01 1.02551079e+00] ... [9.51429443e+01 5.70000000e+01 2.21426392e+00] [1.04714348e+02 6.15000267e+01 8.57146740e+00] [1.10357109e+02 6.33571091e+01 9.35710812e+00]] [[1.14693878e+02 7.08877563e+01 4.36224604e+00] [1.14423470e+02 7.30051041e+01 7.14285660e+00] [1.10346939e+02 7.12346954e+01 6.29081631e+00] ... [9.44694366e+01 5.59285736e+01 1.34179497e+00] [1.00285751e+02 5.70714264e+01 4.14286995e+00] [1.07311264e+02 5.74490089e+01 4.40309334e+00]] [[1.07239792e+02 6.82346954e+01 7.65306503e-02] [1.13382652e+02 7.48112259e+01 7.59694004e+00] [1.13469383e+02 7.48520432e+01 9.94387722e+00] ... [9.81888351e+01 5.96173630e+01 5.04589176e+00] [1.02857155e+02 5.68571548e+01 4.85715580e+00] [1.09576607e+02 5.95766068e+01 6.57660770e+00]] ... [[9.08060455e+01 6.93775177e+01 2.37751961e+00] [9.02449646e+01 6.91990509e+01 4.21435165e+00] [8.26632767e+01 6.32806396e+01 9.18274522e-02] ... [5.38824654e+01 2.38824654e+01 1.11722851e+00] [4.99592323e+01 2.15307064e+01 0.00000000e+00] [5.27040253e+01 2.40612335e+01 2.35202789e+00]] [[8.45306549e+01 6.55765839e+01 2.39839837e-01] [8.84488907e+01 7.05815887e+01 4.51524401e+00] [8.39744873e+01 6.90051041e+01 1.98979414e+00] ... [4.98469048e+01 2.14183769e+01 6.12294711e-02] [5.55152206e+01 2.75152206e+01 3.64789557e+00] [5.30714417e+01 2.50714417e+01 3.07144165e+00]] [[8.37753220e+01 6.57753220e+01 1.10198092e+00] [8.21327438e+01 6.69949951e+01 6.63330257e-01] [8.74032364e+01 7.44032364e+01 6.40323353e+00] ... [5.51223755e+01 2.66938477e+01 2.69384837e+00] [4.85256996e+01 2.05256996e+01 1.73511118e-01] [5.52040520e+01 2.92040520e+01 6.20405197e+00]]] [[[3.60000000e+01 0.00000000e+00 0.00000000e+00] [3.60000000e+01 0.00000000e+00 0.00000000e+00] [3.60000000e+01 0.00000000e+00 0.00000000e+00] ... [3.60000000e+01 1.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.90000000e+01 0.00000000e+00 0.00000000e+00]] [[3.60000000e+01 0.00000000e+00 0.00000000e+00] [3.60000000e+01 0.00000000e+00 0.00000000e+00] [3.60000000e+01 0.00000000e+00 0.00000000e+00] ... [3.75714722e+01 2.14263916e-01 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00]] [[3.60000000e+01 0.00000000e+00 0.00000000e+00] [3.60000000e+01 0.00000000e+00 0.00000000e+00] [3.60000000e+01 0.00000000e+00 0.00000000e+00] ... [3.80459137e+01 0.00000000e+00 0.00000000e+00] [3.75714264e+01 2.14285851e-01 0.00000000e+00] [3.60000000e+01 1.00000000e+00 0.00000000e+00]] ... [[3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] ... [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00]] [[3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] ... [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00]] [[3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] ... [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00] [3.80000000e+01 0.00000000e+00 0.00000000e+00]]]], shape=(32, 224, 224, 3), dtype=float32) tf.Tensor( [[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]], shape=(32, 10), dtype=float32)
Notice how the image arrays come out as tensors of pixel values where as the labels come out as one-hot encodings (e.g. [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
for hamburger
).
Model 0: Building a transfer learning model using the Keras Functional API¶
Alright, our data is tensor-ified, let's build a model.
To do so we're going to be using the tf.keras.applications
module as it contains a series of already trained (on ImageNet) computer vision models as well as the Keras Functional API to construct our model.
We're going to go through the following steps:
- Instantiate a pre-trained base model object by choosing a target model such as
EfficientNetB0
fromtf.keras.applications
, setting theinclude_top
parameter toFalse
(we do this because we're going to create our own top, which are the output layers for the model). - Set the base model's
trainable
attribute toFalse
to freeze all of the weights in the pre-trained model. - Define an input layer for our model, for example, what shape of data should our model expect?
- [Optional] Normalize the inputs to our model if it requires. Some computer vision models such as
ResNetV250
require their inputs to be between 0 & 1.
🤔 Note: As of writing, the
EfficientNet
models in thetf.keras.applications
module do not require images to be normalized (pixel values between 0 and 1) on input, where as many of the other models do. I posted an issue to the TensorFlow GitHub about this and they confirmed this.
- Pass the inputs to the base model.
- Pool the outputs of the base model into a shape compatible with the output activation layer (turn base model output tensors into same shape as label tensors). This can be done using
tf.keras.layers.GlobalAveragePooling2D()
ortf.keras.layers.GlobalMaxPooling2D()
though the former is more common in practice. - Create an output activation layer using
tf.keras.layers.Dense()
with the appropriate activation function and number of neurons. - Combine the inputs and outputs layer into a model using
tf.keras.Model()
. - Compile the model using the appropriate loss function and choose of optimizer.
- Fit the model for desired number of epochs and with necessary callbacks (in our case, we'll start off with the TensorBoard callback).
Woah... that sounds like a lot. Before we get ahead of ourselves, let's see it in practice.
# 1. Create base model with tf.keras.applications
base_model = tf.keras.applications.EfficientNetB0(include_top=False)
# 2. Freeze the base model (so the pre-learned patterns remain)
base_model.trainable = False
# 3. Create inputs into the base model
inputs = tf.keras.layers.Input(shape=(224, 224, 3), name="input_layer")
# 4. If using ResNet50V2, add this to speed up convergence, remove for EfficientNet
# x = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(inputs)
# 5. Pass the inputs to the base_model (note: using tf.keras.applications, EfficientNet inputs don't have to be normalized)
x = base_model(inputs)
# Check data shape after passing it to base_model
print(f"Shape after base_model: {x.shape}")
# 6. Average pool the outputs of the base model (aggregate all the most important information, reduce number of computations)
x = tf.keras.layers.GlobalAveragePooling2D(name="global_average_pooling_layer")(x)
print(f"After GlobalAveragePooling2D(): {x.shape}")
# 7. Create the output activation layer
outputs = tf.keras.layers.Dense(10, activation="softmax", name="output_layer")(x)
# 8. Combine the inputs with the outputs into a model
model_0 = tf.keras.Model(inputs, outputs)
# 9. Compile the model
model_0.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"])
# 10. Fit the model (we use less steps for validation so it's faster)
history_10_percent = model_0.fit(train_data_10_percent,
epochs=5,
steps_per_epoch=len(train_data_10_percent),
validation_data=test_data_10_percent,
# Go through less of the validation data so epochs are faster (we want faster experiments!)
validation_steps=int(0.25 * len(test_data_10_percent)),
# Track our model's training logs for visualization later
callbacks=[create_tensorboard_callback("transfer_learning", "10_percent_feature_extract")])
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5 16705208/16705208 [==============================] - 1s 0us/step Shape after base_model: (None, 7, 7, 1280) After GlobalAveragePooling2D(): (None, 1280) Saving TensorBoard log files to: transfer_learning/10_percent_feature_extract/20230512-080215 Epoch 1/5 24/24 [==============================] - 23s 264ms/step - loss: 1.8800 - accuracy: 0.4133 - val_loss: 1.3234 - val_accuracy: 0.7270 Epoch 2/5 24/24 [==============================] - 4s 133ms/step - loss: 1.0934 - accuracy: 0.7653 - val_loss: 0.8853 - val_accuracy: 0.8076 Epoch 3/5 24/24 [==============================] - 4s 135ms/step - loss: 0.7876 - accuracy: 0.8293 - val_loss: 0.7109 - val_accuracy: 0.8372 Epoch 4/5 24/24 [==============================] - 3s 128ms/step - loss: 0.6420 - accuracy: 0.8520 - val_loss: 0.6223 - val_accuracy: 0.8487 Epoch 5/5 24/24 [==============================] - 3s 127ms/step - loss: 0.5426 - accuracy: 0.8867 - val_loss: 0.5699 - val_accuracy: 0.8569
Nice! After a minute or so of training our model performs incredibly well on both the training (87%+ accuracy) and test sets (~83% accuracy).
This is incredible. All thanks to the power of transfer learning.
It's important to note the kind of transfer learning we used here is called feature extraction transfer learning, similar to what we did with the TensorFlow Hub models.
In other words, we passed our custom data to an already pre-trained model (EfficientNetB0
), asked it "what patterns do you see?" and then put our own output layer on top to make sure the outputs were tailored to our desired number of classes.
We also used the Keras Functional API to build our model rather than the Sequential API. For now, the benefits of this main not seem clear but when you start to build more sophisticated models, you'll probably want to use the Functional API. So it's important to have exposure to this way of building models.
📖 Resource: To see the benefits and use cases of the Functional API versus the Sequential API, check out the TensorFlow Functional API documentation.
Let's inspect the layers in our model, we'll start with the base.
# Check layers in our base model
for layer_number, layer in enumerate(base_model.layers):
print(layer_number, layer.name)
0 input_1 1 rescaling 2 normalization 3 tf.math.truediv 4 stem_conv_pad 5 stem_conv 6 stem_bn 7 stem_activation 8 block1a_dwconv 9 block1a_bn 10 block1a_activation 11 block1a_se_squeeze 12 block1a_se_reshape 13 block1a_se_reduce 14 block1a_se_expand 15 block1a_se_excite 16 block1a_project_conv 17 block1a_project_bn 18 block2a_expand_conv 19 block2a_expand_bn 20 block2a_expand_activation 21 block2a_dwconv_pad 22 block2a_dwconv 23 block2a_bn 24 block2a_activation 25 block2a_se_squeeze 26 block2a_se_reshape 27 block2a_se_reduce 28 block2a_se_expand 29 block2a_se_excite 30 block2a_project_conv 31 block2a_project_bn 32 block2b_expand_conv 33 block2b_expand_bn 34 block2b_expand_activation 35 block2b_dwconv 36 block2b_bn 37 block2b_activation 38 block2b_se_squeeze 39 block2b_se_reshape 40 block2b_se_reduce 41 block2b_se_expand 42 block2b_se_excite 43 block2b_project_conv 44 block2b_project_bn 45 block2b_drop 46 block2b_add 47 block3a_expand_conv 48 block3a_expand_bn 49 block3a_expand_activation 50 block3a_dwconv_pad 51 block3a_dwconv 52 block3a_bn 53 block3a_activation 54 block3a_se_squeeze 55 block3a_se_reshape 56 block3a_se_reduce 57 block3a_se_expand 58 block3a_se_excite 59 block3a_project_conv 60 block3a_project_bn 61 block3b_expand_conv 62 block3b_expand_bn 63 block3b_expand_activation 64 block3b_dwconv 65 block3b_bn 66 block3b_activation 67 block3b_se_squeeze 68 block3b_se_reshape 69 block3b_se_reduce 70 block3b_se_expand 71 block3b_se_excite 72 block3b_project_conv 73 block3b_project_bn 74 block3b_drop 75 block3b_add 76 block4a_expand_conv 77 block4a_expand_bn 78 block4a_expand_activation 79 block4a_dwconv_pad 80 block4a_dwconv 81 block4a_bn 82 block4a_activation 83 block4a_se_squeeze 84 block4a_se_reshape 85 block4a_se_reduce 86 block4a_se_expand 87 block4a_se_excite 88 block4a_project_conv 89 block4a_project_bn 90 block4b_expand_conv 91 block4b_expand_bn 92 block4b_expand_activation 93 block4b_dwconv 94 block4b_bn 95 block4b_activation 96 block4b_se_squeeze 97 block4b_se_reshape 98 block4b_se_reduce 99 block4b_se_expand 100 block4b_se_excite 101 block4b_project_conv 102 block4b_project_bn 103 block4b_drop 104 block4b_add 105 block4c_expand_conv 106 block4c_expand_bn 107 block4c_expand_activation 108 block4c_dwconv 109 block4c_bn 110 block4c_activation 111 block4c_se_squeeze 112 block4c_se_reshape 113 block4c_se_reduce 114 block4c_se_expand 115 block4c_se_excite 116 block4c_project_conv 117 block4c_project_bn 118 block4c_drop 119 block4c_add 120 block5a_expand_conv 121 block5a_expand_bn 122 block5a_expand_activation 123 block5a_dwconv 124 block5a_bn 125 block5a_activation 126 block5a_se_squeeze 127 block5a_se_reshape 128 block5a_se_reduce 129 block5a_se_expand 130 block5a_se_excite 131 block5a_project_conv 132 block5a_project_bn 133 block5b_expand_conv 134 block5b_expand_bn 135 block5b_expand_activation 136 block5b_dwconv 137 block5b_bn 138 block5b_activation 139 block5b_se_squeeze 140 block5b_se_reshape 141 block5b_se_reduce 142 block5b_se_expand 143 block5b_se_excite 144 block5b_project_conv 145 block5b_project_bn 146 block5b_drop 147 block5b_add 148 block5c_expand_conv 149 block5c_expand_bn 150 block5c_expand_activation 151 block5c_dwconv 152 block5c_bn 153 block5c_activation 154 block5c_se_squeeze 155 block5c_se_reshape 156 block5c_se_reduce 157 block5c_se_expand 158 block5c_se_excite 159 block5c_project_conv 160 block5c_project_bn 161 block5c_drop 162 block5c_add 163 block6a_expand_conv 164 block6a_expand_bn 165 block6a_expand_activation 166 block6a_dwconv_pad 167 block6a_dwconv 168 block6a_bn 169 block6a_activation 170 block6a_se_squeeze 171 block6a_se_reshape 172 block6a_se_reduce 173 block6a_se_expand 174 block6a_se_excite 175 block6a_project_conv 176 block6a_project_bn 177 block6b_expand_conv 178 block6b_expand_bn 179 block6b_expand_activation 180 block6b_dwconv 181 block6b_bn 182 block6b_activation 183 block6b_se_squeeze 184 block6b_se_reshape 185 block6b_se_reduce 186 block6b_se_expand 187 block6b_se_excite 188 block6b_project_conv 189 block6b_project_bn 190 block6b_drop 191 block6b_add 192 block6c_expand_conv 193 block6c_expand_bn 194 block6c_expand_activation 195 block6c_dwconv 196 block6c_bn 197 block6c_activation 198 block6c_se_squeeze 199 block6c_se_reshape 200 block6c_se_reduce 201 block6c_se_expand 202 block6c_se_excite 203 block6c_project_conv 204 block6c_project_bn 205 block6c_drop 206 block6c_add 207 block6d_expand_conv 208 block6d_expand_bn 209 block6d_expand_activation 210 block6d_dwconv 211 block6d_bn 212 block6d_activation 213 block6d_se_squeeze 214 block6d_se_reshape 215 block6d_se_reduce 216 block6d_se_expand 217 block6d_se_excite 218 block6d_project_conv 219 block6d_project_bn 220 block6d_drop 221 block6d_add 222 block7a_expand_conv 223 block7a_expand_bn 224 block7a_expand_activation 225 block7a_dwconv 226 block7a_bn 227 block7a_activation 228 block7a_se_squeeze 229 block7a_se_reshape 230 block7a_se_reduce 231 block7a_se_expand 232 block7a_se_excite 233 block7a_project_conv 234 block7a_project_bn 235 top_conv 236 top_bn 237 top_activation
Wow, that's a lot of layers... to handcode all of those would've taken a fairly long time to do, yet we can still take advatange of them thanks to the power of transfer learning.
How about a summary of the base model?
base_model.summary()
Model: "efficientnetb0" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, None, None, 0 [] 3)] rescaling (Rescaling) (None, None, None, 0 ['input_1[0][0]'] 3) normalization (Normalization) (None, None, None, 7 ['rescaling[0][0]'] 3) tf.math.truediv (TFOpLambda) (None, None, None, 0 ['normalization[0][0]'] 3) stem_conv_pad (ZeroPadding2D) (None, None, None, 0 ['tf.math.truediv[0][0]'] 3) stem_conv (Conv2D) (None, None, None, 864 ['stem_conv_pad[0][0]'] 32) stem_bn (BatchNormalization) (None, None, None, 128 ['stem_conv[0][0]'] 32) stem_activation (Activation) (None, None, None, 0 ['stem_bn[0][0]'] 32) block1a_dwconv (DepthwiseConv2 (None, None, None, 288 ['stem_activation[0][0]'] D) 32) block1a_bn (BatchNormalization (None, None, None, 128 ['block1a_dwconv[0][0]'] ) 32) block1a_activation (Activation (None, None, None, 0 ['block1a_bn[0][0]'] ) 32) block1a_se_squeeze (GlobalAver (None, 32) 0 ['block1a_activation[0][0]'] agePooling2D) block1a_se_reshape (Reshape) (None, 1, 1, 32) 0 ['block1a_se_squeeze[0][0]'] block1a_se_reduce (Conv2D) (None, 1, 1, 8) 264 ['block1a_se_reshape[0][0]'] block1a_se_expand (Conv2D) (None, 1, 1, 32) 288 ['block1a_se_reduce[0][0]'] block1a_se_excite (Multiply) (None, None, None, 0 ['block1a_activation[0][0]', 32) 'block1a_se_expand[0][0]'] block1a_project_conv (Conv2D) (None, None, None, 512 ['block1a_se_excite[0][0]'] 16) block1a_project_bn (BatchNorma (None, None, None, 64 ['block1a_project_conv[0][0]'] lization) 16) block2a_expand_conv (Conv2D) (None, None, None, 1536 ['block1a_project_bn[0][0]'] 96) block2a_expand_bn (BatchNormal (None, None, None, 384 ['block2a_expand_conv[0][0]'] ization) 96) block2a_expand_activation (Act (None, None, None, 0 ['block2a_expand_bn[0][0]'] ivation) 96) block2a_dwconv_pad (ZeroPaddin (None, None, None, 0 ['block2a_expand_activation[0][0] g2D) 96) '] block2a_dwconv (DepthwiseConv2 (None, None, None, 864 ['block2a_dwconv_pad[0][0]'] D) 96) block2a_bn (BatchNormalization (None, None, None, 384 ['block2a_dwconv[0][0]'] ) 96) block2a_activation (Activation (None, None, None, 0 ['block2a_bn[0][0]'] ) 96) block2a_se_squeeze (GlobalAver (None, 96) 0 ['block2a_activation[0][0]'] agePooling2D) block2a_se_reshape (Reshape) (None, 1, 1, 96) 0 ['block2a_se_squeeze[0][0]'] block2a_se_reduce (Conv2D) (None, 1, 1, 4) 388 ['block2a_se_reshape[0][0]'] block2a_se_expand (Conv2D) (None, 1, 1, 96) 480 ['block2a_se_reduce[0][0]'] block2a_se_excite (Multiply) (None, None, None, 0 ['block2a_activation[0][0]', 96) 'block2a_se_expand[0][0]'] block2a_project_conv (Conv2D) (None, None, None, 2304 ['block2a_se_excite[0][0]'] 24) block2a_project_bn (BatchNorma (None, None, None, 96 ['block2a_project_conv[0][0]'] lization) 24) block2b_expand_conv (Conv2D) (None, None, None, 3456 ['block2a_project_bn[0][0]'] 144) block2b_expand_bn (BatchNormal (None, None, None, 576 ['block2b_expand_conv[0][0]'] ization) 144) block2b_expand_activation (Act (None, None, None, 0 ['block2b_expand_bn[0][0]'] ivation) 144) block2b_dwconv (DepthwiseConv2 (None, None, None, 1296 ['block2b_expand_activation[0][0] D) 144) '] block2b_bn (BatchNormalization (None, None, None, 576 ['block2b_dwconv[0][0]'] ) 144) block2b_activation (Activation (None, None, None, 0 ['block2b_bn[0][0]'] ) 144) block2b_se_squeeze (GlobalAver (None, 144) 0 ['block2b_activation[0][0]'] agePooling2D) block2b_se_reshape (Reshape) (None, 1, 1, 144) 0 ['block2b_se_squeeze[0][0]'] block2b_se_reduce (Conv2D) (None, 1, 1, 6) 870 ['block2b_se_reshape[0][0]'] block2b_se_expand (Conv2D) (None, 1, 1, 144) 1008 ['block2b_se_reduce[0][0]'] block2b_se_excite (Multiply) (None, None, None, 0 ['block2b_activation[0][0]', 144) 'block2b_se_expand[0][0]'] block2b_project_conv (Conv2D) (None, None, None, 3456 ['block2b_se_excite[0][0]'] 24) block2b_project_bn (BatchNorma (None, None, None, 96 ['block2b_project_conv[0][0]'] lization) 24) block2b_drop (Dropout) (None, None, None, 0 ['block2b_project_bn[0][0]'] 24) block2b_add (Add) (None, None, None, 0 ['block2b_drop[0][0]', 24) 'block2a_project_bn[0][0]'] block3a_expand_conv (Conv2D) (None, None, None, 3456 ['block2b_add[0][0]'] 144) block3a_expand_bn (BatchNormal (None, None, None, 576 ['block3a_expand_conv[0][0]'] ization) 144) block3a_expand_activation (Act (None, None, None, 0 ['block3a_expand_bn[0][0]'] ivation) 144) block3a_dwconv_pad (ZeroPaddin (None, None, None, 0 ['block3a_expand_activation[0][0] g2D) 144) '] block3a_dwconv (DepthwiseConv2 (None, None, None, 3600 ['block3a_dwconv_pad[0][0]'] D) 144) block3a_bn (BatchNormalization (None, None, None, 576 ['block3a_dwconv[0][0]'] ) 144) block3a_activation (Activation (None, None, None, 0 ['block3a_bn[0][0]'] ) 144) block3a_se_squeeze (GlobalAver (None, 144) 0 ['block3a_activation[0][0]'] agePooling2D) block3a_se_reshape (Reshape) (None, 1, 1, 144) 0 ['block3a_se_squeeze[0][0]'] block3a_se_reduce (Conv2D) (None, 1, 1, 6) 870 ['block3a_se_reshape[0][0]'] block3a_se_expand (Conv2D) (None, 1, 1, 144) 1008 ['block3a_se_reduce[0][0]'] block3a_se_excite (Multiply) (None, None, None, 0 ['block3a_activation[0][0]', 144) 'block3a_se_expand[0][0]'] block3a_project_conv (Conv2D) (None, None, None, 5760 ['block3a_se_excite[0][0]'] 40) block3a_project_bn (BatchNorma (None, None, None, 160 ['block3a_project_conv[0][0]'] lization) 40) block3b_expand_conv (Conv2D) (None, None, None, 9600 ['block3a_project_bn[0][0]'] 240) block3b_expand_bn (BatchNormal (None, None, None, 960 ['block3b_expand_conv[0][0]'] ization) 240) block3b_expand_activation (Act (None, None, None, 0 ['block3b_expand_bn[0][0]'] ivation) 240) block3b_dwconv (DepthwiseConv2 (None, None, None, 6000 ['block3b_expand_activation[0][0] D) 240) '] block3b_bn (BatchNormalization (None, None, None, 960 ['block3b_dwconv[0][0]'] ) 240) block3b_activation (Activation (None, None, None, 0 ['block3b_bn[0][0]'] ) 240) block3b_se_squeeze (GlobalAver (None, 240) 0 ['block3b_activation[0][0]'] agePooling2D) block3b_se_reshape (Reshape) (None, 1, 1, 240) 0 ['block3b_se_squeeze[0][0]'] block3b_se_reduce (Conv2D) (None, 1, 1, 10) 2410 ['block3b_se_reshape[0][0]'] block3b_se_expand (Conv2D) (None, 1, 1, 240) 2640 ['block3b_se_reduce[0][0]'] block3b_se_excite (Multiply) (None, None, None, 0 ['block3b_activation[0][0]', 240) 'block3b_se_expand[0][0]'] block3b_project_conv (Conv2D) (None, None, None, 9600 ['block3b_se_excite[0][0]'] 40) block3b_project_bn (BatchNorma (None, None, None, 160 ['block3b_project_conv[0][0]'] lization) 40) block3b_drop (Dropout) (None, None, None, 0 ['block3b_project_bn[0][0]'] 40) block3b_add (Add) (None, None, None, 0 ['block3b_drop[0][0]', 40) 'block3a_project_bn[0][0]'] block4a_expand_conv (Conv2D) (None, None, None, 9600 ['block3b_add[0][0]'] 240) block4a_expand_bn (BatchNormal (None, None, None, 960 ['block4a_expand_conv[0][0]'] ization) 240) block4a_expand_activation (Act (None, None, None, 0 ['block4a_expand_bn[0][0]'] ivation) 240) block4a_dwconv_pad (ZeroPaddin (None, None, None, 0 ['block4a_expand_activation[0][0] g2D) 240) '] block4a_dwconv (DepthwiseConv2 (None, None, None, 2160 ['block4a_dwconv_pad[0][0]'] D) 240) block4a_bn (BatchNormalization (None, None, None, 960 ['block4a_dwconv[0][0]'] ) 240) block4a_activation (Activation (None, None, None, 0 ['block4a_bn[0][0]'] ) 240) block4a_se_squeeze (GlobalAver (None, 240) 0 ['block4a_activation[0][0]'] agePooling2D) block4a_se_reshape (Reshape) (None, 1, 1, 240) 0 ['block4a_se_squeeze[0][0]'] block4a_se_reduce (Conv2D) (None, 1, 1, 10) 2410 ['block4a_se_reshape[0][0]'] block4a_se_expand (Conv2D) (None, 1, 1, 240) 2640 ['block4a_se_reduce[0][0]'] block4a_se_excite (Multiply) (None, None, None, 0 ['block4a_activation[0][0]', 240) 'block4a_se_expand[0][0]'] block4a_project_conv (Conv2D) (None, None, None, 19200 ['block4a_se_excite[0][0]'] 80) block4a_project_bn (BatchNorma (None, None, None, 320 ['block4a_project_conv[0][0]'] lization) 80) block4b_expand_conv (Conv2D) (None, None, None, 38400 ['block4a_project_bn[0][0]'] 480) block4b_expand_bn (BatchNormal (None, None, None, 1920 ['block4b_expand_conv[0][0]'] ization) 480) block4b_expand_activation (Act (None, None, None, 0 ['block4b_expand_bn[0][0]'] ivation) 480) block4b_dwconv (DepthwiseConv2 (None, None, None, 4320 ['block4b_expand_activation[0][0] D) 480) '] block4b_bn (BatchNormalization (None, None, None, 1920 ['block4b_dwconv[0][0]'] ) 480) block4b_activation (Activation (None, None, None, 0 ['block4b_bn[0][0]'] ) 480) block4b_se_squeeze (GlobalAver (None, 480) 0 ['block4b_activation[0][0]'] agePooling2D) block4b_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block4b_se_squeeze[0][0]'] block4b_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block4b_se_reshape[0][0]'] block4b_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block4b_se_reduce[0][0]'] block4b_se_excite (Multiply) (None, None, None, 0 ['block4b_activation[0][0]', 480) 'block4b_se_expand[0][0]'] block4b_project_conv (Conv2D) (None, None, None, 38400 ['block4b_se_excite[0][0]'] 80) block4b_project_bn (BatchNorma (None, None, None, 320 ['block4b_project_conv[0][0]'] lization) 80) block4b_drop (Dropout) (None, None, None, 0 ['block4b_project_bn[0][0]'] 80) block4b_add (Add) (None, None, None, 0 ['block4b_drop[0][0]', 80) 'block4a_project_bn[0][0]'] block4c_expand_conv (Conv2D) (None, None, None, 38400 ['block4b_add[0][0]'] 480) block4c_expand_bn (BatchNormal (None, None, None, 1920 ['block4c_expand_conv[0][0]'] ization) 480) block4c_expand_activation (Act (None, None, None, 0 ['block4c_expand_bn[0][0]'] ivation) 480) block4c_dwconv (DepthwiseConv2 (None, None, None, 4320 ['block4c_expand_activation[0][0] D) 480) '] block4c_bn (BatchNormalization (None, None, None, 1920 ['block4c_dwconv[0][0]'] ) 480) block4c_activation (Activation (None, None, None, 0 ['block4c_bn[0][0]'] ) 480) block4c_se_squeeze (GlobalAver (None, 480) 0 ['block4c_activation[0][0]'] agePooling2D) block4c_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block4c_se_squeeze[0][0]'] block4c_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block4c_se_reshape[0][0]'] block4c_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block4c_se_reduce[0][0]'] block4c_se_excite (Multiply) (None, None, None, 0 ['block4c_activation[0][0]', 480) 'block4c_se_expand[0][0]'] block4c_project_conv (Conv2D) (None, None, None, 38400 ['block4c_se_excite[0][0]'] 80) block4c_project_bn (BatchNorma (None, None, None, 320 ['block4c_project_conv[0][0]'] lization) 80) block4c_drop (Dropout) (None, None, None, 0 ['block4c_project_bn[0][0]'] 80) block4c_add (Add) (None, None, None, 0 ['block4c_drop[0][0]', 80) 'block4b_add[0][0]'] block5a_expand_conv (Conv2D) (None, None, None, 38400 ['block4c_add[0][0]'] 480) block5a_expand_bn (BatchNormal (None, None, None, 1920 ['block5a_expand_conv[0][0]'] ization) 480) block5a_expand_activation (Act (None, None, None, 0 ['block5a_expand_bn[0][0]'] ivation) 480) block5a_dwconv (DepthwiseConv2 (None, None, None, 12000 ['block5a_expand_activation[0][0] D) 480) '] block5a_bn (BatchNormalization (None, None, None, 1920 ['block5a_dwconv[0][0]'] ) 480) block5a_activation (Activation (None, None, None, 0 ['block5a_bn[0][0]'] ) 480) block5a_se_squeeze (GlobalAver (None, 480) 0 ['block5a_activation[0][0]'] agePooling2D) block5a_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block5a_se_squeeze[0][0]'] block5a_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block5a_se_reshape[0][0]'] block5a_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block5a_se_reduce[0][0]'] block5a_se_excite (Multiply) (None, None, None, 0 ['block5a_activation[0][0]', 480) 'block5a_se_expand[0][0]'] block5a_project_conv (Conv2D) (None, None, None, 53760 ['block5a_se_excite[0][0]'] 112) block5a_project_bn (BatchNorma (None, None, None, 448 ['block5a_project_conv[0][0]'] lization) 112) block5b_expand_conv (Conv2D) (None, None, None, 75264 ['block5a_project_bn[0][0]'] 672) block5b_expand_bn (BatchNormal (None, None, None, 2688 ['block5b_expand_conv[0][0]'] ization) 672) block5b_expand_activation (Act (None, None, None, 0 ['block5b_expand_bn[0][0]'] ivation) 672) block5b_dwconv (DepthwiseConv2 (None, None, None, 16800 ['block5b_expand_activation[0][0] D) 672) '] block5b_bn (BatchNormalization (None, None, None, 2688 ['block5b_dwconv[0][0]'] ) 672) block5b_activation (Activation (None, None, None, 0 ['block5b_bn[0][0]'] ) 672) block5b_se_squeeze (GlobalAver (None, 672) 0 ['block5b_activation[0][0]'] agePooling2D) block5b_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5b_se_squeeze[0][0]'] block5b_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5b_se_reshape[0][0]'] block5b_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5b_se_reduce[0][0]'] block5b_se_excite (Multiply) (None, None, None, 0 ['block5b_activation[0][0]', 672) 'block5b_se_expand[0][0]'] block5b_project_conv (Conv2D) (None, None, None, 75264 ['block5b_se_excite[0][0]'] 112) block5b_project_bn (BatchNorma (None, None, None, 448 ['block5b_project_conv[0][0]'] lization) 112) block5b_drop (Dropout) (None, None, None, 0 ['block5b_project_bn[0][0]'] 112) block5b_add (Add) (None, None, None, 0 ['block5b_drop[0][0]', 112) 'block5a_project_bn[0][0]'] block5c_expand_conv (Conv2D) (None, None, None, 75264 ['block5b_add[0][0]'] 672) block5c_expand_bn (BatchNormal (None, None, None, 2688 ['block5c_expand_conv[0][0]'] ization) 672) block5c_expand_activation (Act (None, None, None, 0 ['block5c_expand_bn[0][0]'] ivation) 672) block5c_dwconv (DepthwiseConv2 (None, None, None, 16800 ['block5c_expand_activation[0][0] D) 672) '] block5c_bn (BatchNormalization (None, None, None, 2688 ['block5c_dwconv[0][0]'] ) 672) block5c_activation (Activation (None, None, None, 0 ['block5c_bn[0][0]'] ) 672) block5c_se_squeeze (GlobalAver (None, 672) 0 ['block5c_activation[0][0]'] agePooling2D) block5c_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5c_se_squeeze[0][0]'] block5c_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5c_se_reshape[0][0]'] block5c_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5c_se_reduce[0][0]'] block5c_se_excite (Multiply) (None, None, None, 0 ['block5c_activation[0][0]', 672) 'block5c_se_expand[0][0]'] block5c_project_conv (Conv2D) (None, None, None, 75264 ['block5c_se_excite[0][0]'] 112) block5c_project_bn (BatchNorma (None, None, None, 448 ['block5c_project_conv[0][0]'] lization) 112) block5c_drop (Dropout) (None, None, None, 0 ['block5c_project_bn[0][0]'] 112) block5c_add (Add) (None, None, None, 0 ['block5c_drop[0][0]', 112) 'block5b_add[0][0]'] block6a_expand_conv (Conv2D) (None, None, None, 75264 ['block5c_add[0][0]'] 672) block6a_expand_bn (BatchNormal (None, None, None, 2688 ['block6a_expand_conv[0][0]'] ization) 672) block6a_expand_activation (Act (None, None, None, 0 ['block6a_expand_bn[0][0]'] ivation) 672) block6a_dwconv_pad (ZeroPaddin (None, None, None, 0 ['block6a_expand_activation[0][0] g2D) 672) '] block6a_dwconv (DepthwiseConv2 (None, None, None, 16800 ['block6a_dwconv_pad[0][0]'] D) 672) block6a_bn (BatchNormalization (None, None, None, 2688 ['block6a_dwconv[0][0]'] ) 672) block6a_activation (Activation (None, None, None, 0 ['block6a_bn[0][0]'] ) 672) block6a_se_squeeze (GlobalAver (None, 672) 0 ['block6a_activation[0][0]'] agePooling2D) block6a_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block6a_se_squeeze[0][0]'] block6a_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block6a_se_reshape[0][0]'] block6a_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block6a_se_reduce[0][0]'] block6a_se_excite (Multiply) (None, None, None, 0 ['block6a_activation[0][0]', 672) 'block6a_se_expand[0][0]'] block6a_project_conv (Conv2D) (None, None, None, 129024 ['block6a_se_excite[0][0]'] 192) block6a_project_bn (BatchNorma (None, None, None, 768 ['block6a_project_conv[0][0]'] lization) 192) block6b_expand_conv (Conv2D) (None, None, None, 221184 ['block6a_project_bn[0][0]'] 1152) block6b_expand_bn (BatchNormal (None, None, None, 4608 ['block6b_expand_conv[0][0]'] ization) 1152) block6b_expand_activation (Act (None, None, None, 0 ['block6b_expand_bn[0][0]'] ivation) 1152) block6b_dwconv (DepthwiseConv2 (None, None, None, 28800 ['block6b_expand_activation[0][0] D) 1152) '] block6b_bn (BatchNormalization (None, None, None, 4608 ['block6b_dwconv[0][0]'] ) 1152) block6b_activation (Activation (None, None, None, 0 ['block6b_bn[0][0]'] ) 1152) block6b_se_squeeze (GlobalAver (None, 1152) 0 ['block6b_activation[0][0]'] agePooling2D) block6b_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6b_se_squeeze[0][0]'] block6b_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6b_se_reshape[0][0]'] block6b_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6b_se_reduce[0][0]'] block6b_se_excite (Multiply) (None, None, None, 0 ['block6b_activation[0][0]', 1152) 'block6b_se_expand[0][0]'] block6b_project_conv (Conv2D) (None, None, None, 221184 ['block6b_se_excite[0][0]'] 192) block6b_project_bn (BatchNorma (None, None, None, 768 ['block6b_project_conv[0][0]'] lization) 192) block6b_drop (Dropout) (None, None, None, 0 ['block6b_project_bn[0][0]'] 192) block6b_add (Add) (None, None, None, 0 ['block6b_drop[0][0]', 192) 'block6a_project_bn[0][0]'] block6c_expand_conv (Conv2D) (None, None, None, 221184 ['block6b_add[0][0]'] 1152) block6c_expand_bn (BatchNormal (None, None, None, 4608 ['block6c_expand_conv[0][0]'] ization) 1152) block6c_expand_activation (Act (None, None, None, 0 ['block6c_expand_bn[0][0]'] ivation) 1152) block6c_dwconv (DepthwiseConv2 (None, None, None, 28800 ['block6c_expand_activation[0][0] D) 1152) '] block6c_bn (BatchNormalization (None, None, None, 4608 ['block6c_dwconv[0][0]'] ) 1152) block6c_activation (Activation (None, None, None, 0 ['block6c_bn[0][0]'] ) 1152) block6c_se_squeeze (GlobalAver (None, 1152) 0 ['block6c_activation[0][0]'] agePooling2D) block6c_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6c_se_squeeze[0][0]'] block6c_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6c_se_reshape[0][0]'] block6c_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6c_se_reduce[0][0]'] block6c_se_excite (Multiply) (None, None, None, 0 ['block6c_activation[0][0]', 1152) 'block6c_se_expand[0][0]'] block6c_project_conv (Conv2D) (None, None, None, 221184 ['block6c_se_excite[0][0]'] 192) block6c_project_bn (BatchNorma (None, None, None, 768 ['block6c_project_conv[0][0]'] lization) 192) block6c_drop (Dropout) (None, None, None, 0 ['block6c_project_bn[0][0]'] 192) block6c_add (Add) (None, None, None, 0 ['block6c_drop[0][0]', 192) 'block6b_add[0][0]'] block6d_expand_conv (Conv2D) (None, None, None, 221184 ['block6c_add[0][0]'] 1152) block6d_expand_bn (BatchNormal (None, None, None, 4608 ['block6d_expand_conv[0][0]'] ization) 1152) block6d_expand_activation (Act (None, None, None, 0 ['block6d_expand_bn[0][0]'] ivation) 1152) block6d_dwconv (DepthwiseConv2 (None, None, None, 28800 ['block6d_expand_activation[0][0] D) 1152) '] block6d_bn (BatchNormalization (None, None, None, 4608 ['block6d_dwconv[0][0]'] ) 1152) block6d_activation (Activation (None, None, None, 0 ['block6d_bn[0][0]'] ) 1152) block6d_se_squeeze (GlobalAver (None, 1152) 0 ['block6d_activation[0][0]'] agePooling2D) block6d_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6d_se_squeeze[0][0]'] block6d_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6d_se_reshape[0][0]'] block6d_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6d_se_reduce[0][0]'] block6d_se_excite (Multiply) (None, None, None, 0 ['block6d_activation[0][0]', 1152) 'block6d_se_expand[0][0]'] block6d_project_conv (Conv2D) (None, None, None, 221184 ['block6d_se_excite[0][0]'] 192) block6d_project_bn (BatchNorma (None, None, None, 768 ['block6d_project_conv[0][0]'] lization) 192) block6d_drop (Dropout) (None, None, None, 0 ['block6d_project_bn[0][0]'] 192) block6d_add (Add) (None, None, None, 0 ['block6d_drop[0][0]', 192) 'block6c_add[0][0]'] block7a_expand_conv (Conv2D) (None, None, None, 221184 ['block6d_add[0][0]'] 1152) block7a_expand_bn (BatchNormal (None, None, None, 4608 ['block7a_expand_conv[0][0]'] ization) 1152) block7a_expand_activation (Act (None, None, None, 0 ['block7a_expand_bn[0][0]'] ivation) 1152) block7a_dwconv (DepthwiseConv2 (None, None, None, 10368 ['block7a_expand_activation[0][0] D) 1152) '] block7a_bn (BatchNormalization (None, None, None, 4608 ['block7a_dwconv[0][0]'] ) 1152) block7a_activation (Activation (None, None, None, 0 ['block7a_bn[0][0]'] ) 1152) block7a_se_squeeze (GlobalAver (None, 1152) 0 ['block7a_activation[0][0]'] agePooling2D) block7a_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block7a_se_squeeze[0][0]'] block7a_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block7a_se_reshape[0][0]'] block7a_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block7a_se_reduce[0][0]'] block7a_se_excite (Multiply) (None, None, None, 0 ['block7a_activation[0][0]', 1152) 'block7a_se_expand[0][0]'] block7a_project_conv (Conv2D) (None, None, None, 368640 ['block7a_se_excite[0][0]'] 320) block7a_project_bn (BatchNorma (None, None, None, 1280 ['block7a_project_conv[0][0]'] lization) 320) top_conv (Conv2D) (None, None, None, 409600 ['block7a_project_bn[0][0]'] 1280) top_bn (BatchNormalization) (None, None, None, 5120 ['top_conv[0][0]'] 1280) top_activation (Activation) (None, None, None, 0 ['top_bn[0][0]'] 1280) ================================================================================================== Total params: 4,049,571 Trainable params: 0 Non-trainable params: 4,049,571 __________________________________________________________________________________________________
You can see how each of the different layers have a certain number of parameters each. Since we are using a pre-trained model, you can think of all of these parameters are patterns the base model has learned on another dataset. And because we set base_model.trainable = False
, these patterns remain as they are during training (they're frozen and don't get updated).
Alright that was the base model, let's see the summary of our overall model.
# Check summary of model constructed with Functional API
model_0.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_layer (InputLayer) [(None, 224, 224, 3)] 0 efficientnetb0 (Functional) (None, None, None, 1280) 4049571 global_average_pooling_laye (None, 1280) 0 r (GlobalAveragePooling2D) output_layer (Dense) (None, 10) 12810 ================================================================= Total params: 4,062,381 Trainable params: 12,810 Non-trainable params: 4,049,571 _________________________________________________________________
Our overall model has five layers but really, one of those layers (efficientnetb0
) has 236 layers.
You can see how the output shape started out as (None, 224, 224, 3)
for the input layer (the shape of our images) but was transformed to be (None, 10)
by the output layer (the shape of our labels), where None
is the placeholder for the batch size.
Notice too, the only trainable parameters in the model are those in the output layer.
How do our model's training curves look?
# Check out our model's training curves
plot_loss_curves(history_10_percent)
Getting a feature vector from a trained model¶
🤔 Question: What happens with the
tf.keras.layers.GlobalAveragePooling2D()
layer? I haven't seen it before.
The tf.keras.layers.GlobalAveragePooling2D()
layer transforms a 4D tensor into a 2D tensor by averaging the values across the inner-axes.
The previous sentence is a bit of a mouthful, so let's see an example.
# Define input tensor shape (same number of dimensions as the output of efficientnetb0)
input_shape = (1, 4, 4, 3)
# Create a random tensor
tf.random.set_seed(42)
input_tensor = tf.random.normal(input_shape)
print(f"Random input tensor:\n {input_tensor}\n")
# Pass the random tensor through a global average pooling 2D layer
global_average_pooled_tensor = tf.keras.layers.GlobalAveragePooling2D()(input_tensor)
print(f"2D global average pooled random tensor:\n {global_average_pooled_tensor}\n")
# Check the shapes of the different tensors
print(f"Shape of input tensor: {input_tensor.shape}")
print(f"Shape of 2D global averaged pooled input tensor: {global_average_pooled_tensor.shape}")
Random input tensor: [[[[ 0.3274685 -0.8426258 0.3194337 ] [-1.4075519 -2.3880599 -1.0392479 ] [-0.5573232 0.539707 1.6994323 ] [ 0.28893656 -1.5066116 -0.2645474 ]] [[-0.59722406 -1.9171132 -0.62044144] [ 0.8504023 -0.40604794 -3.0258412 ] [ 0.9058464 0.29855987 -0.22561555] [-0.7616443 -1.8917141 -0.93847126]] [[ 0.77852213 -0.47338897 0.97772694] [ 0.24694404 0.20573747 -0.5256233 ] [ 0.32410017 0.02545409 -0.10638497] [-0.6369475 1.1603122 0.2507359 ]] [[-0.41728503 0.4012578 -1.4145443 ] [-0.5931857 -1.6617213 0.33567193] [ 0.10815629 0.23479682 -0.56668764] [-0.35819843 0.88698614 0.52744764]]]] 2D global average pooled random tensor: [[-0.09368646 -0.45840448 -0.2885598 ]] Shape of input tensor: (1, 4, 4, 3) Shape of 2D global averaged pooled input tensor: (1, 3)
You can see the tf.keras.layers.GlobalAveragePooling2D()
layer condensed the input tensor from shape (1, 4, 4, 3)
to (1, 3)
. It did so by averaging the input_tensor
across the middle two axes.
We can replicate this operation using the tf.reduce_mean()
operation and specifying the appropriate axes.
# This is the same as GlobalAveragePooling2D()
tf.reduce_mean(input_tensor, axis=[1, 2]) # average across the middle axes
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-0.09368646, -0.45840448, -0.2885598 ]], dtype=float32)>
Doing this not only makes the output of the base model compatible with the input shape requirement of our output layer (tf.keras.layers.Dense()
), it also condenses the information found by the base model into a lower dimension feature vector.
🔑 Note: One of the reasons feature extraction transfer learning is named how it is is because what often happens is a pretrained model outputs a feature vector (a long tensor of numbers, in our case, this is the output of the
tf.keras.layers.GlobalAveragePooling2D()
layer) which can then be used to extract patterns out of.
🛠 Practice: Do the same as the above cell but for
tf.keras.layers.GlobalMaxPool2D()
.
Running a series of transfer learning experiments¶
We've seen the incredible results of transfer learning on 10% of the training data, what about 1% of the training data?
What kind of results do you think we can get using 100x less data than the original CNN models we built ourselves?
Why don't we answer that question while running the following modelling experiments:
model_1
: Use feature extraction transfer learning on 1% of the training data with data augmentation.model_2
: Use feature extraction transfer learning on 10% of the training data with data augmentation.model_3
: Use fine-tuning transfer learning on 10% of the training data with data augmentation.model_4
: Use fine-tuning transfer learning on 100% of the training data with data augmentation.
While all of the experiments will be run on different versions of the training data, they will all be evaluated on the same test dataset, this ensures the results of each experiment are as comparable as possible.
All experiments will be done using the EfficientNetB0
model within the tf.keras.applications
module.
To make sure we're keeping track of our experiments, we'll use our create_tensorboard_callback()
function to log all of the model training logs.
We'll construct each model using the Keras Functional API and instead of implementing data augmentation in the ImageDataGenerator
class as we have previously, we're going to build it right into the model using the tf.keras.layers.experimental.preprocessing
module.
Let's begin by downloading the data for experiment 1, using feature extraction transfer learning on 1% of the training data with data augmentation.
# Download and unzip data
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_1_percent.zip
unzip_data("10_food_classes_1_percent.zip")
# Create training and test dirs
train_dir_1_percent = "10_food_classes_1_percent/train/"
test_dir = "10_food_classes_1_percent/test/"
--2023-05-12 08:02:54-- https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_1_percent.zip Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.204.128, 64.233.188.128, 64.233.189.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.204.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 133612354 (127M) [application/zip] Saving to: ‘10_food_classes_1_percent.zip’ 10_food_classes_1_p 100%[===================>] 127.42M 25.7MB/s in 6.0s 2023-05-12 08:03:00 (21.2 MB/s) - ‘10_food_classes_1_percent.zip’ saved [133612354/133612354]
How many images are we working with?
# Walk through 1 percent data directory and list number of files
walk_through_dir("10_food_classes_1_percent")
There are 2 directories and 0 images in '10_food_classes_1_percent'. There are 10 directories and 0 images in '10_food_classes_1_percent/train'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/hamburger'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/steak'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/grilled_salmon'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/chicken_curry'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/sushi'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/ramen'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/ice_cream'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/pizza'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/chicken_wings'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/fried_rice'. There are 10 directories and 0 images in '10_food_classes_1_percent/test'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/hamburger'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/steak'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/grilled_salmon'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/chicken_curry'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/sushi'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/ramen'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/ice_cream'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/pizza'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/chicken_wings'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/fried_rice'.
Alright, looks like we've only got seven images of each class, this should be a bit of a challenge for our model.
🔑 Note: As with the 10% of data subset, the 1% of images were chosen at random from the original full training dataset. The test images are the same as the ones which have previously been used. If you want to see how this data was preprocessed, check out the Food Vision Image Preprocessing notebook.
Time to load our images in as tf.data.Dataset
objects, to do so, we'll use the image_dataset_from_directory()
method.
import tensorflow as tf
IMG_SIZE = (224, 224)
train_data_1_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir_1_percent,
label_mode="categorical",
batch_size=32, # default
image_size=IMG_SIZE)
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
label_mode="categorical",
image_size=IMG_SIZE)
Found 70 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
Data loaded. Time to augment it.
Adding data augmentation right into the model¶
Previously we've used the different parameters of the ImageDataGenerator
class to augment our training images, this time we're going to build data augmentation right into the model.
How?
Using the tf.keras.layers.experimental.preprocessing
module and creating a dedicated data augmentation layer.
This a relatively new feature added to TensorFlow 2.2+ but it's very powerful. Adding a data augmentation layer to the model has the following benefits:
- Preprocessing of the images (augmenting them) happens on the GPU rather than on the CPU (much faster).
- Images are best preprocessed on the GPU where as text and structured data are more suited to be preprocessed on the CPU.
- Image data augmentation only happens during training so we can still export our whole model and use it elsewhere. And if someone else wanted to train the same model as us, including the same kind of data augmentation, they could.
Example of using data augmentation as the first layer within a model (EfficientNetB0).
🤔 Note: At the time of writing, the preprocessing layers we're using for data augmentation are in experimental status within the in TensorFlow library. This means although the layers should be considered stable, the code may change slightly in a future version of TensorFlow. For more information on the other preprocessing layers avaiable and the different methods of data augmentation, check out the Keras preprocessing layers guide and the TensorFlow data augmentation guide.
To use data augmentation right within our model we'll create a Keras Sequential model consisting of only data preprocessing layers, we can then use this Sequential model within another Functional model.
If that sounds confusing, it'll make sense once we create it in code.
The data augmentation transformations we're going to use are:
- RandomFlip - flips image on horizontal or vertical axis.
- RandomRotation - randomly rotates image by a specified amount.
- RandomZoom - randomly zooms into an image by specified amount.
- RandomHeight - randomly shifts image height by a specified amount.
- RandomWidth - randomly shifts image width by a specified amount.
- Rescaling - normalizes the image pixel values to be between 0 and 1, this is worth mentioning because it is required for some image models but since we're using the
tf.keras.applications
implementation ofEfficientNetB0
, it's not required.
There are more option but these will do for now.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# from tensorflow.keras.layers.experimental import preprocessing
# NEW: Newer versions of TensorFlow (2.10+) can use the tensorflow.keras.layers API directly for data augmentation
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
layers.RandomHeight(0.2),
layers.RandomWidth(0.2),
# preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNetB0
], name ="data_augmentation")
# # UPDATE: Previous versions of TensorFlow (e.g. 2.4 and below used the tensorflow.keras.layers.experimental.processing API)
# # Create a data augmentation stage with horizontal flipping, rotations, zooms
# data_augmentation = keras.Sequential([
# preprocessing.RandomFlip("horizontal"),
# preprocessing.RandomRotation(0.2),
# preprocessing.RandomZoom(0.2),
# preprocessing.RandomHeight(0.2),
# preprocessing.RandomWidth(0.2),
# # preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNetB0
# ], name ="data_augmentation")
And that's it! Our data augmentation Sequential model is ready to go. As you'll see shortly, we'll be able to slot this "model" as a layer into our transfer learning model later on.
But before we do that, let's test it out by passing random images through it.
# View a random image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import random
target_class = random.choice(train_data_1_percent.class_names) # choose a random class
target_dir = "10_food_classes_1_percent/train/" + target_class # create the target directory
random_image = random.choice(os.listdir(target_dir)) # choose a random image from target directory
random_image_path = target_dir + "/" + random_image # create the choosen random image path
img = mpimg.imread(random_image_path) # read in the chosen target image
plt.imshow(img) # plot the target image
plt.title(f"Original random image from class: {target_class}")
plt.axis(False); # turn off the axes
# Augment the image
augmented_img = data_augmentation(tf.expand_dims(img, axis=0)) # data augmentation model requires shape (None, height, width, 3)
plt.figure()
plt.imshow(tf.squeeze(augmented_img)/255.) # requires normalization after augmentation
plt.title(f"Augmented random image from class: {target_class}")
plt.axis(False);