Python笔记 #3 Matplotlib

本文最后更新于:2022年11月22日 下午

学习 Machine Learning 的时候发现需要用许多矩阵运算和画图的库,本文将以实用主义的方式记录每次遇到的新用法。

2021 年贵系的暑培新增了「科学计算」内容,本文部分内容参考了清华 LZJ 同学的教程。本文将持续更新。

Matplotlib 基础

绘图时最常用的模块是 Matplotlib 中的 pyplot 模块。绘图时先调用相关绘图函数,设置图像各种细节,最后调用 plt.show() 打开新窗口显示图片,此时程序会暂停。如果使用 %matplotlib inline 方法,可以使图片在 Cell 中显示而不开新窗口(适用于 Notebook)。

这时会出现一个基于 Qt 实现的交互窗口,程序执行到 plt.show() 时阻塞。在交互窗口中可以进一步调整图片格式细节或保存图片,关闭窗口后程序继续运行。下面以一个例子说明:

1
2
3
4
5
6
7
8
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-4, 4, 30) # [-4, 4] 中长度为 30 的等差数列
y = np.sin(x) # 获得 sin 值
plt.plot(x, y) # 以 x 为自变量,y 为因变量,绘制折线图
plt.show() # 显示图像,程序阻塞
plt.savefig('myFig.png') # 保存图像, 可以设置 dpi=1200/600/300 调整清晰度

格式字符串

与 Matlab 相似,Matplotlib 使用事先约定好的字符串代表绘图格式,将其写入 plt.plot() 即可,如 plt.plot(x, y, 'bo') 即蓝色圆圈标记。

  • color:绘制点的颜色,支持以下常用缩写(还有更多未列出的颜色全名),还可以用十六进制代码 c='#000' 指定。
字符颜色字符颜色字符颜色字符颜色
b蓝色r红色m洋红色k黑色
g绿色c青色y黄色w白色
  • color#xkcd:xkdc 调色盘总结了数百种最常用的颜色,比标准色更美观,使用 c='xkcd:pink' 格式即可使用。
全称颜色全称颜色全称颜色全称颜色
pink粉色sky blue天蓝色orange橘色light green嫩绿色
light purple嫩紫色lavender粉紫色tan褐色aqua海绿色
  • marker:绘制点的形状。
字符标记字符标记字符标记字符标记
.点标记>右三角标记p五边形标记D菱形标记
,像素标记1三叉戟标记*星形标记d菱形标记
o圆圈标记2三叉戟标记h六角形标记竖线标记
v倒三角标记3三叉戟标记H六角形标记_横线标记
^正三角标记4三叉戟标记++标记
<左三角标记s正方形标记xx标记
  • line:绘制线的形状。
字符格式字符格式字符格式字符格式
-实线--虚线-.点划线:点线

特定类型图

下面是一些常用的绘图函数。更多官网案例:https://matplotlib.org/stable/gallery/index.html。

折线图 plot()

plt.plot() 用于绘制折线图,需要一系列点作为因变量和自变量,函数接口如下:

1
matplotlib.pyplot.plot(*args, scalex=True, scaley=True, data=None, **kwargs)

观察接口,可以发现有以下的用法:

  • 传入两个等长数组或列表,前者是自变量,后者是因变量。
  • 只传入一个数组或列表,自变量默认从 0 开始整数递增。
  • 传入两个列表,各包含两个值,可以绘制一条直线。

散点图 scatter()

plt.scatter() 用于绘制散点图,需要两个等长数组或列表,函数接口如下:

1
matplotlib.pyplot.scatter(x, y, s=None, c=None, marker=None, **kwargs)

其中只有 xy 为必填项,s 代表每个点的大小,可以是常数也可以是列表。

轮廓图 contour()

plt.contour() 用于绘制等高线图,也可绘制闭合的轮廓图,函数接口如下:

1
matplotlib.pyplot.contour(x, y, Z, [levels], **kwargs)

