06. Transfer Learning with TensorFlow Part 3: Scaling up (🍔👁 Food Vision mini)¶
In the previous two notebooks (transfer learning part 1: feature extraction and part 2: fine-tuning) we've seen the power of transfer learning.
Now we know our smaller modelling experiments are working, it's time to step things up a notch with more data.
This is a common practice in machine learning and deep learning: get a model working on a small amount of data before scaling it up to a larger amount of data.
🔑 Note: You haven't forgotten the machine learning practitioners motto have you? "Experiment, experiment, experiment."
It's time to get closer to our Food Vision project coming to life. In this notebook we're going to scale up from using 10 classes of the Food101 data to using all of the classes in the Food101 dataset.
Our goal is to beat the original Food101 paper's results with 10% of data.
Machine learning practitioners are serial experimenters. Start small, get a model working, see if your experiments work then gradually scale them up to where you want to go (we're going to be looking at scaling up throughout this notebook).
What we're going to cover¶
We're going to go through the follow with TensorFlow:
- Downloading and preparing 10% of the Food101 data (10% of training data)
- Training a feature extraction transfer learning model on 10% of the Food101 training data
- Fine-tuning our feature extraction model
- Saving and loaded our trained model
- Evaluating the performance of our Food Vision model trained on 10% of the training data
- Finding our model's most wrong predictions
- Making predictions with our Food Vision model on custom images of food
How you can use this notebook¶
You can read through the descriptions and the code (it should all run, except for the cells which error on purpose), but there's a better option.
Write all of the code yourself.
Yes. I'm serious. Create a new notebook, and rewrite each line by yourself. Investigate it, see if you can break it, why does it break?
You don't have to write the text descriptions but writing the code yourself is a great way to get hands-on experience.
Don't worry if you make mistakes, we all do. The way to get better and make less mistakes is to write more code.
📖 Resources:
- See the full set of course materials on GitHub: https://github.com/mrdbourke/tensorflow-deep-learning
- See updates for this notebook here: https://github.com/mrdbourke/tensorflow-deep-learning/discussions/549
# Are we using a GPU?
# If not, and you're in Google Colab, go to Runtime -> Change runtime type -> Hardware accelerator -> GPU
!nvidia-smi
Thu May 18 02:18:58 2023 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100-SXM... Off | 00000000:00:04.0 Off | 0 | | N/A 45C P0 49W / 400W | 0MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
import datetime
print(f"Notebook last run (end-to-end): {datetime.datetime.now()}")
Notebook last run (end-to-end): 2023-05-18 02:18:58.717388
Creating helper functions¶
We've created a series of helper functions throughout the previous notebooks. Instead of rewriting them (tedious), we'll import the helper_functions.py
file from the GitHub repo.
# Get helper functions file
!wget https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
--2023-05-18 02:18:58-- https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 10246 (10K) [text/plain] Saving to: ‘helper_functions.py’ helper_functions.py 100%[===================>] 10.01K --.-KB/s in 0s 2023-05-18 02:18:58 (109 MB/s) - ‘helper_functions.py’ saved [10246/10246]
# Import series of helper functions for the notebook (we've created/used these in previous notebooks)
from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, compare_historys, walk_through_dir
101 Food Classes: Working with less data¶
So far we've confirmed the transfer learning model's we've been using work pretty well with the 10 Food Classes dataset. Now it's time to step it up and see how they go with the full 101 Food Classes.
In the original Food101 dataset there's 1000 images per class (750 of each class in the training set and 250 of each class in the test set), totalling 101,000 imags.
We could start modelling straight away on this large dataset but in the spirit of continually experimenting, we're going to see how our previously working model's go with 10% of the training data.
This means for each of the 101 food classes we'll be building a model on 75 training images and evaluating it on 250 test images.
Downloading and preprocessing the data¶
Just as before we'll download a subset of the Food101 dataset which has been extracted from the original dataset (to see the preprocessing of the data check out the Food Vision preprocessing notebook).
We download the data as a zip file so we'll use our unzip_data()
function to unzip it.
# Download data from Google Storage (already preformatted)
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip
unzip_data("101_food_classes_10_percent.zip")
train_dir = "101_food_classes_10_percent/train/"
test_dir = "101_food_classes_10_percent/test/"
--2023-05-18 02:19:02-- https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.141.128, 142.251.2.128, 2607:f8b0:4023:c03::80, ... Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.141.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 1625420029 (1.5G) [application/zip] Saving to: ‘101_food_classes_10_percent.zip’ 101_food_classes_10 100%[===================>] 1.51G 196MB/s in 7.9s 2023-05-18 02:19:10 (197 MB/s) - ‘101_food_classes_10_percent.zip’ saved [1625420029/1625420029]
# How many images/classes are there?
walk_through_dir("101_food_classes_10_percent")
There are 2 directories and 0 images in '101_food_classes_10_percent'. There are 101 directories and 0 images in '101_food_classes_10_percent/train'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/dumplings'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/ceviche'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/cheese_plate'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/spring_rolls'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/hot_dog'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/tacos'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/cup_cakes'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/samosa'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/gnocchi'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/pad_thai'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_fries'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/spaghetti_carbonara'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/huevos_rancheros'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/frozen_yogurt'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/gyoza'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/spaghetti_bolognese'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/eggs_benedict'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/miso_soup'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/hummus'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/lasagna'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/mussels'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/ice_cream'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/foie_gras'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/chocolate_mousse'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/beet_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/poutine'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/garlic_bread'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/waffles'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/peking_duck'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/caprese_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/fried_calamari'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/fish_and_chips'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/macaroni_and_cheese'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/crab_cakes'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_onion_soup'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/caesar_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/baklava'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/pizza'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/escargots'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/carrot_cake'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/lobster_bisque'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/beignets'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/apple_pie'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/donuts'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/prime_rib'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/cheesecake'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/scallops'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/bruschetta'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/pulled_pork_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/chocolate_cake'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/pork_chop'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/croque_madame'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/sushi'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/paella'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/ramen'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/edamame'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/macarons'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/deviled_eggs'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/oysters'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_wings'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/hot_and_sour_soup'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/onion_rings'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/churros'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_toast'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/risotto'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_curry'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/club_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_quesadilla'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/hamburger'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/steak'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/beef_carpaccio'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/tuna_tartare'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/greek_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/omelette'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/shrimp_and_grits'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/panna_cotta'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/grilled_cheese_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/clam_chowder'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/sashimi'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/cannoli'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/creme_brulee'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/seaweed_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/strawberry_shortcake'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/guacamole'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/breakfast_burrito'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/ravioli'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/pho'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/takoyaki'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/grilled_salmon'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/pancakes'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/falafel'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/filet_mignon'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/beef_tartare'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/lobster_roll_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/fried_rice'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/bread_pudding'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/red_velvet_cake'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/nachos'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/tiramisu'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/baby_back_ribs'. There are 0 directories and 75 images in '101_food_classes_10_percent/train/bibimbap'. There are 101 directories and 0 images in '101_food_classes_10_percent/test'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/dumplings'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/ceviche'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/cheese_plate'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/spring_rolls'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/hot_dog'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/tacos'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/cup_cakes'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/samosa'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/gnocchi'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/pad_thai'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_fries'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/spaghetti_carbonara'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/huevos_rancheros'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/frozen_yogurt'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/gyoza'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/spaghetti_bolognese'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/eggs_benedict'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/miso_soup'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/hummus'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/lasagna'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/mussels'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/ice_cream'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/foie_gras'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/chocolate_mousse'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/beet_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/poutine'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/garlic_bread'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/waffles'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/peking_duck'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/caprese_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/fried_calamari'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/fish_and_chips'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/macaroni_and_cheese'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/crab_cakes'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_onion_soup'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/caesar_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/baklava'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/pizza'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/escargots'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/carrot_cake'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/lobster_bisque'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/beignets'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/apple_pie'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/donuts'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/prime_rib'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/cheesecake'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/scallops'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/bruschetta'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/pulled_pork_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/chocolate_cake'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/pork_chop'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/croque_madame'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/sushi'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/paella'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/ramen'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/edamame'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/macarons'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/deviled_eggs'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/oysters'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_wings'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/hot_and_sour_soup'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/onion_rings'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/churros'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_toast'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/risotto'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_curry'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/club_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_quesadilla'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/hamburger'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/steak'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/beef_carpaccio'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/tuna_tartare'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/greek_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/omelette'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/shrimp_and_grits'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/panna_cotta'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/grilled_cheese_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/clam_chowder'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/sashimi'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/cannoli'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/creme_brulee'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/seaweed_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/strawberry_shortcake'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/guacamole'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/breakfast_burrito'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/ravioli'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/pho'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/takoyaki'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/grilled_salmon'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/pancakes'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/falafel'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/filet_mignon'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/beef_tartare'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/lobster_roll_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/fried_rice'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/bread_pudding'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/red_velvet_cake'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/nachos'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/tiramisu'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/baby_back_ribs'. There are 0 directories and 250 images in '101_food_classes_10_percent/test/bibimbap'.
As before our data comes in the common image classification data format of:
Example of file structure
10_food_classes_10_percent <- top level folder
└───train <- training images
│ └───pizza
│ │ │ 1008104.jpg
│ │ │ 1638227.jpg
│ │ │ ...
│ └───steak
│ │ 1000205.jpg
│ │ 1647351.jpg
│ │ ...
│
└───test <- testing images
│ └───pizza
│ │ │ 1001116.jpg
│ │ │ 1507019.jpg
│ │ │ ...
│ └───steak
│ │ 100274.jpg
│ │ 1653815.jpg
│ │ ...
Let's use the image_dataset_from_directory()
function to turn our images and labels into a tf.data.Dataset
, a TensorFlow datatype which allows for us to pass it directory to our model.
For the test dataset, we're going to set shuffle=False
so we can perform repeatable evaluation and visualization on it later.
# Setup data inputs
import tensorflow as tf
IMG_SIZE = (224, 224)
train_data_all_10_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir,
label_mode="categorical",
image_size=IMG_SIZE)
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
label_mode="categorical",
image_size=IMG_SIZE,
shuffle=False) # don't shuffle test data for prediction analysis
Found 7575 files belonging to 101 classes. Found 25250 files belonging to 101 classes.
Wonderful! It looks like our data has been imported as expected with 75 images per class in the training set (75 images * 101 classes = 7575 images) and 25250 images in the test set (250 images * 101 classes = 25250 images).
Train a big dog model with transfer learning on 10% of 101 food classes¶
Our food image data has been imported into TensorFlow, time to model it.
To keep our experiments swift, we're going to start by using feature extraction transfer learning with a pre-trained model for a few epochs and then fine-tune for a few more epochs.
More specifically, our goal will be to see if we can beat the baseline from original Food101 paper (50.76% accuracy on 101 classes) with 10% of the training data and the following modelling setup:
- A
ModelCheckpoint
callback to save our progress during training, this means we could experiment with further training later without having to train from scratch every time - Data augmentation built right into the model
- A headless (no top layers)
EfficientNetB0
architecture fromtf.keras.applications
as our base model - A
Dense
layer with 101 hidden neurons (same as number of food classes) and softmax activation as the output layer - Categorical crossentropy as the loss function since we're dealing with more than two classes
- The Adam optimizer with the default settings
- Fitting for 5 full passes on the training data while evaluating on 15% of the test data
It seems like a lot but these are all things we've covered before in the Transfer Learning in TensorFlow Part 2: Fine-tuning notebook.
Let's start by creating the ModelCheckpoint
callback.
Since we want our model to perform well on unseen data we'll set it to monitor the validation accuracy metric and save the model weights which score the best on that.
# Create checkpoint callback to save model for later use
checkpoint_path = "101_classes_10_percent_data_model_checkpoint"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, # save only the model weights
monitor="val_accuracy", # save the model weights which score the best validation accuracy
save_best_only=True) # only keep the best model weights on file (delete the rest)
Checkpoint ready. Now let's create a small data augmentation model with the Sequential API. Because we're working with a reduced sized training set, this will help prevent our model from overfitting on the training data.
# Import the required modules for model creation
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
## NEW: Newer versions of TensorFlow (2.10+) can use the tensorflow.keras.layers API directly for data augmentation
data_augmentation = Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
layers.RandomHeight(0.2),
layers.RandomWidth(0.2),
# preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNetB0
], name ="data_augmentation")
## OLD
# # Setup data augmentation
# from tensorflow.keras.layers.experimental import preprocessing
# data_augmentation = Sequential([
# preprocessing.RandomFlip("horizontal"), # randomly flip images on horizontal edge
# preprocessing.RandomRotation(0.2), # randomly rotate images by a specific amount
# preprocessing.RandomHeight(0.2), # randomly adjust the height of an image by a specific amount
# preprocessing.RandomWidth(0.2), # randomly adjust the width of an image by a specific amount
# preprocessing.RandomZoom(0.2), # randomly zoom into an image
# # preprocessing.Rescaling(1./255) # keep for models like ResNet50V2, remove for EfficientNet
# ], name="data_augmentation")
Beautiful! We'll be able to insert the data_augmentation
Sequential model as a layer in our Functional API model. That way if we want to continue training our model at a later time, the data augmentation is already built right in.
Speaking of Functional API model's, time to put together a feature extraction transfer learning model using tf.keras.applications.efficientnet.EfficientNetB0
as our base model.
We'll import the base model using the parameter include_top=False
so we can add on our own output layers, notably GlobalAveragePooling2D()
(condense the outputs of the base model into a shape usable by the output layer) followed by a Dense
layer.
# Setup base model and freeze its layers (this will extract features)
base_model = tf.keras.applications.efficientnet.EfficientNetB0(include_top=False)
base_model.trainable = False
# Setup model architecture with trainable top layers
inputs = layers.Input(shape=(224, 224, 3), name="input_layer") # shape of input image
x = data_augmentation(inputs) # augment images (only happens during training)
x = base_model(x, training=False) # put the base model in inference mode so we can use it to extract features without updating the weights
x = layers.GlobalAveragePooling2D(name="global_average_pooling")(x) # pool the outputs of the base model
outputs = layers.Dense(len(train_data_all_10_percent.class_names), activation="softmax", name="output_layer")(x) # same number of outputs as classes
model = tf.keras.Model(inputs, outputs)
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5 16705208/16705208 [==============================] - 0s 0us/step
A colourful figure of the model we've created with: 224x224 images as input, data augmentation as a layer, EfficientNetB0 as a backbone, an averaging pooling layer as well as dense layer with 10 neurons (same as number of classes we're working with) as output.
Model created. Let's inspect it.
# Get a summary of our model
model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_layer (InputLayer) [(None, 224, 224, 3)] 0 data_augmentation (Sequenti (None, None, None, 3) 0 al) efficientnetb0 (Functional) (None, None, None, 1280) 4049571 global_average_pooling (Glo (None, 1280) 0 balAveragePooling2D) output_layer (Dense) (None, 101) 129381 ================================================================= Total params: 4,178,952 Trainable params: 129,381 Non-trainable params: 4,049,571 _________________________________________________________________
Looking good! Our Functional model has 5 layers but each of those layers have varying amounts of layers within them.
Notice the number of trainable and non-trainable parameters. It seems the only trainable parameters are within the output_layer
which is exactly what we're after with this first run of feature extraction; keep all the learned patterns in the base model (EfficientNetb0
) frozen whilst allowing the model to tune its outputs to our custom data.
Time to compile and fit.
# Compile
model.compile(loss="categorical_crossentropy",
optimizer=tf.keras.optimizers.Adam(), # use Adam with default settings
metrics=["accuracy"])
# Fit
history_all_classes_10_percent = model.fit(train_data_all_10_percent,
epochs=5, # fit for 5 epochs to keep experiments quick
validation_data=test_data,
validation_steps=int(0.15 * len(test_data)), # evaluate on smaller portion of test data
callbacks=[checkpoint_callback]) # save best model weights to file
Epoch 1/5 237/237 [==============================] - 29s 59ms/step - loss: 3.3881 - accuracy: 0.2796 - val_loss: 2.4697 - val_accuracy: 0.4661 Epoch 2/5 237/237 [==============================] - 12s 51ms/step - loss: 2.1996 - accuracy: 0.4937 - val_loss: 2.0252 - val_accuracy: 0.5188 Epoch 3/5 237/237 [==============================] - 12s 51ms/step - loss: 1.8245 - accuracy: 0.5675 - val_loss: 1.8783 - val_accuracy: 0.5344 Epoch 4/5 237/237 [==============================] - 12s 50ms/step - loss: 1.6061 - accuracy: 0.6055 - val_loss: 1.8206 - val_accuracy: 0.5355 Epoch 5/5 237/237 [==============================] - 11s 48ms/step - loss: 1.4493 - accuracy: 0.6440 - val_loss: 1.7824 - val_accuracy: 0.5310
Woah! It looks like our model is getting some impressive results, but remember, during training our model only evaluated on 15% of the test data. Let's see how it did on the whole test dataset.
# Evaluate model
results_feature_extraction_model = model.evaluate(test_data)
results_feature_extraction_model
790/790 [==============================] - 16s 21ms/step - loss: 1.5888 - accuracy: 0.5797
[1.5887829065322876, 0.5797227621078491]
Well it looks like we just beat our baseline (the results from the original Food101 paper) with 10% of the data! In under 5-minutes... that's the power of deep learning and more precisely, transfer learning: leveraging what one model has learned on another dataset for our own dataset.
How do the loss curves look?
plot_loss_curves(history_all_classes_10_percent)
🤔 Question: What do these curves suggest? Hint: ideally, the two curves should be very similar to each other, if not, there may be some overfitting or underfitting.
Fine-tuning¶
Our feature extraction transfer learning model is performing well. Why don't we try to fine-tune a few layers in the base model and see if we gain any improvements?
The good news is, thanks to the ModelCheckpoint
callback, we've got the saved weights of our already well-performing model so if fine-tuning doesn't add any benefits, we can revert back.
To fine-tune the base model we'll first set its trainable
attribute to True
, unfreezing all of the frozen.
Then since we've got a relatively small training dataset, we'll refreeze every layer except for the last 5, making them trainable.
# Unfreeze all of the layers in the base model
base_model.trainable = True
# Refreeze every layer except for the last 5
for layer in base_model.layers[:-5]:
layer.trainable = False
We just made a change to the layers in our model and what do we have to do every time we make a change to our model?
Recompile it.
Because we're fine-tuning, we'll use a 10x lower learning rate to ensure the updates to the previous trained weights aren't too large.
When fine-tuning and unfreezing layers of your pre-trained model, it's common practice to lower the learning rate you used for your feature extraction model. How much by? A 10x lower learning rate is usually a good place to to start.
# Recompile model with lower learning rate
model.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(1e-4), # 10x lower learning rate than default
metrics=['accuracy'])
Model recompiled, how about we make sure the layers we want are trainable?
# What layers in the model are trainable?
for layer in model.layers:
print(layer.name, layer.trainable)
input_layer True data_augmentation True efficientnetb0 True global_average_pooling True output_layer True
# Check which layers are trainable
for layer_number, layer in enumerate(base_model.layers):
print(layer_number, layer.name, layer.trainable)
0 input_1 False 1 rescaling False 2 normalization False 3 rescaling_1 False 4 stem_conv_pad False 5 stem_conv False 6 stem_bn False 7 stem_activation False 8 block1a_dwconv False 9 block1a_bn False 10 block1a_activation False 11 block1a_se_squeeze False 12 block1a_se_reshape False 13 block1a_se_reduce False 14 block1a_se_expand False 15 block1a_se_excite False 16 block1a_project_conv False 17 block1a_project_bn False 18 block2a_expand_conv False 19 block2a_expand_bn False 20 block2a_expand_activation False 21 block2a_dwconv_pad False 22 block2a_dwconv False 23 block2a_bn False 24 block2a_activation False 25 block2a_se_squeeze False 26 block2a_se_reshape False 27 block2a_se_reduce False 28 block2a_se_expand False 29 block2a_se_excite False 30 block2a_project_conv False 31 block2a_project_bn False 32 block2b_expand_conv False 33 block2b_expand_bn False 34 block2b_expand_activation False 35 block2b_dwconv False 36 block2b_bn False 37 block2b_activation False 38 block2b_se_squeeze False 39 block2b_se_reshape False 40 block2b_se_reduce False 41 block2b_se_expand False 42 block2b_se_excite False 43 block2b_project_conv False 44 block2b_project_bn False 45 block2b_drop False 46 block2b_add False 47 block3a_expand_conv False 48 block3a_expand_bn False 49 block3a_expand_activation False 50 block3a_dwconv_pad False 51 block3a_dwconv False 52 block3a_bn False 53 block3a_activation False 54 block3a_se_squeeze False 55 block3a_se_reshape False 56 block3a_se_reduce False 57 block3a_se_expand False 58 block3a_se_excite False 59 block3a_project_conv False 60 block3a_project_bn False 61 block3b_expand_conv False 62 block3b_expand_bn False 63 block3b_expand_activation False 64 block3b_dwconv False 65 block3b_bn False 66 block3b_activation False 67 block3b_se_squeeze False 68 block3b_se_reshape False 69 block3b_se_reduce False 70 block3b_se_expand False 71 block3b_se_excite False 72 block3b_project_conv False 73 block3b_project_bn False 74 block3b_drop False 75 block3b_add False 76 block4a_expand_conv False 77 block4a_expand_bn False 78 block4a_expand_activation False 79 block4a_dwconv_pad False 80 block4a_dwconv False 81 block4a_bn False 82 block4a_activation False 83 block4a_se_squeeze False 84 block4a_se_reshape False 85 block4a_se_reduce False 86 block4a_se_expand False 87 block4a_se_excite False 88 block4a_project_conv False 89 block4a_project_bn False 90 block4b_expand_conv False 91 block4b_expand_bn False 92 block4b_expand_activation False 93 block4b_dwconv False 94 block4b_bn False 95 block4b_activation False 96 block4b_se_squeeze False 97 block4b_se_reshape False 98 block4b_se_reduce False 99 block4b_se_expand False 100 block4b_se_excite False 101 block4b_project_conv False 102 block4b_project_bn False 103 block4b_drop False 104 block4b_add False 105 block4c_expand_conv False 106 block4c_expand_bn False 107 block4c_expand_activation False 108 block4c_dwconv False 109 block4c_bn False 110 block4c_activation False 111 block4c_se_squeeze False 112 block4c_se_reshape False 113 block4c_se_reduce False 114 block4c_se_expand False 115 block4c_se_excite False 116 block4c_project_conv False 117 block4c_project_bn False 118 block4c_drop False 119 block4c_add False 120 block5a_expand_conv False 121 block5a_expand_bn False 122 block5a_expand_activation False 123 block5a_dwconv False 124 block5a_bn False 125 block5a_activation False 126 block5a_se_squeeze False 127 block5a_se_reshape False 128 block5a_se_reduce False 129 block5a_se_expand False 130 block5a_se_excite False 131 block5a_project_conv False 132 block5a_project_bn False 133 block5b_expand_conv False 134 block5b_expand_bn False 135 block5b_expand_activation False 136 block5b_dwconv False 137 block5b_bn False 138 block5b_activation False 139 block5b_se_squeeze False 140 block5b_se_reshape False 141 block5b_se_reduce False 142 block5b_se_expand False 143 block5b_se_excite False 144 block5b_project_conv False 145 block5b_project_bn False 146 block5b_drop False 147 block5b_add False 148 block5c_expand_conv False 149 block5c_expand_bn False 150 block5c_expand_activation False 151 block5c_dwconv False 152 block5c_bn False 153 block5c_activation False 154 block5c_se_squeeze False 155 block5c_se_reshape False 156 block5c_se_reduce False 157 block5c_se_expand False 158 block5c_se_excite False 159 block5c_project_conv False 160 block5c_project_bn False 161 block5c_drop False 162 block5c_add False 163 block6a_expand_conv False 164 block6a_expand_bn False 165 block6a_expand_activation False 166 block6a_dwconv_pad False 167 block6a_dwconv False 168 block6a_bn False 169 block6a_activation False 170 block6a_se_squeeze False 171 block6a_se_reshape False 172 block6a_se_reduce False 173 block6a_se_expand False 174 block6a_se_excite False 175 block6a_project_conv False 176 block6a_project_bn False 177 block6b_expand_conv False 178 block6b_expand_bn False 179 block6b_expand_activation False 180 block6b_dwconv False 181 block6b_bn False 182 block6b_activation False 183 block6b_se_squeeze False 184 block6b_se_reshape False 185 block6b_se_reduce False 186 block6b_se_expand False 187 block6b_se_excite False 188 block6b_project_conv False 189 block6b_project_bn False 190 block6b_drop False 191 block6b_add False 192 block6c_expand_conv False 193 block6c_expand_bn False 194 block6c_expand_activation False 195 block6c_dwconv False 196 block6c_bn False 197 block6c_activation False 198 block6c_se_squeeze False 199 block6c_se_reshape False 200 block6c_se_reduce False 201 block6c_se_expand False 202 block6c_se_excite False 203 block6c_project_conv False 204 block6c_project_bn False 205 block6c_drop False 206 block6c_add False 207 block6d_expand_conv False 208 block6d_expand_bn False 209 block6d_expand_activation False 210 block6d_dwconv False 211 block6d_bn False 212 block6d_activation False 213 block6d_se_squeeze False 214 block6d_se_reshape False 215 block6d_se_reduce False 216 block6d_se_expand False 217 block6d_se_excite False 218 block6d_project_conv False 219 block6d_project_bn False 220 block6d_drop False 221 block6d_add False 222 block7a_expand_conv False 223 block7a_expand_bn False 224 block7a_expand_activation False 225 block7a_dwconv False 226 block7a_bn False 227 block7a_activation False 228 block7a_se_squeeze False 229 block7a_se_reshape False 230 block7a_se_reduce False 231 block7a_se_expand False 232 block7a_se_excite False 233 block7a_project_conv True 234 block7a_project_bn True 235 top_conv True 236 top_bn True 237 top_activation True
Excellent! Time to fine-tune our model.
Another 5 epochs should be enough to see whether any benefits come about (though we could always try more).
We'll start the training off where the feature extraction model left off using the initial_epoch
parameter in the fit()
function.
# Fine-tune for 5 more epochs
fine_tune_epochs = 10 # model has already done 5 epochs, this is the total number of epochs we're after (5+5=10)
history_all_classes_10_percent_fine_tune = model.fit(train_data_all_10_percent,
epochs=fine_tune_epochs,
validation_data=test_data,
validation_steps=int(0.15 * len(test_data)), # validate on 15% of the test data
initial_epoch=history_all_classes_10_percent.epoch[-1]) # start from previous last epoch
Epoch 5/10 237/237 [==============================] - 51s 173ms/step - loss: 1.2195 - accuracy: 0.6764 - val_loss: 1.7268 - val_accuracy: 0.5392 Epoch 6/10 237/237 [==============================] - 25s 104ms/step - loss: 1.1019 - accuracy: 0.7006 - val_loss: 1.7007 - val_accuracy: 0.5493 Epoch 7/10 237/237 [==============================] - 22s 91ms/step - loss: 1.0089 - accuracy: 0.7291 - val_loss: 1.7452 - val_accuracy: 0.5426 Epoch 8/10 237/237 [==============================] - 20s 84ms/step - loss: 0.9453 - accuracy: 0.7480 - val_loss: 1.7133 - val_accuracy: 0.5506 Epoch 9/10 237/237 [==============================] - 19s 79ms/step - loss: 0.8790 - accuracy: 0.7673 - val_loss: 1.7370 - val_accuracy: 0.5471 Epoch 10/10 237/237 [==============================] - 17s 72ms/step - loss: 0.8351 - accuracy: 0.7768 - val_loss: 1.7431 - val_accuracy: 0.5498
Once again, during training we were only evaluating on a small portion of the test data, let's find out how our model went on all of the test data.
# Evaluate fine-tuned model on the whole test dataset
results_all_classes_10_percent_fine_tune = model.evaluate(test_data)
results_all_classes_10_percent_fine_tune
790/790 [==============================] - 16s 21ms/step - loss: 1.5078 - accuracy: 0.6011
[1.5077574253082275, 0.6010693311691284]
Hmm... it seems like our model got a slight boost from fine-tuning.
We might get a better picture by using our compare_historys()
function and seeing what the training curves say.
compare_historys(original_history=history_all_classes_10_percent,
new_history=history_all_classes_10_percent_fine_tune,
initial_epochs=5)
It seems that after fine-tuning, our model's training metrics improved significantly but validation, not so much. Looks like our model is starting to overfit.
This is okay though, its very often the case that fine-tuning leads to overfitting when the data a pre-trained model has been trained on is similar to your custom data.
In our case, our pre-trained model, EfficientNetB0
was trained on ImageNet which contains many real life pictures of food just like our food dataset.
If feautre extraction already works well, the improvements you see from fine-tuning may not be as great as if your dataset was significantly different from the data your base model was pre-trained on.
# # Save model to drive so it can be used later
# model.save("drive/My Drive/tensorflow_course/101_food_class_10_percent_saved_big_dog_model")
Evaluating the performance of the big dog model across all different classes¶
We've got a trained and saved model which according to the evaluation metrics we've used is performing fairly well.
But metrics schmetrics, let's dive a little deeper into our model's performance and get some visualizations going.
To do so, we'll load in the saved model and use it to make some predictions on the test dataset.
🔑 Note: Evaluating a machine learning model is as important as training one. Metrics can be deceiving. You should always visualize your model's performance on unseen data to make sure you aren't being fooled good looking training numbers.
import tensorflow as tf
# Download pre-trained model from Google Storage (like a cooking show, I trained this model earlier, so the results may be different than above)
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
saved_model_path = "06_101_food_class_10_percent_saved_big_dog_model.zip"
unzip_data(saved_model_path)
# Note: loading a model will output a lot of 'WARNINGS', these can be ignored: https://www.tensorflow.org/tutorials/keras/save_and_load#save_checkpoints_during_training
# There's also a thread on GitHub trying to fix these warnings: https://github.com/tensorflow/tensorflow/issues/40166
# model = tf.keras.models.load_model("drive/My Drive/tensorflow_course/101_food_class_10_percent_saved_big_dog_model/") # path to drive model
model = tf.keras.models.load_model(saved_model_path.split(".")[0]) # don't include ".zip" in loaded model path
To make sure our loaded model is indead a trained model, let's evaluate its performance on the test dataset.
# Check to see if loaded model is a trained model
loaded_loss, loaded_accuracy = model.evaluate(test_data)
loaded_loss, loaded_accuracy
790/790 [==============================] - 19s 22ms/step - loss: 1.8022 - accuracy: 0.6078
(1.8021684885025024, 0.6078019738197327)
Wonderful! It looks like our loaded model is performing just as well as it was before we saved it. Let's make some predictions.
Making predictions with our trained model¶
To evaluate our trained model, we need to make some predictions with it and then compare those predictions to the test dataset.
Because the model has never seen the test dataset, this should give us an indication of how the model will perform in the real world on data similar to what it has been trained on.
To make predictions with our trained model, we can use the predict()
method passing it the test data.
Since our data is multi-class, doing this will return a prediction probably tensor for each sample.
In other words, every time the trained model see's an image it will compare it to all of the patterns it learned during training and return an output for every class (all 101 of them) of how likely the image is to be that class.
# Make predictions with model
pred_probs = model.predict(test_data, verbose=1) # set verbosity to see how long it will take
790/790 [==============================] - 15s 18ms/step
We just passed all of the test images to our model and asked it to make a prediction on what food it thinks is in each.
So if we had 25250 images in the test dataset, how many predictions do you think we should have?
# How many predictions are there?
len(pred_probs)
25250
And if each image could be one of 101 classes, how many predictions do you think we'll have for each image?
# What's the shape of our predictions?
pred_probs.shape
(25250, 101)
What we've got is often referred to as a predictions probability tensor (or array).
Let's see what the first 10 look like.
# How do they look?
pred_probs[:10]
array([[5.9555572e-02, 3.5662119e-06, 4.1279349e-02, ..., 1.4194699e-09, 8.4039129e-05, 3.0820314e-03], [9.6338320e-01, 1.3765826e-09, 8.5042708e-04, ..., 5.4804143e-05, 7.8341188e-12, 9.7811304e-10], [9.5942634e-01, 3.2432759e-05, 1.4769287e-03, ..., 7.1438910e-07, 5.5323352e-07, 4.0179562e-05], ..., [4.7279873e-01, 1.2954312e-07, 1.4748254e-03, ..., 5.9630687e-04, 6.7163957e-05, 2.3532135e-05], [4.4502247e-02, 4.7261000e-07, 1.2174438e-01, ..., 6.2917384e-06, 7.5576504e-06, 3.6633476e-03], [7.2373080e-01, 1.9256416e-09, 5.2089705e-05, ..., 1.2218992e-03, 1.5755526e-09, 9.6206924e-05]], dtype=float32)
Alright, it seems like we've got a bunch of tensors of really small numbers, how about we zoom into one of them?
# We get one prediction probability per class
print(f"Number of prediction probabilities for sample 0: {len(pred_probs[0])}")
print(f"What prediction probability sample 0 looks like:\n {pred_probs[0]}")
print(f"The class with the highest predicted probability by the model for sample 0: {pred_probs[0].argmax()}")
Number of prediction probabilities for sample 0: 101 What prediction probability sample 0 looks like: [5.95555715e-02 3.56621194e-06 4.12793495e-02 1.06392506e-09 8.19964274e-09 8.61560245e-09 8.10409119e-07 8.49006710e-07 1.98095986e-05 7.99297993e-07 3.17695403e-09 9.83702989e-07 2.83578265e-04 7.78589082e-10 7.43659853e-04 3.87942600e-05 6.44463080e-06 2.50333073e-06 3.77276956e-05 2.05599179e-07 1.55429479e-05 8.10427650e-07 2.60885736e-06 2.03088433e-07 8.30327224e-07 5.42988437e-06 3.74273668e-06 1.32360203e-08 2.74555851e-03 2.79426695e-05 6.86718571e-10 2.53517483e-05 1.66382728e-04 7.57455043e-10 4.02367179e-04 1.30578828e-08 1.79280721e-06 1.43956686e-06 2.30789818e-02 8.24019480e-07 8.61712351e-07 1.69789450e-06 7.03542946e-06 1.85722051e-08 2.87577478e-07 7.99586633e-06 2.07466110e-06 1.86462174e-07 3.34909487e-08 3.17168073e-04 1.04948031e-05 8.55388123e-07 8.47566068e-01 1.05432355e-05 4.34703423e-07 3.72825380e-05 3.50601949e-05 3.25856490e-05 6.74493349e-05 1.27383810e-08 2.62344230e-10 1.03191360e-05 8.54040845e-05 1.05053311e-06 2.11897259e-06 3.72938448e-05 7.52453388e-08 2.50746176e-04 9.31413410e-07 1.24124577e-04 6.20197443e-06 1.24211716e-08 4.04207458e-05 6.82277772e-08 1.24867688e-06 5.15563556e-08 7.48909628e-08 7.54529829e-05 7.75354492e-05 6.31509749e-07 9.79340939e-07 2.18733803e-05 1.49467369e-05 1.39609512e-07 1.22257961e-05 1.90126449e-02 4.97255787e-05 4.59980038e-06 1.51860661e-07 3.38210441e-07 3.89491328e-09 1.64673807e-07 8.08345867e-05 4.90067578e-06 2.41164742e-07 2.32299317e-05 3.09399824e-04 3.10968826e-05 1.41946987e-09 8.40391294e-05 3.08203138e-03] The class with the highest predicted probability by the model for sample 0: 52
As we discussed before, for each image tensor we pass to our model, because of the number of output neurons and activation function in the last layer (layers.Dense(len(train_data_all_10_percent.class_names), activation="softmax"
), it outputs a prediction probability between 0 and 1 for all each of the 101 classes.
And the index of the highest prediction probability can be considered what the model thinks is the most likely label. Similarly, the lower prediction probaiblity value, the less the model thinks that the target image is that specific class.
🔑 Note: Due to the nature of the softmax activation function, the sum of each of the prediction probabilities for a single sample will be 1 (or at least very close to 1). E.g.
pred_probs[0].sum() = 1
.
We can find the index of the maximum value in each prediction probability tensor using the argmax()
method.
# Get the class predicitons of each label
pred_classes = pred_probs.argmax(axis=1)
# How do they look?
pred_classes[:10]
array([52, 0, 0, 80, 79, 61, 29, 0, 85, 0])
Beautiful! We've now got the predicted class index for each of the samples in our test dataset.
We'll be able to compare these to the test dataset labels to further evaluate our model.
To get the test dataset labels we can unravel our test_data
object (which is in the form of a tf.data.Dataset
) using the unbatch()
method.
Doing this will give us access to the images and labels in the test dataset. Since the labels are in one-hot encoded format, we'll take use the argmax()
method to return the index of the label.
🔑 Note: This unravelling is why we
shuffle=False
when creating the test data object. Otherwise, whenever we loaded the test dataset (like when making predictions), it would be shuffled every time, meaning if we tried to compare our predictions to the labels, they would be in different orders.
# Note: This might take a minute or so due to unravelling 790 batches
y_labels = []
for images, labels in test_data.unbatch(): # unbatch the test data and get images and labels
y_labels.append(labels.numpy().argmax()) # append the index which has the largest value (labels are one-hot)
y_labels[:10] # check what they look like (unshuffled)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Nice! Since test_data
isn't shuffled, the y_labels
array comes back in the same order as the pred_classes
array.
The final check is to see how many labels we've got.
# How many labels are there? (should be the same as how many prediction probabilities we have)
len(y_labels)
25250
As expected, the number of labels matches the number of images we've got. Time to compare our model's predictions with the ground truth labels.
Evaluating our models predictions¶
A very simple evaluation is to use Scikit-Learn's accuracy_score()
function which compares truth labels to predicted labels and returns an accuracy score.
If we've created our y_labels
and pred_classes
arrays correctly, this should return the same accuracy value (or at least very close) as the evaluate()
method we used earlier.
# Get accuracy score by comparing predicted classes to ground truth labels
from sklearn.metrics import accuracy_score
sklearn_accuracy = accuracy_score(y_labels, pred_classes)
sklearn_accuracy
0.6078019801980198
# Does the evaluate method compare to the Scikit-Learn measured accuracy?
import numpy as np
print(f"Close? {np.isclose(loaded_accuracy, sklearn_accuracy)} | Difference: {loaded_accuracy - sklearn_accuracy}")
Close? True | Difference: -6.378287120689663e-09
Okay, it looks like our pred_classes
array and y_labels
arrays are in the right orders.
How about we get a little bit more visual with a confusion matrix?
To do so, we'll use our make_confusion_matrix
function we created in a previous notebook.
# We'll import our make_confusion_matrix function from https://github.com/mrdbourke/tensorflow-deep-learning/blob/main/extras/helper_functions.py
# But if you run it out of the box, it doesn't really work for 101 classes...
# the cell below adds a little functionality to make it readable.
from helper_functions import make_confusion_matrix
# Note: The following confusion matrix code is a remix of Scikit-Learn's
# plot_confusion_matrix function - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
# Our function needs a different name to sklearn's plot_confusion_matrix
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False):
"""Makes a labelled confusion matrix comparing predictions and ground truth labels.
If classes is passed, confusion matrix will be labelled, if not, integer class values
will be used.
Args:
y_true: Array of truth labels (must be same shape as y_pred).
y_pred: Array of predicted labels (must be same shape as y_true).
classes: Array of class labels (e.g. string form). If `None`, integer labels are used.
figsize: Size of output figure (default=(10, 10)).
text_size: Size of output figure text (default=15).
norm: normalize values or not (default=False).
savefig: save confusion matrix to file (default=False).
Returns:
A labelled confusion matrix plot comparing y_true and y_pred.
Example usage:
make_confusion_matrix(y_true=test_labels, # ground truth test labels
y_pred=y_preds, # predicted labels
classes=class_names, # array of class label names
figsize=(15, 15),
text_size=10)
"""
# Create the confustion matrix
cm = confusion_matrix(y_true, y_pred)
cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
n_classes = cm.shape[0] # find the number of classes we're dealing with
# Plot the figure and make it pretty
fig, ax = plt.subplots(figsize=figsize)
cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
fig.colorbar(cax)
# Are there a list of classes?
if classes:
labels = classes
else:
labels = np.arange(cm.shape[0])
# Label the axes
ax.set(title="Confusion Matrix",
xlabel="Predicted label",
ylabel="True label",
xticks=np.arange(n_classes), # create enough axis slots for each class
yticks=np.arange(n_classes),
xticklabels=labels, # axes will labeled with class names (if they exist) or ints
yticklabels=labels)
# Make x-axis labels appear on bottom
ax.xaxis.set_label_position("bottom")
ax.xaxis.tick_bottom()
### Added: Rotate xticks for readability & increase font size (required due to such a large confusion matrix)
plt.xticks(rotation=70, fontsize=text_size)
plt.yticks(fontsize=text_size)
# Set the threshold for different colors
threshold = (cm.max() + cm.min()) / 2.
# Plot the text on each cell
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if norm:
plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
horizontalalignment="center",
color="white" if cm[i, j] > threshold else "black",
size=text_size)
else:
plt.text(j, i, f"{cm[i, j]}",
horizontalalignment="center",
color="white" if cm[i, j] > threshold else "black",
size=text_size)
# Save the figure to the current working directory
if savefig:
fig.savefig("confusion_matrix.png")
Right now our predictions and truth labels are in the form of integers, however, they'll be much easier to understand if we get their actual names. We can do so using the class_names
attribute on our test_data
object.
# Get the class names
class_names = test_data.class_names
class_names[:10]
['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito']
101 class names and 25250 predictions and ground truth labels ready to go! Looks like our confusion matrix is going to be a big one!
# Plot a confusion matrix with all 25250 predictions, ground truth labels and 101 classes
make_confusion_matrix(y_true=y_labels,
y_pred=pred_classes,
classes=class_names,
figsize=(100, 100),
text_size=20,
norm=False,
savefig=True)