利用卷积神经网络处理cifar图像分类

这是一个图像分类的比赛CIFAR( CIFAR-10 – Object Recognition in Images )

首先我们需要下载数据文件,地址:

http://www.cs.toronto.edu/~kriz/cifar.html

CIFAR-10数据集包含10个类别的60000个32×32彩色图像,每个类别6000个图像。有50000张训练图像和10000张测试图像。

数据集分为五个训练批次和一个测试批次,每个批次具有10000张图像。测试批次包含每个类别中恰好1000张随机选择的图像。训练批次按随机顺序包含其余图像,但是某些训练批次可能包含比另一类更多的图像。在它们之间,培训批次精确地包含每个班级的5000张图像。

这些类是完全互斥的。汽车和卡车之间没有重叠。”汽车”包括轿车,SUV和类似的东西。”卡车”仅包括大型卡车。都不包括皮卡车。

利用卷积神经网络处理cifar图像分类

详细代码:

1.导包

1 import numpy as np
 2
 3 # 序列化和反序列化
 4 import pickle
 5
 6 from sklearn.preprocessing import OneHotEncoder
 7
 8 import warnings
 9 warnings.filterwarnings('ignore')
10
11 import tensorflow as tf

2.数据加载

1 def unpickle(file):
 2      3     with open(file, 'rb') as fo:
 4         dict = pickle.load(fo, encoding='ISO-8859-1')
 5     return dict
 6
 7 # def unpickle(file):
 8 #     import pickle
 9 #     with open(file, 'rb') as fo:
10 #         dict = pickle.load(fo, encoding='bytes')
11 #     return dict
12
13 labels = []
14 X_train = []
15 for i in range(1,6):
16     data = unpickle('./cifar-10-batches-py/data_batch_%d'%(i))
17     labels.append(data['labels'])
18     X_train.append(data['data'])
19
20 # 将list类型转换为ndarray
21 y_train = np.array(labels).reshape(-1)
22 X_train = np.array(X_train)
23
24 # reshape
25 X_train = X_train.reshape(-1,3072)
26
27 # 目标值概率
28 one_hot = OneHotEncoder()
29 y_train =one_hot.fit_transform(y_train.reshape(-1,1)).toarray()
30 display(X_train.shape,y_train.shape)

3.构建神经网络

1 X = tf.placeholder(dtype=tf.float32,shape = [None,3072])
 2 y = tf.placeholder(dtype=tf.float32,shape = [None,10])
 3 kp = tf.placeholder(dtype=tf.float32)
 4
 5 def gen_v(shape):
 6     return tf.Variable(tf.truncated_normal(shape = shape))
 7
 8 def conv(input_,filter_,b):
 9     conv = tf.nn.relu(tf.nn.conv2d(input_,filter_,strides=[1,1,1,1],padding='SAME') + b)
10     return tf.nn.max_pool(conv,[1,3,3,1],[1,2,2,1],'SAME')
11
12 def net_work(input_,kp):
13
14 #     形状改变,4维
15     input_ = tf.reshape(input_,shape = [-1,32,32,3])
16 #     第一层
17     filter1 = gen_v(shape = [3,3,3,64])
18     b1 = gen_v(shape = [64])
19     conv1 = conv(input_,filter1,b1)
20 #     归一化
21     conv1 = tf.layers.batch_normalization(conv1,training=True)
22
23 #     第二层
24     filter2 = gen_v([3,3,64,128])
25     b2 = gen_v(shape = [128])
26     conv2 = conv(conv1,filter2,b2)
27     conv2 = tf.layers.batch_normalization(conv2,training=True)
28
29 #     第三层
30     filter3 = gen_v([3,3,128,256])
31     b3 = gen_v([256])
32     conv3 = conv(conv2,filter3,b3)
33     conv3 = tf.layers.batch_normalization(conv3,training=True)
34
35 #     第一层全连接层
36     dense = tf.reshape(conv3,shape = [-1,4*4*256])
37     fc1_w = gen_v(shape = [4*4*256,1024])
38     fc1_b = gen_v([1024])
39     fc1 = tf.matmul(dense,fc1_w) + fc1_b
40     fc1 = tf.layers.batch_normalization(fc1,training=True)
41     fc1 = tf.nn.relu(fc1)
42 #     fc1.shape = [-1,1024]
43
44
45 #     dropout
46     dp = tf.nn.dropout(fc1,keep_prob=kp)
47
48 #     第二层全连接层
49     fc2_w = gen_v(shape = [1024,1024])
50     fc2_b = gen_v(shape = [1024])
51     fc2 = tf.nn.relu(tf.layers.batch_normalization(tf.matmul(dp,fc2_w) + fc2_b,training=True))
52
53 #     输出层
54     out_w = gen_v(shape = [1024,10])
55     out_b = gen_v(shape = [10])
56     out = tf.matmul(fc2,out_w) + out_b
57     return out

