05. Transfer Learning with TensorFlow Part 2: Fine-tuning¶
In the previous section, we saw how we could leverage feature extraction transfer learning to get far better results on our Food Vision project than building our own models (even with less data).
Now we're going to cover another type of transfer learning: fine-tuning.
In fine-tuning transfer learning the pre-trained model weights from another model are unfrozen and tweaked during to better suit your own data.
For feature extraction transfer learning, you may only train the top 1-3 layers of a pre-trained model with your own data, in fine-tuning transfer learning, you might train 1-3+ layers of a pre-trained model (where the '+' indicates that many or all of the layers could be trained).
Feature extraction transfer learning vs. fine-tuning transfer learning. The main difference between the two is that in fine-tuning, more layers of the pre-trained model get unfrozen and tuned on custom data. This fine-tuning usually takes more data than feature extraction to be effective.
What we're going to cover¶
We're going to go through the follow with TensorFlow:
- Introduce fine-tuning, a type of transfer learning to modify a pre-trained model to be more suited to your data
- Using the Keras Functional API (a differnt way to build models in Keras)
- Using a smaller dataset to experiment faster (e.g. 1-10% of training samples of 10 classes of food)
- Data augmentation (how to make your training dataset more diverse without adding more data)
- Running a series of modelling experiments on our Food Vision data
- Model 0: a transfer learning model using the Keras Functional API
- Model 1: a feature extraction transfer learning model on 1% of the data with data augmentation
- Model 2: a feature extraction transfer learning model on 10% of the data with data augmentation
- Model 3: a fine-tuned transfer learning model on 10% of the data
- Model 4: a fine-tuned transfer learning model on 100% of the data
- Introduce the ModelCheckpoint callback to save intermediate training results
- Compare model experiments results using TensorBoard
How you can use this notebook¶
You can read through the descriptions and the code (it should all run, except for the cells which error on purpose), but there's a better option.
Write all of the code yourself.
Yes. I'm serious. Create a new notebook, and rewrite each line by yourself. Investigate it, see if you can break it, why does it break?
You don't have to write the text descriptions but writing the code yourself is a great way to get hands-on experience.
Don't worry if you make mistakes, we all do. The way to get better and make less mistakes is to write more code.
import datetime
print(f"Notebook last run (end-to-end): {datetime.datetime.now()}")
Notebook last run (end-to-end): 2023-08-18 01:39:51.865865
🔑 Note: As of TensorFlow 2.10+ there seems to be issues with the
tf.keras.applications.efficientnet
models (used later on) when loading weights via theload_weights()
methods.To fix this, I've updated the code to use
tf.keras.applications.efficientnet_v2
, this is a small change but results in far less errors.You can see the full write-up on the course GitHub.
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
TensorFlow version: 2.12.0
# Are we using a GPU? (if not & you're using Google Colab, go to Runtime -> Change Runtime Type -> Harware Accelerator: GPU )
!nvidia-smi
Fri Aug 18 01:39:54 2023 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 | | N/A 36C P0 24W / 300W | 0MiB / 16384MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
Creating helper functions¶
Throughout your machine learning experiments, you'll likely come across snippets of code you want to use over and over again.
For example, a plotting function which plots a model's history
object (see plot_loss_curves()
below).
You could recreate these functions over and over again.
But as you might've guessed, rewritting the same functions becomes tedious.
One of the solutions is to store them in a helper script such as helper_functions.py
. And then import the necesary functionality when you need it.
For example, you might write:
from helper_functions import plot_loss_curves
...
plot_loss_curves(history)
Let's see what this looks like.
# Get helper_functions.py script from course GitHub
!wget https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
# Import helper functions we're going to use
from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, walk_through_dir
--2023-08-18 01:39:54-- https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 10246 (10K) [text/plain] Saving to: ‘helper_functions.py’ helper_functions.py 100%[===================>] 10.01K --.-KB/s in 0s 2023-08-18 01:39:54 (41.3 MB/s) - ‘helper_functions.py’ saved [10246/10246]
Wonderful, now we've got a bunch of helper functions we can use throughout the notebook without having to rewrite them from scratch each time.
🔑 Note: If you're running this notebook in Google Colab, when it times out Colab will delete the
helper_functions.py
file. So to use the functions imported above, you'll have to rerun the cell.
10 Food Classes: Working with less data¶
We saw in the previous notebook that we could get great results with only 10% of the training data using transfer learning with TensorFlow Hub.
In this notebook, we're going to continue to work with smaller subsets of the data, except this time we'll have a look at how we can use the in-built pretrained models within the tf.keras.applications
module as well as how to fine-tune them to our own custom dataset.
We'll also practice using a new but similar dataloader function to what we've used before, image_dataset_from_directory()
which is part of the tf.keras.utils
module.
Finally, we'll also be practicing using the Keras Functional API for building deep learning models. The Functional API is a more flexible way to create models than the tf.keras.Sequential API.
We'll explore each of these in more detail as we go.
Let's start by downloading some data.
# Get 10% of the data of the 10 classes
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_10_percent.zip
unzip_data("10_food_classes_10_percent.zip")
--2023-08-18 01:39:55-- https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_10_percent.zip Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.203.128, 74.125.199.128, 74.125.195.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.203.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 168546183 (161M) [application/zip] Saving to: ‘10_food_classes_10_percent.zip’ 10_food_classes_10_ 100%[===================>] 160.74M 242MB/s in 0.7s 2023-08-18 01:39:56 (242 MB/s) - ‘10_food_classes_10_percent.zip’ saved [168546183/168546183]
The dataset we're downloading is the 10 food classes dataset (from Food 101) with 10% of the training images we used in the previous notebook.
🔑 Note: You can see how this dataset was created in the image data modification notebook.
# Walk through 10 percent data directory and list number of files
walk_through_dir("10_food_classes_10_percent")
There are 2 directories and 0 images in '10_food_classes_10_percent'. There are 10 directories and 0 images in '10_food_classes_10_percent/train'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/ramen'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_curry'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/pizza'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/ice_cream'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/grilled_salmon'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/steak'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_wings'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/hamburger'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/sushi'. There are 0 directories and 75 images in '10_food_classes_10_percent/train/fried_rice'. There are 10 directories and 0 images in '10_food_classes_10_percent/test'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/ramen'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_curry'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/pizza'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/ice_cream'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/grilled_salmon'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/steak'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_wings'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/hamburger'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/sushi'. There are 0 directories and 250 images in '10_food_classes_10_percent/test/fried_rice'.
We can see that each of the training directories contain 75 images and each of the testing directories contain 250 images.
Let's define our training and test filepaths.
# Create training and test directories
train_dir = "10_food_classes_10_percent/train/"
test_dir = "10_food_classes_10_percent/test/"
Now we've got some image data, we need a way of loading it into a TensorFlow compatible format.
Previously, we've used the ImageDataGenerator
class.
However, as of August 2023, this class is deprecated and isn't recommended for future usage (it's too slow).
Because of this, we'll move onto using tf.keras.utils.image_dataset_from_directory()
.
This method expects image data in the following file format:
Example of file structure
10_food_classes_10_percent <- top level folder
└───train <- training images
│ └───pizza
│ │ │ 1008104.jpg
│ │ │ 1638227.jpg
│ │ │ ...
│ └───steak
│ │ 1000205.jpg
│ │ 1647351.jpg
│ │ ...
│
└───test <- testing images
│ └───pizza
│ │ │ 1001116.jpg
│ │ │ 1507019.jpg
│ │ │ ...
│ └───steak
│ │ 100274.jpg
│ │ 1653815.jpg
│ │ ...
One of the main benefits of using tf.keras.prepreprocessing.image_dataset_from_directory()
rather than ImageDataGenerator
is that it creates a tf.data.Dataset
object rather than a generator.
The main advantage of this is the tf.data.Dataset
API is much more efficient (faster) than the ImageDataGenerator
API which is paramount for larger datasets.
Let's see it in action.
# Create data inputs
import tensorflow as tf
IMG_SIZE = (224, 224) # define image size
train_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(directory=train_dir,
image_size=IMG_SIZE,
label_mode="categorical", # what type are the labels?
batch_size=32) # batch_size is 32 by default, this is generally a good number
test_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(directory=test_dir,
image_size=IMG_SIZE,
label_mode="categorical")
Found 750 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
Wonderful! Looks like our dataloaders have found the correct number of images for each dataset.
For now, the main parameters we're concerned about in the image_dataset_from_directory()
funtion are:
directory
- the filepath of the target directory we're loading images in from.image_size
- the target size of the images we're going to load in (height, width).batch_size
- the batch size of the images we're going to load in. For example if thebatch_size
is 32 (the default), batches of 32 images and labels at a time will be passed to the model.
There are more we could play around with if we needed to [in the documentation.
If we check the training data datatype we should see it as a BatchDataset
with shapes relating to our data.
# Check the training data datatype
train_data_10_percent
<_BatchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>
In the above output:
(None, 224, 224, 3)
refers to the tensor shape of our images whereNone
is the batch size,224
is the height (and width) and3
is the color channels (red, green, blue).(None, 10)
refers to the tensor shape of the labels whereNone
is the batch size and10
is the number of possible labels (the 10 different food classes).- Both image tensors and labels are of the datatype
tf.float32
.
The batch_size
is None
due to it only being used during model training. You can think of None
as a placeholder waiting to be filled with the batch_size
parameter from image_dataset_from_directory()
.
Another benefit of using the tf.data.Dataset
API are the assosciated methods which come with it.
For example, if we want to find the name of the classes we were working with, we could use the class_names
attribute.
# Check out the class names of our dataset
train_data_10_percent.class_names
['chicken_curry', 'chicken_wings', 'fried_rice', 'grilled_salmon', 'hamburger', 'ice_cream', 'pizza', 'ramen', 'steak', 'sushi']
Or if we wanted to see an example batch of data, we could use the take()
method.
# See an example batch of data
for images, labels in train_data_10_percent.take(1):
print(images, labels)
tf.Tensor( [[[[1.18658157e+02 1.34658173e+02 1.34658173e+02] [1.18117348e+02 1.34117340e+02 1.34117340e+02] [1.19637756e+02 1.35637756e+02 1.35637756e+02] ... [5.63165398e+01 9.91634979e+01 9.18114929e+01] [6.20816345e+01 1.09224518e+02 1.03224518e+02] [6.39487991e+01 1.12260063e+02 1.08489639e+02]] [[1.06357147e+02 1.22357147e+02 1.21357147e+02] [1.09709190e+02 1.25709190e+02 1.24709190e+02] [1.12872452e+02 1.28872452e+02 1.27872452e+02] ... [6.32701912e+01 1.05913071e+02 9.67702332e+01] [6.22856750e+01 1.07270393e+02 1.02076508e+02] [5.61224899e+01 1.01571533e+02 9.63112946e+01]] [[9.16428604e+01 1.06071434e+02 1.05857147e+02] [9.68418427e+01 1.11270416e+02 1.11056129e+02] [1.00045921e+02 1.14474495e+02 1.14260208e+02] ... [8.43824387e+01 1.24550812e+02 1.16336548e+02] [7.37497177e+01 1.15162979e+02 1.08994621e+02] [2.96886349e+01 7.09029236e+01 6.69029236e+01]] ... [[7.94541016e+01 8.68878021e+01 8.90306473e+01] [8.81582336e+01 9.64031525e+01 9.78163910e+01] [9.55510254e+01 1.03428596e+02 1.04954079e+02] ... [1.24428589e+02 1.21428589e+02 1.16000000e+02] [1.22801048e+02 1.19586754e+02 1.12158165e+02] [1.22933716e+02 1.19719421e+02 1.11862244e+02]] [[9.76941833e+01 1.07643173e+02 1.08668678e+02] [1.03255264e+02 1.13250168e+02 1.12250168e+02] [9.99133377e+01 1.10097023e+02 1.08500069e+02] ... [1.30443985e+02 1.27443985e+02 1.18586807e+02] [1.29714355e+02 1.25714355e+02 1.14790840e+02] [1.32857300e+02 1.27000122e+02 1.15071533e+02]] [[9.17858047e+01 1.03785805e+02 1.01785805e+02] [8.95970154e+01 1.01597015e+02 9.95970154e+01] [8.95051575e+01 1.01505157e+02 9.75051575e+01] ... [1.35357208e+02 1.33505112e+02 1.22775513e+02] [1.34025513e+02 1.30025513e+02 1.18357147e+02] [1.33086792e+02 1.29857208e+02 1.15545952e+02]]] [[[1.00561228e+01 1.34846935e+01 1.30714283e+01] [1.87397976e+01 1.10255098e+01 8.76530552e+00] [2.44948978e+01 1.09336739e+01 4.22448969e+00] ... [1.84438972e+01 1.08571215e+01 8.43875980e+00] [1.86428566e+01 1.16428576e+01 5.64285707e+00] [1.95867577e+01 1.18724718e+01 4.22961426e+00]] [[1.50255108e+01 1.14030609e+01 1.49285717e+01] [1.97806129e+01 9.20408058e+00 1.28622456e+01] [1.88571415e+01 7.15816259e+00 6.84183645e+00] ... [1.63265114e+01 1.22703876e+01 9.66830635e+00] [1.59438534e+01 1.30000134e+01 7.72956896e+00] [1.64080982e+01 1.43316650e+01 6.19381332e+00]] [[1.38571434e+01 7.85714293e+00 9.42857170e+00] [1.84438782e+01 1.00867357e+01 1.43010216e+01] [1.81887741e+01 1.03367348e+01 1.41887751e+01] ... [1.38112078e+01 1.21683502e+01 8.71933651e+00] [1.24285717e+01 1.30000000e+01 7.42857170e+00] [1.30765381e+01 1.40765381e+01 8.29082394e+00]] ... [[1.65714722e+01 1.17857361e+01 8.50511646e+00] [1.60561237e+01 1.00561237e+01 1.15153484e+01] [1.53060913e+01 9.21426392e+00 1.06938648e+01] ... [1.58775539e+01 1.18316450e+01 8.78573608e+00] [1.82703991e+01 1.32703981e+01 1.02703981e+01] [1.81377335e+01 1.31377335e+01 9.35199738e+00]] [[1.50000000e+01 1.00000000e+01 6.64285707e+00] [1.59948969e+01 9.99489689e+00 1.19847174e+01] [1.60714417e+01 1.00000000e+01 1.42143250e+01] ... [1.70561352e+01 1.30561342e+01 1.04846621e+01] [1.68622208e+01 1.28622208e+01 9.86222076e+00] [1.79540749e+01 1.29540758e+01 8.95407581e+00]] [[1.63571777e+01 1.13571777e+01 8.00003529e+00] [1.60000000e+01 1.00000000e+01 1.37142868e+01] [1.56428223e+01 8.64282227e+00 1.59234619e+01] ... [1.78521194e+01 1.44949112e+01 1.32806473e+01] [1.59285583e+01 1.19285583e+01 8.92855835e+00] [1.73571777e+01 1.23571777e+01 8.35717773e+00]]] [[[9.96428604e+01 4.36428566e+01 2.66428566e+01] [1.00974487e+02 4.49744911e+01 2.79744892e+01] [1.00928566e+02 4.55714302e+01 2.83571434e+01] ... [1.59892273e+02 1.26616730e+02 9.66270294e+01] [9.49591599e+01 5.68213806e+01 3.52959061e+01] [8.11578674e+01 3.91578674e+01 2.31578693e+01]] [[9.77397995e+01 4.17397957e+01 2.47397957e+01] [9.90051041e+01 4.30051041e+01 2.60051022e+01] [1.00770409e+02 4.54132690e+01 2.81989803e+01] ... [6.65968018e+01 3.01681900e+01 2.42336082e+00] [6.98725586e+01 2.89388580e+01 9.08175468e+00] [8.22501984e+01 3.64389687e+01 2.14440765e+01]] [[1.03571426e+02 4.71428566e+01 3.23571434e+01] [1.00071426e+02 4.40714264e+01 2.90714283e+01] [9.82602081e+01 4.22602043e+01 2.72602043e+01] ... [8.19692001e+01 4.12548027e+01 1.43263178e+01] [8.23265686e+01 3.61989899e+01 1.27245321e+01] [8.15460205e+01 3.23317337e+01 1.40460205e+01]] ... [[7.69943161e+01 4.02750015e+01 7.95883179e+00] [1.49617050e+02 1.07101807e+02 6.84590378e+01] [1.72382751e+02 1.23142990e+02 7.40257111e+01] ... [1.24050819e+02 1.10096733e+02 6.00713196e+01] [1.25897758e+02 1.11469231e+02 6.21835861e+01] [1.36408173e+02 1.21979645e+02 7.45511169e+01]] [[5.31785240e+01 2.17499924e+01 2.08678961e+00] [8.83721466e+01 5.16629753e+01 2.31681156e+01] [1.69887405e+02 1.21387421e+02 8.01731873e+01] ... [1.03188866e+02 9.80715027e+01 4.75154610e+01] [1.07494804e+02 9.97652588e+01 5.28469086e+01] [1.18214325e+02 1.07500092e+02 6.25715332e+01]] [[4.53569336e+01 1.83263817e+01 7.04070187e+00] [5.42549515e+01 2.05866871e+01 1.83662802e-01] [1.09019783e+02 6.42289886e+01 2.82239151e+01] ... [9.87245026e+01 9.85969162e+01 4.91582108e+01] [8.90457077e+01 8.59538651e+01 4.18059921e+01] [9.70410385e+01 9.20410385e+01 5.07553978e+01]]] ... [[[2.54000000e+02 2.54000000e+02 2.54000000e+02] [2.54000000e+02 2.54000000e+02 2.54000000e+02] [2.54000000e+02 2.54000000e+02 2.54000000e+02] ... [1.03785736e+02 1.04785736e+02 9.07857361e+01] [1.04071442e+02 1.05071442e+02 9.10714417e+01] [1.05000000e+02 1.06000000e+02 9.20000000e+01]] [[2.54000000e+02 2.54000000e+02 2.54000000e+02] [2.54000000e+02 2.54000000e+02 2.54000000e+02] [2.54000000e+02 2.54000000e+02 2.54000000e+02] ... [1.04729614e+02 1.05729614e+02 9.17296143e+01] [1.04933678e+02 1.05933678e+02 9.19336777e+01] [1.05000000e+02 1.06000000e+02 9.20000000e+01]] [[2.54000000e+02 2.54000000e+02 2.54000000e+02] [2.54000000e+02 2.54000000e+02 2.54000000e+02] [2.54000000e+02 2.54000000e+02 2.54000000e+02] ... [1.05382637e+02 1.06382637e+02 9.23826370e+01] [1.05000000e+02 1.06000000e+02 9.20000000e+01] [1.05000000e+02 1.06000000e+02 9.20000000e+01]] ... [[1.51000366e+02 1.23148277e+02 8.75717087e+01] [1.81531052e+02 1.54173843e+02 1.14031006e+02] [2.01418442e+02 1.77464355e+02 1.29372528e+02] ... [1.56698837e+02 1.34270309e+02 9.72243958e+01] [1.67561020e+02 1.44703964e+02 1.07703964e+02] [1.80540848e+02 1.55683792e+02 1.19469528e+02]] [[1.66280731e+02 1.41760361e+02 1.06020538e+02] [1.74851913e+02 1.51122360e+02 1.13341789e+02] [1.67749680e+02 1.45892563e+02 1.07321213e+02] ... [1.67688797e+02 1.49688797e+02 1.11688797e+02] [1.50270462e+02 1.32270462e+02 9.61377792e+01] [1.44224365e+02 1.26224358e+02 9.02243576e+01]] [[1.86229233e+02 1.67770065e+02 1.31999649e+02] [1.52626785e+02 1.34698318e+02 1.02295227e+02] [1.28321167e+02 1.11397705e+02 8.26171722e+01] ... [1.54229584e+02 1.36306122e+02 9.60765305e+01] [1.69928619e+02 1.51928619e+02 1.12020470e+02] [1.49765457e+02 1.34765457e+02 9.57654572e+01]]] [[[1.60000000e+01 9.00000000e+00 0.00000000e+00] [1.80373096e+01 8.03730869e+00 0.00000000e+00] [2.04422836e+01 8.15210438e+00 1.52678585e+00] ... [1.41512177e+02 8.60389099e+01 5.67488060e+01] [1.31497604e+02 7.44976120e+01 4.74976120e+01] [9.87061234e+01 4.07061234e+01 1.82527027e+01]] [[1.72560596e+01 9.60108471e+00 1.92857170e+00] [1.90000000e+01 8.00000000e+00 2.00000000e+00] [2.05267868e+01 7.23660707e+00 3.52678585e+00] ... [1.37288544e+02 7.97384033e+01 5.13768768e+01] [1.32715057e+02 7.09780884e+01 4.58486443e+01] [1.25001762e+02 6.13413200e+01 3.88709908e+01]] [[1.80000000e+01 8.78571415e+00 4.42857170e+00] [2.12598858e+01 7.25988531e+00 4.25988531e+00] [2.27633934e+01 7.23660707e+00 6.52678585e+00] ... [1.35598724e+02 7.42236252e+01 4.85731277e+01] [1.30317886e+02 6.42071915e+01 3.97330399e+01] [1.44314301e+02 7.44399185e+01 5.00696678e+01]] ... [[5.50862360e+00 4.50862360e+00 5.08623779e-01] [4.74013519e+00 3.74013519e+00 0.00000000e+00] [4.00000000e+00 3.00000000e+00 0.00000000e+00] ... [2.16505070e+01 1.66505070e+01 1.26505070e+01] [2.12142639e+01 1.62142639e+01 1.22142639e+01] [2.12142639e+01 1.62142639e+01 1.22142639e+01]] [[5.60107565e+00 4.60107565e+00 6.01075709e-01] [5.00000000e+00 4.00000000e+00 0.00000000e+00] [5.00000000e+00 4.00000000e+00 0.00000000e+00] ... [2.64329834e+01 2.14329834e+01 1.74329834e+01] [2.13169250e+01 1.63169250e+01 1.23169250e+01] [1.90251942e+01 1.40251951e+01 1.00251951e+01]] [[3.76879120e+00 2.76879120e+00 0.00000000e+00] [5.02072906e+00 4.02072906e+00 2.07291134e-02] [5.84790373e+00 4.84790373e+00 8.47903728e-01] ... [2.81019630e+01 2.31019630e+01 1.91019630e+01] [2.44732056e+01 1.94732056e+01 1.54732056e+01] [2.17098389e+01 1.67098389e+01 1.27098389e+01]]] [[[1.38484695e+02 1.24229591e+02 1.24357140e+02] [1.37551025e+02 1.25474495e+02 1.27500008e+02] [1.41642853e+02 1.30642853e+02 1.37071426e+02] ... [1.36806229e+02 7.66633682e+01 7.10919342e+01] [1.23571205e+02 7.72651825e+01 6.48824997e+01] [8.62853622e+01 5.75712547e+01 3.92854691e+01]] [[1.33163269e+02 1.16306122e+02 1.16331627e+02] [1.33005096e+02 1.17852043e+02 1.20714287e+02] [1.35816330e+02 1.23244896e+02 1.30459183e+02] ... [9.15152588e+01 5.30867348e+01 5.65152817e+01] [7.96835403e+01 5.30356674e+01 5.03111496e+01] [5.42649994e+01 4.33110542e+01 3.53110161e+01]] [[1.42214294e+02 1.21928574e+02 1.23071426e+02] [1.40397949e+02 1.24183670e+02 1.27112244e+02] [1.40928558e+02 1.28188782e+02 1.35739792e+02] ... [5.83366356e+01 4.21683273e+01 5.07397346e+01] [5.45663376e+01 4.67398643e+01 5.04082108e+01] [5.12038918e+01 5.11376419e+01 5.01376076e+01]] ... [[7.41377640e+01 3.99948845e+01 2.08571262e+01] [6.82295914e+01 3.51989784e+01 1.61989784e+01] [6.40000000e+01 3.34285736e+01 1.58571434e+01] ... [1.67765320e+02 1.39551056e+02 1.03193848e+02] [1.69755112e+02 1.43755112e+02 1.06755119e+02] [1.64994873e+02 1.39209167e+02 1.01566284e+02]] [[7.61020660e+01 3.91938820e+01 2.01479759e+01] [7.46531754e+01 3.96531448e+01 2.07194729e+01] [6.97041855e+01 3.91327591e+01 2.15613327e+01] ... [1.71642914e+02 1.43428650e+02 1.07071442e+02] [1.69999908e+02 1.43999908e+02 1.06999908e+02] [1.64494888e+02 1.39494888e+02 9.94948883e+01]] [[7.87857056e+01 4.10153008e+01 2.20153008e+01] [7.00408401e+01 3.48979874e+01 1.59694157e+01] [7.56938705e+01 4.48469353e+01 2.68469372e+01] ... [1.64714294e+02 1.36500031e+02 1.00142822e+02] [1.55663269e+02 1.30663269e+02 9.06632614e+01] [1.59474701e+02 1.34474701e+02 9.44746933e+01]]]], shape=(32, 224, 224, 3), dtype=float32) tf.Tensor( [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(32, 10), dtype=float32)
Notice how the image arrays come out as tensors of pixel values where as the labels come out as one-hot encodings (e.g. [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
for hamburger
).
Model 0: Building a transfer learning model using the Keras Functional API¶
Alright, our data is tensor-ified, let's build a model.
To do so we're going to be using the tf.keras.applications
module as it contains a series of already trained (on ImageNet) computer vision models as well as the Keras Functional API to construct our model.
We're going to go through the following steps:
- Instantiate a pre-trained base model object by choosing a target model such as
EfficientNetV2B0
fromtf.keras.applications.efficientnet_v2
, setting theinclude_top
parameter toFalse
(we do this because we're going to create our own top, which are the output layers for the model). - Set the base model's
trainable
attribute toFalse
to freeze all of the weights in the pre-trained model. - Define an input layer for our model, for example, what shape of data should our model expect?
- [Optional] Normalize the inputs to our model if it requires. Some computer vision models such as
ResNetV250
require their inputs to be between 0 & 1.
🤔 Note: As of writing, the
EfficientNet
(andEfficientNetV2
) models in thetf.keras.applications
module do not require images to be normalized (pixel values between 0 and 1) on input, where as many of the other models do. I posted an issue to the TensorFlow GitHub about this and they confirmed this.
- Pass the inputs to the base model.
- Pool the outputs of the base model into a shape compatible with the output activation layer (turn base model output tensors into same shape as label tensors). This can be done using
tf.keras.layers.GlobalAveragePooling2D()
ortf.keras.layers.GlobalMaxPooling2D()
though the former is more common in practice. - Create an output activation layer using
tf.keras.layers.Dense()
with the appropriate activation function and number of neurons. - Combine the inputs and outputs layer into a model using
tf.keras.Model()
. - Compile the model using the appropriate loss function and choose of optimizer.
- Fit the model for desired number of epochs and with necessary callbacks (in our case, we'll start off with the TensorBoard callback).
Woah... that sounds like a lot. Before we get ahead of ourselves, let's see it in practice.
# 1. Create base model with tf.keras.applications
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False)
# OLD
# base_model = tf.keras.applications.EfficientNetB0(include_top=False)
# 2. Freeze the base model (so the pre-learned patterns remain)
base_model.trainable = False
# 3. Create inputs into the base model
inputs = tf.keras.layers.Input(shape=(224, 224, 3), name="input_layer")
# 4. If using ResNet50V2, add this to speed up convergence, remove for EfficientNetV2
# x = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(inputs)
# 5. Pass the inputs to the base_model (note: using tf.keras.applications, EfficientNetV2 inputs don't have to be normalized)
x = base_model(inputs)
# Check data shape after passing it to base_model
print(f"Shape after base_model: {x.shape}")
# 6. Average pool the outputs of the base model (aggregate all the most important information, reduce number of computations)
x = tf.keras.layers.GlobalAveragePooling2D(name="global_average_pooling_layer")(x)
print(f"After GlobalAveragePooling2D(): {x.shape}")
# 7. Create the output activation layer
outputs = tf.keras.layers.Dense(10, activation="softmax", name="output_layer")(x)
# 8. Combine the inputs with the outputs into a model
model_0 = tf.keras.Model(inputs, outputs)
# 9. Compile the model
model_0.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"])
# 10. Fit the model (we use less steps for validation so it's faster)
history_10_percent = model_0.fit(train_data_10_percent,
epochs=5,
steps_per_epoch=len(train_data_10_percent),
validation_data=test_data_10_percent,
# Go through less of the validation data so epochs are faster (we want faster experiments!)
validation_steps=int(0.25 * len(test_data_10_percent)),
# Track our model's training logs for visualization later
callbacks=[create_tensorboard_callback("transfer_learning", "10_percent_feature_extract")])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/efficientnetv2-b0_notop.h5 24274472/24274472 [==============================] - 0s 0us/step Shape after base_model: (None, 7, 7, 1280) After GlobalAveragePooling2D(): (None, 1280) Saving TensorBoard log files to: transfer_learning/10_percent_feature_extract/20230818-014007 Epoch 1/5 24/24 [==============================] - 20s 156ms/step - loss: 1.8704 - accuracy: 0.4387 - val_loss: 1.2824 - val_accuracy: 0.7418 Epoch 2/5 24/24 [==============================] - 2s 85ms/step - loss: 1.1395 - accuracy: 0.7533 - val_loss: 0.8783 - val_accuracy: 0.8010 Epoch 3/5 24/24 [==============================] - 2s 67ms/step - loss: 0.8327 - accuracy: 0.8147 - val_loss: 0.7168 - val_accuracy: 0.8322 Epoch 4/5 24/24 [==============================] - 2s 85ms/step - loss: 0.6915 - accuracy: 0.8453 - val_loss: 0.6149 - val_accuracy: 0.8520 Epoch 5/5 24/24 [==============================] - 2s 86ms/step - loss: 0.5813 - accuracy: 0.8720 - val_loss: 0.5555 - val_accuracy: 0.8569
Nice! After a minute or so of training our model performs incredibly well on both the training and test sets.
This is incredible.
All thanks to the power of transfer learning!
It's important to note the kind of transfer learning we used here is called feature extraction transfer learning, similar to what we did with the TensorFlow Hub models.
In other words, we passed our custom data to an already pre-trained model (EfficientNetV2B0
), asked it "what patterns do you see?" and then put our own output layer on top to make sure the outputs were tailored to our desired number of classes.
We also used the Keras Functional API to build our model rather than the Sequential API. For now, the benefits of this main not seem clear but when you start to build more sophisticated models, you'll probably want to use the Functional API. So it's important to have exposure to this way of building models.
📖 Resource: To see the benefits and use cases of the Functional API versus the Sequential API, check out the TensorFlow Functional API documentation.
Let's inspect the layers in our model, we'll start with the base.
# Check layers in our base model
for layer_number, layer in enumerate(base_model.layers):
print(layer_number, layer.name)
0 input_1 1 rescaling 2 normalization 3 stem_conv 4 stem_bn 5 stem_activation 6 block1a_project_conv 7 block1a_project_bn 8 block1a_project_activation 9 block2a_expand_conv 10 block2a_expand_bn 11 block2a_expand_activation 12 block2a_project_conv 13 block2a_project_bn 14 block2b_expand_conv 15 block2b_expand_bn 16 block2b_expand_activation 17 block2b_project_conv 18 block2b_project_bn 19 block2b_drop 20 block2b_add 21 block3a_expand_conv 22 block3a_expand_bn 23 block3a_expand_activation 24 block3a_project_conv 25 block3a_project_bn 26 block3b_expand_conv 27 block3b_expand_bn 28 block3b_expand_activation 29 block3b_project_conv 30 block3b_project_bn 31 block3b_drop 32 block3b_add 33 block4a_expand_conv 34 block4a_expand_bn 35 block4a_expand_activation 36 block4a_dwconv2 37 block4a_bn 38 block4a_activation 39 block4a_se_squeeze 40 block4a_se_reshape 41 block4a_se_reduce 42 block4a_se_expand 43 block4a_se_excite 44 block4a_project_conv 45 block4a_project_bn 46 block4b_expand_conv 47 block4b_expand_bn 48 block4b_expand_activation 49 block4b_dwconv2 50 block4b_bn 51 block4b_activation 52 block4b_se_squeeze 53 block4b_se_reshape 54 block4b_se_reduce 55 block4b_se_expand 56 block4b_se_excite 57 block4b_project_conv 58 block4b_project_bn 59 block4b_drop 60 block4b_add 61 block4c_expand_conv 62 block4c_expand_bn 63 block4c_expand_activation 64 block4c_dwconv2 65 block4c_bn 66 block4c_activation 67 block4c_se_squeeze 68 block4c_se_reshape 69 block4c_se_reduce 70 block4c_se_expand 71 block4c_se_excite 72 block4c_project_conv 73 block4c_project_bn 74 block4c_drop 75 block4c_add 76 block5a_expand_conv 77 block5a_expand_bn 78 block5a_expand_activation 79 block5a_dwconv2 80 block5a_bn 81 block5a_activation 82 block5a_se_squeeze 83 block5a_se_reshape 84 block5a_se_reduce 85 block5a_se_expand 86 block5a_se_excite 87 block5a_project_conv 88 block5a_project_bn 89 block5b_expand_conv 90 block5b_expand_bn 91 block5b_expand_activation 92 block5b_dwconv2 93 block5b_bn 94 block5b_activation 95 block5b_se_squeeze 96 block5b_se_reshape 97 block5b_se_reduce 98 block5b_se_expand 99 block5b_se_excite 100 block5b_project_conv 101 block5b_project_bn 102 block5b_drop 103 block5b_add 104 block5c_expand_conv 105 block5c_expand_bn 106 block5c_expand_activation 107 block5c_dwconv2 108 block5c_bn 109 block5c_activation 110 block5c_se_squeeze 111 block5c_se_reshape 112 block5c_se_reduce 113 block5c_se_expand 114 block5c_se_excite 115 block5c_project_conv 116 block5c_project_bn 117 block5c_drop 118 block5c_add 119 block5d_expand_conv 120 block5d_expand_bn 121 block5d_expand_activation 122 block5d_dwconv2 123 block5d_bn 124 block5d_activation 125 block5d_se_squeeze 126 block5d_se_reshape 127 block5d_se_reduce 128 block5d_se_expand 129 block5d_se_excite 130 block5d_project_conv 131 block5d_project_bn 132 block5d_drop 133 block5d_add 134 block5e_expand_conv 135 block5e_expand_bn 136 block5e_expand_activation 137 block5e_dwconv2 138 block5e_bn 139 block5e_activation 140 block5e_se_squeeze 141 block5e_se_reshape 142 block5e_se_reduce 143 block5e_se_expand 144 block5e_se_excite 145 block5e_project_conv 146 block5e_project_bn 147 block5e_drop 148 block5e_add 149 block6a_expand_conv 150 block6a_expand_bn 151 block6a_expand_activation 152 block6a_dwconv2 153 block6a_bn 154 block6a_activation 155 block6a_se_squeeze 156 block6a_se_reshape 157 block6a_se_reduce 158 block6a_se_expand 159 block6a_se_excite 160 block6a_project_conv 161 block6a_project_bn 162 block6b_expand_conv 163 block6b_expand_bn 164 block6b_expand_activation 165 block6b_dwconv2 166 block6b_bn 167 block6b_activation 168 block6b_se_squeeze 169 block6b_se_reshape 170 block6b_se_reduce 171 block6b_se_expand 172 block6b_se_excite 173 block6b_project_conv 174 block6b_project_bn 175 block6b_drop 176 block6b_add 177 block6c_expand_conv 178 block6c_expand_bn 179 block6c_expand_activation 180 block6c_dwconv2 181 block6c_bn 182 block6c_activation 183 block6c_se_squeeze 184 block6c_se_reshape 185 block6c_se_reduce 186 block6c_se_expand 187 block6c_se_excite 188 block6c_project_conv 189 block6c_project_bn 190 block6c_drop 191 block6c_add 192 block6d_expand_conv 193 block6d_expand_bn 194 block6d_expand_activation 195 block6d_dwconv2 196 block6d_bn 197 block6d_activation 198 block6d_se_squeeze 199 block6d_se_reshape 200 block6d_se_reduce 201 block6d_se_expand 202 block6d_se_excite 203 block6d_project_conv 204 block6d_project_bn 205 block6d_drop 206 block6d_add 207 block6e_expand_conv 208 block6e_expand_bn 209 block6e_expand_activation 210 block6e_dwconv2 211 block6e_bn 212 block6e_activation 213 block6e_se_squeeze 214 block6e_se_reshape 215 block6e_se_reduce 216 block6e_se_expand 217 block6e_se_excite 218 block6e_project_conv 219 block6e_project_bn 220 block6e_drop 221 block6e_add 222 block6f_expand_conv 223 block6f_expand_bn 224 block6f_expand_activation 225 block6f_dwconv2 226 block6f_bn 227 block6f_activation 228 block6f_se_squeeze 229 block6f_se_reshape 230 block6f_se_reduce 231 block6f_se_expand 232 block6f_se_excite 233 block6f_project_conv 234 block6f_project_bn 235 block6f_drop 236 block6f_add 237 block6g_expand_conv 238 block6g_expand_bn 239 block6g_expand_activation 240 block6g_dwconv2 241 block6g_bn 242 block6g_activation 243 block6g_se_squeeze 244 block6g_se_reshape 245 block6g_se_reduce 246 block6g_se_expand 247 block6g_se_excite 248 block6g_project_conv 249 block6g_project_bn 250 block6g_drop 251 block6g_add 252 block6h_expand_conv 253 block6h_expand_bn 254 block6h_expand_activation 255 block6h_dwconv2 256 block6h_bn 257 block6h_activation 258 block6h_se_squeeze 259 block6h_se_reshape 260 block6h_se_reduce 261 block6h_se_expand 262 block6h_se_excite 263 block6h_project_conv 264 block6h_project_bn 265 block6h_drop 266 block6h_add 267 top_conv 268 top_bn 269 top_activation
Wow, that's a lot of layers... to handcode all of those would've taken a fairly long time to do, yet we can still take advatange of them thanks to the power of transfer learning.
How about a summary of the base model?
base_model.summary()
Model: "efficientnetv2-b0" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, None, None, 0 [] 3)] rescaling (Rescaling) (None, None, None, 0 ['input_1[0][0]'] 3) normalization (Normalization) (None, None, None, 0 ['rescaling[0][0]'] 3) stem_conv (Conv2D) (None, None, None, 864 ['normalization[0][0]'] 32) stem_bn (BatchNormalization) (None, None, None, 128 ['stem_conv[0][0]'] 32) stem_activation (Activation) (None, None, None, 0 ['stem_bn[0][0]'] 32) block1a_project_conv (Conv2D) (None, None, None, 4608 ['stem_activation[0][0]'] 16) block1a_project_bn (BatchNorma (None, None, None, 64 ['block1a_project_conv[0][0]'] lization) 16) block1a_project_activation (Ac (None, None, None, 0 ['block1a_project_bn[0][0]'] tivation) 16) block2a_expand_conv (Conv2D) (None, None, None, 9216 ['block1a_project_activation[0][0 64) ]'] block2a_expand_bn (BatchNormal (None, None, None, 256 ['block2a_expand_conv[0][0]'] ization) 64) block2a_expand_activation (Act (None, None, None, 0 ['block2a_expand_bn[0][0]'] ivation) 64) block2a_project_conv (Conv2D) (None, None, None, 2048 ['block2a_expand_activation[0][0] 32) '] block2a_project_bn (BatchNorma (None, None, None, 128 ['block2a_project_conv[0][0]'] lization) 32) block2b_expand_conv (Conv2D) (None, None, None, 36864 ['block2a_project_bn[0][0]'] 128) block2b_expand_bn (BatchNormal (None, None, None, 512 ['block2b_expand_conv[0][0]'] ization) 128) block2b_expand_activation (Act (None, None, None, 0 ['block2b_expand_bn[0][0]'] ivation) 128) block2b_project_conv (Conv2D) (None, None, None, 4096 ['block2b_expand_activation[0][0] 32) '] block2b_project_bn (BatchNorma (None, None, None, 128 ['block2b_project_conv[0][0]'] lization) 32) block2b_drop (Dropout) (None, None, None, 0 ['block2b_project_bn[0][0]'] 32) block2b_add (Add) (None, None, None, 0 ['block2b_drop[0][0]', 32) 'block2a_project_bn[0][0]'] block3a_expand_conv (Conv2D) (None, None, None, 36864 ['block2b_add[0][0]'] 128) block3a_expand_bn (BatchNormal (None, None, None, 512 ['block3a_expand_conv[0][0]'] ization) 128) block3a_expand_activation (Act (None, None, None, 0 ['block3a_expand_bn[0][0]'] ivation) 128) block3a_project_conv (Conv2D) (None, None, None, 6144 ['block3a_expand_activation[0][0] 48) '] block3a_project_bn (BatchNorma (None, None, None, 192 ['block3a_project_conv[0][0]'] lization) 48) block3b_expand_conv (Conv2D) (None, None, None, 82944 ['block3a_project_bn[0][0]'] 192) block3b_expand_bn (BatchNormal (None, None, None, 768 ['block3b_expand_conv[0][0]'] ization) 192) block3b_expand_activation (Act (None, None, None, 0 ['block3b_expand_bn[0][0]'] ivation) 192) block3b_project_conv (Conv2D) (None, None, None, 9216 ['block3b_expand_activation[0][0] 48) '] block3b_project_bn (BatchNorma (None, None, None, 192 ['block3b_project_conv[0][0]'] lization) 48) block3b_drop (Dropout) (None, None, None, 0 ['block3b_project_bn[0][0]'] 48) block3b_add (Add) (None, None, None, 0 ['block3b_drop[0][0]', 48) 'block3a_project_bn[0][0]'] block4a_expand_conv (Conv2D) (None, None, None, 9216 ['block3b_add[0][0]'] 192) block4a_expand_bn (BatchNormal (None, None, None, 768 ['block4a_expand_conv[0][0]'] ization) 192) block4a_expand_activation (Act (None, None, None, 0 ['block4a_expand_bn[0][0]'] ivation) 192) block4a_dwconv2 (DepthwiseConv (None, None, None, 1728 ['block4a_expand_activation[0][0] 2D) 192) '] block4a_bn (BatchNormalization (None, None, None, 768 ['block4a_dwconv2[0][0]'] ) 192) block4a_activation (Activation (None, None, None, 0 ['block4a_bn[0][0]'] ) 192) block4a_se_squeeze (GlobalAver (None, 192) 0 ['block4a_activation[0][0]'] agePooling2D) block4a_se_reshape (Reshape) (None, 1, 1, 192) 0 ['block4a_se_squeeze[0][0]'] block4a_se_reduce (Conv2D) (None, 1, 1, 12) 2316 ['block4a_se_reshape[0][0]'] block4a_se_expand (Conv2D) (None, 1, 1, 192) 2496 ['block4a_se_reduce[0][0]'] block4a_se_excite (Multiply) (None, None, None, 0 ['block4a_activation[0][0]', 192) 'block4a_se_expand[0][0]'] block4a_project_conv (Conv2D) (None, None, None, 18432 ['block4a_se_excite[0][0]'] 96) block4a_project_bn (BatchNorma (None, None, None, 384 ['block4a_project_conv[0][0]'] lization) 96) block4b_expand_conv (Conv2D) (None, None, None, 36864 ['block4a_project_bn[0][0]'] 384) block4b_expand_bn (BatchNormal (None, None, None, 1536 ['block4b_expand_conv[0][0]'] ization) 384) block4b_expand_activation (Act (None, None, None, 0 ['block4b_expand_bn[0][0]'] ivation) 384) block4b_dwconv2 (DepthwiseConv (None, None, None, 3456 ['block4b_expand_activation[0][0] 2D) 384) '] block4b_bn (BatchNormalization (None, None, None, 1536 ['block4b_dwconv2[0][0]'] ) 384) block4b_activation (Activation (None, None, None, 0 ['block4b_bn[0][0]'] ) 384) block4b_se_squeeze (GlobalAver (None, 384) 0 ['block4b_activation[0][0]'] agePooling2D) block4b_se_reshape (Reshape) (None, 1, 1, 384) 0 ['block4b_se_squeeze[0][0]'] block4b_se_reduce (Conv2D) (None, 1, 1, 24) 9240 ['block4b_se_reshape[0][0]'] block4b_se_expand (Conv2D) (None, 1, 1, 384) 9600 ['block4b_se_reduce[0][0]'] block4b_se_excite (Multiply) (None, None, None, 0 ['block4b_activation[0][0]', 384) 'block4b_se_expand[0][0]'] block4b_project_conv (Conv2D) (None, None, None, 36864 ['block4b_se_excite[0][0]'] 96) block4b_project_bn (BatchNorma (None, None, None, 384 ['block4b_project_conv[0][0]'] lization) 96) block4b_drop (Dropout) (None, None, None, 0 ['block4b_project_bn[0][0]'] 96) block4b_add (Add) (None, None, None, 0 ['block4b_drop[0][0]', 96) 'block4a_project_bn[0][0]'] block4c_expand_conv (Conv2D) (None, None, None, 36864 ['block4b_add[0][0]'] 384) block4c_expand_bn (BatchNormal (None, None, None, 1536 ['block4c_expand_conv[0][0]'] ization) 384) block4c_expand_activation (Act (None, None, None, 0 ['block4c_expand_bn[0][0]'] ivation) 384) block4c_dwconv2 (DepthwiseConv (None, None, None, 3456 ['block4c_expand_activation[0][0] 2D) 384) '] block4c_bn (BatchNormalization (None, None, None, 1536 ['block4c_dwconv2[0][0]'] ) 384) block4c_activation (Activation (None, None, None, 0 ['block4c_bn[0][0]'] ) 384) block4c_se_squeeze (GlobalAver (None, 384) 0 ['block4c_activation[0][0]'] agePooling2D) block4c_se_reshape (Reshape) (None, 1, 1, 384) 0 ['block4c_se_squeeze[0][0]'] block4c_se_reduce (Conv2D) (None, 1, 1, 24) 9240 ['block4c_se_reshape[0][0]'] block4c_se_expand (Conv2D) (None, 1, 1, 384) 9600 ['block4c_se_reduce[0][0]'] block4c_se_excite (Multiply) (None, None, None, 0 ['block4c_activation[0][0]', 384) 'block4c_se_expand[0][0]'] block4c_project_conv (Conv2D) (None, None, None, 36864 ['block4c_se_excite[0][0]'] 96) block4c_project_bn (BatchNorma (None, None, None, 384 ['block4c_project_conv[0][0]'] lization) 96) block4c_drop (Dropout) (None, None, None, 0 ['block4c_project_bn[0][0]'] 96) block4c_add (Add) (None, None, None, 0 ['block4c_drop[0][0]', 96) 'block4b_add[0][0]'] block5a_expand_conv (Conv2D) (None, None, None, 55296 ['block4c_add[0][0]'] 576) block5a_expand_bn (BatchNormal (None, None, None, 2304 ['block5a_expand_conv[0][0]'] ization) 576) block5a_expand_activation (Act (None, None, None, 0 ['block5a_expand_bn[0][0]'] ivation) 576) block5a_dwconv2 (DepthwiseConv (None, None, None, 5184 ['block5a_expand_activation[0][0] 2D) 576) '] block5a_bn (BatchNormalization (None, None, None, 2304 ['block5a_dwconv2[0][0]'] ) 576) block5a_activation (Activation (None, None, None, 0 ['block5a_bn[0][0]'] ) 576) block5a_se_squeeze (GlobalAver (None, 576) 0 ['block5a_activation[0][0]'] agePooling2D) block5a_se_reshape (Reshape) (None, 1, 1, 576) 0 ['block5a_se_squeeze[0][0]'] block5a_se_reduce (Conv2D) (None, 1, 1, 24) 13848 ['block5a_se_reshape[0][0]'] block5a_se_expand (Conv2D) (None, 1, 1, 576) 14400 ['block5a_se_reduce[0][0]'] block5a_se_excite (Multiply) (None, None, None, 0 ['block5a_activation[0][0]', 576) 'block5a_se_expand[0][0]'] block5a_project_conv (Conv2D) (None, None, None, 64512 ['block5a_se_excite[0][0]'] 112) block5a_project_bn (BatchNorma (None, None, None, 448 ['block5a_project_conv[0][0]'] lization) 112) block5b_expand_conv (Conv2D) (None, None, None, 75264 ['block5a_project_bn[0][0]'] 672) block5b_expand_bn (BatchNormal (None, None, None, 2688 ['block5b_expand_conv[0][0]'] ization) 672) block5b_expand_activation (Act (None, None, None, 0 ['block5b_expand_bn[0][0]'] ivation) 672) block5b_dwconv2 (DepthwiseConv (None, None, None, 6048 ['block5b_expand_activation[0][0] 2D) 672) '] block5b_bn (BatchNormalization (None, None, None, 2688 ['block5b_dwconv2[0][0]'] ) 672) block5b_activation (Activation (None, None, None, 0 ['block5b_bn[0][0]'] ) 672) block5b_se_squeeze (GlobalAver (None, 672) 0 ['block5b_activation[0][0]'] agePooling2D) block5b_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5b_se_squeeze[0][0]'] block5b_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5b_se_reshape[0][0]'] block5b_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5b_se_reduce[0][0]'] block5b_se_excite (Multiply) (None, None, None, 0 ['block5b_activation[0][0]', 672) 'block5b_se_expand[0][0]'] block5b_project_conv (Conv2D) (None, None, None, 75264 ['block5b_se_excite[0][0]'] 112) block5b_project_bn (BatchNorma (None, None, None, 448 ['block5b_project_conv[0][0]'] lization) 112) block5b_drop (Dropout) (None, None, None, 0 ['block5b_project_bn[0][0]'] 112) block5b_add (Add) (None, None, None, 0 ['block5b_drop[0][0]', 112) 'block5a_project_bn[0][0]'] block5c_expand_conv (Conv2D) (None, None, None, 75264 ['block5b_add[0][0]'] 672) block5c_expand_bn (BatchNormal (None, None, None, 2688 ['block5c_expand_conv[0][0]'] ization) 672) block5c_expand_activation (Act (None, None, None, 0 ['block5c_expand_bn[0][0]'] ivation) 672) block5c_dwconv2 (DepthwiseConv (None, None, None, 6048 ['block5c_expand_activation[0][0] 2D) 672) '] block5c_bn (BatchNormalization (None, None, None, 2688 ['block5c_dwconv2[0][0]'] ) 672) block5c_activation (Activation (None, None, None, 0 ['block5c_bn[0][0]'] ) 672) block5c_se_squeeze (GlobalAver (None, 672) 0 ['block5c_activation[0][0]'] agePooling2D) block5c_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5c_se_squeeze[0][0]'] block5c_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5c_se_reshape[0][0]'] block5c_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5c_se_reduce[0][0]'] block5c_se_excite (Multiply) (None, None, None, 0 ['block5c_activation[0][0]', 672) 'block5c_se_expand[0][0]'] block5c_project_conv (Conv2D) (None, None, None, 75264 ['block5c_se_excite[0][0]'] 112) block5c_project_bn (BatchNorma (None, None, None, 448 ['block5c_project_conv[0][0]'] lization) 112) block5c_drop (Dropout) (None, None, None, 0 ['block5c_project_bn[0][0]'] 112) block5c_add (Add) (None, None, None, 0 ['block5c_drop[0][0]', 112) 'block5b_add[0][0]'] block5d_expand_conv (Conv2D) (None, None, None, 75264 ['block5c_add[0][0]'] 672) block5d_expand_bn (BatchNormal (None, None, None, 2688 ['block5d_expand_conv[0][0]'] ization) 672) block5d_expand_activation (Act (None, None, None, 0 ['block5d_expand_bn[0][0]'] ivation) 672) block5d_dwconv2 (DepthwiseConv (None, None, None, 6048 ['block5d_expand_activation[0][0] 2D) 672) '] block5d_bn (BatchNormalization (None, None, None, 2688 ['block5d_dwconv2[0][0]'] ) 672) block5d_activation (Activation (None, None, None, 0 ['block5d_bn[0][0]'] ) 672) block5d_se_squeeze (GlobalAver (None, 672) 0 ['block5d_activation[0][0]'] agePooling2D) block5d_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5d_se_squeeze[0][0]'] block5d_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5d_se_reshape[0][0]'] block5d_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5d_se_reduce[0][0]'] block5d_se_excite (Multiply) (None, None, None, 0 ['block5d_activation[0][0]', 672) 'block5d_se_expand[0][0]'] block5d_project_conv (Conv2D) (None, None, None, 75264 ['block5d_se_excite[0][0]'] 112) block5d_project_bn (BatchNorma (None, None, None, 448 ['block5d_project_conv[0][0]'] lization) 112) block5d_drop (Dropout) (None, None, None, 0 ['block5d_project_bn[0][0]'] 112) block5d_add (Add) (None, None, None, 0 ['block5d_drop[0][0]', 112) 'block5c_add[0][0]'] block5e_expand_conv (Conv2D) (None, None, None, 75264 ['block5d_add[0][0]'] 672) block5e_expand_bn (BatchNormal (None, None, None, 2688 ['block5e_expand_conv[0][0]'] ization) 672) block5e_expand_activation (Act (None, None, None, 0 ['block5e_expand_bn[0][0]'] ivation) 672) block5e_dwconv2 (DepthwiseConv (None, None, None, 6048 ['block5e_expand_activation[0][0] 2D) 672) '] block5e_bn (BatchNormalization (None, None, None, 2688 ['block5e_dwconv2[0][0]'] ) 672) block5e_activation (Activation (None, None, None, 0 ['block5e_bn[0][0]'] ) 672) block5e_se_squeeze (GlobalAver (None, 672) 0 ['block5e_activation[0][0]'] agePooling2D) block5e_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5e_se_squeeze[0][0]'] block5e_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5e_se_reshape[0][0]'] block5e_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5e_se_reduce[0][0]'] block5e_se_excite (Multiply) (None, None, None, 0 ['block5e_activation[0][0]', 672) 'block5e_se_expand[0][0]'] block5e_project_conv (Conv2D) (None, None, None, 75264 ['block5e_se_excite[0][0]'] 112) block5e_project_bn (BatchNorma (None, None, None, 448 ['block5e_project_conv[0][0]'] lization) 112) block5e_drop (Dropout) (None, None, None, 0 ['block5e_project_bn[0][0]'] 112) block5e_add (Add) (None, None, None, 0 ['block5e_drop[0][0]', 112) 'block5d_add[0][0]'] block6a_expand_conv (Conv2D) (None, None, None, 75264 ['block5e_add[0][0]'] 672) block6a_expand_bn (BatchNormal (None, None, None, 2688 ['block6a_expand_conv[0][0]'] ization) 672) block6a_expand_activation (Act (None, None, None, 0 ['block6a_expand_bn[0][0]'] ivation) 672) block6a_dwconv2 (DepthwiseConv (None, None, None, 6048 ['block6a_expand_activation[0][0] 2D) 672) '] block6a_bn (BatchNormalization (None, None, None, 2688 ['block6a_dwconv2[0][0]'] ) 672) block6a_activation (Activation (None, None, None, 0 ['block6a_bn[0][0]'] ) 672) block6a_se_squeeze (GlobalAver (None, 672) 0 ['block6a_activation[0][0]'] agePooling2D) block6a_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block6a_se_squeeze[0][0]'] block6a_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block6a_se_reshape[0][0]'] block6a_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block6a_se_reduce[0][0]'] block6a_se_excite (Multiply) (None, None, None, 0 ['block6a_activation[0][0]', 672) 'block6a_se_expand[0][0]'] block6a_project_conv (Conv2D) (None, None, None, 129024 ['block6a_se_excite[0][0]'] 192) block6a_project_bn (BatchNorma (None, None, None, 768 ['block6a_project_conv[0][0]'] lization) 192) block6b_expand_conv (Conv2D) (None, None, None, 221184 ['block6a_project_bn[0][0]'] 1152) block6b_expand_bn (BatchNormal (None, None, None, 4608 ['block6b_expand_conv[0][0]'] ization) 1152) block6b_expand_activation (Act (None, None, None, 0 ['block6b_expand_bn[0][0]'] ivation) 1152) block6b_dwconv2 (DepthwiseConv (None, None, None, 10368 ['block6b_expand_activation[0][0] 2D) 1152) '] block6b_bn (BatchNormalization (None, None, None, 4608 ['block6b_dwconv2[0][0]'] ) 1152) block6b_activation (Activation (None, None, None, 0 ['block6b_bn[0][0]'] ) 1152) block6b_se_squeeze (GlobalAver (None, 1152) 0 ['block6b_activation[0][0]'] agePooling2D) block6b_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6b_se_squeeze[0][0]'] block6b_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6b_se_reshape[0][0]'] block6b_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6b_se_reduce[0][0]'] block6b_se_excite (Multiply) (None, None, None, 0 ['block6b_activation[0][0]', 1152) 'block6b_se_expand[0][0]'] block6b_project_conv (Conv2D) (None, None, None, 221184 ['block6b_se_excite[0][0]'] 192) block6b_project_bn (BatchNorma (None, None, None, 768 ['block6b_project_conv[0][0]'] lization) 192) block6b_drop (Dropout) (None, None, None, 0 ['block6b_project_bn[0][0]'] 192) block6b_add (Add) (None, None, None, 0 ['block6b_drop[0][0]', 192) 'block6a_project_bn[0][0]'] block6c_expand_conv (Conv2D) (None, None, None, 221184 ['block6b_add[0][0]'] 1152) block6c_expand_bn (BatchNormal (None, None, None, 4608 ['block6c_expand_conv[0][0]'] ization) 1152) block6c_expand_activation (Act (None, None, None, 0 ['block6c_expand_bn[0][0]'] ivation) 1152) block6c_dwconv2 (DepthwiseConv (None, None, None, 10368 ['block6c_expand_activation[0][0] 2D) 1152) '] block6c_bn (BatchNormalization (None, None, None, 4608 ['block6c_dwconv2[0][0]'] ) 1152) block6c_activation (Activation (None, None, None, 0 ['block6c_bn[0][0]'] ) 1152) block6c_se_squeeze (GlobalAver (None, 1152) 0 ['block6c_activation[0][0]'] agePooling2D) block6c_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6c_se_squeeze[0][0]'] block6c_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6c_se_reshape[0][0]'] block6c_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6c_se_reduce[0][0]'] block6c_se_excite (Multiply) (None, None, None, 0 ['block6c_activation[0][0]', 1152) 'block6c_se_expand[0][0]'] block6c_project_conv (Conv2D) (None, None, None, 221184 ['block6c_se_excite[0][0]'] 192) block6c_project_bn (BatchNorma (None, None, None, 768 ['block6c_project_conv[0][0]'] lization) 192) block6c_drop (Dropout) (None, None, None, 0 ['block6c_project_bn[0][0]'] 192) block6c_add (Add) (None, None, None, 0 ['block6c_drop[0][0]', 192) 'block6b_add[0][0]'] block6d_expand_conv (Conv2D) (None, None, None, 221184 ['block6c_add[0][0]'] 1152) block6d_expand_bn (BatchNormal (None, None, None, 4608 ['block6d_expand_conv[0][0]'] ization) 1152) block6d_expand_activation (Act (None, None, None, 0 ['block6d_expand_bn[0][0]'] ivation) 1152) block6d_dwconv2 (DepthwiseConv (None, None, None, 10368 ['block6d_expand_activation[0][0] 2D) 1152) '] block6d_bn (BatchNormalization (None, None, None, 4608 ['block6d_dwconv2[0][0]'] ) 1152) block6d_activation (Activation (None, None, None, 0 ['block6d_bn[0][0]'] ) 1152) block6d_se_squeeze (GlobalAver (None, 1152) 0 ['block6d_activation[0][0]'] agePooling2D) block6d_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6d_se_squeeze[0][0]'] block6d_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6d_se_reshape[0][0]'] block6d_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6d_se_reduce[0][0]'] block6d_se_excite (Multiply) (None, None, None, 0 ['block6d_activation[0][0]', 1152) 'block6d_se_expand[0][0]'] block6d_project_conv (Conv2D) (None, None, None, 221184 ['block6d_se_excite[0][0]'] 192) block6d_project_bn (BatchNorma (None, None, None, 768 ['block6d_project_conv[0][0]'] lization) 192) block6d_drop (Dropout) (None, None, None, 0 ['block6d_project_bn[0][0]'] 192) block6d_add (Add) (None, None, None, 0 ['block6d_drop[0][0]', 192) 'block6c_add[0][0]'] block6e_expand_conv (Conv2D) (None, None, None, 221184 ['block6d_add[0][0]'] 1152) block6e_expand_bn (BatchNormal (None, None, None, 4608 ['block6e_expand_conv[0][0]'] ization) 1152) block6e_expand_activation (Act (None, None, None, 0 ['block6e_expand_bn[0][0]'] ivation) 1152) block6e_dwconv2 (DepthwiseConv (None, None, None, 10368 ['block6e_expand_activation[0][0] 2D) 1152) '] block6e_bn (BatchNormalization (None, None, None, 4608 ['block6e_dwconv2[0][0]'] ) 1152) block6e_activation (Activation (None, None, None, 0 ['block6e_bn[0][0]'] ) 1152) block6e_se_squeeze (GlobalAver (None, 1152) 0 ['block6e_activation[0][0]'] agePooling2D) block6e_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6e_se_squeeze[0][0]'] block6e_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6e_se_reshape[0][0]'] block6e_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6e_se_reduce[0][0]'] block6e_se_excite (Multiply) (None, None, None, 0 ['block6e_activation[0][0]', 1152) 'block6e_se_expand[0][0]'] block6e_project_conv (Conv2D) (None, None, None, 221184 ['block6e_se_excite[0][0]'] 192) block6e_project_bn (BatchNorma (None, None, None, 768 ['block6e_project_conv[0][0]'] lization) 192) block6e_drop (Dropout) (None, None, None, 0 ['block6e_project_bn[0][0]'] 192) block6e_add (Add) (None, None, None, 0 ['block6e_drop[0][0]', 192) 'block6d_add[0][0]'] block6f_expand_conv (Conv2D) (None, None, None, 221184 ['block6e_add[0][0]'] 1152) block6f_expand_bn (BatchNormal (None, None, None, 4608 ['block6f_expand_conv[0][0]'] ization) 1152) block6f_expand_activation (Act (None, None, None, 0 ['block6f_expand_bn[0][0]'] ivation) 1152) block6f_dwconv2 (DepthwiseConv (None, None, None, 10368 ['block6f_expand_activation[0][0] 2D) 1152) '] block6f_bn (BatchNormalization (None, None, None, 4608 ['block6f_dwconv2[0][0]'] ) 1152) block6f_activation (Activation (None, None, None, 0 ['block6f_bn[0][0]'] ) 1152) block6f_se_squeeze (GlobalAver (None, 1152) 0 ['block6f_activation[0][0]'] agePooling2D) block6f_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6f_se_squeeze[0][0]'] block6f_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6f_se_reshape[0][0]'] block6f_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6f_se_reduce[0][0]'] block6f_se_excite (Multiply) (None, None, None, 0 ['block6f_activation[0][0]', 1152) 'block6f_se_expand[0][0]'] block6f_project_conv (Conv2D) (None, None, None, 221184 ['block6f_se_excite[0][0]'] 192) block6f_project_bn (BatchNorma (None, None, None, 768 ['block6f_project_conv[0][0]'] lization) 192) block6f_drop (Dropout) (None, None, None, 0 ['block6f_project_bn[0][0]'] 192) block6f_add (Add) (None, None, None, 0 ['block6f_drop[0][0]', 192) 'block6e_add[0][0]'] block6g_expand_conv (Conv2D) (None, None, None, 221184 ['block6f_add[0][0]'] 1152) block6g_expand_bn (BatchNormal (None, None, None, 4608 ['block6g_expand_conv[0][0]'] ization) 1152) block6g_expand_activation (Act (None, None, None, 0 ['block6g_expand_bn[0][0]'] ivation) 1152) block6g_dwconv2 (DepthwiseConv (None, None, None, 10368 ['block6g_expand_activation[0][0] 2D) 1152) '] block6g_bn (BatchNormalization (None, None, None, 4608 ['block6g_dwconv2[0][0]'] ) 1152) block6g_activation (Activation (None, None, None, 0 ['block6g_bn[0][0]'] ) 1152) block6g_se_squeeze (GlobalAver (None, 1152) 0 ['block6g_activation[0][0]'] agePooling2D) block6g_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6g_se_squeeze[0][0]'] block6g_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6g_se_reshape[0][0]'] block6g_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6g_se_reduce[0][0]'] block6g_se_excite (Multiply) (None, None, None, 0 ['block6g_activation[0][0]', 1152) 'block6g_se_expand[0][0]'] block6g_project_conv (Conv2D) (None, None, None, 221184 ['block6g_se_excite[0][0]'] 192) block6g_project_bn (BatchNorma (None, None, None, 768 ['block6g_project_conv[0][0]'] lization) 192) block6g_drop (Dropout) (None, None, None, 0 ['block6g_project_bn[0][0]'] 192) block6g_add (Add) (None, None, None, 0 ['block6g_drop[0][0]', 192) 'block6f_add[0][0]'] block6h_expand_conv (Conv2D) (None, None, None, 221184 ['block6g_add[0][0]'] 1152) block6h_expand_bn (BatchNormal (None, None, None, 4608 ['block6h_expand_conv[0][0]'] ization) 1152) block6h_expand_activation (Act (None, None, None, 0 ['block6h_expand_bn[0][0]'] ivation) 1152) block6h_dwconv2 (DepthwiseConv (None, None, None, 10368 ['block6h_expand_activation[0][0] 2D) 1152) '] block6h_bn (BatchNormalization (None, None, None, 4608 ['block6h_dwconv2[0][0]'] ) 1152) block6h_activation (Activation (None, None, None, 0 ['block6h_bn[0][0]'] ) 1152) block6h_se_squeeze (GlobalAver (None, 1152) 0 ['block6h_activation[0][0]'] agePooling2D) block6h_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6h_se_squeeze[0][0]'] block6h_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6h_se_reshape[0][0]'] block6h_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6h_se_reduce[0][0]'] block6h_se_excite (Multiply) (None, None, None, 0 ['block6h_activation[0][0]', 1152) 'block6h_se_expand[0][0]'] block6h_project_conv (Conv2D) (None, None, None, 221184 ['block6h_se_excite[0][0]'] 192) block6h_project_bn (BatchNorma (None, None, None, 768 ['block6h_project_conv[0][0]'] lization) 192) block6h_drop (Dropout) (None, None, None, 0 ['block6h_project_bn[0][0]'] 192) block6h_add (Add) (None, None, None, 0 ['block6h_drop[0][0]', 192) 'block6g_add[0][0]'] top_conv (Conv2D) (None, None, None, 245760 ['block6h_add[0][0]'] 1280) top_bn (BatchNormalization) (None, None, None, 5120 ['top_conv[0][0]'] 1280) top_activation (Activation) (None, None, None, 0 ['top_bn[0][0]'] 1280) ================================================================================================== Total params: 5,919,312 Trainable params: 0 Non-trainable params: 5,919,312 __________________________________________________________________________________________________
You can see how each of the different layers have a certain number of parameters each. Since we are using a pre-trained model, you can think of all of these parameters are patterns the base model has learned on another dataset. And because we set base_model.trainable = False
, these patterns remain as they are during training (they're frozen and don't get updated).
Alright that was the base model, let's see the summary of our overall model.
# Check summary of model constructed with Functional API
model_0.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_layer (InputLayer) [(None, 224, 224, 3)] 0 efficientnetv2-b0 (Function (None, None, None, 1280) 5919312 al) global_average_pooling_laye (None, 1280) 0 r (GlobalAveragePooling2D) output_layer (Dense) (None, 10) 12810 ================================================================= Total params: 5,932,122 Trainable params: 12,810 Non-trainable params: 5,919,312 _________________________________________________________________
Our overall model has five layers but really, one of those layers (efficientnetv2-b0
) has 269 layers.
You can see how the output shape started out as (None, 224, 224, 3)
for the input layer (the shape of our images) but was transformed to be (None, 10)
by the output layer (the shape of our labels), where None
is the placeholder for the batch size.
Notice too, the only trainable parameters in the model are those in the output layer.
How do our model's training curves look?
# Check out our model's training curves
plot_loss_curves(history_10_percent)
Getting a feature vector from a trained model¶
🤔 Question: What happens with the
tf.keras.layers.GlobalAveragePooling2D()
layer? I haven't seen it before.
The tf.keras.layers.GlobalAveragePooling2D()
layer transforms a 4D tensor into a 2D tensor by averaging the values across the inner-axes.
The previous sentence is a bit of a mouthful, so let's see an example.
# Define input tensor shape (same number of dimensions as the output of efficientnetv2-b0)
input_shape = (1, 4, 4, 3)
# Create a random tensor
tf.random.set_seed(42)
input_tensor = tf.random.normal(input_shape)
print(f"Random input tensor:\n {input_tensor}\n")
# Pass the random tensor through a global average pooling 2D layer
global_average_pooled_tensor = tf.keras.layers.GlobalAveragePooling2D()(input_tensor)
print(f"2D global average pooled random tensor:\n {global_average_pooled_tensor}\n")
# Check the shapes of the different tensors
print(f"Shape of input tensor: {input_tensor.shape}")
print(f"Shape of 2D global averaged pooled input tensor: {global_average_pooled_tensor.shape}")
Random input tensor: [[[[ 0.3274685 -0.8426258 0.3194337 ] [-1.4075519 -2.3880599 -1.0392479 ] [-0.5573232 0.539707 1.6994323 ] [ 0.28893656 -1.5066116 -0.2645474 ]] [[-0.59722406 -1.9171132 -0.62044144] [ 0.8504023 -0.40604794 -3.0258412 ] [ 0.9058464 0.29855987 -0.22561555] [-0.7616443 -1.8917141 -0.93847126]] [[ 0.77852213 -0.47338897 0.97772694] [ 0.24694404 0.20573747 -0.5256233 ] [ 0.32410017 0.02545409 -0.10638497] [-0.6369475 1.1603122 0.2507359 ]] [[-0.41728503 0.4012578 -1.4145443 ] [-0.5931857 -1.6617213 0.33567193] [ 0.10815629 0.23479682 -0.56668764] [-0.35819843 0.88698614 0.52744764]]]] 2D global average pooled random tensor: [[-0.09368646 -0.45840448 -0.2885598 ]] Shape of input tensor: (1, 4, 4, 3) Shape of 2D global averaged pooled input tensor: (1, 3)
You can see the tf.keras.layers.GlobalAveragePooling2D()
layer condensed the input tensor from shape (1, 4, 4, 3)
to (1, 3)
. It did so by averaging the input_tensor
across the middle two axes.
We can replicate this operation using the tf.reduce_mean()
operation and specifying the appropriate axes.
# This is the same as GlobalAveragePooling2D()
tf.reduce_mean(input_tensor, axis=[1, 2]) # average across the middle axes
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-0.09368646, -0.45840448, -0.2885598 ]], dtype=float32)>
Doing this not only makes the output of the base model compatible with the input shape requirement of our output layer (tf.keras.layers.Dense()
), it also condenses the information found by the base model into a lower dimension feature vector.
🔑 Note: One of the reasons feature extraction transfer learning is named how it is is because what often happens is a pretrained model outputs a feature vector (a long tensor of numbers, in our case, this is the output of the
tf.keras.layers.GlobalAveragePooling2D()
layer) which can then be used to extract patterns out of.
🛠 Practice: Do the same as the above cell but for
tf.keras.layers.GlobalMaxPool2D()
.
Running a series of transfer learning experiments¶
We've seen the incredible results of transfer learning on 10% of the training data, what about 1% of the training data?
What kind of results do you think we can get using 100x less data than the original CNN models we built ourselves?
Why don't we answer that question while running the following modelling experiments:
- Model 1: Use feature extraction transfer learning on 1% of the training data with data augmentation.
- Model 2: Use feature extraction transfer learning on 10% of the training data with data augmentation and save the results to a checkpoint.
- Model 3: Fine-tune the Model 2 checkpoint on 10% of the training data with data augmentation.
- Model 4: Fine-tune the Model 2 checkpoint on 100% of the training data with data augmentation.
While all of the experiments will be run on different versions of the training data, they will all be evaluated on the same test dataset, this ensures the results of each experiment are as comparable as possible.
All experiments will be done using the EfficientNetV2B0
model within the tf.keras.applications.efficientnet_v2
module.
To make sure we're keeping track of our experiments, we'll use our create_tensorboard_callback()
function to log all of the model training logs.
We'll construct each model using the Keras Functional API and instead of implementing data augmentation in the ImageDataGenerator
class as we have previously, we're going to build it right into the model using the tf.keras.layers
module.
Let's begin by downloading the data for experiment 1, using feature extraction transfer learning on 1% of the training data with data augmentation.
# Download and unzip data
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_1_percent.zip
unzip_data("10_food_classes_1_percent.zip")
# Create training and test dirs
train_dir_1_percent = "10_food_classes_1_percent/train/"
test_dir = "10_food_classes_1_percent/test/"
--2023-08-18 01:40:38-- https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_1_percent.zip Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.20.128, 108.177.98.128, 74.125.197.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.20.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 133612354 (127M) [application/zip] Saving to: ‘10_food_classes_1_percent.zip’ 10_food_classes_1_p 100%[===================>] 127.42M 222MB/s in 0.6s 2023-08-18 01:40:39 (222 MB/s) - ‘10_food_classes_1_percent.zip’ saved [133612354/133612354]
How many images are we working with?
# Walk through 1 percent data directory and list number of files
walk_through_dir("10_food_classes_1_percent")
There are 2 directories and 0 images in '10_food_classes_1_percent'. There are 10 directories and 0 images in '10_food_classes_1_percent/train'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/ramen'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/chicken_curry'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/pizza'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/ice_cream'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/grilled_salmon'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/steak'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/chicken_wings'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/hamburger'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/sushi'. There are 0 directories and 7 images in '10_food_classes_1_percent/train/fried_rice'. There are 10 directories and 0 images in '10_food_classes_1_percent/test'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/ramen'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/chicken_curry'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/pizza'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/ice_cream'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/grilled_salmon'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/steak'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/chicken_wings'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/hamburger'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/sushi'. There are 0 directories and 250 images in '10_food_classes_1_percent/test/fried_rice'.
Alright, looks like we've only got seven images of each class, this should be a bit of a challenge for our model.
🔑 Note: As with the 10% of data subset, the 1% of images were chosen at random from the original full training dataset. The test images are the same as the ones which have previously been used. If you want to see how this data was preprocessed, check out the Food Vision Image Preprocessing notebook.
Time to load our images in as tf.data.Dataset
objects, to do so, we'll use the image_dataset_from_directory()
method.
import tensorflow as tf
IMG_SIZE = (224, 224)
train_data_1_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir_1_percent,
label_mode="categorical",
batch_size=32, # default
image_size=IMG_SIZE)
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
label_mode="categorical",
image_size=IMG_SIZE)
Found 70 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
Data loaded. Time to augment it.
Adding data augmentation right into the model¶
Previously we've used the different parameters of the ImageDataGenerator
class to augment our training images, this time we're going to build data augmentation right into the model.
How?
Using the tf.keras.layers
module and creating a dedicated data augmentation layer.
This a relatively new feature added to TensorFlow 2.10+ but it's very powerful. Adding a data augmentation layer to the model has the following benefits:
- Preprocessing of the images (augmenting them) happens on the GPU rather than on the CPU (much faster).
- Images are best preprocessed on the GPU where as text and structured data are more suited to be preprocessed on the CPU.
- Image data augmentation only happens during training so we can still export our whole model and use it elsewhere. And if someone else wanted to train the same model as us, including the same kind of data augmentation, they could.
Example of using data augmentation as the first layer within a model (EfficientNetB0).
📚 Resource: For more information on different methods of data augmentation, check out the the TensorFlow data augmentation guide.
To use data augmentation right within our model we'll create a Keras Sequential model consisting of only data preprocessing layers, we can then use this Sequential model within another Functional model.
If that sounds confusing, it'll make sense once we create it in code.
The data augmentation transformations we're going to use are:
tf.keras.layers.RandomFlip
- flips image on horizontal or vertical axis.tf.keras.layersRandomRotation
- randomly rotates image by a specified amount.tf.keras.layers.RandomZoom
- randomly zooms into an image by specified amount.tf.keras.layers.RandomHeight
- randomly shifts image height by a specified amount.tf.keras.layers.RandomWidth
- randomly shifts image width by a specified amount.tf.keras.layers.Rescaling
- normalizes the image pixel values to be between 0 and 1, this is worth mentioning because it is required for some image models but since we're usingtf.keras.applications.efficientnet_v2.EfficientNetV2B0
, it's not required (the model pretrained model implements rescaling itself).
There are more option but these will do for now.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# from tensorflow.keras.layers.experimental import preprocessing
# NEW: Newer versions of TensorFlow (2.10+) can use the tensorflow.keras.layers API directly for data augmentation
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
layers.RandomHeight(0.2),
layers.RandomWidth(0.2),
# preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNetV2B0
], name ="data_augmentation")
# # UPDATE: Previous versions of TensorFlow (e.g. 2.4 and below used the tensorflow.keras.layers.experimental.processing API)
# # Create a data augmentation stage with horizontal flipping, rotations, zooms
# data_augmentation = keras.Sequential([
# preprocessing.RandomFlip("horizontal"),
# preprocessing.RandomRotation(0.2),
# preprocessing.RandomZoom(0.2),
# preprocessing.RandomHeight(0.2),
# preprocessing.RandomWidth(0.2),
# # preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNetV2B0
# ], name ="data_augmentation")
And that's it! Our data augmentation Sequential model is ready to go. As you'll see shortly, we'll be able to slot this "model" as a layer into our transfer learning model later on.
But before we do that, let's test it out by passing random images through it.
# View a random image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import random
target_class = random.choice(train_data_1_percent.class_names) # choose a random class
target_dir = "10_food_classes_1_percent/train/" + target_class # create the target directory
random_image = random.choice(os.listdir(target_dir)) # choose a random image from target directory
random_image_path = target_dir + "/" + random_image # create the choosen random image path
img = mpimg.imread(random_image_path) # read in the chosen target image
plt.imshow(img) # plot the target image
plt.title(f"Original random image from class: {target_class}")
plt.axis(False); # turn off the axes
# Augment the image
augmented_img = data_augmentation(tf.expand_dims(img, axis=0)) # data augmentation model requires shape (None, height, width, 3)
plt.figure()
plt.imshow(tf.squeeze(augmented_img)/255.) # requires normalization after augmentation
plt.title(f"Augmented random image from class: {target_class}")
plt.axis(False);