当前位置:网站首页>Tensorflow 2. X realizes iris classification
Tensorflow 2. X realizes iris classification
2022-07-01 14:42:00 【NigeloYang】
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
# Import data , Input features and labels, respectively
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target
# Randomly scramble data ( Because the raw data is sequential , If the order is not disordered, the accuracy will be affected )
# seed: Random number seed , It's an integer , After setting , The random number generated each time is the same ( For the convenience of teaching , To ensure that each student's results are consistent )
np.random.seed(116) # Use the same seed, Ensure that the input features and labels correspond one-to-one
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# The disrupted data set is divided into training set and test set , Training set first 120 That's ok , After test set 30 That's ok
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
print('x_train shape{} y_train shape{} x_test shape{} y_test shape{} '.format(
x_train.shape, y_train.shape, x_test.shape, y_test.shape))
# transformation x Data type of , Otherwise, errors will be reported due to inconsistent data types when multiplying the following matrices
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices Function to make the input feature correspond to the label value one by one .( Batch data sets , Each batch batch Group data )
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# Generate the parameters of the neural network ,4 There are two input features, so , The input layer is 4 Input nodes ; because 3 classification , Therefore, the output layer is 3 Neurons
# use tf.Variable() Marker parameters can be trained , Use seed Make the random number generated each time the same
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))
lr = 0.1 # The learning rate is 0.1
train_loss_results = [] # Put each round of loss Record in this list , Draw for the follow-up loss Curves provide data
test_acc = [] # Put each round of acc Record in this list , Draw for the follow-up acc Curves provide data
epoch = 500 # loop 500 round
loss_all = 0 # Minutes per round 4 individual step,loss_all Record four step Generated 4 individual loss And
# Training part
for epoch in range(epoch): # Data set level loops , Every epoch Loop once data set
for step, (x_train, y_train) in enumerate(train_db): # batch Level loop , Every step Circle one batch
with tf.GradientTape() as tape: # with Structure records gradient information
y = tf.matmul(x_train, w1) + b1 # Neural network multiplication and addition operation
y = tf.nn.softmax(y) # Make the output y According to the probability distribution ( After this operation, it is the same order of magnitude as the single heat code , Subtract to find loss)
y_onehot = tf.one_hot(y_train, depth=3) # Convert the tag value to a unique hot code format , Easy to calculate loss and accuracy
loss = tf.reduce_mean(tf.square(y_onehot - y)) # The mean square error loss function mse = mean(sum(y-out)^2)
loss_all += loss.numpy() # Each one step Calculated loss Add up , For follow-up loss The average provides data , It's calculated in this way loss More accurate
# Calculation loss The gradient of each parameter
grads = tape.gradient(loss, [w1, b1])
# Achieve gradient update w1 = w1 - lr * w1_grad b = b - lr * b_grad
w1.assign_sub(lr * grads[0]) # Parameters w1 Self updating
b1.assign_sub(lr * grads[1]) # Parameters b Self updating
# Every epoch, Print loss Information
print("Epoch {}, loss: {}".format(epoch, loss_all / 4))
train_loss_results.append(loss_all / 4) # take 4 individual step Of loss Average and record in this variable
loss_all = 0 # loss_all Zeroing , To record the next epoch Of loss To prepare for
# Test part
# total_correct The number of samples for the prediction pair , total_number Is the total number of samples tested , Initialize both variables to 0
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
# Use the updated parameters for prediction
y = tf.matmul(x_test, w1) + b1
y = tf.nn.softmax(y)
pred = tf.argmax(y, axis=1) # return y The index of the maximum value in , That is, the classification of prediction
# take pred Convert to y_test Data type of
pred = tf.cast(pred, dtype=y_test.dtype)
# If the classification is correct , be correct=1, Otherwise 0, take bool The result of type is converted to int type
correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
# Each one batch Of correct The numbers add up
correct = tf.reduce_sum(correct)
# Will all batch Medium correct The numbers add up
total_correct += int(correct)
# total_number Is the total number of samples tested , That is to say x_test The number of rows ,shape[0] Returns the number of rows of the variable
total_number += x_test.shape[0]
# The total accuracy is equal to total_correct/total_number
acc = total_correct / total_number
test_acc.append(acc)
print("Test_acc:", acc)
print("--------------------------")
# draw loss curve
plt.subplot(1, 2, 1)
plt.title('Loss Function Curve') # Picture title
plt.xlabel('Epoch') # x Axis variable name
plt.ylabel('Loss') # y Axis variable name
plt.plot(train_loss_results, label="$Loss$") # Draw... Point by point trian_loss_results Value and connect , The connection icon is Loss
plt.legend() # Draw a curve icon
# draw Accuracy curve
plt.subplot(1, 2, 2)
plt.title('Acc Curve') # Picture title
plt.xlabel('Epoch') # x Axis variable name
plt.ylabel('Acc') # y Axis variable name
plt.plot(test_acc, label="$Accuracy$") # Draw... Point by point test_acc Value and connect , The connection icon is Accuracy
plt.legend()
plt.show() # Draw an image
边栏推荐
- Sorting learning sorting
- Tdengine connector goes online Google Data Studio app store
- [零基础学IoT Pwn] 复现Netgear WNAP320 RCE
- [零基础学IoT Pwn] 复现Netgear WNAP320 RCE
- When the main process architecture game, to prevent calls everywhere to reduce coupling, how to open the interface to others to call?
- 博文推荐 | 深入研究 Pulsar 中的消息分块
- Halo effect - who says that those with light on their heads are heroes
- Guess lantern riddles, not programmers still can't understand?
- Advanced C language
- Is it reasonable and safe for securities companies to open accounts for 10000 free securities? How to say
猜你喜欢

