当前位置:网站首页>tensorflow2.0 handwritten digit recognition (tensorflow handwriting recognition)
tensorflow2.0 handwritten digit recognition (tensorflow handwriting recognition)
2022-08-01 13:03:00 【Full stack programmer webmaster】
大家好,又见面了,我是你们的朋友全栈君.
This section notes as Tensorflow 的 Hello World,用 MNIST Handwritten digit recognition to explore Tensorflow.The content of the note comes from Tensorflow Chinese community and Huang Wenjian's《Tensorflow 实战》,Just a summary for my own review.
环境:
- Windows 10
- Anaconda 4.3.0
- Spyder
The notes in this section are mainly used Softmax Regression 算法,Build a neural network with no hidden layers to do so MNIST 手写数字识别.
1. MNIST 数据集加载
MNIST 数据集可以从MNIST官网下载.也可以通过 Tensorflow 提供的 input_data.py进行载入.
Because the above method is relatively slow to download the dataset,I have uploaded the downloaded dataset to CSDN资源中,可以直接下载.
Put the downloaded dataset in the directoryC:/Users/Administrator/.spyder-py3/MNIST_data/下.The directory can be changed according to your liking,Just change the code accordingly.
通过运行Tensorflow The provided code loads the dataset:
from tensorflow.examples.tutorials.mnist import input_data
# 获取数据
mnist = input_data.read_data_sets("C:/Users/Administrator/.spyder-py3/MNIST_data/", one_hot=True)MNIST数据集包含55000样本的训练集,5000A validation set of samples,10000样本的测试集. input_data.py The downloaded dataset has been decompressed、Reconstruct image and label data to form new dataset objects.
图像是28像素x28Pixel-sized grayscale image.All blank parts are 0,There are places with handwriting depending on the shade of the color0~1的取值,因此,每个样本有28×28=784维的特征,Equivalent to expanding to1维.
所以,训练集的特征是一个 55000×784 的 Tensor,The first latitude is the picture number,The second dimension is the image pixel number.And the training set Label(The picture represents0~9中哪个数)是一个 55000×10 的 Tensor,10是10kind of meaning,进行 one-hot 编码 That is, there is only one value1,其余为0,如数字0,对于 label 为[1,0,0,0,0,0,0,0,0,0].
2. Softmax Regression 算法
数字都是0~9之间的,一共有10个类别,When making predictions on pictures,Softmax Regression A probability is estimated for each category,And output the number with the highest probability as the result.
Softmax Regression Add the features that can be judged as a certain class,These features are then converted into the probability that the decision is this class.We find one for all the pixels of the picture加权和.If the gray value of a certain pixel is large, it is likely to be a numbern,This pixel weight is very large,反之,This weight is likely to be negative.
characteristic formula:
b i b_i bi 为偏置值,It is some tendency of the data itself.
然后用 softmax The function converts these features into probabilities y y y :
Calculated for all features softmax,并进行标准化(The sum of the probability values of all class outputs is 1):
Judged as No i 类的概率为:
Softmax Regression 流程如下:
Convert to matrix multiplication:
写成公式如下:
3.实现模型
import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)首先载入 Tensorflow 库,并创建一个新的 InteractiveSession ,Subsequent operations default to this session 中.
placeholder:输入数据的地方,NoneRepresents an unlimited number of entries,每条是784维的向量Variable:存储模型参数,持久化的
4.训练模型
我们定义一个 loss function to describe the classification accuracy of the model for the problem. Loss 越小,模型越精确.这里采用交叉熵:
其中,y 是我们预测的概率分布, y’ 是实际的分布.
y_ = tf.placeholder(tf.float32, [None,10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))定义一个 placeholder for entering the correct value,And calculate the cross entropy.
Then use stochastic gradient descent,步长为0.5进行训练.
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)训练模型,Let the model train in a loop1000次,Randomly go from the training set each time100条样本,to improve the convergence speed.
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})5.评估模型
We evaluate the model by judging whether the actual and predicted values are the same,并计算准确率,准确率越高,The more precise the classification.
6.总结
实现的整个流程:
- 定义算法公式,That is, the calculation of the neural network forward propagation.
- 定义 loss ,选定优化器,并指定优化器优化 loss.
- 迭代地对数据进行训练.
- 在测试集或验证集上对准确率进行评测.
7.全部代码
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 获取数据
mnist = input_data.read_data_sets("C:/Users/Administrator/.spyder-py3/MNIST_data/", one_hot=True)
print('训练集信息:')
print(mnist.train.images.shape,mnist.train.labels.shape)
print('测试集信息:')
print(mnist.test.images.shape,mnist.test.labels.shape)
print('验证集信息:')
print(mnist.validation.images.shape,mnist.validation.labels.shape)
# 构建图
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder(tf.float32, [None,10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 进行训练
tf.global_variables_initializer().run()
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})
# 模型评估
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('MNIST手写图片准确率:')
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/126515.html原文链接:https://javaforall.cn
边栏推荐
- Apex installation error
- SQL函数 STR
- 脚本语言Lua的基础知识总结
- Audio and Video Technology Development Weekly | 256
- 批量任务导入到数据库中
- 达梦更换正式授权dm.key
- Fault 007: The dexp derivative is inexplicably interrupted
- Towhee 每周模型
- This article will take you to thoroughly clarify the working mechanism of certificates in Isito
- How to Integrate Your Service Registry with Istio?
猜你喜欢

leetcode:1201. 丑数 III【二分 + 数学 + 容斥原理】

观察者模式

快速幂---学习笔记

让程序员早点下班的效率工具

芝加哥丰田技术学院 | Leveraging Natural Supervision for Language Representation Learning and Generation(利用自然监督进行语言表示学习和生成)

实现集中式身份认证管理的案例

【公开课预告】:超分辨率技术在视频画质增强领域的研究与应用

STM32 CAN过滤器配置详解

库函数的模拟实现(strlen)(strcpy)(strcat)(strcmp)(strstr)(memcpy)(memmove)(C语言)(VS)

蔚来又一新品牌披露:产品价格低于20万
随机推荐
【公开课预告】:超分辨率技术在视频画质增强领域的研究与应用
【面试高频题】难度 1.5/5,二分经典运用题
程序员的自我修养
如何设计一个分布式 ID 发号器?
Pytest e-commerce project combat (below)
为什么最大值加一等于最小值
【StoneDB Class】入门第二课:StoneDB 整体架构解析
《MySQL核心知识》第6章:查询语句
bpmn-process-designer基础上进行自定义样式(工具、元素、菜单)
六石编程学:问题要面对,办法要技巧,做不好的功能要想办法
ECCV22|只能11%的参数就能优于Swin,微软提出快速预训练蒸馏方法TinyViT
这项工作事关中小学生生命安全!五部门作出联合部署
SQL function SQUARE
markdown常用数学符号cov(markdown求和符号)
Towhee 每周模型
The CAN communication standard frame and extended frame is introduced
leetcode: 1201. Ugly Number III [Dichotomy + Mathematics + Inclusion and Exclusion Principle]
R language fitting ARIMA model: use the auto.arima function in the forecast package to automatically search for the best parameter combination, model order (p, d, q), set the seasonal parameter to spe
Dameng replaces the officially authorized dm.key
SQL函数 STR