其中 xy 是一维的数组,代表绘制的网格点的横纵坐标;Z 是形状为 (len(X), len(Y)) 的二维数组,代表绘制点的高度。参数 levels 是一个列表,如果空缺则默认绘制一组等距的等高线,也可以传入 [0, 0.5, 1] 这样的列表,则只会显示对应高度的轮廓线。

网格的横纵坐标通常由 Numpy 中的 merhgrid 生成,完整代码如下:

1
2
3
4
5
u = np.linspace(-1, 1, 50)
v = np.linspace(-1, 1, 50)
U, V = np.meshgrid(u, v) # 生成 50x50 网格点矩阵
z = (map_feature(U.flatten(), V.flatten()) @ theta).reshape((50, 50)) # 计算等高线
plt.contour(u, v, z, [0], colors='r') # 绘制轮廓图,取高度为 0 的点

热力图 imshow()

plt.imshow() 本身是用于绘制二维数字图像,但图像有更好的工具,因此主要用于绘制热力图(heatmap),配套使用的还有色值柱,函数接口如下:

1
2
matplotlib.pyplot.imshow(X, cmap=None, interpolation=None, **kwargs)
matplotlib.pyplot.colorbar()

其中 X 是二维数组,数组元素的值就是热力值,cmap色彩风格,可选的值有:plt.cm.hotcoolgraybonewhitespringsummerautumnwinter

非必要的参数如,interpolation设置为 ’nearest’ 可以将相邻的相同的颜色连成片。最后调用 plt.colorbar() 可以把色值柱附在图像旁边。

直方图 hist()

plt.hist() 用于绘制直方图(Histogram),一种特殊的柱状图。,函数接口如下:

1
matplotlib.pyplot.hist(x, bins=None, range=None, **kwargs)

其中 x 为一维数组,可以是浮点数,bins 为直方图的组数,会自动量化原始数据。非必要的参数 range 可输入一个范围元组,默认为 (x.min(), x.max()),如果设置了则会依据 range 来划分直方图的组宽。

子图布局

Matplotlib 有一个概念 subplot:包含在 Figure 对象中的小型 Axes 对象。这允许我们在一幅图中创建很多个子图,方便对比数据。在前面绘制单张图时,可以不声明 Figure 对象作为所有内容的容器,但绘制子图时则必须声明。总共有三种方法。

任意位置 add_axes()

先调用 plt.figure() 创建 Figure 对象,接着调用 fig.add_axes() 在图表的任意位置添加子图,函数接口如下:

1
ax = fig.add_axes(rect, projection=None, polar=False, **kwargs)

其中只有 rect 是必要的,这是一个四个浮点数的列表 [left, bottom, width, height],分别代表子图左下角的坐标,子图的宽度和高度。这四个数字的取值范围都是 0 到 1,代表相对位置和大小

最后在每个子图里用各自的类型图进行绘制即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
fig = plt.figure(figsize=(10, 7))

# 创建子图
ax1 = fig.add_axes([0, 0.5, 0.45, 0.45]) # 在图表的左上角创建一个子图
ax2 = fig.add_axes([0.5, 0, 0.45, 0.45]) # 在图表的右下方创建一个子图

# 左上角子图:曲线图
x1 = np.linspace(-10, 10, 100)
ax1.plot(x1, np.sin(x1), color="red")

# 右下角子图:柱状图
x2 = ["a", "b", "c", "d", "e", "f"]
y2 = [1.2, 1.3, 2.5, 0.25, 5, 1.56]
ax2.bar(x2, y2, color="blue")

对齐网格 subplot()

plt.subplot() 用于在一张图里绘制多个子图,最常用,函数接口如下:

1
ax = matplotlib.pyplot.subplot(nrows=1, ncols=1, index, **kwargs)

其中只有 nrowncols 表示总共有多少子图,index 代表其中第几个。如:plt.subplot(2,2,1),也可以缩写为 plt.subplot(221)。紧跟在 plt.subplot() 语句后面的语句绘制的就是 index 所指向的图,绘制完再次使用 plt.subplot() 语句切换到下一张子图。

