数学中(教科书、大学课堂、数学相关的科普视频),一个矩阵的向量往往是竖着的, 一列作为一个vector,这一点numpy库也是这样默认的。
但是在机器学习以torch框架为例,一个有意义的向量或者说embedding 是横着的。
因为numpy库默认是一列是一个向量而torch等机器学习框架默认一行是一个向量,所以
torch.cov(X)
和 numpy.cov(X.T)
是相等的。
自行实现
torch在较高版本中才有 torch.cov
函数,低版本的需要自行实现。
因为大部分博客都是数学风格的,在减掉均值后,大部分写X X T XX^T X X T算协方差矩阵,这是默认以列为一个vector,一定要注意。
因为torch的一个向量是一个横行,所以自行实现其实是X T X X^TX X T X
def torch_cov(input_vec:torch.tensor):
x = input_vec- torch.mean(input_vec,axis=0)
cov_matrix = torch.matmul(x.T, x) / (x.shape[0]-1)
return cov_matrix
这样子可以和numpy的cov比较一下:
vecs=torch.tensor([[1,2,3,4],[2,2,3,4]]).float()
vecs_np=vecs.numpy()
cov = np.cov(vecs_np.T)
array([[0.5, 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ]])
torch_cov(vecs)
tensor([[0.5000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000]])
二者是一样的。
直面矩阵的数学解释
对于矩阵M M M来说, 1行为一个高维变量x i x_i x i 应当表示成
[ x 1 x 2 x 3 ] \left[ \begin{matrix} x_1\ x_2\ x_3\ \end{matrix} \right]⎣⎡x 1 x 2 x 3 ⎦⎤
计算均值μ \mu μ,应当是对x i x_i x i 求μ \mu μ,μ = 1 N ∑ N x i \mu=\frac1N\sum_Nx_i μ=N 1 N ∑x i 所以μ \mu μ也是一个高维(与x同维度)的向量。
M − μ M-\mu M −μ变换应当表示成
X = [ x 1 − μ x 2 − μ x 3 − μ ] = [ x 1 ′ x 2 ′ x 3 ′ ] X=\left[ \begin{matrix} x_1-\mu\ x_2-\mu\ x_3-\mu\ \end{matrix} \right]=\left[ \begin{matrix} x_1’\ x_2’\ x_3’\ \end{matrix} \right]X =⎣⎡x 1 −μx 2 −μx 3 −μ⎦⎤=⎣⎡x 1 ′x 2 ′x 3 ′⎦⎤
我们把变换后的M M M写做X X X,变换后的x i x_i x i 写作x i ′ x’_i x i ′。
协方差矩阵Σ \Sigma Σ的意义是各个维度之间相互的方差,则应当是
1 3 X T X = 1 3 [ x 1 ′ , x 2 ′ , x 3 ′ ] [ x 1 ′ x 2 ′ x 3 ′ ] = Σ \frac13X^TX=\frac13\left[ \begin{matrix} x_1′, x_2′, x_3’\ \end{matrix} \right]\left[ \begin{matrix} x_1’\ x_2’\ x_3’\ \end{matrix} \right]=\Sigma 3 1 X T X =3 1 [x 1 ′,x 2 ′,x 3 ′]⎣⎡x 1 ′x 2 ′x 3 ′⎦⎤=Σ
直观解释是这个乘法Σ \Sigma Σ最左上角的元素,恰好是x i ′ x’_i x i ′第1维对第1维的自我方差,此时可以确认是正确意义的协方差矩阵。
当然,算完之后还要乘变量个1 3 \frac13 3 1 或者1 3 − 1 \frac1{3-1}3 −1 1 。
Original: https://blog.csdn.net/Yonggie/article/details/124757929
Author: Yonggie
Title: 协方差矩阵在torch和numpy中的比较,自行实现torch协方差矩阵
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/709380/
转载文章受原作者版权保护。转载请注明原作者出处!