文章目录
注意力汇聚:Nadaraya-Watson核回归
框架下的注意力机制的主要成分:查询(自主提示)和键(非自主提示)之间交互形成了注意力汇聚,注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。在本节中,我们将介绍注意力汇聚的更多细节,以便从宏观上了解注意力机制在实践中的运作方式。1964年提出的Nadaraya-Watson核回归模型是⼀个简单但完整的例⼦,可以⽤于演⽰具有注意⼒机制的机器学习
import torch
from torch import nn
from d2l import torch as d2l
1 – 生成数据集
n_train = 50
x_train,_ = torch.sort(torch.rand(n_train) * 5)
def f(x):
return 2 * torch.sin(x) + x**0.8
y_train = f(x_train) + torch.normal(0.0,0.5,(n_train,))
x_test = torch.arange(0,5,0.1)
y_truth = f(x_test)
n_test = len(x_test)
n_test
50
下面的函数将绘制所有的训练样本(样本由圆圈表示),不带噪声项的真实数据生成函数f(标记为”Truth”),以及学习得到的预测函数(标记为”Pred”)
def plot_kernel_reg(y_hat):
d2l.plot(x_test,[y_truth,y_hat],'x','y',legend=['Truth','Pred'],xlim=[0,5],ylim=[-1,5])
d2l.plt.plot(x_train,y_train,'o',alpha=0.5);
2 – 平均汇聚
y_hat = torch.repeat_interleave(y_train.mean(),n_test)
plot_kernel_reg(y_hat)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j6TRUXSQ-1662988499736)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107812.svg)]
3 – 非参数注意力汇聚
X_repeat = x_test.repeat_interleave(n_train).reshape((-1,n_train))
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2,dim=1)
y_hat = torch.matmul(attention_weights,y_train)
plot_kernel_reg(y_hat)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bHHxRgOe-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107813.svg)]
现在,我们来观察注意力的权重,这里测试数据的输入相当于查询,而训练数据的输入相当于键。因为两个输入都是经过排序的,因此由观察可知,”查询-键”对越接近,注意力汇聚的注意力权重就越高
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d4R52D1O-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107814.svg)]
4 – 带参数注意力汇聚
; 批量矩阵乘法
X = torch.ones((2,1,4))
Y = torch.ones((2,4,6))
torch.bmm(X,Y).shape
torch.Size([2, 1, 6])
在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值
weights = torch.ones((2,10)) * 0.1
values = torch.arange(20.0).reshape((2,10))
torch.bmm(weights.unsqueeze(1),values.unsqueeze(-1))
tensor([[[ 4.5000]],
[[14.5000]]])
定义模型
基于带参数的注意力汇聚,使用小批量矩阵乘法,定义Nadaraya-Watson核回归的带参数版本为:
class NWKernelRegression(nn.Module):
def __init__(self,**kwargs):
super().__init__(**kwargs)
self.w = nn.Parameter(torch.rand((1,),requires_grad = True))
def forward(self,queries,keys,values):
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1,keys.shape[1]))
self.attention_weights = nn.functional.softmax(
-((queries - keys) * self.w)**2 /2 ,dim=1)
return torch.bmm(self.attention_weights.unsqueeze(1),
values.unsqueeze(-1)).reshape(-1)
训练
接下来,将训练数据集变换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的”键-值”对进行计算,从而得到其对应的预测输出
X_tile = x_train.repeat((n_train,1))
Y_tile = y_train.repeat((n_train,1))
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(),lr=0.5)
animator = d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[1,5])
for epoch in range(5):
trainer.zero_grad()
l = loss(net(x_train,keys,values),y_train)
l.sum().backward()
trainer.step()
print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
animator.add(epoch + 1, float(l.sum()))
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5nsQsico-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107815.svg)]
如下所示,训练完带参数的注意力汇聚模型后,我们发现:在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑
keys = x_train.repeat((n_test, 1))
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V9IQPxat-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107816.svg)]
为什么新的模型更不平滑了呢?我们看一下输出结果的绘制图:与非参数的注意力汇聚模型相比,带参数的模型加入可学习的参数后,曲线在注意力权重较大的区域变得更不平滑
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rRT7kL7O-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107817.svg)]
5 – 小结
- Nadaraya-Watson核回归时具有注意力机制的机器学习范例
- Nadaraya-Watson核回归的注意⼒汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于你将值所对应的键核查询作为输入的函数
- 注意力汇聚可以分为非参数型核带参数型
Original: https://blog.csdn.net/mynameisgt/article/details/126823006
Author: 未来影子
Title: 注意力机制 – 注意力汇聚:Nadaraya-Watson核回归
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/635210/
转载文章受原作者版权保护。转载请注明原作者出处!