当前位置:网站首页>Tensorflow customize the whole training process
Tensorflow customize the whole training process
2022-07-06 01:35:00 【@zhou】
Create a machine learning problem
f ( x ) = 3 x + 7 f(x) = 3x + 7 f(x)=3x+7
For a machine learning problem , There are the following steps :
- Get training data .
- Defining models .
- Define the loss function .
- Traverse the training data , Calculate the loss from the target value .
- Calculate the gradient of this loss , And use optimizer Adjust variables to fit data .
- The result of the calculation is .
Build data
Supervised learning uses input ( Usually expressed as x) And the output ( Expressed as y, Commonly known as labels ). The goal is to learn from paired inputs and outputs , So that you can predict the output value according to the input .TensorFlow Almost every input data in is represented by tensor , And it's usually a vector . In supervised learning , Output ( That is, think of the predicted value ) It's also a tensor . This is done by putting Gauss ( Normal distribution ) Some data synthesized by adding noise to the points on the line , And visualize these data .
x = np.random.random([1000]) * 5
noise = np.random.random([1000])
y = 3 * x + 7
import matplotlib.pyplot as plt
plt.scatter(x, y, c="b")
plt.show()

Customize the model we need
We inherit tf.module class , And define two variables , Its attribute is trainable_variables.
class selfmodel(tf.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.v1 = tf.Variable(1.0, trainable=True)
self.v2 = tf.Variable(2.0, trainable=True)
def __call__(self, x):
y = self.v1 * x + self.v2
return y
Define the loss function
We use the mean square deviation here to calculate the loss
def loss(target_y, predicted_y):
return tf.reduce_mean(tf.square(target_y - predicted_y))
Define the cycle training function
We use epcohs Function to get the two variables we need v 1 , v 2 v1,v2 v1,v2, After each training v 1 , v 2 v1, v2 v1,v2 Record , Finally, visualization
def train(model, x, y,epochs,optimizer):
v1, v2 = [], []
for j in range(epochs):
with tf.GradientTape() as gd:
y_pred = model(x) # This needs to be inside
loss_score = loss(y, y_pred)
grad = gd.gradient(loss_score, model.trainable_variables)
optimizer.apply_gradients(zip(grad, model.trainable_variables))
v1.append(model.v1.numpy())
v2.append(model.v2.numpy())
return (model, v1, v2)
The end result shows
I'm defining epcohs When , If the setting is too small , Will lead to v 1 , v 2 v1,v2 v1,v2 Can't get the right result
opt = tf.keras.optimizers.SGD()
model = selfmodel()
epochs = 1000
(model, v1, v2) = train(model, x, y,epochs, opt)
# draw
plt.plot(range(epochs), v1, "r",
range(epochs), v2, "b")
plt.plot([3] * epochs, "r--",
[7] * epochs, "b--")
plt.legend(["W", "b", "True W", "True b"])
plt.show()

Problems in the code
# The code in this case will report an error , Say our grad The result is (none, none),
# because y_pred = model(x) It should be written in with Inside
# The following will write the correct way , The reason for this error is loss Function in
# Yes molel.trainable_variables When seeking derivative , Gradient not found
y_pred = model(x)
with tf.GradientTape() as t:
l = loss(y, y_pred)
grad = t.gradient(l, model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
# Correct writing
with tf.GradientTape() as t:
y_pred = model(x)
l = loss(y, model(x))
grad = t.gradient(l, model.trainable_variables)
print(model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
边栏推荐
- SPIR-V初窺
- How to get the PHP version- How to get the PHP Version?
- Three methods of script about login and cookies
- MUX VLAN configuration
- Code review concerns
- leetcode刷题_反转字符串中的元音字母
- Idea sets the default line break for global newly created files
- Kotlin basics 1
- 【Flask】官方教程(Tutorial)-part1:项目布局、应用程序设置、定义和访问数据库
- SCM Chinese data distribution
猜你喜欢

【Flask】官方教程(Tutorial)-part2:蓝图-视图、模板、静态文件

Mongodb problem set

DOM introduction

A picture to understand! Why did the school teach you coding but still not

Blue Bridge Cup embedded stm32g431 - the real topic and code of the eighth provincial competition
![[机缘参悟-39]:鬼谷子-第五飞箝篇 - 警示之二:赞美的六种类型,谨防享受赞美快感如同鱼儿享受诱饵。](/img/3c/ec97abfabecb3f0c821beb6cfe2983.jpg)
[机缘参悟-39]:鬼谷子-第五飞箝篇 - 警示之二:赞美的六种类型,谨防享受赞美快感如同鱼儿享受诱饵。

【Flask】官方教程(Tutorial)-part1:项目布局、应用程序设置、定义和访问数据库

MATLB | real time opportunity constrained decision making and its application in power system

电气数据|IEEE118(含风能太阳能)

3D视觉——4.手势识别(Gesture Recognition)入门——使用MediaPipe含单帧(Singel Frame)和实时视频(Real-Time Video)
随机推荐
记一个 @nestjs/typeorm^8.1.4 版本不能获取.env选项问题
剑指 Offer 12. 矩阵中的路径
MATLB | real time opportunity constrained decision making and its application in power system
Code review concerns
Leetcode1961. Check whether the string is an array prefix
Idea sets the default line break for global newly created files
MUX VLAN configuration
Reasonable and sensible
DOM introduction
Crawler request module
leetcode刷题_平方数之和
Poj2315 football games
【Flask】静态文件与模板渲染
Basic operations of databases and tables ----- primary key constraints
基於DVWA的文件上傳漏洞測試
037 PHP login, registration, message, personal Center Design
2022年广西自治区中职组“网络空间安全”赛题及赛题解析(超详细)
internship:项目代码所涉及陌生注解及其作用
NLP第四范式:Prompt概述【Pre-train,Prompt(提示),Predict】【刘鹏飞】
Kotlin basics 1