当前位置:网站首页>Tensorflow customize the whole training process
Tensorflow customize the whole training process
2022-07-06 01:35:00 【@zhou】
Create a machine learning problem
f ( x ) = 3 x + 7 f(x) = 3x + 7 f(x)=3x+7
For a machine learning problem , There are the following steps :
- Get training data .
- Defining models .
- Define the loss function .
- Traverse the training data , Calculate the loss from the target value .
- Calculate the gradient of this loss , And use optimizer Adjust variables to fit data .
- The result of the calculation is .
Build data
Supervised learning uses input ( Usually expressed as x) And the output ( Expressed as y, Commonly known as labels ). The goal is to learn from paired inputs and outputs , So that you can predict the output value according to the input .TensorFlow Almost every input data in is represented by tensor , And it's usually a vector . In supervised learning , Output ( That is, think of the predicted value ) It's also a tensor . This is done by putting Gauss ( Normal distribution ) Some data synthesized by adding noise to the points on the line , And visualize these data .
x = np.random.random([1000]) * 5
noise = np.random.random([1000])
y = 3 * x + 7
import matplotlib.pyplot as plt
plt.scatter(x, y, c="b")
plt.show()

Customize the model we need
We inherit tf.module class , And define two variables , Its attribute is trainable_variables.
class selfmodel(tf.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.v1 = tf.Variable(1.0, trainable=True)
self.v2 = tf.Variable(2.0, trainable=True)
def __call__(self, x):
y = self.v1 * x + self.v2
return y
Define the loss function
We use the mean square deviation here to calculate the loss
def loss(target_y, predicted_y):
return tf.reduce_mean(tf.square(target_y - predicted_y))
Define the cycle training function
We use epcohs Function to get the two variables we need v 1 , v 2 v1,v2 v1,v2, After each training v 1 , v 2 v1, v2 v1,v2 Record , Finally, visualization
def train(model, x, y,epochs,optimizer):
v1, v2 = [], []
for j in range(epochs):
with tf.GradientTape() as gd:
y_pred = model(x) # This needs to be inside
loss_score = loss(y, y_pred)
grad = gd.gradient(loss_score, model.trainable_variables)
optimizer.apply_gradients(zip(grad, model.trainable_variables))
v1.append(model.v1.numpy())
v2.append(model.v2.numpy())
return (model, v1, v2)
The end result shows
I'm defining epcohs When , If the setting is too small , Will lead to v 1 , v 2 v1,v2 v1,v2 Can't get the right result
opt = tf.keras.optimizers.SGD()
model = selfmodel()
epochs = 1000
(model, v1, v2) = train(model, x, y,epochs, opt)
# draw
plt.plot(range(epochs), v1, "r",
range(epochs), v2, "b")
plt.plot([3] * epochs, "r--",
[7] * epochs, "b--")
plt.legend(["W", "b", "True W", "True b"])
plt.show()

Problems in the code
# The code in this case will report an error , Say our grad The result is (none, none),
# because y_pred = model(x) It should be written in with Inside
# The following will write the correct way , The reason for this error is loss Function in
# Yes molel.trainable_variables When seeking derivative , Gradient not found
y_pred = model(x)
with tf.GradientTape() as t:
l = loss(y, y_pred)
grad = t.gradient(l, model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
# Correct writing
with tf.GradientTape() as t:
y_pred = model(x)
l = loss(y, model(x))
grad = t.gradient(l, model.trainable_variables)
print(model.trainable_variables)
optimizer = tf.keras.optimizers.SGD()
optimizer.apply_gradients(zip(grad, model.trainable_variables))
边栏推荐
- 基于DVWA的文件上传漏洞测试
- leetcode刷题_平方数之和
- Spir - V premier aperçu
- CocoaPods could not find compatible versions for pod 'Firebase/CoreOnly'
- [flask] static file and template rendering
- Paddle framework: paddlenlp overview [propeller natural language processing development library]
- Loop structure of program (for loop)
- Poj2315 football games
- Luo Gu P1170 Bugs Bunny and Hunter
- XSS learning XSS lab problem solution
猜你喜欢

Blue Bridge Cup embedded stm32g431 - the real topic and code of the eighth provincial competition

500 lines of code to understand the principle of mecached cache client driver

Redis-列表

A Cooperative Approach to Particle Swarm Optimization

3D model format summary

You are using pip version 21.1.1; however, version 22.0.3 is available. You should consider upgradin

Basic operations of databases and tables ----- unique constraints

UE4 unreal engine, editor basic application, usage skills (IV)

现货白银的一般操作方法

leetcode刷题_验证回文字符串 Ⅱ
随机推荐
Nmap: network detection tool and security / port scanner
How to get the PHP version- How to get the PHP Version?
Leetcode skimming questions_ Verify palindrome string II
Remember that a version of @nestjs/typeorm^8.1.4 cannot be obtained Env option problem
500 lines of code to understand the principle of mecached cache client driver
Yii console method call, Yii console scheduled task
【已解决】如何生成漂亮的静态文档说明页
[机缘参悟-39]:鬼谷子-第五飞箝篇 - 警示之二:赞美的六种类型,谨防享受赞美快感如同鱼儿享受诱饵。
晶振是如何起振的?
现货白银的一般操作方法
Paddle框架:PaddleNLP概述【飛槳自然語言處理開發庫】
Redis-列表
What is weak reference? What are the weak reference data types in ES6? What are weak references in JS?
Reasonable and sensible
How to see the K-line chart of gold price trend?
3D model format summary
internship:项目代码所涉及陌生注解及其作用
Flutter Doctor:Xcode 安装不完整
VMware Tools安装报错:无法自动安装VSock驱动程序
Accelerating spark data access with alluxio in kubernetes