import datetime
print(f"Last updated: {datetime.datetime.now()}")
Last updated: 2024-09-06 13:12:51.657220
What is matplotlib?¶
Matplotlib is a visualization library for Python.
As in, if you want to display something in a chart or graph, matplotlib can help you do that programmatically.
Many of the graphics you'll see in machine learning research papers or presentations are made with matplotlib.
Why matplotlib?¶
Matplotlib is part of the standard Python data stack (pandas, NumPy, matplotlib, Jupyter).
It has terrific integration with many other Python libraries.
pandas uses matplotlib as a backend to help visualize data in DataFrames.
What does this notebook cover?¶
A central idea in matplotlib is the concept of a "plot" (hence the name).
So we're going to practice making a series of different plots, which is a way to visually represent data.
Since there are basically limitless ways to create a plot, we're going to focus on a making and customizing (making them look pretty) a few common types of plots.
Where can I get help?¶
If you get stuck or think of something you'd like to do which this notebook doesn't cover, don't fear!
The recommended steps you take are:
- Try it - Since matplotlib is very friendly, your first step should be to use what you know and try figure out the answer to your own question (getting it wrong is part of the process). If in doubt, run your code.
- Search for it - If trying it on your own doesn't work, since someone else has probably tried to do something similar, try searching for your problem in the following places (either via a search engine or direct):
- matplotlib documentation - the best place for learning all of the vast functionality of matplotlib. Bonus: You can see a series of matplotlib cheatsheets on the matplotlib website.
- Stack Overflow - this is the developers Q&A hub, it's full of questions and answers of different problems across a wide range of software development topics and chances are, there's one related to your problem.
- ChatGPT - ChatGPT is very good at explaining code, however, it can make mistakes. Best to verify the code it writes first before using it. Try asking "Can you explain the following code for me? {your code here}" and then continue with follow up questions from there. But always be careful using generated code. Avoid blindly copying something you couldn't reproduce yourself with enough effort.
An example of searching for a matplotlib feature might be:
"how to colour the bars of a matplotlib plot"
Searching this on Google leads to this documentation page on the matplotlib website: https://matplotlib.org/stable/gallery/lines_bars_and_markers/bar_colors.html
The next steps here are to read through the post and see if it relates to your problem. If it does, great, take the code/information you need and rewrite it to suit your own problem.
- Ask for help - If you've been through the above 2 steps and you're still stuck, you might want to ask your question on Stack Overflow or in the ZTM Discord chat. Remember to be specific as possible and provide details on what you've tried.
Remember, you don't have to learn all of these functions off by heart to begin with.
What's most important is remembering to continually ask yourself, "what am I trying to visualize?"
Start by answering that question and then practicing finding the code which does it.
Let's get to visualizing some data!
0. Importing matplotlib¶
We'll start by importing matplotlib.pyplot
.
Why pyplot
?
Because pyplot
is a submodule for creating interactive plots programmatically.
pyplot
is often imported as the alias plt
.
Note: In older notebooks and tutorials of matplotlib, you may see the magic command
%matplotlib inline
. This was required to view plots inside a notebook, however, as of 2020 it is mostly no longer required.
# Older versions of Jupyter Notebooks and matplotlib required this magic command
# %matplotlib inline
# Import matplotlib and matplotlib.pyplot
import matplotlib
import matplotlib.pyplot as plt
print(f"matplotlib version: {matplotlib.__version__}")
matplotlib version: 3.9.2
1. 2 ways of creating plots¶
There are two main ways of creating plots in matplotlib.
matplotlib.pyplot.plot()
- Recommended for simple plots (e.g. x and y).matplotlib.pyplot.XX
(where XX can be one of many methods, this is known as the object-oriented API) - Recommended for more complex plots (for exampleplt.subplots()
to create multiple plots on the same Figure, we'll get to this later).
Both of these methods are still often created by building off import matplotlib.pyplot as plt
as a base.
Let's start simple.
# Create a simple plot, without the semi-colon
plt.plot()
[]
# With the semi-colon
plt.plot();
# You could use plt.show() if you want
plt.plot()
plt.show()
# Let's add some data
plt.plot([1, 2, 3, 4]);
# Create some data
x = [1, 2, 3, 4]
y = [11, 22, 33, 44]
A few quick things about a plot:
x
is the horizontal axis.y
is the vertical axis.- In a data point,
x
usually comes first, e.g.(3, 4)
would be(x=3, y=4)
. - The same is happens in
matplotlib.pyplot.plot()
,x
comes beforey
, e.g.plt.plot(x, y)
.
# Now a y-value too!
plt.plot(x, y);
Now let's try using the object-orientated version.
We'll start by creating a figure with plt.figure()
.
And then we'll add an axes with add_subplot
.
# Creating a plot with the object-orientated verison
fig = plt.figure() # create a figure
ax = fig.add_subplot() # add an axes
plt.show()
A note on the terminology:
- A
Figure
(e.g.fig = plt.figure()
) is the final image in matplotlib (and it may contain one or moreAxes
), often shortened tofig
. - The
Axes
are an individual plot (e.g.ax = fig.add_subplot()
), often shorted toax
.- One
Figure
can contain one or moreAxes
.
- One
- The
Axis
are x (horizontal), y (vertical), z (depth).
Now let's add some data to our pevious plot.
# Add some data to our previous plot
fig = plt.figure()
ax = fig.add_axes([1, 1, 1, 1])
ax.plot(x, y)
plt.show()
But there's an easier way we can use matplotlib.pyplot
to help us create a Figure
with multiple potential Axes
.
And that's with plt.subplots()
.
# Create a Figure and multiple potential Axes and add some data
fig, ax = plt.subplots()
ax.plot(x, y);
Anatomy of a Matplotlib Figure¶
Matplotlib offers almost unlimited options for creating plots.
However, let's break down some of the main terms.
- Figure - The base canvas of all matplotlib plots. The overall thing you're plotting is a Figure, often shortened to
fig
. - Axes - One Figure can have one or multiple Axes, for example, a Figure with multiple suplots could have 4 Axes (2 rows and 2 columns). Often shortened to
ax
. - Axis - A particular dimension of an Axes, for example, the x-axis or y-axis.
# This is where the object orientated name comes from
type(fig), type(ax)
(matplotlib.figure.Figure, matplotlib.axes._axes.Axes)
A quick Matplotlib Workflow¶
The following workflow is a standard practice when creating a matplotlib plot:
- Import matplotlib - For example,
import matplotlib.pyplot as plt
). - Prepare data - This may be from an existing dataset (data analysis) or from the outputs of a machine learning model (data science).
- Setup the plot - In other words, create the Figure and various Axes.
- Plot data to the Axes - Send the relevant data to the target Axes.
- Cutomize the plot - Add a title, decorate the colours, label each Axis.
- Save (optional) and show - See what your masterpiece looks like and save it to file if necessary.
# A matplotlib workflow
# 0. Import and get matplotlib ready
# %matplotlib inline # Not necessary in newer versions of Jupyter (e.g. 2022 onwards)
import matplotlib.pyplot as plt
# 1. Prepare data
x = [1, 2, 3, 4]
y = [11, 22, 33, 44]
# 2. Setup plot (Figure and Axes)
fig, ax = plt.subplots(figsize=(10,10))
# 3. Plot data
ax.plot(x, y)
# 4. Customize plot
ax.set(title="Sample Simple Plot", xlabel="x-axis", ylabel="y-axis")
# 5. Save & show
fig.savefig("../images/simple-plot.png")
2. Making the most common type of plots using NumPy arrays¶
Most of figuring out what kind of plot to use is getting a feel for the data, then seeing what kind of plot suits it best.
Matplotlib visualizations are built on NumPy arrays. So in this section we'll build some of the most common types of plots using NumPy arrays.
- Line plot -
ax.plot()
(this is the default plot in matplotlib) - Scatter plot -
ax.scatter()
- Bar plot -
ax.bar()
- Histogram plot -
ax.hist()
We'll see how all of these can be created as a method from matplotlob.pyplot.subplots()
.
Resource: Remember you can see many of the different kinds of matplotlib plot types in the documentation.
To make sure we have access to NumPy, we'll import it as np
.
import numpy as np
Creating a line plot¶
Line is the default type of visualization in Matplotlib. Usually, unless specified otherwise, your plots will start out as lines.
Line plots are great for seeing trends over time.
# Create an array
x = np.linspace(0, 10, 100)
x[:10]
array([0. , 0.1010101 , 0.2020202 , 0.3030303 , 0.4040404 , 0.50505051, 0.60606061, 0.70707071, 0.80808081, 0.90909091])
# The default plot is line
fig, ax = plt.subplots()
ax.plot(x, x**2);
Creating a scatter plot¶
Scatter plots can be great for when you've got many different individual data points and you'd like to see how they interact with eachother without being connected.
# Need to recreate our figure and axis instances when we want a new figure
fig, ax = plt.subplots()
ax.scatter(x, np.exp(x));
fig, ax = plt.subplots()
ax.scatter(x, np.sin(x));
# You can make plots from a dictionary
nut_butter_prices = {"Almond butter": 10,
"Peanut butter": 8,
"Cashew butter": 12}
fig, ax = plt.subplots()
ax.bar(nut_butter_prices.keys(), nut_butter_prices.values())
ax.set(title="Dan's Nut Butter Store", ylabel="Price ($)");
fig, ax = plt.subplots()
ax.barh(list(nut_butter_prices.keys()), list(nut_butter_prices.values()));
Creating a histogram plot¶
Histogram plots are excellent for showing the distribution of data.
For example, you might want to show the distribution of ages of a population or wages of city.
# Make some data from a normal distribution
x = np.random.randn(1000) # pulls data from a normal distribution
fig, ax = plt.subplots()
ax.hist(x);
x = np.random.random(1000) # random data from random distribution
fig, ax = plt.subplots()
ax.hist(x);
Creating Figures with multiple Axes with Subplots¶
Subplots allow you to create multiple Axes on the same Figure (multiple plots within the same plot).
Subplots are helpful because you start with one plot per Figure but scale it up to more when necessary.
For example, let's create a subplot that shows many of the above datasets on the same Figure.
We can do so by creating multiple Axes with plt.subplots()
and setting the nrows
(number of rows) and ncols
(number of columns) parameters to reflect how many Axes we'd like.
nrows
and ncols
parameters are multiplicative, meaning plt.subplots(nrows=2, ncols=2)
will create 2*2=4
total Axes.
Resource: You can see a sensational number of examples for creating Subplots in the matplotlib documentation.
# Option 1: Create 4 subplots with each Axes having its own variable name
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2,
ncols=2,
figsize=(10, 5))
# Plot data to each axis
ax1.plot(x, x/2);
ax2.scatter(np.random.random(10), np.random.random(10));
ax3.bar(nut_butter_prices.keys(), nut_butter_prices.values());
ax4.hist(np.random.randn(1000));
# Option 2: Create 4 subplots with a single ax variable
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 5))
# Index the ax variable to plot data
ax[0, 0].plot(x, x/2);
ax[0, 1].scatter(np.random.random(10), np.random.random(10));
ax[1, 0].bar(nut_butter_prices.keys(), nut_butter_prices.values());
ax[1, 1].hist(np.random.randn(1000));
3. Plotting data directly with pandas¶
Matplotlib has a tight integration with pandas too.
You can directly plot from a pandas DataFrame with DataFrame.plot()
.
Let's see the following plots directly from a pandas DataFrame:
- Line
- Scatter
- Bar
- Hist
To plot data with pandas, we first have to import it as pd
.
import pandas as pd
Now we need some data to check out.
# Let's import the car_sales dataset
# Note: The following two lines load the same data, one does it from a local file path, the other does it from a URL.
# car_sales = pd.read_csv("../data/car-sales.csv") # load data from local file
car_sales = pd.read_csv("https://raw.githubusercontent.com/mrdbourke/zero-to-mastery-ml/master/data/car-sales.csv") # load data from raw URL (original: https://github.com/mrdbourke/zero-to-mastery-ml/blob/master/data/car-sales.csv)
car_sales
Make | Colour | Odometer (KM) | Doors | Price | |
---|---|---|---|---|---|
0 | Toyota | White | 150043 | 4 | $4,000.00 |
1 | Honda | Red | 87899 | 4 | $5,000.00 |
2 | Toyota | Blue | 32549 | 3 | $7,000.00 |
3 | BMW | Black | 11179 | 5 | $22,000.00 |
4 | Nissan | White | 213095 | 4 | $3,500.00 |
5 | Toyota | Green | 99213 | 4 | $4,500.00 |
6 | Honda | Blue | 45698 | 4 | $7,500.00 |
7 | Honda | Blue | 54738 | 4 | $7,000.00 |
8 | Toyota | White | 60000 | 4 | $6,250.00 |
9 | Nissan | White | 31600 | 4 | $9,700.00 |
Line plot from a pandas DataFrame¶
To understand examples, I often find I have to repeat them (code them myself) rather than just read them.
To begin understanding plotting with pandas, let's recreate the a section of the pandas Chart visualization documents.
# Start with some dummy data
ts = pd.Series(np.random.randn(1000),
index=pd.date_range('1/1/2025', periods=1000))
# Note: ts = short for time series (data over time)
ts
2025-01-01 1.724020 2025-01-02 -0.530374 2025-01-03 2.247190 2025-01-04 0.077367 2025-01-05 -1.035777 ... 2027-09-23 -1.467224 2027-09-24 -0.588671 2027-09-25 -0.394004 2027-09-26 1.327045 2027-09-27 -0.160190 Freq: D, Length: 1000, dtype: float64
Great! We've got some random values across time.
Now let's add up the data cumulatively overtime with DataFrame.cumsum()
(cumsum
is short for cumulative sum or continaully adding one thing to the next and so on).
# Add up the values cumulatively
ts.cumsum()
2025-01-01 1.724020 2025-01-02 1.193646 2025-01-03 3.440836 2025-01-04 3.518203 2025-01-05 2.482426 ... 2027-09-23 32.888806 2027-09-24 32.300135 2027-09-25 31.906130 2027-09-26 33.233175 2027-09-27 33.072985 Freq: D, Length: 1000, dtype: float64
We can now visualize the values by calling the plot()
method on the DataFrame and specifying the kind of plot we'd like with the kind
parameter.
In our case, the kind we'd like is a line plot, hence kind="line"
(this is the default for the plot()
method).
# Plot the values over time with a line plot (note: both of these will return the same thing)
# ts.cumsum().plot() # kind="line" is set by default
ts.cumsum().plot(kind="line");
Working with actual data¶
Let's do a little data manipulation on our car_sales
DataFrame.
# Import the car sales data
car_sales = pd.read_csv("https://raw.githubusercontent.com/mrdbourke/zero-to-mastery-ml/master/data/car-sales.csv")
# Remove price column symbols
car_sales["Price"] = car_sales["Price"].str.replace('[\$\,\.]', '',
regex=True) # Tell pandas to replace using regex
car_sales
Make | Colour | Odometer (KM) | Doors | Price | |
---|---|---|---|---|---|
0 | Toyota | White | 150043 | 4 | 400000 |
1 | Honda | Red | 87899 | 4 | 500000 |
2 | Toyota | Blue | 32549 | 3 | 700000 |
3 | BMW | Black | 11179 | 5 | 2200000 |
4 | Nissan | White | 213095 | 4 | 350000 |
5 | Toyota | Green | 99213 | 4 | 450000 |
6 | Honda | Blue | 45698 | 4 | 750000 |
7 | Honda | Blue | 54738 | 4 | 700000 |
8 | Toyota | White | 60000 | 4 | 625000 |
9 | Nissan | White | 31600 | 4 | 970000 |
# Remove last two zeros
car_sales["Price"] = car_sales["Price"].str[:-2]
car_sales
Make | Colour | Odometer (KM) | Doors | Price | |
---|---|---|---|---|---|
0 | Toyota | White | 150043 | 4 | 4000 |
1 | Honda | Red | 87899 | 4 | 5000 |
2 | Toyota | Blue | 32549 | 3 | 7000 |
3 | BMW | Black | 11179 | 5 | 22000 |
4 | Nissan | White | 213095 | 4 | 3500 |
5 | Toyota | Green | 99213 | 4 | 4500 |
6 | Honda | Blue | 45698 | 4 | 7500 |
7 | Honda | Blue | 54738 | 4 | 7000 |
8 | Toyota | White | 60000 | 4 | 6250 |
9 | Nissan | White | 31600 | 4 | 9700 |
# Add a date column
car_sales["Sale Date"] = pd.date_range("1/1/2024", periods=len(car_sales))
car_sales
Make | Colour | Odometer (KM) | Doors | Price | Sale Date | |
---|---|---|---|---|---|---|
0 | Toyota | White | 150043 | 4 | 4000 | 2024-01-01 |
1 | Honda | Red | 87899 | 4 | 5000 | 2024-01-02 |
2 | Toyota | Blue | 32549 | 3 | 7000 | 2024-01-03 |
3 | BMW | Black | 11179 | 5 | 22000 | 2024-01-04 |
4 | Nissan | White | 213095 | 4 | 3500 | 2024-01-05 |
5 | Toyota | Green | 99213 | 4 | 4500 | 2024-01-06 |
6 | Honda | Blue | 45698 | 4 | 7500 | 2024-01-07 |
7 | Honda | Blue | 54738 | 4 | 7000 | 2024-01-08 |
8 | Toyota | White | 60000 | 4 | 6250 | 2024-01-09 |
9 | Nissan | White | 31600 | 4 | 9700 | 2024-01-10 |
# Make total sales column (doesn't work, adds as string)
#car_sales["Total Sales"] = car_sales["Price"].cumsum()
# Oops... want them as int's not string
car_sales["Total Sales"] = car_sales["Price"].astype(int).cumsum()
car_sales
Make | Colour | Odometer (KM) | Doors | Price | Sale Date | Total Sales | |
---|---|---|---|---|---|---|---|
0 | Toyota | White | 150043 | 4 | 4000 | 2024-01-01 | 4000 |
1 | Honda | Red | 87899 | 4 | 5000 | 2024-01-02 | 9000 |
2 | Toyota | Blue | 32549 | 3 | 7000 | 2024-01-03 | 16000 |
3 | BMW | Black | 11179 | 5 | 22000 | 2024-01-04 | 38000 |
4 | Nissan | White | 213095 | 4 | 3500 | 2024-01-05 | 41500 |
5 | Toyota | Green | 99213 | 4 | 4500 | 2024-01-06 | 46000 |
6 | Honda | Blue | 45698 | 4 | 7500 | 2024-01-07 | 53500 |
7 | Honda | Blue | 54738 | 4 | 7000 | 2024-01-08 | 60500 |
8 | Toyota | White | 60000 | 4 | 6250 | 2024-01-09 | 66750 |
9 | Nissan | White | 31600 | 4 | 9700 | 2024-01-10 | 76450 |
car_sales.plot(x='Sale Date', y='Total Sales');
Scatter plot from a pandas DataFrame¶
You can create scatter plots from a pandas DataFrame by using the kind="scatter"
parameter.
However, you'll often find that certain plots require certain kinds of data (e.g. some plots require certain columns to be numeric).
# Note: In previous versions of matplotlib and pandas, have the "Price" column as a string would
# return an error
car_sales["Price"] = car_sales["Price"].astype(str)
# Plot a scatter plot
car_sales.plot(x="Odometer (KM)", y="Price", kind="scatter");
Having the Price
column as an int
returns a much better looking y-axis.
# Convert Price to int
car_sales["Price"] = car_sales["Price"].astype(int)
# Plot a scatter plot
car_sales.plot(x="Odometer (KM)", y="Price", kind='scatter');
Bar plot from a pandas DataFrame¶
Let's see how we can plot a bar plot from a pandas DataFrame.
First, we'll create some data.
# Create 10 random samples across 4 columns
x = np.random.rand(10, 4)
x
array([[0.63664747, 0.11886476, 0.96687683, 0.62490457], [0.9623542 , 0.75100119, 0.08098382, 0.83857796], [0.49430885, 0.00545069, 0.89374991, 0.99877205], [0.89788013, 0.15844467, 0.50083739, 0.72846574], [0.51719877, 0.00978263, 0.74440314, 0.70385373], [0.17211921, 0.42804418, 0.16401737, 0.66153094], [0.39768996, 0.00628579, 0.71681382, 0.83828817], [0.75507146, 0.73571561, 0.30901804, 0.4720662 ], [0.46070935, 0.93093698, 0.01335433, 0.91765471], [0.77798775, 0.70517195, 0.05298553, 0.68972541]])
# Turn the data into a DataFrame
df = pd.DataFrame(x, columns=['a', 'b', 'c', 'd'])
df
a | b | c | d | |
---|---|---|---|---|
0 | 0.636647 | 0.118865 | 0.966877 | 0.624905 |
1 | 0.962354 | 0.751001 | 0.080984 | 0.838578 |
2 | 0.494309 | 0.005451 | 0.893750 | 0.998772 |
3 | 0.897880 | 0.158445 | 0.500837 | 0.728466 |
4 | 0.517199 | 0.009783 | 0.744403 | 0.703854 |
5 | 0.172119 | 0.428044 | 0.164017 | 0.661531 |
6 | 0.397690 | 0.006286 | 0.716814 | 0.838288 |
7 | 0.755071 | 0.735716 | 0.309018 | 0.472066 |
8 | 0.460709 | 0.930937 | 0.013354 | 0.917655 |
9 | 0.777988 | 0.705172 | 0.052986 | 0.689725 |
We can plot a bar chart directly with the bar()
method on the DataFrame.
# Plot a bar chart
df.plot.bar();
And we can also do the same thing passing the kind="bar"
parameter to DataFrame.plot()
.
# Plot a bar chart with the kind parameter
df.plot(kind='bar');
Let's try a bar plot on the car_sales
DataFrame.
This time we'll specify the x
and y
axis values.
# Plot a bar chart from car_sales DataFrame
car_sales.plot(x="Make",
y="Odometer (KM)",
kind="bar");
Histogram plot from a pandas DataFrame¶
We can plot a histogram plot from our car_sales
DataFrame using DataFrame.plot.hist()
or DataFrame.plot(kind="hist")
.
Histograms are great for seeing the distribution or the spread of data.
car_sales["Odometer (KM)"].plot.hist(bins=10); # default number of bins (or groups) is 10
car_sales["Odometer (KM)"].plot(kind="hist");
Changing the bins
parameter we can put our data into different numbers of collections.
For example, by default bins=10
(10 groups of data), let's see what happens when we change it to bins=20
.
# Default number of bins is 10
car_sales["Odometer (KM)"].plot.hist(bins=20);
To practice, let's create a histogram of the Price
column.
# Create a histogram of the Price column
car_sales["Price"].plot.hist(bins=10);
And to practice even further, how about we try another dataset?
Namely, let's create some plots using the heart disease dataset we've worked on before.
# Import the heart disease dataset
# Note: The following two lines create the same DataFrame, one just loads data from a local filepath where as the other downloads it directly from a URL.
# heart_disease = pd.read_csv("../data/heart-disease.csv") # load from local file path (requires data to be downloaded)
heart_disease = pd.read_csv("https://raw.githubusercontent.com/mrdbourke/zero-to-mastery-ml/master/data/heart-disease.csv") # load directly from raw URL (source: https://github.com/mrdbourke/zero-to-mastery-ml/blob/master/data/heart-disease.csv)
heart_disease.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 | 1 |
2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 | 1 |
3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
# Create a histogram of the age column
heart_disease["age"].plot.hist(bins=50);
What does this tell you about the spread of heart disease data across different ages?
Creating a plot with multiple Axes from a pandas DataFrame¶
We can also create a series of plots (multiple Axes on one Figure) from a DataFrame using the subplots=True
parameter.
First, let's remind ourselves what the data looks like.
# Inspect the data
heart_disease.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 | 1 |
2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 | 1 |
3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
Since all of our columns are numeric in value, let's try and create a histogram of each column.
heart_disease.plot.hist(figsize=(5, 20),
subplots=True);
Hmmm... is this a very helpful plot?
Perhaps not.
Sometimes you can visualize too much on the one plot and it becomes confusing.
Best to start with less and gradually increase.
4. Plotting more advanced plots from a pandas DataFrame¶
It's possible to achieve far more complicated and detailed plots from a pandas DataFrame.
Let's practice using the heart_disease
DataFrame.
And as an example, let's do some analysis on people over 50 years of age.
To do so, let's start by creating a plot directly from pandas and then using the object-orientated API (plt.subplots()
) to build upon it.
# Perform data analysis on patients over 50
over_50 = heart_disease[heart_disease["age"] > 50]
over_50
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
5 | 57 | 1 | 0 | 140 | 192 | 0 | 1 | 148 | 0 | 0.4 | 1 | 0 | 1 | 1 |
6 | 56 | 0 | 1 | 140 | 294 | 0 | 0 | 153 | 0 | 1.3 | 1 | 0 | 2 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
297 | 59 | 1 | 0 | 164 | 176 | 1 | 0 | 90 | 0 | 1.0 | 1 | 2 | 1 | 0 |
298 | 57 | 0 | 0 | 140 | 241 | 0 | 1 | 123 | 1 | 0.2 | 1 | 0 | 3 | 0 |
300 | 68 | 1 | 0 | 144 | 193 | 1 | 1 | 141 | 0 | 3.4 | 1 | 2 | 3 | 0 |
301 | 57 | 1 | 0 | 130 | 131 | 0 | 1 | 115 | 1 | 1.2 | 1 | 1 | 3 | 0 |
302 | 57 | 0 | 1 | 130 | 236 | 0 | 0 | 174 | 0 | 0.0 | 1 | 1 | 2 | 0 |
208 rows × 14 columns
Now let's create a scatter plot directly from the pandas DataFrame.
This is quite easy to do but is a bit limited in terms of customization.
Let's visualize patients over 50 cholesterol levels.
We can visualize which patients have or don't have heart disease by colouring the samples to be in line with the target
column (e.g. 0
= no heart disease, 1
= heart disease).
# Create a scatter plot directly from the pandas DataFrame
over_50.plot(kind="scatter",
x="age",
y="chol",
c="target", # colour the dots by target value
figsize=(10, 6));
We can recreate the same plot using plt.subplots()
and then passing the Axes variable (ax
) to the pandas plot()
method.
# Create a Figure and Axes instance
fig, ax = plt.subplots(figsize=(10, 6))
# Plot data from the DataFrame to the ax object
over_50.plot(kind="scatter",
x="age",
y="chol",
c="target",
ax=ax); # set the target Axes
# Customize the x-axis limits (to be within our target age ranges)
ax.set_xlim([45, 100]);
Now instead of plotting directly from the pandas DataFrame, we can make a bit more of a comprehensive plot by plotting data directly to a target Axes instance.
# Create Figure and Axes instance
fig, ax = plt.subplots(figsize=(10, 6))
# Plot data directly to the Axes intance
scatter = ax.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"]) # Colour the data with the "target" column
# Customize the plot parameters
ax.set(title="Heart Disease and Cholesterol Levels",
xlabel="Age",
ylabel="Cholesterol");
# Setup the legend
ax.legend(*scatter.legend_elements(),
title="Target");
What if we wanted a horizontal line going across with the mean of heart_disease["chol"]
?
We do so with the Axes.axhline()
method.
# Create the plot
fig, ax = plt.subplots(figsize=(10, 6))
# Plot the data
scatter = ax.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"])
# Customize the plot
ax.set(title="Heart Disease and Cholesterol Levels",
xlabel="Age",
ylabel="Cholesterol");
# Add a legned
ax.legend(*scatter.legend_elements(),
title="Target")
# Add a meanline
ax.axhline(over_50["chol"].mean(),
linestyle="--"); # style the line to make it look nice
Plotting multiple plots on the same figure (adding another plot to an existing one)¶
Sometimes you'll want to visualize multiple features of a dataset or results of a model in one Figure.
You can achieve this by adding data to multiple Axes on the same Figure.
The plt.subplots()
method helps you create Figures with a desired number of Axes in a desired figuration.
Using nrows
(number of rows) and ncols
(number of columns) parameters you can control the number of Axes on the Figure.
For example:
nrows=2
,ncols=1
= 2x1 = a Figure with 2 Axesnrows=5
,ncols=5
= 5x5 = a Figure with 25 Axes
Let's create a plot with 2 Axes.
One the first Axes (Axes 0), we'll plot heart disease against cholesterol levels (chol
).
On the second Axes (Axis 1), we'll plot heart disease against max heart rate levels (thalach
).
# Setup plot (2 rows, 1 column)
fig, (ax0, ax1) = plt.subplots(nrows=2, # 2 rows
ncols=1, # 1 column
sharex=True, # both plots should use the same x-axis
figsize=(10, 8))
# ---------- Axis 0: Heart Disease and Cholesterol Levels ----------
# Add data for ax0
scatter = ax0.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"])
# Customize ax0
ax0.set(title="Heart Disease and Cholesterol Levels",
ylabel="Cholesterol")
ax0.legend(*scatter.legend_elements(), title="Target")
# Setup a mean line
ax0.axhline(y=over_50["chol"].mean(),
color='b',
linestyle='--',
label="Average")
# ---------- Axis 1: Heart Disease and Max Heart Rate Levels ----------
# Add data for ax1
scatter = ax1.scatter(over_50["age"],
over_50["thalach"],
c=over_50["target"])
# Customize ax1
ax1.set(title="Heart Disease and Max Heart Rate Levels",
xlabel="Age",
ylabel="Max Heart Rate")
ax1.legend(*scatter.legend_elements(), title="Target")
# Setup a mean line
ax1.axhline(y=over_50["thalach"].mean(),
color='b',
linestyle='--',
label="Average")
# Title the figure
fig.suptitle('Heart Disease Analysis',
fontsize=16,
fontweight='bold');
5. Customizing your plots (making them look pretty)¶
If you're not a fan of the default matplotlib styling, there are plenty of ways to make your plots look prettier.
The more visually appealing your plot, the higher the chance people are going to want to look at them.
However, be careful not to overdo the customizations, as they may hinder the information being conveyed.
Some of the things you can customize include:
- Axis limits - The range in which your data is displayed.
- Colors - That colors appear on the plot to represent different data.
- Overall style - Matplotlib has several different styles built-in which offer different overall themes for your plots, you can see examples of these in the matplotlib style sheets reference documentation.
- Legend - One of the most informative pieces of information on a Figure can be the legend, you can modify the legend of an Axes with the
plt.legend()
method.
Let's start by exploring different styles built into matplotlib.
Customizing the style of plots¶
Matplotlib comes with several built-in styles that are all created with an overall theme.
You can see what styles are available by using plt.style.available
.
Resources:
- To see what many of the available styles look like, you can refer to the matplotlib style sheets reference documentation.
- For a deeper guide on customizing, refer to the Customizing Matplotlib with style sheets and rcParams tutorial.
# Check the available styles
plt.style.available
['Solarize_Light2', '_classic_test_patch', '_mpl-gallery', '_mpl-gallery-nogrid', 'bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-v0_8', 'seaborn-v0_8-bright', 'seaborn-v0_8-colorblind', 'seaborn-v0_8-dark', 'seaborn-v0_8-dark-palette', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8-deep', 'seaborn-v0_8-muted', 'seaborn-v0_8-notebook', 'seaborn-v0_8-paper', 'seaborn-v0_8-pastel', 'seaborn-v0_8-poster', 'seaborn-v0_8-talk', 'seaborn-v0_8-ticks', 'seaborn-v0_8-white', 'seaborn-v0_8-whitegrid', 'tableau-colorblind10']
Before we change the style of a plot, let's remind ourselves what the default plot style looks like.
# Plot before changing style
car_sales["Price"].plot();
Wonderful!
Now let's change the style of our future plots using the plt.style.use(style)
method.
Where the style
parameter is one of the available matplotlib styles.
How about we try "seaborn-v0_8-whitegrid"
(seaborn is another common visualization library built on top of matplotlib)?
# Change the style of our future plots
plt.style.use("seaborn-v0_8-whitegrid")
# Plot the same plot as before
car_sales["Price"].plot();
Wonderful!
Notice the slightly different styling of the plot?
Some styles change more than others.
How about we try "fivethirtyeight"
?
# Change the plot style
plt.style.use("fivethirtyeight")
car_sales["Price"].plot();
Ohhh that's a nice looking plot!
Does the style carry over for another type of plot?
How about we try a scatter plot?
car_sales.plot(x="Odometer (KM)",
y="Price",
kind="scatter");
It does!
Looks like we may need to adjust the spacing on our x-axis though.
What about another style?
Let's try "ggplot"
.
# Change the plot style
plt.style.use("ggplot")
car_sales["Price"].plot.hist(bins=10);
Cool!
Now how can we go back to the default style?
Hint: with "default"
.
# Change the plot style back to the default
plt.style.use("default")
car_sales["Price"].plot.hist();
Customizing the title, legend and axis labels¶
When you have a matplotlib Figure or Axes object, you can customize many of the attributes by using the Axes.set()
method.
For example, you can change the:
xlabel
- Labels on the x-axis.ylim
- Limits of the y-axis.xticks
- Style of the x-ticks.- much more in the documentation.
Rather than talking about it, let's practice!
First, we'll create some random data and then put it into a DataFrame.
Then we'll make a plot from that DataFrame and see how to customize it.
# Create random data
x = np.random.randn(10, 4)
x
array([[ 1.17212975, 0.46563975, -1.90589871, -1.19235958], [-0.63717099, -0.08598952, -0.14465387, 0.54449588], [-1.60294003, 0.96718789, -0.13203246, 0.37619322], [-1.08186882, -1.7225243 , -1.91029832, -1.42247578], [-0.22936709, 1.79289551, 0.24236151, -0.11114891], [-0.22966661, -0.04768414, 0.74157096, -1.71206472], [-0.15221366, -0.34325158, 0.96609502, -1.03521241], [ 1.09157697, -0.77361491, 0.35805583, 0.91628358], [ 0.15352594, -1.22128756, -0.45763768, -1.3302614 ], [-0.86535615, -0.4931282 , -0.43404157, 0.55973627]])
# Turn data into DataFrame with simple column names
df = pd.DataFrame(x,
columns=['a', 'b', 'c', 'd'])
df
a | b | c | d | |
---|---|---|---|---|
0 | 1.172130 | 0.465640 | -1.905899 | -1.192360 |
1 | -0.637171 | -0.085990 | -0.144654 | 0.544496 |
2 | -1.602940 | 0.967188 | -0.132032 | 0.376193 |
3 | -1.081869 | -1.722524 | -1.910298 | -1.422476 |
4 | -0.229367 | 1.792896 | 0.242362 | -0.111149 |
5 | -0.229667 | -0.047684 | 0.741571 | -1.712065 |
6 | -0.152214 | -0.343252 | 0.966095 | -1.035212 |
7 | 1.091577 | -0.773615 | 0.358056 | 0.916284 |
8 | 0.153526 | -1.221288 | -0.457638 | -1.330261 |
9 | -0.865356 | -0.493128 | -0.434042 | 0.559736 |
Now let's plot the data from the DataFrame in a bar chart.
This time we'll save the plot to a variable called ax
(short for Axes).
# Create a bar plot
ax = df.plot(kind="bar")
# Check the type of the ax variable
type(ax)
matplotlib.axes._axes.Axes
Excellent!
We can see the type of our ax
variable is of AxesSubplot
which allows us to use all of the methods available in matplotlib for Axes
.
Let's set a few attributes of the plot with the set()
method.
Namely, we'll change the title
, xlabel
and ylabel
to communicate what's being displayed.
# Recreate the ax object
ax = df.plot(kind="bar")
# Set various attributes
ax.set(title="Random Number Bar Graph from DataFrame",
xlabel="Row number",
ylabel="Random number");
Notice the legend is up in the top left corner by default, we can change that if we like with the loc
parameter of the legend()
method.
loc
can be set as a string to reflect where the legend should be.
By default it is set to loc="best"
which means matplotlib will try to figure out the best positioning for it.
Let's try changing it to "loc="upper right"
.
# Recreate the ax object
ax = df.plot(kind="bar")
# Set various attributes
ax.set(title="Random Number Bar Graph from DataFrame",
xlabel="Row number",
ylabel="Random number")
# Change the legend position
ax.legend(loc="upper right");
Nice!
Is that a better fit?
Perhaps not, but it goes to show how you can change the legend position if needed.
Customizing the colours of plots with colormaps (cmap)¶
Colour is one of the most important features of a plot.
It can help to separate different kinds of information.
And with the right colours, plots can be fun to look at and try to learn more.
Matplotlib provides many different colour options through matplotlib.colormaps
.
Let's see how we can change the colours of a matplotlib plot via the cmap
parameter (cmap
is short for colormaps
).
We'll start by creating a scatter plot with the default cmap
value (cmap="viridis"
).
# Setup the Figure and Axes
fig, ax = plt.subplots(figsize=(10, 6))
# Create a scatter plot with no cmap change (use default colormap)
scatter = ax.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"],
cmap="viridis") # default cmap value
# Add attributes to the plot
ax.set(title="Heart Disease and Cholesterol Levels",
xlabel="Age",
ylabel="Cholesterol");
ax.axhline(y=over_50["chol"].mean(),
c='b',
linestyle='--',
label="Average");
ax.legend(*scatter.legend_elements(),
title="Target");
Wonderful!
That plot doesn't look too bad.
But what if we wanted to change the colours?
There are many different cmap
parameter options available in the colormap reference.
How about we try cmap="winter"
?
We can also change the colour of the horizontal line using the color
parameter and setting it to a string of the colour we'd like (e.g. color="r"
for red).
fig, ax = plt.subplots(figsize=(10, 6))
# Setup scatter plot with different cmap
scatter = ax.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"],
cmap="winter") # Change cmap value
# Add attributes to the plot with different color line
ax.set(title="Heart Disease and Cholesterol Levels",
xlabel="Age",
ylabel="Cholesterol")
ax.axhline(y=over_50["chol"].mean(),
color="r", # Change color of line to "r" (for red)
linestyle='--',
label="Average");
ax.legend(*scatter.legend_elements(),
title="Target");
Woohoo!
The first plot looked nice, but I think I prefer the colours of this new plot better.
For more on choosing colormaps in matplotlib, there's a sensational and in-depth tutorial in the matplotlib documentation.
Customizing the xlim & ylim¶
Matplotlib is pretty good at setting the ranges of values on the x-axis and the y-axis.
But as you might've guessed, you can customize these to suit your needs.
You can change the ranges of different axis values using the xlim
and ylim
parameters inside of the set()
method.
To practice, let's recreate our double Axes plot from before with the default x-axis and y-axis values.
We'll add in the colour updates from the previous section too.
# Recreate double Axes plot from above with colour updates
fig, (ax0, ax1) = plt.subplots(nrows=2,
ncols=1,
sharex=True,
figsize=(10, 7))
# ---------- Axis 0 ----------
scatter = ax0.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"],
cmap="winter")
ax0.set(title="Heart Disease and Cholesterol Levels",
ylabel="Cholesterol")
# Setup a mean line
ax0.axhline(y=over_50["chol"].mean(),
color="r",
linestyle="--",
label="Average");
ax0.legend(*scatter.legend_elements(), title="Target")
# ---------- Axis 1 ----------
scatter = ax1.scatter(over_50["age"],
over_50["thalach"],
c=over_50["target"],
cmap="winter")
ax1.set(title="Heart Disease and Max Heart Rate Levels",
xlabel="Age",
ylabel="Max Heart Rate")
# Setup a mean line
ax1.axhline(y=over_50["thalach"].mean(),
color="r",
linestyle="--",
label="Average");
ax1.legend(*scatter.legend_elements(),
title="Target")
# Title the figure
fig.suptitle("Heart Disease Analysis",
fontsize=16,
fontweight="bold");
Now let's recreate the plot from above but this time we'll change the axis limits.
We can do so by using Axes.set(xlim=[50, 80])
or Axes.set(ylim=[60, 220])
where the inputs to xlim
and ylim
are a list of integers defining a range of values.
For example, xlim=[50, 80]
will set the x-axis values to start at 50
and end at 80
.
# Recreate the plot from above with custom x and y axis ranges
fig, (ax0, ax1) = plt.subplots(nrows=2,
ncols=1,
sharex=True,
figsize=(10, 7))
scatter = ax0.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"],
cmap='winter')
ax0.set(title="Heart Disease and Cholesterol Levels",
ylabel="Cholesterol",
xlim=[50, 80]) # set the x-axis ranges
# Setup a mean line
ax0.axhline(y=over_50["chol"].mean(),
color="r",
linestyle="--",
label="Average");
ax0.legend(*scatter.legend_elements(), title="Target")
# Axis 1, 1 (row 1, column 1)
scatter = ax1.scatter(over_50["age"],
over_50["thalach"],
c=over_50["target"],
cmap='winter')
ax1.set(title="Heart Disease and Max Heart Rate Levels",
xlabel="Age",
ylabel="Max Heart Rate",
ylim=[60, 220]) # change the y-axis range
# Setup a mean line
ax1.axhline(y=over_50["thalach"].mean(),
color="r",
linestyle="--",
label="Average");
ax1.legend(*scatter.legend_elements(),
title="Target")
# Title the figure
fig.suptitle("Heart Disease Analysis",
fontsize=16,
fontweight="bold");
Now that's a nice looking plot!
Let's figure out how we'd save it.
6. Saving plots¶
Once you've got a nice looking plot that you're happy with, the next thing is going to be sharing it with someone else.
In a report, blog post, presentation or something similar.
You can save matplotlib Figures with plt.savefig(fname="your_plot_file_name")
where fname
is the target filename you'd like to save the plot to.
Before we save our plot, let's recreate it.
# Recreate the plot from above with custom x and y axis ranges
fig, (ax0, ax1) = plt.subplots(nrows=2,
ncols=1,
sharex=True,
figsize=(10, 7))
scatter = ax0.scatter(over_50["age"],
over_50["chol"],
c=over_50["target"],
cmap='winter')
ax0.set(title="Heart Disease and Cholesterol Levels",
ylabel="Cholesterol",
xlim=[50, 80]) # set the x-axis ranges
# Setup a mean line
ax0.axhline(y=over_50["chol"].mean(),
color="r",
linestyle="--",
label="Average");
ax0.legend(*scatter.legend_elements(), title="Target")
# Axis 1, 1 (row 1, column 1)
scatter = ax1.scatter(over_50["age"],
over_50["thalach"],
c=over_50["target"],
cmap='winter')
ax1.set(title="Heart Disease and Max Heart Rate Levels",
xlabel="Age",
ylabel="Max Heart Rate",
ylim=[60, 220]) # change the y-axis range
# Setup a mean line
ax1.axhline(y=over_50["thalach"].mean(),
color="r",
linestyle="--",
label="Average");
ax1.legend(*scatter.legend_elements(),
title="Target")
# Title the figure
fig.suptitle("Heart Disease Analysis",
fontsize=16,
fontweight="bold");
Nice!
We can save our plots to several different kinds of filetypes.
And we can check these filetypes with fig.canvas.get_supported_filetypes()
.
# Check the supported filetypes
fig.canvas.get_supported_filetypes()
{'eps': 'Encapsulated Postscript', 'jpg': 'Joint Photographic Experts Group', 'jpeg': 'Joint Photographic Experts Group', 'pdf': 'Portable Document Format', 'pgf': 'PGF code for LaTeX', 'png': 'Portable Network Graphics', 'ps': 'Postscript', 'raw': 'Raw RGBA bitmap', 'rgba': 'Raw RGBA bitmap', 'svg': 'Scalable Vector Graphics', 'svgz': 'Scalable Vector Graphics', 'tif': 'Tagged Image File Format', 'tiff': 'Tagged Image File Format', 'webp': 'WebP Image Format'}
Image filetypes such as jpg
and png
are excellent for blog posts and presentations.
Where as the pgf
or pdf
filetypes may be better for reports and papers.
One last look at our Figure, which is saved to the fig
variable.
fig
Beautiful!
Now let's save it to file.
# Save the file
fig.savefig(fname="../images/heart-disease-analysis.png",
dpi=100)
File saved!
Let's try and display it.
We can do so with the HTML code:
<img src="../images/heart-disease-analysis.png" alt="a plot showing a heart disease analysis comparing the presense of heart disease, cholesterol levels and heart rate on patients over 50/>
And changing the cell below to markdown.
Note: Because the plot is highly visual, it's import to make sure there is an
alt="some_text_here"
tag available when displaying the image, as this tag is used to make the plot more accessible to those with visual impairments. For more on displaying images with HTML, see the Mozzila documentation.
Finally, if we wanted to start making more and different Figures, we can reset our fig
variable by creating another plot.
# Resets figure
fig, ax = plt.subplots()
If you're creating plots and saving them like this often, to save writing excess code, you might put it into a function.
A function which follows the Matplotlib workflow.
# Potential matplotlib workflow function
def plotting_workflow(data):
# 1. Manipulate data
# 2. Create plot
# 3. Plot data
# 4. Customize plot
# 5. Save plot
# 6. Return plot
return plot
Extra resources¶
We've covered a fair bit here.
But really we've only scratched the surface of what's possible with matplotlib.
So for more, I'd recommend going through the following:
- Matplotlib quick start guide - Try rewriting all the code in this guide to get familiar with it.
- Matplotlib plot types guide - Inside you'll get an idea of just how many kinds of plots are possible with matplotlib.
- Matplotlib lifecycle of a plot guide - A sensational ground-up walkthrough of the many different things you can do with a plot.