1
2
3
4
5
6
7
fig = plt.figure(figsize=(10, 7))  # 如果对整体有标题、标签,在此处添加

ax1 = plt.subplot(1,2,1)
ax1.plot() # 如果对子图有标题、标签,在此处添加

ax2 = plt.subplot(1,2,2)
ax2.plot() # 如果对子图有标题、标签,在此处添加

但是,上述方法不适用于大量子图的绘制,如 100 张子图拼接等,因此有另一个相似的方法 plt.subplots()

1
fig, axes = matplotlib.pyplot.subplots(nrows=1, ncols=1, *, sharex=False, sharey=False, **fig_kw)

同时返回一个固定的 Figure 对象(设置的参数一起写即可)和一个 Axes 对象二维列表,通过循环可以遍历:

1
2
3
4
5
6
7
8
9
10
11
12
# 创建一个图形对象,拆分为2*3的网格,包含6个坐标对象
fig, axes = plt.subplots(
nrows=2, # 定义行数
ncols=3, # 定义列数
sharex=True, # 是否共享x轴坐标
sharey=True, # 是否共享y轴坐标
figsize=(10, 7) # 图像大小
)

for i in range(2):
for j in range(3):
axes[i, j].plot()

自由网格 GridSpec

如果想创建不规则的子图,部分子图更大,展示核心信息,有的子图较小,展现辅助信息,plt.GridSpec() 可实现这一点,工作原理是先创建一个网格状的蓝图,然后合并部分子图。函数接口如下:

1
grid = matplotlib.gridspec.GridSpec(nrows, ncols, figure=None, **kwargs)

使用时调用 plt.GridSpec 创建网格状的二维数组 grid,通过切片和索引按需求合并子图,最后调用 ax.plot() 将数据映射到图表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
fig = plt.figure(figsize=(10, 7))

grid = plt.GridSpec(nrows=2, ncols=3, figure=fig) # 创建 2*3 的网格蓝图

# 合并子图
ax1 = plt.subplot(grid[0, 0])
ax2 = plt.subplot(grid[0, 1:]) # 合并(0,1)和(0,2)位置的子图
ax3 = plt.subplot(grid[1, 0:2]) # 合并(1,0)和(1,1)位置的子图
ax4 = plt.subplot(grid[1, 2])

x = np.linspace(0, 10, 30)

ax1.plot(x, np.sin(x), "-r")
ax2.plot(x, np.cos(x), "-ob")
ax3.plot(x, np.sin(x + 10), "-oy")
ax4.plot(x, np.cos(x + 10), "-g")

装饰输出

除了上述常用的绘图函数,Matplotlib 还带有各种绘图组件,用于装饰输出。

全局配置 rcParams

Matplotlib 使用 MRC(Matplotlib Resource Configurations)配置文件来自定义各种属性,我们称之为 rc 配置或者 rc 参数。使用 rcParams 可以控制几乎所有的默认属性:视图窗口、线条、颜色、样式、坐标轴、网格、文本、字体等属性。

通常在 Notebook 中使用,使用方法如下:

1
2
3
4
5
6
from matplotlib import pyplot as plt

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

画板 figure()

图例 legend()

注意 label 不要拼错了

坐标轴

去掉坐标轴的轴线、刻度、标签:plt.axis('off')

去掉坐标轴的刻度、标签:plt.xticks([])plt.yticks([])

去掉坐标轴的刻度,但保留标签:plt.tick_params(left=False, bottom=False)

去掉坐标轴的标签,但保留刻度(适用于画 Attention 热力图):

1
2
3
ax = plt.gca()
ax.axes.xaxis.set_ticklabels([])
ax.axes.yaxis.set_ticklabels([])

文字说明

  • plt.xlabel('X Label')
  • plt.ylabel('Y Label')
  • plt.title('TITLE')

Python笔记 #3 Matplotlib
https://hwcoder.top/Python-Note-3
作者
Wei He
发布于
2021年9月27日
许可协议