当前位置:网站首页>Handwritten numeral recognition based on tensorflow
Handwritten numeral recognition based on tensorflow
2022-06-26 18:17:00 【Little fox dreams of going to fairy tale town】
import numpy as np
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() # solve tf.placeholder Report the wrong question
import matplotlib.pyplot as plt
import input_data # The database used is tensorflow Built in database , Can be downloaded to local
mnist = input_data.read_data_sets('data/',one_hot=True)
#network topologies Network topology
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10
#inputs and outputs Input Output
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])
#network parameters Network parameters
stddev = 0.1
weights = {
'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev=stddev)),
'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))
}
biases = {
'b1':tf.Variable(tf.random_normal([n_hidden_1])),
'b2':tf.Variable(tf.random_normal([n_hidden_2])),
'out':tf.Variable(tf.random_normal([n_classes]))
}
print("NETWORK READY")
def multilayer_perceptron(_X,_weights,_biases):
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,_weights['w2']),_biases['b2']))
return (tf.matmul(layer_2,_weights['out'])+_biases['out'])
#prediction
pred = multilayer_perceptron(x,weights,biases)
#loss and optimizer Loss function and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
#initializer
init = tf.global_variables_initializer()
print("FUNCTIONS READY")
# iteration
training_epochs = 20
batch_size = 100
display_step = 4
#launch the graph
sess = tf.Session()
sess.run(init)
#optimize
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
#iteration
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
feeds = {
x:batch_xs,y:batch_ys}
sess.run(optm,feed_dict=feeds)
avg_cost +=sess.run(cost,feed_dict=feeds)
avg_cost = avg_cost/total_batch
#display
if (epoch+1)%display_step==0:
print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
feeds = {
x:batch_xs,y:batch_ys}
training_acc = sess.run(accr,feed_dict=feeds)
print("Train Accuracy:%.3f"%(training_acc))
feeds = {
x:mnist.test.images,y:mnist.test.labels}
test_acc = sess.run(accr,feed_dict=feeds)
print("Test Accuracy:%.3f"%(test_acc))
print("Optimization Finished")
边栏推荐
猜你喜欢
随机推荐
Concept and working principle of data encryption standard (DES)
LeetCode 238 除自身以外数组的乘积
Publish message publishers and subscribe message subscribers of ROS
JS common regular expressions
properties文件乱码
ISO文件
分页查询、JOIN关联查询优化
如何创建并强制使用索引
Detailed explanation of dos and attack methods
wechat_ Solve the problem of page Jump and parameter transfer by navigator in wechat applet
ROS的发布消息Publishers和订阅消息Subscribers
Comparing the size relationship between two objects turns out to be so fancy
DoS及攻擊方法詳解
比较两个对象的大小关系原来可以如此花里胡哨
ROS query topic specific content common instructions
JS cast
Decompilation of zero time technology smart contract security series articles
vutils. make_ A little experience of grid () in relation to black and white images
Map和List<Map>转相应的对象
A little experience of next (ITER (dataloader))








