【nn.Parameter】Pytorch特征融合自适应权重设置(可学习权重使用)

2021年11月17日11:32:14
今天我们来完成Pytorch自适应可学习权重系数,在进行特征融合时,给不同特征图分配可学习的权重!

原文:基于自适应特征融合与转换的小样本图像分类(2021)
期刊:计算机工程与应用(中文核心、CSCD扩展版)

实现这篇论文里面多特征融合的分支!

实现自适应特征处理模块如下图所示:

【nn.Parameter】Pytorch特征融合自适应权重设置(可学习权重使用)
特征融合公式如下:
F f f = α 1 ∗ F i d + α 2 ∗ F d c o n v + α 3 ∗ F max ⁡ + α 4 ∗ F a v g a i = e w i Σ j e w j ( i = 1 , 2 , 3 , 4 ; j = 1 , 2 , 3 , 4 ) \begin{aligned} &F_{f f}=\alpha_{1} * F_{i d}+\alpha_{2} * F_{dconv}+\alpha_{3} * F_{\max }+\alpha_{4} * F_{a v g} \ &a_{i}=\frac{e^{w_{i}}}{\Sigma_{j} e^{w_{j}}}(i=1,2,3,4 ; j=1,2,3,4)\end{aligned}​F f f ​=α1 ​∗F i d ​+α2 ​∗F d c o n v ​+α3 ​∗F max ​+α4 ​∗F a v g ​a i ​=Σj ​e w j ​e w i ​​(i =1 ,2 ,3 ,4 ;j =1 ,2 ,3 ,4 )​
其中,α i \alpha_i αi ​为归一化权重,Σ α i = 1 \Sigma\alpha_i=1 Σαi ​=1,w i w_i w i ​为初始化权重系数。

结构分析:

  1. 对于一个输入的特征图,有四个分支
  2. 从上往下,第一个分支用的是Maxpooling进行最大池化提取局部特征
  3. 第二个分支用的是Avgpooling进行平均池化提取全局特征
  4. 第三个分支,原文中讲的是”用两组1×1卷积将特征的通道减半压缩,一是为了减少参数量防止过拟合,二是方便后续进行卷积特征拼接进行加性融合;接着在第一组1×1卷积后加入两组3×3卷积来替代5×5卷积后按通道进行拼接(Combine按通道拼接)。”原文将这个分支称作双卷积分支DConv,卷积能提取丰富特征,在拼接后接入一个SE注意力模块
  5. 第三个分支,是残差分支Identity,把输入直接跳跃连接加过去,保留原始特征

模型分析:

  1. 分析下模块结构,既然对于特征融合,最后的操作是Add,那么4个分支输出的特征图大小和维度是相同的!跳跃连接时原图的大小和维度都没有变,所以我们让四个分支的输出和原图大小保持一致
  2. 原文在3.2.2参数设置里面说:最大池化支路池化尺寸设为3,平均池化分支池化尺寸设为2
  3. 初始化各特征的权重全为1,使用nn.Parameter实现
  4. 输入图像的大小为3×84×84

AFP模块Pytorch实现

"""
Author: yida
Time is: 2021/11/17 15:45
this Code:
1.实现中的自适应特征处理模块AFP
2.演示: nn.Parameter的使用
"""
import torch
import torch.nn as nn

class AFP(nn.Module):
    def __init__(self):
        super(AFP, self).__init__()

        self.branch1 = nn.Sequential(
            nn.MaxPool2d(3, 1, padding=1),
        )
        self.branch2 = nn.Sequential(
            nn.AvgPool2d(3, 1, padding=1),
        )

        self.branch3_1 = nn.Sequential(
            nn.Conv2d(3, 1, 1),
            nn.Conv2d(1, 1, 3, padding=1),
            nn.Conv2d(1, 1, 3, padding=1),
        )

        self.branch3_2 = nn.Sequential(
            nn.Conv2d(3, 2, 1),
            nn.Conv2d(2, 2, 3, padding=1)
        )

        self.branch_SE = SEblock(channel=3)

        self.w = nn.Parameter(torch.ones(4))

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)

        b3_1 = self.branch3_1(x)
        b3_2 = self.branch3_2(x)
        b3_Combine = torch.cat((b3_1, b3_2), dim=1)
        b3 = self.branch_SE(b3_Combine)

        b4 = x

        print("b1:", b1.shape)
        print("b2:", b2.shape)
        print("b3:", b3.shape)
        print("b4:", b4.shape)

        w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
        w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
        w3 = torch.exp(self.w[2]) / torch.sum(torch.exp(self.w))
        w4 = torch.exp(self.w[3]) / torch.sum(torch.exp(self.w))

        x_out = b1 * w1 + b2 * w2 + b3 * w3 + b4 * w4
        print("特征融合结果:", x_out.shape)
        return x_out