4.损失函数准确率

1 out = net_work(X,kp)
 2
 3 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=out))
 4
 5 # 准确率
 6 y_ = tf.nn.softmax(out)
 7
 8 # equal 相当于 ==
 9 accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,axis = -1),tf.argmax(y_,axis = 1)),tf.float16))
10 accuracy

5.最优化

1 opt = tf.train.AdamOptimizer().minimize(loss)
2 opt

6.开启训练

1 epoches = 50000
 2 saver = tf.train.Saver()
 3
 4 index = 0
 5 def next_batch(X,y):
 6     global index
 7     batch_X = X[index*128:(index+1)*128]
 8     batch_y = y[index*128:(index+1)*128]
 9     index+=1
10     if index == 390:
11         index = 0
12     return batch_X,batch_y
13
14 test = unpickle('./cifar-10-batches-py/test_batch')
15 y_test = test['labels']
16 y_test = np.array(y_test)
17 X_test = test['data']
18 y_test = one_hot.transform(y_test.reshape(-1,1)).toarray()
19 y_test[:10]
20
21 with tf.Session() as sess:
22     sess.run(tf.global_variables_initializer())
23     for i in range(epoches):
24         batch_X,batch_y = next_batch(X_train,y_train)
25         opt_,loss_ = sess.run([opt,loss],feed_dict = {X:batch_X,y:batch_y,kp:0.5})
26         print('----------------------------',loss_)
27         if i % 100 == 0:
28             score_test = sess.run(accuracy,feed_dict = {X:X_test,y:y_test,kp:1.0})
29             score_train = sess.run(accuracy,feed_dict = {X:batch_X,y:batch_y,kp:1.0})
30             print('iter count:%d。mini_batch loss:%0.4f。训练数据上的准确率:%0.4f。测试数据上准确率:%0.4f'%
31                   (i+1,loss_,score_train,score_test))

这个准确率只达到了百分之80

如果想提高准确率,还需要进一步优化,调参

