当前位置:网站首页>Tensorflow 2. X realizes iris classification
Tensorflow 2. X realizes iris classification
2022-07-01 14:42:00 【NigeloYang】
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
# Import data , Input features and labels, respectively
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target
# Randomly scramble data ( Because the raw data is sequential , If the order is not disordered, the accuracy will be affected )
# seed: Random number seed , It's an integer , After setting , The random number generated each time is the same ( For the convenience of teaching , To ensure that each student's results are consistent )
np.random.seed(116) # Use the same seed, Ensure that the input features and labels correspond one-to-one
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# The disrupted data set is divided into training set and test set , Training set first 120 That's ok , After test set 30 That's ok
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
print('x_train shape{} y_train shape{} x_test shape{} y_test shape{} '.format(
x_train.shape, y_train.shape, x_test.shape, y_test.shape))
# transformation x Data type of , Otherwise, errors will be reported due to inconsistent data types when multiplying the following matrices
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices Function to make the input feature correspond to the label value one by one .( Batch data sets , Each batch batch Group data )
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# Generate the parameters of the neural network ,4 There are two input features, so , The input layer is 4 Input nodes ; because 3 classification , Therefore, the output layer is 3 Neurons
# use tf.Variable() Marker parameters can be trained , Use seed Make the random number generated each time the same
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))
lr = 0.1 # The learning rate is 0.1
train_loss_results = [] # Put each round of loss Record in this list , Draw for the follow-up loss Curves provide data
test_acc = [] # Put each round of acc Record in this list , Draw for the follow-up acc Curves provide data
epoch = 500 # loop 500 round
loss_all = 0 # Minutes per round 4 individual step,loss_all Record four step Generated 4 individual loss And
# Training part
for epoch in range(epoch): # Data set level loops , Every epoch Loop once data set
for step, (x_train, y_train) in enumerate(train_db): # batch Level loop , Every step Circle one batch
with tf.GradientTape() as tape: # with Structure records gradient information
y = tf.matmul(x_train, w1) + b1 # Neural network multiplication and addition operation
y = tf.nn.softmax(y) # Make the output y According to the probability distribution ( After this operation, it is the same order of magnitude as the single heat code , Subtract to find loss)
y_onehot = tf.one_hot(y_train, depth=3) # Convert the tag value to a unique hot code format , Easy to calculate loss and accuracy
loss = tf.reduce_mean(tf.square(y_onehot - y)) # The mean square error loss function mse = mean(sum(y-out)^2)
loss_all += loss.numpy() # Each one step Calculated loss Add up , For follow-up loss The average provides data , It's calculated in this way loss More accurate
# Calculation loss The gradient of each parameter
grads = tape.gradient(loss, [w1, b1])
# Achieve gradient update w1 = w1 - lr * w1_grad b = b - lr * b_grad
w1.assign_sub(lr * grads[0]) # Parameters w1 Self updating
b1.assign_sub(lr * grads[1]) # Parameters b Self updating
# Every epoch, Print loss Information
print("Epoch {}, loss: {}".format(epoch, loss_all / 4))
train_loss_results.append(loss_all / 4) # take 4 individual step Of loss Average and record in this variable
loss_all = 0 # loss_all Zeroing , To record the next epoch Of loss To prepare for
# Test part
# total_correct The number of samples for the prediction pair , total_number Is the total number of samples tested , Initialize both variables to 0
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
# Use the updated parameters for prediction
y = tf.matmul(x_test, w1) + b1
y = tf.nn.softmax(y)
pred = tf.argmax(y, axis=1) # return y The index of the maximum value in , That is, the classification of prediction
# take pred Convert to y_test Data type of
pred = tf.cast(pred, dtype=y_test.dtype)
# If the classification is correct , be correct=1, Otherwise 0, take bool The result of type is converted to int type
correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
# Each one batch Of correct The numbers add up
correct = tf.reduce_sum(correct)
# Will all batch Medium correct The numbers add up
total_correct += int(correct)
# total_number Is the total number of samples tested , That is to say x_test The number of rows ,shape[0] Returns the number of rows of the variable
total_number += x_test.shape[0]
# The total accuracy is equal to total_correct/total_number
acc = total_correct / total_number
test_acc.append(acc)
print("Test_acc:", acc)
print("--------------------------")
# draw loss curve
plt.subplot(1, 2, 1)
plt.title('Loss Function Curve') # Picture title
plt.xlabel('Epoch') # x Axis variable name
plt.ylabel('Loss') # y Axis variable name
plt.plot(train_loss_results, label="$Loss$") # Draw... Point by point trian_loss_results Value and connect , The connection icon is Loss
plt.legend() # Draw a curve icon
# draw Accuracy curve
plt.subplot(1, 2, 2)
plt.title('Acc Curve') # Picture title
plt.xlabel('Epoch') # x Axis variable name
plt.ylabel('Acc') # y Axis variable name
plt.plot(test_acc, label="$Accuracy$") # Draw... Point by point test_acc Value and connect , The connection icon is Accuracy
plt.legend()
plt.show() # Draw an image
边栏推荐
- Research Report on the development trend and competitive strategy of the global chemical glassware industry
- sqlilabs less-8
- Basic concepts of programming
- It's suitable for people who don't have eloquence. The benefits of joining the China Video partner program are really delicious. One video gets 3 benefits
- 适合没口才的人做,加入中视频伙伴计划收益是真香,一个视频拿3份收益
- Opencv interpolation mode
- [15. Interval consolidation]
- 券商万1免5证券开户是合理安全的吗,怎么讲
- Research Report on the development trend and competitive strategy of the global powder filling machine industry
- Opencv mat class
猜你喜欢

