import numpy as np
深入理解 axis
axis
的英文含义为轴,它在 NumPy 中是许多函数的一个重要参数。这些函数往往都支持复杂的多维数组汇总运算,axis
的值决定了程序员可以对高维矩阵的任何一维做汇总运算。
这个参数取名为轴是相当形象的。下面这个简单的例子就能很好地帮助我们理解这个参数为什么叫做 axis
。
Z = np.arange(12).reshape(4, 3)
Z
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
我们设置 axis=0
,发现每一列的平均值被求解出来排成了一行:
np.mean(Z, axis=0)
array([4.5, 5.5, 6.5])
我们设置 axis=1
,发现每一行的平均值被求解出来排成了一行:
np.mean(Z, axis=1)
array([ 1., 4., 7., 10.])
上面的程序段首先定义了一个大小为 16 的一维数组并将之打包成 $4 \times 4$ 矩阵,参数 axis
先后分别设置为 0 和 1。观察输出结果可知:第一次是求出了每列的平均值,第二次是求出了每行的平均值。
由此,我们可以得出一个初步的结论:axis=1
是对矩阵的行汇总,axis=0
是对矩阵的列汇总。这个过程和 Excel 中的一些公式和函数的功能很相近,特别是和其中的分类汇总功能很相近。axis
确实就是一个轴,它在矩阵中平移(在垂直轴方向上分类),所到之处沿平行轴方向汇总。
例如当 axis=1
时,说明要在 index=1
的列那一维上做汇总操作。此时的轴一定是水平的,因为在平行轴方向 axis
所表示的那一维是变化的,是要对整个域汇总的。相应地,会在另一维(行)上遍历。所得数组一定是降低了一维变成了一维数组。
按照这个运算法则,可以将之推广到多维数组。这个时候的数组不能简单表示成一维向量或者二维矩阵的形式了,只能线性地用括号匹配原则表示。但是 axis
的含义没有本质区别,可以总结为一句话:设 axis=i
,则 NumPy 沿着第 i
个下标变化的方向进行汇总操作。下面是一个三维的例子:
Z = np.arange(24).reshape(2, 3, 4)
Z
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
我们设置 axis=0
,发现大小为 2 的那一维被汇总了,只剩下 $3 \times 4$ 矩阵。
np.mean(Z, axis=0)
array([[ 6., 7., 8., 9.],
[10., 11., 12., 13.],
[14., 15., 16., 17.]])
我们设置 axis=1
,发现大小为 3 的那一维被汇总了,只剩下 $2 \times 4$ 矩阵。
np.mean(Z, axis=1)
array([[ 4., 5., 6., 7.],
[16., 17., 18., 19.]])
我们设置 axis=2
,发现大小为 4 的那一维被汇总了,只剩下 $2 \times 3$ 矩阵。
np.mean(Z, axis=2)
array([[ 1.5, 5.5, 9.5],
[13.5, 17.5, 21.5]])
NumPy 统计函数
NumPy 统计函数经常涉及到 axis 参数的使用。
A = np.arange(12).reshape(3, 4)
A
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
求和
np.sum(A, axis=0)
array([12, 15, 18, 21])
np.sum(A, axis=1)
array([ 6, 22, 38])
平均数
np.mean(A, axis=0)
array([4., 5., 6., 7.])
np.mean(A, axis=1)
array([1.5, 5.5, 9.5])
np.average
专门用于求解加权平均数,w
为权重矩阵。
w = np.random.randint(1, 10, size=(3, 4))
w
array([[2, 7, 3, 5],
[7, 4, 5, 3],
[2, 4, 2, 7]])
np.average(A, axis=0, weights=w)
array([4. , 4.2 , 5.6 , 7.53333333])
np.average(A, axis=1, weights=w)
array([1.64705882, 5.21052632, 9.93333333])
方差与标准差
方差
np.var(A, axis=0)
array([10.66666667, 10.66666667, 10.66666667, 10.66666667])
np.var(A, axis=1)
array([1.25, 1.25, 1.25])
标准差
np.std(A, axis=0)
array([3.26598632, 3.26598632, 3.26598632, 3.26598632])
np.std(A, axis=1)
array([1.11803399, 1.11803399, 1.11803399])
不支持 axis 参数的统计函数
最大值(最小值)
np.max(A)
11
np.min(A)
0
最大值(最小值)索引
np.argmax(A)
11
np.argmin(A)
0
极差,即最大值与最小值之差
np.ptp(A)
11
中位数
np.median(A)
5.5
axis 参数的应用
拼接两个矩阵
x = np.arange(9).reshape(3, 3)
y = np.arange(-3, 6).reshape(3, 3)
x, y
(array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]),
array([[-3, -2, -1],
[ 0, 1, 2],
[ 3, 4, 5]]))
np.concatenate((x, y), axis=0)
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[-3, -2, -1],
[ 0, 1, 2],
[ 3, 4, 5]])
np.concatenate((x, y), axis=1)
array([[ 0, 1, 2, -3, -2, -1],
[ 3, 4, 5, 0, 1, 2],
[ 6, 7, 8, 3, 4, 5]])