基于pytorch的CNN猫狗图分

1.所需模块
2.前提知识
3.CNN简要
4.基本框架
5.代码

.

1.所需的模块

相关的作用在用到的时候单独讲

import numpy as np
import matplotlib.pyplot as plt
import torch
import os
from PIL import Image
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader as DataLoader
import torch.utils.data

2.前提知识

1.训练集:(training data set):以下简称 Trset, Trset 类似于为了让计算机记住某些 特征,而存储的一些带有 标记的数据的一个 集合

2.测试集 : (testing data set ):以下简称 Teset, 用于 检测 一数据是什么 类别集合
3.激励函数的作用 (这里不多讲,csdn上面有很多)
4.简单神经网络构架

3.CNN简要:

CNN(Convolutional Neural Networks), 卷积神经网络,以 卷积的基本操作而命名,简单点的主要分3个部分: 输入层(Input)卷积层(Conv)池化层(Pool), 和 全连层(FC)

基于pytorch的CNN猫狗图分
根据上图
输入层:根据第一层Conv(Conv1),该层数据一共有3层,故Conc1的输入层是1个3通道的图片,事实也是这样,彩图是3通道( RGB)(3个feature map)的,灰图则1个通道(1 个 feature map);并且没每个像素点的范围为[0,255](像素点)。一般图片的数据形式则是 [h w c] ,其中对应的字母分别为 图片的 高,宽,通道数。
卷积层
用于取特征,由卷积核对输入层图像进行卷积操作以提取图像特征。
另外卷积核(下图移动的部分)**:1个卷积核生成1个feature map,即卷积输出的图像通道数与卷积核的个数一致,卷积核的尺寸为(S×S×C×N),其中C表示卷积核深度,必须与输入层图像的通道数一致。
基于pytorch的CNN猫狗图分
浅显易懂吧

池化层
主要用于图像下采样,降低图像分辨率,减少区域内图像的特征数。本文用的池化方法为max pooling,max pooling就是在池化核大小区域内选择最大的数值作为输出结果。
池化的演示:

基于pytorch的CNN猫狗图分

全连层
用于分类的操作,若卷积后的图像尺寸为(h×w×c),需分成n类,则全连层的作用为将[h×w×c]的矩阵转换成[n×1]的矩阵。

; 4.基本框架:

准备数据:将数据集中的数据整理成程序代码可识别读取的形式。
搭建网络:利用PyTorch提供的API搭建设计的网络。
训练网络:把1中准备好的数据送入2中搭建的网络中进行训练,获得网络各节点权值参数(model)。
测试网络:导入3中获取的参数,并输入网络一个数据,然后评估网络的输出结果。
代码实现
代码前言:准备数据之前,先吧同一文件夹下的 data文件准备好,data文件包train 和test文件,其中train 里的文件要命名为 cat(dog).x.jpg ,test里的从0排序就行了。
大概就是这样的:
data文件目录下:

基于pytorch的CNN猫狗图分
基于pytorch的CNN猫狗图分
data->test
基于pytorch的CNN猫狗图分

5.代码

getdata的代码如下:

import os
import torch
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as Trans
img_size = 200

tran = Trans.Compose([Trans.Resize(img_size), Trans.CenterCrop([img_size, img_size]), Trans.ToTensor()])

class DogsVSCatsDataset(data.Dataset):
    def __init__(self, mode, dir):
        self.data_size = 0
        self.img_list=[]
        self.img_label =[]
        self.trans=tran
        self.mode =mode

        if self.mode =='train':
            dir += '/train/'
            for file in os.listdir(dir):
                self.img_list.append(dir+file)
                self.data_size += 1
                name = file.split(sep='.')
                label_x =0
                if name[0] =='cat':
                    label_x =1
                self.img_label.append(label_x)

        elif self.mode == 'test':
            dir +='/test/'
            for file in os.listdir(dir):
                self.img_list.append(dir+file)
                self.data_size +=1
                self.img_label.append(2)
        else:
            print("没有这个mode")

    def __getitem__(self,item):
        if self.mode =='train':
            img =Image.open(self.img_list[item])
            label_y = self.img_label[item]
            return self.trans(img), torch.LongTensor([label_y])
        elif self.mode=='test':
            img =Image.open(self.img_list[item])
            return self.trans(img)
        else:
            print("None")
    def __len__(self):
        return self.data_size

network代码如下

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.utils.data as data

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 16,3, padding=1)
        self.conv2 = torch.nn.Conv2d(16, 16, 3,padding=1)

        self.fc1 = torch.nn.Linear(50*50*16, 128)
        self.fc2 = torch.nn.Linear(128, 64)
        self.fc3 = torch.nn.Linear(64,2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x =self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return F.softmax(x, dim=1)

上面的代码中,Conv2中的pading等参数作用,可以看看这篇blog

train 代码如下:

from getdata import DogsVSCatsDataset as DVCD
from torch.utils.data import DataLoader as DataLoader
from network import Net
import torch
from torch.autograd import Variable
import torch.nn as nn

dataset_dir = './data/'

model_dir = './model/'
workers = 10
batch_size = 16
lr = 0.001
nepoch = 1

def train():
    datafile = DVCD('train', dataset_dir)
    dataloader = DataLoader(datafile, batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)

    print('Dataset loaded! length of train set is {0}'.format(len(datafile)))

    model = Net()

    model = nn.DataParallel(model)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    Lossfuc = torch.nn.CrossEntropyLoss()

    cnt = 0
    for epoch in range(nepoch):
        for img, label in dataloader:
            img, label = Variable(img), Variable(label)
            out = model(img)
            loss = Lossfuc(out, label.squeeze())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            cnt += 1

            print('Epoch:{0},Frame:{1}, train_loss {2}'.format(epoch, cnt*batch_size, loss/batch_size))
    torch.save(model.state_dict(), '{0}/model.pth'.format(model_dir))
if __name__ == '__main__':
    train()

test 的代码如下:

from getdata import DogsVSCatsDataset as DVCD
from network import Net
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

data_dir = './data/'
model_file = './model/model.pth'
N = 10

model =Net()
model = nn.DataParallel(model)
model.load_state_dict(torch.load(model_file))
model.eval()

datafile = DVCD('test', data_dir)

index = np.random.randint(0, datafile.data_size, 1)[0]
img = datafile.__getitem__(index)

img = img.unsqueeze(0)
img = Variable(img)

out =model(img)
out = F.softmax(out, dim=1)
print(out.data)
if out[0, 0]>out[0, 1]:
    print("the picture is a cat")
else:
    print("the picture is a dog")
img = Image.open(datafile.img_list[index])
plt.figure('image')
plt.imshow(img)
plt.show()

效果图:

基于pytorch的CNN猫狗图分

mood:

基于pytorch的CNN猫狗图分
没有谁忘不了谁吧

Original: https://blog.csdn.net/qq_57862276/article/details/124067565
Author: dai _ tu
Title: 基于pytorch的CNN猫狗图分

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/640832/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球