当前位置:网站首页>Tensorflow customize the whole training process
Tensorflow customize the whole training process
2022-07-06 01:35:00 【@zhou】
Create a machine learning problem
f ( x ) = 3 x + 7 f(x) = 3x + 7 f(x)=3x+7
For a machine learning problem , There are the following steps :
- Get training data .
- Defining models .
- Define the loss function .
- Traverse the training data , Calculate the loss from the target value .
- Calculate the gradient of this loss , And use optimizer Adjust variables to fit data .
- The result of the calculation is .
Build data
Supervised learning uses input ( Usually expressed as x) And the output ( Expressed as y, Commonly known as labels ). The goal is to learn from paired inputs and outputs , So that you can predict the output value according to the input .TensorFlow Almost every input data in is represented by tensor , And it's usually a vector . In supervised learning , Output ( That is, think of the predicted value ) It's also a tensor . This is done by putting Gauss ( Normal distribution ) Some data synthesized by adding noise to the points on the line , And visualize these data .
x = np.random.random([1000]) * 5
noise = np.random.random([1000])
y = 3 * x + 7
import matplotlib.pyplot as plt
plt.scatter(x, y, c="b")
plt.show()
Customize the model we need
We inherit tf.module class , And define two variables , Its attribute is trainable_variables.
class selfmodel(tf.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.v1 = tf.Variable(1.0, trainable=True)
self.v2 = tf.Variable(2.0, trainable=True)
def __call__(self, x):
y = self.v1 * x + self.v2
return y
Define the loss function
We use the mean square deviation here to calculate the loss
def loss(target_y, predicted_y):
return tf.reduce_mean(tf.square(target_y - predicted_y))
Define the cycle training function
We use epcohs Function to get the two variables we need v 1 , v 2 v1,v2 v1,v2, After each training v 1 , v 2 v1, v2 v1,v2 Record , Finally, visualization
def train(model, x, y,epochs,optimizer):
v1, v2 = [], []
for j in range(epochs):
with tf.GradientTape() as gd:
y_pred = model(x) # This needs to be inside
loss_score = loss(y, y_pred)
grad = gd.gradient(loss_score, model.trainable_variables)
optimizer.apply_gradients(zip(grad, model.trainable_variables))
v1.append(model.v1.numpy())
v2.append(model.v2.numpy())
return (model, v1, v2)
The end result shows
I'm defining epcohs When , If the setting is too small , Will lead to v 1 , v 2 v1,v2 v1,v2 Can't get the right result
opt = tf.keras.optimizers.SGD()
model = selfmodel()
epochs = 1000
(model, v1, v2) = train(model, x, y,epochs, opt)
# draw
plt.plot(range(epochs), v1, "r",
range(epochs), v2, "b")
plt.plot([3] * epochs, "r--",
[7] * epochs, "b--")
plt.legend(["W", "b", "True W", "True b"])
plt.show()
Problems in the code
# The code in this case will report an error , Say our grad The result is (none, none),
# because y_pred = model(x) It should be written in with Inside
# The following will write the correct way , The reason for this error is loss Function in
# Yes molel.trainable_variables When seeking derivative , Gradient not found
y_pred = model(x)
with tf.GradientTape() as t:
l = loss(y, y_pred)
grad = t.gradient(l, model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
# Correct writing
with tf.GradientTape() as t:
y_pred = model(x)
l = loss(y, model(x))
grad = t.gradient(l, model.trainable_variables)
print(model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
边栏推荐
- Development trend of Ali Taobao fine sorting model
- Loop structure of program (for loop)
- Redis守护进程无法停止解决方案
- How does the crystal oscillator vibrate?
- Basic operations of database and table ----- delete data table
- A picture to understand! Why did the school teach you coding but still not
- Code Review关注点
- LeetCode 322. Change exchange (dynamic planning)
- 现货白银的一般操作方法
- [detailed] several ways to quickly realize object mapping
猜你喜欢
ORA-00030
【Flask】官方教程(Tutorial)-part2:蓝图-视图、模板、静态文件
Une image! Pourquoi l'école t'a - t - elle appris à coder, mais pourquoi pas...
Alibaba-Canal使用详解(排坑版)_MySQL与ES数据同步
C web page open WinForm exe
Accelerating spark data access with alluxio in kubernetes
TrueType字体文件提取关键信息
Basic operations of database and table ----- delete data table
[ssrf-01] principle and utilization examples of server-side Request Forgery vulnerability
dried food! Accelerating sparse neural network through hardware and software co design
随机推荐
MUX VLAN configuration
Electrical data | IEEE118 (including wind and solar energy)
Code Review关注点
ThreeDPoseTracker项目解析
Ordinary people end up in Global trade, and a new round of structural opportunities emerge
Leetcode skimming questions_ Invert vowels in a string
Paging of a scratch (page turning processing)
WGet: command line download tool
基於DVWA的文件上傳漏洞測試
Leetcode skimming questions_ Verify palindrome string II
SCM Chinese data distribution
Leetcode 208. 实现 Trie (前缀树)
【Flask】官方教程(Tutorial)-part3:blog蓝图、项目可安装化
Kotlin basics 1
现货白银的一般操作方法
Leetcode sword finger offer 59 - ii Maximum value of queue
ClickOnce does not support request execution level 'requireAdministrator'
A picture to understand! Why did the school teach you coding but still not
ORA-00030
Tcpdump: monitor network traffic