当前位置:网站首页>基于tensorflow的手写数字识别
基于tensorflow的手写数字识别
2022-06-26 18:08:00 【小狐狸梦想去童话镇】
import numpy as np
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() #解决tf.placeholder报错问题
import matplotlib.pyplot as plt
import input_data #使用的数据库是tensorflow内置数据库,可下载到本地
mnist = input_data.read_data_sets('data/',one_hot=True)
#network topologies 网络拓扑
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10
#inputs and outputs 输入 输出
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])
#network parameters 网络参数
stddev = 0.1
weights = {
'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev=stddev)),
'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))
}
biases = {
'b1':tf.Variable(tf.random_normal([n_hidden_1])),
'b2':tf.Variable(tf.random_normal([n_hidden_2])),
'out':tf.Variable(tf.random_normal([n_classes]))
}
print("NETWORK READY")
def multilayer_perceptron(_X,_weights,_biases):
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,_weights['w2']),_biases['b2']))
return (tf.matmul(layer_2,_weights['out'])+_biases['out'])
#prediction
pred = multilayer_perceptron(x,weights,biases)
#loss and optimizer 损失函数及优化器
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
#initializer
init = tf.global_variables_initializer()
print("FUNCTIONS READY")
#迭代
training_epochs = 20
batch_size = 100
display_step = 4
#launch the graph
sess = tf.Session()
sess.run(init)
#optimize
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
#iteration
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
feeds = {
x:batch_xs,y:batch_ys}
sess.run(optm,feed_dict=feeds)
avg_cost +=sess.run(cost,feed_dict=feeds)
avg_cost = avg_cost/total_batch
#display
if (epoch+1)%display_step==0:
print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
feeds = {
x:batch_xs,y:batch_ys}
training_acc = sess.run(accr,feed_dict=feeds)
print("Train Accuracy:%.3f"%(training_acc))
feeds = {
x:mnist.test.images,y:mnist.test.labels}
test_acc = sess.run(accr,feed_dict=feeds)
print("Test Accuracy:%.3f"%(test_acc))
print("Optimization Finished")
边栏推荐
- CD-CompactDisk
- 利用递归实现求n位所有格雷码
- 如何创建并强制使用索引
- Deep understanding of MySQL lock and transaction isolation level
- 小程序设置按钮分享功能
- sql中的几种删除操作
- Clion编译catkin_ws(ROS工作空间包的简称)加载CMakeLists.txt出现的问题
- JVM entry door (1)
- Runtimeerror: CUDA error: out of memory own solution (it is estimated that it is not applicable to most people in special circumstances)
- 数字签名论述及生成与优点分析
猜你喜欢

RuntimeError: CUDA error: out of memory自己的解决方法(情况比较特殊估计对大部分人不适用)

In and exceptions, count (*) query optimization

非对称密码体制详解

Properties file garbled

Analysis of deep security definition and encryption technology

小程序设置按钮分享功能

ISO文件

博云,站在中国容器潮头

MySQL download and configuration MySQL remote control

Let torch cuda. is_ Experience of available() changing from false to true
随机推荐
Please advise tonghuashun which securities firm to choose for opening an account? Is it safe to open an account online now?
我想知道,我在肇庆,到哪里开户比较好?网上开户是否安全么?
你好,现在网上股票开户买股票安全吗?
No manual prior is required! HKU & Tongji & lunarai & Kuangshi proposed self supervised visual representation learning based on semantic grouping, which significantly improved the tasks of target dete
深入理解MySQL锁与事务隔离级别
图像二值化处理
腾讯钱智明:信息流业务中的预训练方法探索与应用实践
IDEA收藏代码、快速打开favorites收藏窗口
Knapsack problem with dependency
JNI的 静态注册与动态注册
Properties file garbled
数据加密标准(DES)概念及工作原理
map和filter方法对于稀缺数组的处理
CLion断点单步调试
KDD 2022 | how to use comparative learning in cross domain recommendation?
Deep understanding of MySQL lock and transaction isolation level
Clion编译catkin_ws(ROS工作空间包的简称)加载CMakeLists.txt出现的问题
VCD video disc
解决pycharm里面每个字母占一格空格的问题
零时科技 | 智能合约安全系列文章之反编译篇