# 构建逻辑回归模型识别MNIST手写字——单个神经元

## 1、导入库

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
print("Tensorflow&#x7248;&#x672C;&#x662F;:",tf.__version__)


## 2、数据获取

MNIST 数据集可在http://yann.lecun.com/exdb/mnist/获取

TensorFlow提供了数据集读取方法(1.x和2.0版本提供的方法不同)

mnist = tf.keras.datasets.mnist


MNIST数据集文件在读取时如果指定目录下不存在，则会自动去下载，需等待 一定时间；如果已经存在了，则直接读取

## 3、数据集划分

total_num = len(train_images)
valid_split = 0.2
train_num = int(total_num*(1-valid_split))

train_x = train_images[:train_num]
train_y = train_labels[:train_num]

valid_x = train_images[train_num:]
valid_y = train_labels[train_num:]

test_x = test_images
test_y = test_labels

valid_x.shape


## 4、数据塑形

train_x = train_x.reshape(-1,784)
valid_x = valid_x.reshape(-1,784)
test_x = test_x.reshape(-1,784)


## 5、特征数据归一化

train_x = tf.cast(train_x/255.0,tf.float32)
valid_x = tf.cast(valid_x/255.0,tf.float32)
test_x = tf.cast(test_x/255.0,tf.float32)

train_x[1]


## 6、标签数据独热编码

train_y = tf.one_hot(train_y,depth=10)
valid_y = tf.one_hot(valid_y,depth=10)
test_y = tf.one_hot(test_y,depth=10)

train_y


## 7、构建模型

def model(x,w,b):
pred = tf.matmul(x,w)+b
return tf.nn.softmax(pred)


## 8、定义模型变量

W=tf.Variable(tf.random.normal([784,10],mean=0.0,stddev=1.0,dtype=tf.float32))

B=tf.Variable(tf.zeros([10]),dtype=tf.float32)


## 9、定义交叉熵损失函数

def loss(x,y,w,b):
pred = model(x,w,b)
loss_ = tf.keras.losses.categorical_crossentropy(y_true=y,y_pred = pred)
return tf.reduce_mean(loss_)


## 10、定义训练参数

training_epochs=20
batch_size=50
learning_rate=0.001


## 11、定义梯度计算函数

def grad(x,y,w,b):
loss_=loss(x,y,w,b)


## 12、选择优化器

optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate)


## 13、定义准确率

def accuracy(x,y,w,b):
pred=model(x,w,b)
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
return tf.reduce_mean(tf.cast(correct_prediction,tf.float32))


## 14、训练模型

total_step = int(train_num/batch_size)

loss_list_train = []
loss_list_valid = []
acc_list_train = []
acc_list_valid = []

for epoch in range (training_epochs):
for step in range(total_step):
xs = train_x[step*batch_size:(step+1)*batch_size]
ys = train_y[step*batch_size:(step+1)*batch_size]

loss_train = loss(train_x,train_y,W,B).numpy()
loss_valid = loss(valid_x,valid_y,W,B).numpy()
acc_train = accuracy(train_x,train_y,W,B).numpy()
acc_valid = accuracy(valid_x,valid_y,W,B).numpy()
loss_list_train.append(loss_train)
loss_list_valid.append(loss_valid)
acc_list_train.append(acc_train)
acc_list_valid.append(acc_valid)
print("epoch={:3d},train_loss={:.4f},train_acc={:.4f},val_loss={:.4f},val_lacc={:.4f}".format(epoch+1,loss_train,acc_train,loss_valid,acc_valid))


## 15、显示训练过程数据

plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(loss_list_train,'blue',label="Train Loss")
plt.plot(loss_list_valid,'red',label='Valid Loss')
plt.legend(loc=1)

plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(acc_list_train,'blue',label="Train Loss")
plt.plot(acc_list_valid,'red',label='Valid Loss')
plt.legend(loc=1)


## 16、评估模型

acc_test = accuracy(test_x,test_y,W,B).numpy()
print("Test accuracy:",acc_test)


## 17、模型应用与可视化

1. 应用模型
def predict(x,w,b):
pred=model(x,w,b)
result=tf.argmax(pred,1).numpy()
return result

pred_test=predict(test_x,W,B)

pred_test[0]


2. 定义可视化函数
import matplotlib.pyplot as plt
import numpy as np
def plot_images_label_prediction(images,
labels,
preds,
index=0,
num=10
):
fig = plt.gcf()
fig.set_size_inches(10,4)
if num > 10:
num = 10
for i in range(0,num):
ax = plt.subplot(2,5,i+1)

ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')

title = "label=" + str(labels[index])
if len(preds)>0:
title +=",predict=" + str(labels[index])

ax.set_title(title,fontsize=10)
ax.set_xticks([]);
ax.set_yticks([])
index = index + 1

plt.show()

1. 可视化预测结果
plot_images_label_prediction(test_images,test_labels,pred_test,10,10)


Original: https://blog.csdn.net/m0_59324564/article/details/124474111
Author: 小洁酱
Title: 构建逻辑回归模型识别MNIST手写字——单个神经元

(0)

### 大家都在看

• #### tensorflow-gpu安装过程中出现的tf.test.is_gpu_avaiable()返回false的一部分解决方法

