当前位置:网站首页>Tensorflow customize the whole training process
Tensorflow customize the whole training process
2022-07-06 01:35:00 【@zhou】
Create a machine learning problem
f ( x ) = 3 x + 7 f(x) = 3x + 7 f(x)=3x+7
For a machine learning problem , There are the following steps :
- Get training data .
- Defining models .
- Define the loss function .
- Traverse the training data , Calculate the loss from the target value .
- Calculate the gradient of this loss , And use optimizer Adjust variables to fit data .
- The result of the calculation is .
Build data
Supervised learning uses input ( Usually expressed as x) And the output ( Expressed as y, Commonly known as labels ). The goal is to learn from paired inputs and outputs , So that you can predict the output value according to the input .TensorFlow Almost every input data in is represented by tensor , And it's usually a vector . In supervised learning , Output ( That is, think of the predicted value ) It's also a tensor . This is done by putting Gauss ( Normal distribution ) Some data synthesized by adding noise to the points on the line , And visualize these data .
x = np.random.random([1000]) * 5
noise = np.random.random([1000])
y = 3 * x + 7
import matplotlib.pyplot as plt
plt.scatter(x, y, c="b")
plt.show()
Customize the model we need
We inherit tf.module class , And define two variables , Its attribute is trainable_variables.
class selfmodel(tf.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.v1 = tf.Variable(1.0, trainable=True)
self.v2 = tf.Variable(2.0, trainable=True)
def __call__(self, x):
y = self.v1 * x + self.v2
return y
Define the loss function
We use the mean square deviation here to calculate the loss
def loss(target_y, predicted_y):
return tf.reduce_mean(tf.square(target_y - predicted_y))
Define the cycle training function
We use epcohs Function to get the two variables we need v 1 , v 2 v1,v2 v1,v2, After each training v 1 , v 2 v1, v2 v1,v2 Record , Finally, visualization
def train(model, x, y,epochs,optimizer):
v1, v2 = [], []
for j in range(epochs):
with tf.GradientTape() as gd:
y_pred = model(x) # This needs to be inside
loss_score = loss(y, y_pred)
grad = gd.gradient(loss_score, model.trainable_variables)
optimizer.apply_gradients(zip(grad, model.trainable_variables))
v1.append(model.v1.numpy())
v2.append(model.v2.numpy())
return (model, v1, v2)
The end result shows
I'm defining epcohs When , If the setting is too small , Will lead to v 1 , v 2 v1,v2 v1,v2 Can't get the right result
opt = tf.keras.optimizers.SGD()
model = selfmodel()
epochs = 1000
(model, v1, v2) = train(model, x, y,epochs, opt)
# draw
plt.plot(range(epochs), v1, "r",
range(epochs), v2, "b")
plt.plot([3] * epochs, "r--",
[7] * epochs, "b--")
plt.legend(["W", "b", "True W", "True b"])
plt.show()
Problems in the code
# The code in this case will report an error , Say our grad The result is (none, none),
# because y_pred = model(x) It should be written in with Inside
# The following will write the correct way , The reason for this error is loss Function in
# Yes molel.trainable_variables When seeking derivative , Gradient not found
y_pred = model(x)
with tf.GradientTape() as t:
l = loss(y, y_pred)
grad = t.gradient(l, model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
# Correct writing
with tf.GradientTape() as t:
y_pred = model(x)
l = loss(y, model(x))
grad = t.gradient(l, model.trainable_variables)
print(model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
边栏推荐
- Paging of a scratch (page turning processing)
- ClickOnce does not support request execution level 'requireAdministrator'
- dried food! Accelerating sparse neural network through hardware and software co design
- 伦敦银走势中的假突破
- Crawler request module
- Poj2315 football games
- MUX VLAN configuration
- ORA-00030
- Leetcode 208. 实现 Trie (前缀树)
- Unity | 实现面部驱动的两种方式
猜你喜欢
About error 2003 (HY000): can't connect to MySQL server on 'localhost' (10061)
A Cooperative Approach to Particle Swarm Optimization
Test de vulnérabilité de téléchargement de fichiers basé sur dvwa
A picture to understand! Why did the school teach you coding but still not
[flask] official tutorial -part3: blog blueprint, project installability
Force buckle 1020 Number of enclaves
c#网页打开winform exe
Recommended areas - ways to explore users' future interests
Basic operations of database and table ----- set the fields of the table to be automatically added
Basic operations of databases and tables ----- unique constraints
随机推荐
[机缘参悟-39]:鬼谷子-第五飞箝篇 - 警示之二:赞美的六种类型,谨防享受赞美快感如同鱼儿享受诱饵。
A glimpse of spir-v
General operation method of spot Silver
Basic operations of databases and tables ----- non empty constraints
Opinions on softmax function
DOM introduction
Internship: unfamiliar annotations involved in the project code and their functions
National intangible cultural heritage inheritor HD Wang's shadow digital collection of "Four Beauties" made an amazing debut!
Poj2315 football games
MySQL learning notes 2
Test de vulnérabilité de téléchargement de fichiers basé sur dvwa
网易智企逆势进场,游戏工业化有了新可能
How to get all sequences in Oracle database- How can I get all sequences in an Oracle database?
3D模型格式汇总
Huawei Hrbrid interface and VLAN division based on IP
VMware Tools安装报错:无法自动安装VSock驱动程序
Docker compose配置MySQL并实现远程连接
Basic operations of databases and tables ----- unique constraints
Paging of a scratch (page turning processing)
TrueType字体文件提取关键信息