Mobilenet系列(3):使用pytorch搭建MobileNetV2并基于迁移学习训练

MobileNetV2网络结构如下,网络的详细讲解参考博客:MobileNet系列(2):MobileNet-V2 网络详解

Mobilenet系列(3):使用pytorch搭建MobileNetV2并基于迁移学习训练

图1 MobileNet V2网络架构

从表格的网络结构可以看出,模型基本上就是堆叠倒残差结构 (bottleneck),然后通过 1x1的普通卷积核操作,紧接着是池化核为 7x7的平均池化下采样,最后通过 1x1卷积得到最终的输出。搭建该网络的关键是 倒残差结构,只要构建好 倒残差结构,就能很方便对网络进行搭建了。

; pytorch 网络搭建

model.py文件中,首先定义网络的基础组件。
mobilenet v2网络中卷积基本上都是通过: Conv+BN+ReLU6组成的。

卷积组件

Conv+BN+ReLU6

class ConvBNReLU(nn.Sequential):
    def __init__(self,in_channel,out_channel,kernel_size,stride=1,groups=1):
        padding=(kernel_size-1) // 2
        super(ConvBNReLU,self).__init__(
            nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding,groups=groups,bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU6(inplace=True)
        )

注意 groups=1表示构建的是普通的卷积,如果 groups等于in_channel,那么它就是 DW卷积。由于要使用 BN层,因此 bias是不使用的,设置为 False

倒残差结构

定义一个 InvertedResidual类,它继承与 nn.Moudle这个父类。倒残差结构网络图如下:

Mobilenet系列(3):使用pytorch搭建MobileNetV2并基于迁移学习训练
倒残差网络结构与普通的残差网络结构是类似的, 普通残差结构是两头粗中间细的结构,倒残差结构相反是两头细中间粗的结构。详见:MobileNet系列(2):MobileNet-V2 网络详解,DW卷积的个数是个输入channel是一样的,每个DW卷积层只负责一个channel.所以经过 DW卷积后不改变channel的大小。
class InvertedResidual(nn.Module):
    def __init__(self,in_channel,out_channel,stride,expand_ratio):
        super(InvertResidual,self).__init__()
        hidden_channel=in_channel*expand_ratio
        self.use_shotcut = stride ==1 and in_channel==out_channel
        layers= []
        if expand_ratio !=1:

            layers.append(ConvBNReLU(in_channel,hidden_channel,kernel_size=1))
        layers.extend([

            ConvBNReLU(hidden_channel,hidden_channel,stride=stride,groups=hidden_channel)

            nn.Conv2d(hidden_channel,out_channel,kernel_size=1,bias=False)
            nn.BatchNorm2d(out_channel)
        ])
        self.conv=nn.Sequential(*layers)

    def forward(self,x):
        if self.use_shotcut:
            return x+ self.conv(x)
        else:
            return self.conv(x)

MobileNet V2网络结构

定义 MobileNetV2类,继承 nn.Module, 完整网络搭建代码如下:

class MobileNetV2(nn.Module):
    def __init__(self,num_classes=100,alpha=1.0,round_nearest=8):
        super(MobileNetV2,self).__init__()
        block=InvertedResidual
        input_channel=_make_divisible(32*alpha,round_nearest)
        last_channel=_make_divisible(1280*alpha,round_nearest)

        inverted_residual_setting = [

            [1,16,1,1],
            [6,24,2,2],
            [6,32,3,2],
            [6,64,4,2],
            [9,96,3,1],
            [6,160,3,2],
            [6,320,1,1]
        ]

        features = []

        features.append(ConvBNReLU(3,input_channel,stride=2))

        for t,c,n,s in inverted_residual_setting:

            output_channels= _make_divisible(c*alpha,round_nearest)
            for i in range(n):

                stride= s if i==0 else 1
                features.append(block(input_channel,output_channel,stride,expand_ratio=t))
                input_channel=output_channel

        features.append(ConvBNReLU(input_channel,last_channel,1))

        self.features=nn.Sequential(*features)

        self.avgpool=nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(last_channel,num_classes)
        )

        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode='fan_out')
                if m.bias is not None:
                    m.init.zeros_(m.bias)
                elif isinstance(m,nn.BatchNorm2d):
                    nn.init.ones_(m.weight,0,0.01)
                    nn.init.zeros_(m.bias)

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