说起安装tensorflow-gpu的时候出现的一些坑就有点郁闷写个博客记录一下这一些坑，也算给后人一点解决方法 Question Ⅰ 第一种出现在 import tensorfl…

人工智能 2023年5月24日
0220
• #### 关于修Bug的一些想法

0. 前言 八月份快要结束了，这个月也没有啥输出，今天下班较早，赶一篇学了一年多C++后的一些思考，关于修Bug的一些想法和思路。平时工作中，如果写代码花费一天时间，那调试解决Bu…

人工智能 2023年6月4日
0189
• #### 双十二买什么蓝牙耳机好？平价好用蓝牙耳机推荐

如果您正在寻找蓝牙耳机来接听电话或锻炼，而不必担心耳机线，这些蓝牙设备是您的完美选择。十年前，入耳式无线耳机风靡一时。但随着越来越多的人习惯于将手机放在耳朵旁边，技术和质量要求导致…

人工智能 2023年5月25日
0151
• #### 【菜菜的sklearn课堂笔记】支持向量机-SVC真实数据案例：预测明天是否会下雨-处理困难特征：日期

啊哦~你想找的内容离你而去了哦 内容不存在，可能为如下原因导致： ① 内容还在审核中 ② 内容以前存在，但是由于不符合新 的规定而被删除 ③ 内容地址错误 ④ 作者删除了内容。 可…

人工智能 2023年6月29日
0203
• #### Nuscenes 完整版数据集批量下载

Nuscenes 完整版数据集批量下载 需求： 高速下载Nuscenes完整版数据集。之前mini版本尝鲜版，采用google浏览器自带工具下载，速度慢，且容易断。 1. 数据地址…

人工智能 2023年6月15日
0178
• #### 【Unity】 脚本实现对象自由移动（第一人称）

建立一个脚本名称为：Move(与脚本内命名空间一致) using System.Collections; using System.Collections.Generic; usi…

人工智能 2023年6月4日
0172
• #### python合并根目录下所有表格文件并增加文件名索引

目录 前言 一、代码展示 二、主要函数 1.os.walk() 2.pd.concat(） 前言 遇到了批量合并根目录下大量不同格式文件并进行简单处理的需求，在网上没有搜到完全相同…

人工智能 2023年7月8日
0155
• #### 数据分类分级的深度思考

1、敏感数据识别 数据分类分级的准确度和效率取决于工具的识别能力是否强大，即”工具是不是真的能够看到数据、看懂数据”。 1.1落地难点 数据分类分类大多数安…

人工智能 2023年7月1日
0167
• #### 基于Java+Spring+Vue+elementUI大学生求职招聘系统详细设计实现

博主介绍： ✌全网粉丝20W+,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技…

人工智能 2023年7月31日
0128
• #### Matlab中镜头畸变矫正

matlab中其实自己带了镜头畸变矫正的代码。找了很久才发现原来兜兜转转还是回到原点 %% Correct Image for Lens Distortion%%%close al…

人工智能 2023年6月22日
0137
• #### 论文阅读笔记（四）【ACL 2021】FEW-NERD: A Few-shot Named Entity Recognition Dataset

过去的难点：过去的都是粗粒度的；Few-NERD：一个大规模的人类注释的小样本NERD数据集，它具有 8种粗粒度和66种细粒度实体类型的层次结构。Few-NERD由来自维基百科的1…

人工智能 2023年6月10日
0141
• #### 语音识别（利用python将语音转化为文字）

提示：文章写完后，目录可以自动生成，如何生成可参考右边的帮助文档 文章目录 前言 一、申请讯飞语音端口 * 1.点击链接进入讯飞平台主页面 2.在页面注册自己的个人账户 3.申请语…

人工智能 2023年5月27日
0226
• #### 机器学习系列4 使用Python创建Scikit-Learn回归模型

本文中包含的案例jupyter笔记本可在我的资源中免费下载：机器学习系列4 使用Python创建Scikit-learn线性回归模型.ipynb 图1 使用Python和Sciki…

人工智能 2023年6月17日
0163
• #### 【Anaconda3】笔记内容008：详解Anaconda3的安装、Conda虚拟环境创建和其他项目环境的布置

目录 摘要 一 将电脑中的原有的Anaconda3环境删除 二 进行Anaconda3安装 三 创建虚拟环境 四 如何在虚拟环境中复制原项目环境 五 补充下conda如何更全局源 …

人工智能 2023年7月17日
0155
• #### MATLAB 基础知识 数据类型 分组数组 创建分类数组

本文说明如何创建分类数组。categorical 是一个数据类型，用来存储值来自一组有限离散类别的数据。这些分类可以采用自然排序，但并不要求一定如此。分类数组可用来有效地存储并方便…

人工智能 2023年7月2日
0184
• #### 最新最全Diffusion Models论文、代码汇总(图像生成、图像分割、图像翻译、超分辨率重建、医疗影像、自然语言处理、视频生生成、时间序列生成、3D点云生成、文本语音转换、音频生成等)

抵扣说明： 1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。2.余额无法直接购买下载，可以购买VIP、C币套餐、付费专栏及课程。 Original: https:…

人工智能 2023年6月17日
0129