Original: https://www.cnblogs.com/xiuercui/p/12047336.html
Author: 程序界第一佳丽
Title: 利用卷积神经网络处理cifar图像分类

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/577565/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • Redis深入浅出

    bash;gutter:true; RDB持久化是指在指定的时间间隔内将内存中的数据集快照写入磁盘,实际操作过程是fork一个子进程,先将数据集写入临时文件,写入成功后,再替换之前…

    Linux 2023年5月28日
    0121
  • Linux快速入门(七)效率工具(Vim)

    Vim编辑器 所有的 Linux系统都会内建一个 Vi文本编辑器,而 Vim是从 Vi发展出来的一个高度可配置的文本编辑器,旨在高效的创建和更改任何类型的文本,它还可以根据文件的扩…

    Linux 2023年6月6日
    0109
  • linux free命令available小于free值

    问题:前段时间在做服务器巡检时发现系统可用内存值小于空闲内存值 分析:查询网上各种资料,都说的是 available=free + buff/cache 这样一个大致计算方式,按这…

    Linux 2023年6月14日
    0185
  • LM算法探讨(附python代码)

    1. 案例分析 考虑如下公式: [\gamma_i=\frac{2\pi}{\lambda}\times 2 \sqrt{(x_i-x_p)^2+(y_i-y_p)^2+(z_i-…

    Linux 2023年6月14日
    0168
  • Shell 脚本是什么?

    一个 Shell 脚本是一个文本文件,包含一个或多个命令。作为系统管理员,我们经常需要使用多个命令来完成一项任务,我们可以添加这些所有命令在一个文本文件(Shell 脚本)来完成这…

    Linux 2023年5月28日
    0113
  • 【Python】AttributeError: ‘Rotation’ object has no attribute ‘from_dcm’

    报错的代码如下: from scipy.spatial.transform import Rotation def dcm2euler(mats: np.ndarray, seq:…

    Linux 2023年6月13日
    082
  • MySQL的约束

    主键约束 能够唯一确定一张表中的一条记录,通过给某个字段添加约束,就可以使得该字段不重复且不为空 create table user( id int primary key, na…

    Linux 2023年6月7日
    097
  • Kibana 7.15.x [error][savedobjects-service] [.kibana] Action failed with ‘Request timed out’. Retrying attempt 报错处理。

    1、报错 近日在windows平台使用7.15.2 的elasticsearch 和kibana 时候,在开启es cmd窗口后,kibana无法启动,报错误下。 log [09:…

    Linux 2023年6月6日
    0135
  • MANIFEST.MF文件对Import-Package/Export-Package重排列

    众所周知,MANIFEST.MF文件中的空格开头的行是相当于拼接在上一行末尾的。很多又长又乱的Import-Package或者Export-Package,有时候想要搜索某个pac…

    Linux 2023年6月13日
    0122
  • protobuf 的交叉编译使用(C++)

    为了提高通信效率,可以采用 protobuf 替代 XML 和 Json 数据交互格式,protobuf 相对来说数据量小,在进程间通信或者设备之间通信能够提高通信速率。下面介绍 …

    Linux 2023年6月7日
    0158
  • 大数据之Hadoop集群的HDFS压力测试

    测试HDFS写性能 原文:sw-code1)写测试的原理 2)测试内容:向HDFS集群写10个128MB的文件(3个机器每个4核,2 * 4 = 8 < 10 < 3 …

    Linux 2023年6月8日
    0106
  • Redis时延问题分析及应对

    Redis时延问题分析及应对 Redis的事件循环在一个线程中处理,作为一个单线程程序,重要的是要保证事件处理的时延短,这样,事件循环中的后续任务才不会阻塞;当redis的数据量达…

    Linux 2023年5月28日
    0101
  • vscode配置指南,美化技巧

    "workbench.colorCustomizations": { "editor.selectionBackground": &quot…

    Linux 2023年6月14日
    0101
  • 《Redis开发与运维》——(七)Redis阻塞(脑图)

    posted @2021-01-09 15:06 雪山上的蒲公英 阅读(90 ) 评论() 编辑 / 返回顶部代码 / Original: https://www.cnblogs….

    Linux 2023年5月28日
    0121
  • 如何配置VLAN

    一、vlan的概念与作用 首先,在学习如何配置vlan时我们先要了解一下为什么要配置vlan?vlan在平常的工作中有什么作用? vlan:虚拟的划分网段 即虚拟网络,在平常的工作…

    Linux 2023年6月6日
    0175
  • Redis (error) NOAUTH Authentication required.

    首先查看redis设置密码没 表示没有设置密码,设置redis密码 这个时候查看密码是会报错的。 需要noauth身份验证。 修改密码 Original: https://www….

    Linux 2023年5月28日
    0114
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球