class SEblock(nn.Module):
    def __init__(self, channel, r=0.5):
        super(SEblock, self).__init__()

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(channel, int(channel * r)),
            nn.ReLU(),
            nn.Linear(int(channel * r), channel),
            nn.Sigmoid(),
        )

    def forward(self, x):

        branch = self.global_avg_pool(x)
        branch = branch.view(branch.size(0), -1)

        weight = self.fc(branch)

        h, w = weight.shape
        weight = torch.reshape(weight, (h, w, 1, 1))

        scale = weight * x
        return scale

if __name__ == '__main__':
    model = AFP()
    print(model)
    inputs = torch.randn(10, 3, 84, 84)
    print("输入维度为: ", inputs.shape)
    outputs = model(inputs)
    print("输出维度为: ", outputs.shape)

    for name, p in model.named_parameters():
        if name == 'w':
            print("特征权重: ", name)
            w0 = (torch.exp(p[0]) / torch.sum(torch.exp(p))).item()
            w1 = (torch.exp(p[1]) / torch.sum(torch.exp(p))).item()
            w2 = (torch.exp(p[2]) / torch.sum(torch.exp(p))).item()
            w3 = (torch.exp(p[3]) / torch.sum(torch.exp(p))).item()
            print(w0, w1, w2, w3)

nn.Parameter:上图特征融合中的权重系数w i w_i w i ​

nn.Parameter的使用:可学习权重设置

【nn.Parameter】Pytorch特征融合自适应权重设置(可学习权重使用)

更新记录

  • 2022年04月14日17:37:53

最近看有不少同学关注此博客,所以,我就找一个可以直接运行的手写数字识别代码,把可学习参数放进去,在训练时输出; 为了让大家能够对可学习参数的变化,有更好的理解。

手写数字识别代码

"""
Author: yida
Time is: 2022/3/6 09:30
this Code: 代码原文: https://www.cnblogs.com/wj-1314/p/9842719.html
- 代码: 手写数字识别, 源码参考上面的链接, 仅仅包含两个卷积层的手写数字识别 对每个卷积层设置一个权重系数w
- 可直接运行, torchvision会自动下载手写数字识别的数据集 存放在当前文件夹 ./mnist 模型保存为./model.pth
- 未实现测试功能 大家可以自行添加
- 为了便于大家更好的理解可学习参数
- 直接放到代码里面, 边训练边输出, 方便各位理解
"""
import os

import torch
import torch.nn as nn
import torchvision.datasets as normal_datasets
import torchvision.transforms as transforms
from torch.autograd import Variable

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))

        self.fc = nn.Linear(7 * 7 * 32, 10)
        self.w = nn.Parameter(torch.ones(2))

    def forward(self, x):

        w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
        w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
        out = self.conv1(x) * w1
        out = self.conv2(out) * w2
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def get_variable(x):
    x = Variable(x)
    return x.cuda() if torch.cuda.is_available() else x

if __name__ == '__main__':

    num_epochs = 5
    batch_size = 100
    learning_rate = 0.001

    train_dataset = normal_datasets.MNIST(
        root='./mnist/',
        train=True,
        transform=transforms.ToTensor(),
        download=True)

    test_dataset = normal_datasets.MNIST(root='./mnist/',
                                         train=False,
                                         transform=transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    model = CNN()
    if torch.cuda.is_available():
        model = model.cuda()

    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = get_variable(images)
            labels = get_variable(labels)
            outputs = model(images)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % 100 == 0:
                print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
                      % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.item()))

                for name, p in model.named_parameters():
                    if name == 'w':
                        print("特征权重: ", name)
                        w0 = (torch.exp(p[0]) / torch.sum(torch.exp(p))).item()
                        w1 = (torch.exp(p[1]) / torch.sum(torch.exp(p))).item()
                        print("w0={} w1={}".format(w0, w1))
                        print("")

    print("训练完成...")
    torch.save(model.state_dict(), './model.pth')

【推荐阅读】

Pytorch-GPU安装教程大合集(Perfect完美系列)

Original: https://blog.csdn.net/weixin_43312117/article/details/121374486
Author: 陈嘿萌
Title: 【nn.Parameter】Pytorch特征融合自适应权重设置(可学习权重使用)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/705515/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球