101: Squared Waffle charts in matplotlib
In this tutorial I will show you how to create Waffle charts using Python and Matplotlib. For more matplotlib charts, check out the gallery:

This is what we will create in matplotlib:

Import the packages
We will need the following packages:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import decimal
from matplotlib.lines import Line2D
Generate the data
We could actually go from numpy to matplotlib, but most data projects use pandas to transform the data, so I am using a pandas dataframe as the starting point.
color_dict = {"Norway": "#2B314D", "Denmark": "#A54836", "Sweden": "#5375D4", }
xy_ticklabel_color, xy_label_color, ='#101628',"#101628",
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13,15,4,10,5,8]
}
df= pd.DataFrame(data)
#custom sort
sort_order_dict = {"Denmark":2, "Sweden":3, "Norway":1, 2004:5, 2022:4}
df = df.sort_values(by=['year','countries',], key=lambda x: x.map(sort_order_dict))
#map the colors of a dict to a dataframe
df['color']= df.countries.map(color_dict)
df['sub_total'] = df.groupby('year')['sites'].transform('sum')
df['pct_group'] = 100* df['sites'] / df.sub_total
df['pct_group'] = df['pct_group'].astype(float).round(1)
# use decimal library to round up .5 values to add to 100
df['pct_group'] = df['pct_group'].apply(
lambda x: decimal.Decimal(x).to_integral_value(rounding=decimal.ROUND_HALF_UP)
)
df
year | countries | sites | color | sub_total | pct_group | |
---|---|---|---|---|---|---|
5 | 2022 | Norway | 8 | #2B314D | 33 | 24 |
3 | 2022 | Denmark | 10 | #A54836 | 33 | 30 |
1 | 2022 | Sweden | 15 | #5375D4 | 33 | 46 |
4 | 2004 | Norway | 5 | #2B314D | 22 | 23 |
2 | 2004 | Denmark | 4 | #A54836 | 22 | 18 |
0 | 2004 | Sweden | 13 | #5375D4 | 22 | 59 |
We need to create the basis for the waffle:
#create a 10 by 10 matrix by creating x and y coordinates
X= np.repeat(np.arange(1,11),10)
Y = np.tile(np.arange(1,11),10)
Add the variables
years = df.year.unique()
countries = df.countries.unique()
Plot the chart
fig, axes = plt.subplots(ncols = len(years),figsize=(8,3.7), facecolor = "#FFFFFF")
for ax,year in zip(axes.ravel(), years):
#create the list of colors
pct = df[df.year==year]['pct_group']
colors = df[df.year==year]['color']
dot_colors = np.repeat(colors,pct).to_numpy()
ax.scatter(Y, X, s= 300, marker="s", c= dot_colors )
ax.set_xlabel(year, color= xy_label_color, size=14)
ax.tick_params(axis='both', which='both',length=0)
ax.set_xticks([])
ax.set_yticks([])
ax.spines[['top', 'left', 'right','bottom']].set_visible(False)
#add legend
lines = [Line2D([0], [0], color=c, marker="s",linestyle='', markersize=12,) for c in colors]
plt.figlegend( lines,countries,
labelcolor=xy_label_color,
prop= dict(size=10),
bbox_to_anchor=(0.5, -0.3), loc="lower center",
ncols = 3,frameon=False, fontsize= 10)
The result:

Was this helpful?
Reader Interactions