其中 _make_divisible函数l来源于 tensorflow官方实现的代码:

def _make_divisible(ch,divisor=8,min_ch=None):
"""
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
    if min_ch is None:
        min_ch=divisor
    new_ch=max(min_ch,int(ch+divisor/2)//divisor*divisor)

    if new_ch <0.9 * ch:
        new_ch +=divisor
    return new_ch

模型训练

首先说下,如何去下载官方的预训练模型参数。比如下载mobilenet的预训练模型

import torchvision.models.mobilenet

点击 torchvision.models.mobilenet进入官方的函数定义中,这里有一个 model_urls,这个 url就是模型的预训练权重的下载链接:

model_urls= {
    'mobilenet_v2':'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
}

复制模型 url到迅雷进行下载,下载后存在当前项目目录下,并命名: mobilenet_v2.pth

训练脚本

train.py

1. import python 包

import torch
import torch.nn as nn
from torchvision import transforms,datasets
import json
import os
import torch.optim as optim
from model import MobileNetV2

2. 数据准备

data_transform= {
    "train": transforms.Compose([transforms.RandomResizeCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
    "val":transforms.Compose([transforms.Resize(256),
                              transforms.CenterCrop(224),
                              transforms.ToTensor(),
                              transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

data_root = os.path.abspath(os.path.join(os.getcwd(),'../..'))
image_path=data_root +"/data_set/flower_data/"

train_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])
train_num=len(train_dataset)

flower_list = train_dataset.class_to_idx
cla_dict =dict((val,key) for key,value in flower_list.items())

json_str=json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)

bath_size=16
train_loader=torch.utils.data.DataLoader(train_dataset,
                                        batch_size=batch_size,shuffle=True,
                                        num_workers=0)
validate_data=datasets.ImageFolder(root=image_path + "val",
                                  transform=data_transform["val"])
val_num=len(validate_dataset)
validate_loader=torch.utils.data.DataLoader(validate_dataset,
                                            batch_size=batch_size,shuffle=False,
                                            num_works=0)

3. 加载模型

net=MobileNetV2(num_classes=5)
model_weight_path="./mobilenet_v2.pth"

assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location=device)

pre_dict=={k:v for k,v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}

missing_keys,unexpected_keys=net.load_state_dict(pre_dict,strict=False)

for param in net.features.parameters():
    param.requires_grad=False
net.to(device)

4. 模型的训练


loss_function=nn.CrossEntropyLoss()

params=[p for p in net.parameters() if p.requires_grad]
optimizer=optim.Adam(params,lr=0.0001)

best_acc=0.0
save_path='./MobileNetV2.pth'
train_steps = len(train_loader)

for epoch in range(epochs):

    net.train()
    running_loss=0.0
    train_bar=tqdm(train_loader)
    for step,data in enumerate(train_bar):
        images,labels=data
        optimizer.zero_grad()
        logits=net(images.to(device))
        loss=loss_function(logits,labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss +=loss.item()

        train_bar.desc="train epoch [{} / {}] loss:{:.3f}".format(epoch+1,epochs,loss)

    net.eval()
    acc=0.0
    with torch.no_grad():
        val_bar=tqdm(validate_loader)
        for val_data in val_bar:
            val_images,val_labels=val_data
            outputs = net(val_images.to(device))

            predict_y= torch.max(outputs,dim=1)[1]
            acc += torch.eq(predict_y,val_labels.to(device)).sum().item()

            val_bar.desc ="valid epoch [{}/{}]".format(epoch+1,epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
   print('Finished Training')

源码下载

Original: https://blog.csdn.net/weixin_38346042/article/details/125358925
Author: @BangBang
Title: Mobilenet系列(3):使用pytorch搭建MobileNetV2并基于迁移学习训练

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/709790/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球