本文使用pytorch框架来搭建一个多分支的神经网络,编程时借鉴了Inception的编程思想。
准备实现的网络结构如下图所示
每个分支的输入图片大小设置为64*64,卷积层和池化层的参数设置如下表所示
层参数CONV13,16,kernel_size=3, stride=1, padding=1Pooling1kernel_size=2, stride=2CONV216,32,kernel_size=3, stride=1, padding=1Pooling2kernel_size=2, stride=2CONV332,64,kernel_size=3, stride=1, padding=1Pooling3kernel_size=2, stride=2CONV464,128,kernel_size=3, stride=1, padding=1Pooling4kernel_size=2, stride=2
首先,导入需要的库
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim
接下来,定义网络模型
class ThreeInputsNet(nn.Module):
def __init__(self):
super(ThreeInputsNet, self).__init__()
self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 3)
self.outlayer2 = nn.Linear(128 * 3, 256)
self.outlayer3 = nn.Linear(256, 3)
def forward(self, input1, input2, input3):
out1 = self.pooling1_1(self.conv1_1(input1))
out1 = self.pooling1_1(self.conv1_2(out1))
out1 = self.pooling1_1(self.conv1_3(out1))
out1 = self.pooling1_1(self.conv1_4(out1))
out2 = self.pooling2_1(self.conv2_1(input2))
out2 = self.pooling2_1(self.conv2_2(out2))
out2 = self.pooling2_1(self.conv2_3(out2))
out2 = self.pooling2_1(self.conv2_4(out2))
out3 = self.pooling3_1(self.conv3_1(input3))
out3 = self.pooling3_1(self.conv3_2(out3))
out3 = self.pooling3_1(self.conv3_3(out3))
out3 = self.pooling3_1(self.conv3_4(out3))
out = torch.cat((out1, out2, out3), dim=1)
out = out.view(out.size(0), -1)
out = self.outlayer1(out)
out = self.outlayer2(out)
out = self.outlayer3(out)
return out
输入一些数据测试一下网络能否跑通
if __name__ == '__main__':
input1 = torch.ones(8, 3, 64, 64)
input2 = torch.ones(8, 3, 64, 64)
input3 = torch.ones(8, 3, 64, 64)
net = ThreeInputsNet()
output = net(input1, input2, input3)
print("out.shape:{}".format(output.shape))
完整代码
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim
class ThreeInputsNet(nn.Module):
def __init__(self):
super(ThreeInputsNet, self).__init__()
self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
self.outlayer2 = nn.Linear(128 * 5, 256)
self.outlayer3 = nn.Linear(256, 3)
def forward(self, input1, input2, input3):
out1 = self.pooling1_1(self.conv1_1(input1))
out1 = self.pooling1_1(self.conv1_2(out1))
out1 = self.pooling1_1(self.conv1_3(out1))
out1 = self.pooling1_1(self.conv1_4(out1))
out2 = self.pooling2_1(self.conv2_1(input2))
out2 = self.pooling2_1(self.conv2_2(out2))
out2 = self.pooling2_1(self.conv2_3(out2))
out2 = self.pooling2_1(self.conv2_4(out2))
out3 = self.pooling3_1(self.conv3_1(input3))
out3 = self.pooling3_1(self.conv3_2(out3))
out3 = self.pooling3_1(self.conv3_3(out3))
out3 = self.pooling3_1(self.conv3_4(out3))
out = torch.cat((out1, out2, out3), dim=1)
out = out.view(out.size(0), -1)
out = self.outlayer1(out)
out = self.outlayer2(out)
out = self.outlayer3(out)
return out
if __name__ == '__main__':
input1 = torch.ones(8,3,64,64)
input2 = torch.ones(8, 3, 64, 64)
input3 = torch.ones(8, 3, 64, 64)
net = ThreeInputsNet()
output = net(input1, input2, input3)
print("out.shape:{}".format(output.shape))
Original: https://blog.csdn.net/weixin_44136693/article/details/125573084
Author: Lonelysoul丶枫寒
Title: pytorch学习记录02——多分支网络
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/707880/
转载文章受原作者版权保护。转载请注明原作者出处!