在最后一篇文章中,我们介绍了迁移学习的核心思想和过程,并通过一个例子加深了我们的理解。
[En]
In the last article, we introduced the core ideas and processes of transfer learning, and we introduced an example to deepen our understanding.
传送门:迁移学习概述
获取预训练模型
pytorch和tensorflow都封装了很多预训练模型。
pytorch通过工具包torchvision.models模块获取,主要包括AlexNet、VGG系列、
ResNet系列、SqueezeNet和DenseNet等,通过设置参数pretrained=True即可获取。而Tensorflow内置在keras.application里面,当然,也可以通过TensorFlowHub网站自行下载。
from tensorflow.keras.applications import vgg16,resnet
from torchvision.models import AlexNet,VGG,ResNet
from torchvision.models import SqueezeNet,DenseNet
一个实例
下面通过一个例子对迁移学习有个感性的认识。预训练模型采用retnet18网络,一共分为八大步骤。
注:代码均来源于《深入浅出Embedding》第三章
1.导入模块
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from datetime import datetime
2.加载数据
加载相关数据集,首次下载需要将download设置为True,此外,还对数据做了一些预处理,标准化、图片裁剪等。
trans_train = transforms.Compose(
[transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
]
)
trans_valid = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
]
)
trainset = torchvision.datasets.CIFAR10(root='.\data',train=True,download=True,transform=trans_train)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True,num_workers=2)
testset = torchvision.datasets.CIFAR10(root='.\data',train=False,download=True,transform=trans_valid)
testloader = torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False,num_workers=2)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
下载过程如下:
注:代码直接下载比较慢,可以点击链接直接手动下载,再导入相关路径,再次运行代码download设置为False即可
3.下载预训练模型
net = models.resnet18(pretrained=True)
这一步也需要时间,耐心等待…..如果这一步出错,先手动下载pth模型文件,再执行下面语句,可加载模型:
pthfile = r'/workspace/resnet18-f37072fd.pth'
model = torch.load(pthfile)
net = models.resnet18(pretrained=False)
net.load_state_dict(model)
4.冻结模型参数
将模型参数冻结
for param in net.parameters():
param.requires_grad = False
5.修改输出类别器
将原来输出的1000类改为只有10类,做以下操作:
device = torch.device("cuda:1" if torch.cuda.is_avaliable() else "cpu")
net.fc = nn.Linear(512,10)
6.查看冻结前后参数情况
toatl_params = sum(p.numel() for p in net.parameters())
print('原参数个数:{}'.format(toatl_params))
toatl_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('可训练参数个数:{}'.format(toatl_trainable_params))
7.定义损失函数及优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.fc.parameters(),lr=1e-3,
weight_decay=1e-3,momentum=0.9)
还有评估指标和训练函数
#定义评估指标
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
prev_time = datetime.now()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
for im, label in train_data:
im = im.to(device) # (bs, 3, h, w)
label = label.to(device) # (bs, h, w)
# forward
output = net(im)
loss = criterion(output, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += get_acc(output, label)
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None:
valid_loss = 0
valid_acc = 0
net = net.eval()
for im, label in valid_data:
im = im.to(device) # (bs, 3, h, w)
label = label.to(device) # (bs, h, w)
output = net(im)
loss = criterion(output, label)
valid_loss += loss.item()
valid_acc += get_acc(output, label)
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))
prev_time = cur_time
print(epoch_str + time_str)
8.训练及验证模型
最后,进行模型训练即可
net=net.to(device)
train(net,trainloader,testloader,20,optimizer,criterion)
参考资料:
《深入浅出Embedding》
https://www.ptorch.com/docs/1/models
Original: https://blog.csdn.net/qq_27388259/article/details/120540776
Author: 整得咔咔响
Title: 迁移学习实例
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/511870/
转载文章受原作者版权保护。转载请注明原作者出处!