Build your own website (14)

sqlilabs less-8
![[commercial terminal simulation solution] Shanghai daoning brings you Georgia introduction, trial and tutorial](/img/44/b65aaf11b1e632f2dab55b6fc699f6.jpg)
[commercial terminal simulation solution] Shanghai daoning brings you Georgia introduction, trial and tutorial

WebSocket(简单体验版)
![[leetcode 324] 摆动排序 II 思维+排序](/img/cb/26d89e1a1f548b75a5ef9f29eebeee.png)
[leetcode 324] 摆动排序 II 思维+排序

【14. 区间和(离散化)】

2022-2-15 learning the imitation Niuke project - post in Section 2

Rearrangement of overloaded operators

手把手带你入门 API 开发

2022-2-15 learning xiangniuke project - Section 4 business management
随机推荐
Don't want to knock the code? Here comes the chance
Is it reasonable and safe for securities companies to open accounts for 10000 free securities? How to say
博文推荐 | 深入研究 Pulsar 中的消息分块
Use the right scene, get twice the result with half the effort! Full introduction to the window query function and usage scenarios of tdengine
[15. Interval consolidation]
音乐播放器开发实例(可毕设)
期末琐碎知识点再整理
Research Report on the development trend and competitive strategy of the global navigation simulator industry
【牛客网刷题系列 之 Verilog快速入门】~ 多功能数据处理器、求两个数的差值、使用generate…for语句简化代码、使用子模块实现三输入数的大小比较
Problem note - Oracle 11g uninstall
Play with mongodb - build a mongodb cluster
关于重载运算符的再整理
SWT / anr problem - how to open binder trace (bindertraces) when sending anr / SWT
Internet hospital system source code hospital applet source code smart hospital source code online consultation system source code
sqlilabs less9
首届技术播客月开播在即
Use the npoi package of net core 6 C to read excel Pictures in xlsx cells and stored to the specified server
Research Report on development trend and competitive strategy of global consumer glassware industry
2022-2-15 learning the imitation Niuke project - post in Section 2
Effet halo - qui dit qu'il y a de la lumière sur la tête est un héros