pytorch学习记录02——多分支网络

本文使用pytorch框架来搭建一个多分支的神经网络,编程时借鉴了Inception的编程思想。
准备实现的网络结构如下图所示

pytorch学习记录02——多分支网络

每个分支的输入图片大小设置为64*64,卷积层和池化层的参数设置如下表所示

层参数CONV13,16,kernel_size=3, stride=1, padding=1Pooling1kernel_size=2, stride=2CONV216,32,kernel_size=3, stride=1, padding=1Pooling2kernel_size=2, stride=2CONV332,64,kernel_size=3, stride=1, padding=1Pooling3kernel_size=2, stride=2CONV464,128,kernel_size=3, stride=1, padding=1Pooling4kernel_size=2, stride=2

首先,导入需要的库


import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim

接下来,定义网络模型


class ThreeInputsNet(nn.Module):
    def __init__(self):
        super(ThreeInputsNet, self).__init__()

        self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 3)
        self.outlayer2 = nn.Linear(128 * 3, 256)
        self.outlayer3 = nn.Linear(256, 3)

    def forward(self, input1, input2, input3):
        out1 = self.pooling1_1(self.conv1_1(input1))
        out1 = self.pooling1_1(self.conv1_2(out1))
        out1 = self.pooling1_1(self.conv1_3(out1))
        out1 = self.pooling1_1(self.conv1_4(out1))

        out2 = self.pooling2_1(self.conv2_1(input2))
        out2 = self.pooling2_1(self.conv2_2(out2))
        out2 = self.pooling2_1(self.conv2_3(out2))
        out2 = self.pooling2_1(self.conv2_4(out2))

        out3 = self.pooling3_1(self.conv3_1(input3))
        out3 = self.pooling3_1(self.conv3_2(out3))
        out3 = self.pooling3_1(self.conv3_3(out3))
        out3 = self.pooling3_1(self.conv3_4(out3))

        out = torch.cat((out1, out2, out3), dim=1)
        out = out.view(out.size(0), -1)
        out = self.outlayer1(out)
        out = self.outlayer2(out)
        out = self.outlayer3(out)
        return out

输入一些数据测试一下网络能否跑通

if __name__ == '__main__':
    input1 = torch.ones(8, 3, 64, 64)
    input2 = torch.ones(8, 3, 64, 64)
    input3 = torch.ones(8, 3, 64, 64)
    net = ThreeInputsNet()
    output = net(input1, input2, input3)
    print("out.shape:{}".format(output.shape))

完整代码


import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim

class ThreeInputsNet(nn.Module):
    def __init__(self):
        super(ThreeInputsNet, self).__init__()

        self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
        self.outlayer2 = nn.Linear(128 * 5, 256)
        self.outlayer3 = nn.Linear(256, 3)

    def forward(self, input1, input2, input3):
        out1 = self.pooling1_1(self.conv1_1(input1))
        out1 = self.pooling1_1(self.conv1_2(out1))
        out1 = self.pooling1_1(self.conv1_3(out1))
        out1 = self.pooling1_1(self.conv1_4(out1))

        out2 = self.pooling2_1(self.conv2_1(input2))
        out2 = self.pooling2_1(self.conv2_2(out2))
        out2 = self.pooling2_1(self.conv2_3(out2))
        out2 = self.pooling2_1(self.conv2_4(out2))

        out3 = self.pooling3_1(self.conv3_1(input3))
        out3 = self.pooling3_1(self.conv3_2(out3))
        out3 = self.pooling3_1(self.conv3_3(out3))
        out3 = self.pooling3_1(self.conv3_4(out3))

        out = torch.cat((out1, out2, out3), dim=1)
        out = out.view(out.size(0), -1)
        out = self.outlayer1(out)
        out = self.outlayer2(out)
        out = self.outlayer3(out)
        return out

if __name__ == '__main__':
    input1 = torch.ones(8,3,64,64)
    input2 = torch.ones(8, 3, 64, 64)
    input3 = torch.ones(8, 3, 64, 64)
    net = ThreeInputsNet()
    output = net(input1, input2, input3)
    print("out.shape:{}".format(output.shape))

Original: https://blog.csdn.net/weixin_44136693/article/details/125573084
Author: Lonelysoul丶枫寒
Title: pytorch学习记录02——多分支网络

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/707880/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球