一波三折,终于找到src漏洞挖掘的方法了【建议收藏】

Chapter 4 of getting started with MySQL: creation, modification and deletion of data tables
![[dynamic programming] interval dp:p1005 matrix retrieval](/img/c9/2091f51b905d2c0ebc978dab3d34d3.jpg)
[dynamic programming] interval dp:p1005 matrix retrieval
![[leetcode 324] 摆动排序 II 思维+排序](/img/cb/26d89e1a1f548b75a5ef9f29eebeee.png)
[leetcode 324] 摆动排序 II 思维+排序

Basis of target detection (NMS)

深度合作 | 涛思数据携手长虹佳华为中国区客户提供 TDengine 强大企业级产品与完善服务保障

【14. 区间和(离散化)】

对于编程思想和能力有重大提升的书有哪些?

Salesforce, Johns Hopkins, Columbia | progen2: exploring the boundaries of protein language models

2022-2-15 learning xiangniuke project - Section 4 business management
随机推荐
Yyds dry goods inventory hcie security day13: firewall dual machine hot standby experiment (I) firewall direct deployment, uplink and downlink connection switches
Research Report on the development trend and competitive strategy of the global traditional computer industry
Internet hospital system source code hospital applet source code smart hospital source code online consultation system source code
Provincial election + noi Part XI others
NPDP产品经理国际认证报名有什么要求?
Research Report on the development trend and competitive strategy of the global navigation simulator industry
[零基础学IoT Pwn] 复现Netgear WNAP320 RCE
Guess lantern riddles, not programmers still can't understand?
Research Report on the development trend and competitive strategy of the global camera filter bracket industry
Semiconductor foundation of binary realization principle
手把手带你入门 API 开发
Leetcode (69) -- square root of X
WebSocket(简单体验版)
Error-tf.function-decorated function tried to create variables on non-first call
Use the right scene, get twice the result with half the effort! Full introduction to the window query function and usage scenarios of tdengine
Research Report on the development trend and competitive strategy of the global display filter industry
用对场景,事半功倍!TDengine 的窗口查询功能及使用场景全介绍
Buuctf reinforcement question ezsql
JVM performance tuning and practical basic theory part II
2022-2-15 learning the imitation Niuke project - post in Section 2