='cos x')
plt.plot(x, cos_x, label='sin x')
plt.plot(x, sin_x, label
plt.legend() plt.show()
Plotting (Good)
The various parts of a matplotlib
figure. (From matplotlib.org)
What to Expect in this Chapter
In the previous chapter, we saw how you could use matplotlib
to produce simple decent-looking plots. However, we haven’t really (barely) tapped the full power of what matplotlib
can do. For this, I need to introduce you to a different way of speaking to matplotlib
. So far, the ‘dialect’ we have used to speak to matplotlib
is called the Matlab-like pyplot(plt
) interface. From here onward, I will show you how to use the other, more powerful ‘dialect’ called the Object Oriented (OO) interface. This way of talking to matplotlib
gives us more nuanced control over what is going on by allowing us to manipulate the various axes easily.
Comparing the two ‘dialects’
Some nomenclature
Before going ahead, let’s distinguish between a matplotlib
figure and an axis.
A figure is simple; it is the full canvas you use to draw stuff on. An axis is the individual mathematical axes we use for plotting. So one figure can have multiple axes, as shown below, where we have a (single) figure with four axes.
By the way, you had already encountered a situation with multiple axes in the last chapter when we used twinx()
. It is common to struggle with the concept of axes; but don’t worry; it will become even clearer later.
To see how the OO interface works, let me create the same plot using both ‘dialects’. The plot I will be making is shown below.
We need some data.
First, let’s generate some data to plot.
= np.linspace(-np.pi, np.pi, num=100)
x = np.cos(x)
cos_x = np.sin(x) sin_x
Here comes the comparison
pyplot Interface
OO Interface
= plt.subplots(nrows=1, ncols=1)
fig, ax ='cos x')
ax.plot(x, cos_x, label='sin x')
ax.plot(x, sin_x, label
ax.legend() plt.show()
Both sets of code will produce the same plot. For the OO interface, we have to start by using subplots()
to ask matplotlib
to create a figure and an axis. matplotlib
obliges and gives us a figure (fig
) and an axis (ax
).
Although I have used the variables fig
and ax
you are free to call them what you like. But this is what is commonly used in the documentation. So in this example, I need only one column and one row. But, if I want, I can ask for a grid like in the plot right at the top.
to use the pyplot interface for quick and dirty plots and the OO interface for more complex plots that demand control and finesse.
Using the OO Interface
Getting ax
Let me show you how to split the plots into two separate panes (with two axes) arranged as a column, as shown above. This will also allow us to get comfortable with using ax
.
To get the above plot, we need to ask for two rows (nrows=2
) and one column (ncols=1
).
= plt.subplots(ncols=1, nrows=2) fig, ax
This should give me two axes so that I can plot in both panes.
What is ax
Let’s quickly check a few more details about ax
.
What is
ax
?type(ax)
<class 'numpy.ndarray'>
So
ax
is a NumPy array!What size is
ax
?ax.shape
(2,)
As expected,
ax
has two ‘things’.What is contained in
ax
?type(ax[0])
<class 'matplotlib.axes._subplots.AxesSubplot'>
This is a
matplotlib
axis.
Plots in a column
Now, let use these axes to plot our data.
First plot
0].plot(x, cos_x, label='cos x') ax[
Second plot
1].plot(x, sin_x, label='sin x') ax[
Notice how I use the different axes to plot the \(\cos(x)\) and \(\sin(x)\) data.
Legends
If we want to show the legends, then we have to call it for each axis:
First plot
0].legend() ax[
Second plot
1].legend() ax[
However, a more sensible way to do this is with a for
loop that iterates through the items in ax
:
for a in ax:
a.legend()
Let’s also add the grid to each of the axes.
for a in ax:
a.legend()=.25) a.grid(alpha
Tweaks
Let’s tweak a few details of the plots.
- Let’s change the size of the plot by specifying a figure size (
figsize
). - Let’s ask that the plots share the \(x\)-axis using
sharex
.
= plt.subplots(nrows=2, ncols=1,
fig, ax =(5, 10), # 10 x 5 inches!
figsize=True) sharex
- Let’s add a \(x\) label.
1].set_xlabel('$x$') ax[
Take note that we have to use set_xlabel()
with the OO interface and not xlabel()
.
- Let’s fill our plot between \(0\) and the respective plot values.
0].fill_between(x, 0, cos_x, alpha=.25)
ax[1].fill_between(x, 0, sin_x, alpha=.25) ax[
- Let’s add a super title to the figure.
r'$\sin(x)$ and $\cos(x)$') fig.suptitle(
- Finally, let’s ask
matplotlib
to make any necessary adjustments to the layout to make our plot look nice by callingtight_layout()
. It would help if you convinced yourself of the utility oftight_layout()
by producing the plot with and without it.
fig.tight_layout()
Et Voila!
Here is the full code.
= plt.subplots(nrows=2, ncols=1,
fig, ax =(10, 5), # 10 x 5 inches!
figsize=True)
sharex0].plot(x, cos_x, label=r'$\cos(x)$')
ax[0].fill_between(x, 0, cos_x, alpha=.25)
ax[1].plot(x, sin_x, label=r'$\sin(x)$')
ax[1].fill_between(x, 0, sin_x, alpha=.25)
ax[
for a in ax:
a.legend()=.25)
a.grid(alpha'$y$', rotation=0)
a.set_ylabel(
1].set_xlabel('$x$')
ax[
r'$\sin(x)$ and $\cos(x)$')
fig.suptitle(
fig.tight_layout()
plt.show()
More rows and columns
Now I will show you how to work with a grid of axes like that shown above.
I have intentionally used simple straight lines so we can focus on the details of the plotting. So, for example, there are two sets of \(x\) values, and I have used np.ones_like()
to generate an array of 1
’s that has the same length as the x
array.
Take a look at the code and see if things make sense. I will take you through the details in a bit.
= plt.subplots(nrows=2, ncols=2,
fig, ax =(10, 4),
figsize='col', sharey='col')
sharex
# Some variables to access the axes and improve readabilty
= ax.flatten()
top_left, top_right, bottom_left, bottom_right
# x data for plotting
= np.linspace(0, 10, 100)
x1 = np.linspace(10, 20, 100)
x2
top_left.plot(x1, np.ones_like(x1))2*np.ones_like(x2))
top_right.plot(x2, 3*np.ones_like(x1))
bottom_left.plot(x1, 4**np.ones_like(x2))
bottom_right.plot(x2,
for a in ax.flatten():
=.25)
a.grid(alpha
plt.tight_layout() plt.show()
Using ax
The most important thing you need to understand is how to access ax
. Of course, we know there must be four axes; but how is ax
structured?
If you run ax.shape()
you will get (2,2)
. So, ax
is organised as you see, as a 2 x 2
array. So, I can access each of the axes as follows:
0, 0].plot(x1, np.ones_like(x1))
ax[0, 1].plot(x2, 2*np.ones_like(x2))
ax[1, 0].plot(x1, 3*np.ones_like(x1))
ax[1, 1].plot(x2, 4**np.ones_like(x2)) ax[
This is a perfectly valid way to use ax
. However, when you have to tweak each of the axes separately, I find it easy to use a familiar variable. I can do this by:
=ax[0, 0]
top_left=ax[0, 1]
top_right=ax[1, 0]
bottom_left=ax[1, 1] bottom_right
You can also use:
= ax.flatten() top_left, top_right, bottom_left, bottom_right
flatten()
takes the 2D array and flattens it into a 1D array; unpacking takes care of the assignments.
Accessing all axes
You will often want to apply changes to all the axes, like in the case of the grid. You can do this by
=.25)
top_left.grid(alpha=.25)
top_right.grid(alpha=.25)
bottom_left.grid(alpha=.25) bottom_right.grid(alpha
But this is inefficient and requires a lot of work. It is much nicer to use a for
loop.
for a in ax.flatten():
=.25) a.grid(alpha
Other useful plots
In this section, I will quickly show you some useful plots we can generate with matplotlib
. I will also use a few different plot styles I commonly use so that you can get a feel for changing styles.
Histograms
A histogram is a valuable tool for showing distributions of data. For this example, I have extracted some actual data from sg.gov related to the mean monthly earnings of graduates from the various universities in Singapore.
Here are the links to my data files:
Mean basic monthly earnings by graduates | |
---|---|
All | sg-gov-graduate-employment-survey_basic_monthly_mean_all.csv |
NUS Only | sg-gov-graduate-employment-survey_basic_monthly_mean_nus.csv |
= {}
data
= 'sg-gov-graduate-employment-survey_basic_monthly_mean_all.csv'
filename 'All'] = np.loadtxt(filename, skiprows=1)
data[
= 'sg-gov-graduate-employment-survey_basic_monthly_mean_nus.csv'
filename 'NUS'] = np.loadtxt(filename, skiprows=1)
data[
'bmh')
plt.style.use(
'All'], data['NUS']],
plt.hist([data[=50, # How many bins to split the data
bins=['All', 'NUS']
label
)'Mean of Basic Montly Eraning (S$)')
plt.xlabel('Number of Students')
plt.ylabel(
plt.legend() plt.show()
Scatter plots
Scatter plots are created by putting a marker at an \((x,y)\) point you specify. They are simple yet powerful.
I will be lazy and use the same data as the previous example. But, since I need some values for \(x\) I am going to use range()
along with len()
to generate a list [0,1,2...]
appropriate to the dataset.
= {}
data for label in ['All', 'NUS']:
= f'sg-gov-graduate-employment-survey_basic_monthly_mean_{label.lower()}.csv'
filename = np.loadtxt(filename, skiprows=1)
data[label]
'seaborn-darkgrid')
plt.style.use(
for label, numbers in data.items():
= range(len(numbers))
x = numbers
y =label, alpha=.5)
plt.scatter(x, y, label
'Position in the list')
plt.xlabel('Mean of Basic Montly Eraning (S$)')
plt.ylabel(
plt.legend() plt.show()
Bar charts
I am using some dummy data for a hypothetical class for this example. I extract the data and typecast to pass two lists to bar()
. Use barh()
if you want horizontal bars.
= {'Life Sciences': 14,
student_numbers 'Physics': 12,
'Chemistry': 8,
'Comp. Biology': 1}
= list(student_numbers.keys())
majors = list(student_numbers.values())
numbers
'ggplot')
plt.style.use(
plt.bar(majors, numbers)'Majors')
plt.xlabel('Number of Students')
plt.ylabel(
plt.show()
Pie charts
I am not a big fan of pie charts, but they have their uses. Let me reuse the previous data from the dummy class.
= {'Life Sciences': 14,
student_numbers 'Physics': 12,
'Chemistry': 8,
'Comp. Biology': 1}
= list(student_numbers.keys())
majors = list(student_numbers.values())
numbers
'fivethirtyeight')
plt.style.use(
plt.pie(numbers, =majors,
labels='%1.1f%%', # How to format the percentages
autopct=-90
startangle
)'Percentage of each major')
plt.title( plt.show()