当前位置:网站首页>TensorFlow自定义训练函数
TensorFlow自定义训练函数
2022-07-30 14:41:00 【u012804784】
优质资源分享
| 学习路线指引(点击解锁) | 知识定位 | 人群定位 |
|---|---|---|
| 🧡 Python实战微信订餐小程序 🧡 | 进阶级 | 本课程是python flask+微信小程序的完美结合,从项目搭建到腾讯云部署上线,打造一个全栈订餐系统。 |
| Python量化交易实战 | 入门级 | 手把手带你打造一个易扩展、更安全、效率更高的量化交易系统 |
本文记录了在TensorFlow框架中自定义训练函数的模板并简述了使用自定义训练函数的优势与劣势。
首先需要说明的是,本文中所记录的训练函数模板参考自https://stackoverflow.com/questions/59438904/applying-callbacks-in-a-custom-training-loop-in-tensorflow-2-0中的回答以及Hands-On Machine Learning with Scikit-Learn, Keras, and Tensorflow一书中第12.3.9节的内容,如有错漏,欢迎指正。
为什么和什么时候需要自定义训练函数
除非你真的需要额外的灵活性,否则应该更倾向使用fit()方法,为不是实现你自己的循环,尤其是在团队合作中。
如果你还在困惑为什么需要自定义训练函数的时候,那说明你还不需要自定义训练函数。通常只有在搭建一些结构奇特的模型时,我们才会发现model.fit()无法完全满足需求,接下来首先该尝试的方法是去看TensorFlow相关部分的源码,看看有没有认识之外的参数或方法,其次才是考虑使用自定义训练函数。毫无疑问,自定义训练函数会让代码更长、更难维护、更难懂。
但是,自定义训练函数的灵活性是fit()方法无法比拟的。比如,在自定义函数中你可以实现使用多个不同优化器的训练循环或是在多个数据集上计算验证循环。
自定义训练函数模板
模板设计的目的在于让我们通过对代码块的复用以及对关键部位的填空快速完成自定义训练函数,以使我们更专注于训练函数结构本身而非一些细枝末节的部分(如未知长度训练集的处理)并实现一些fit()方法支持的功能(如Callback类的使用)。
def train(model:keras.Model,train\_batchs,epochs=1,initial\_epoch=0,callbacks=None,steps\_per\_epoch=None,val\_batchs=None):
callbacks = tf.keras.callbacks.CallbackList(
callbacks, add_history=True, model=model)
logs_dict = {}
# init optimizer, loss function and metrics
optimizer = keras.optimizers.Nadam(learning_rate=0.0005)
loss_fn = keras.losses.MeanSquaredError
train_loss_tracker = keras.metrics.Mean(name="train\_loss")
val_loss_tracker = keras.metrics.Mean(name="val\_loss")
# train\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="train\_acc")
# val\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="val\_acc")
def count(): # infinite iter
x = 0
while True:yield x;x+=1
def print\_status\_bar(iteration, total, metrics=None):
metrics = " - ".join(["{}:{:.4f}".format(m.name,m.result()) for m in (metrics or [])])
end = "" if iteration < total or float('inf') else "\n"
print("\r{}/{} - ".format(iteration,total) + metrics, end=end)
def train\_step(x,y,loss\_tracker:keras.metrics.Metric):
with tf.GradientTape() as tape:
outputs = model(x)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients,model.trainable_variables))
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
def val\_step(x,y,loss\_tracker:keras.metrics.Metric):
outputs = model.predict(x,verbose=0)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
# init train\_batchs
train_iter = iter(train_batchs)
callbacks.on_train_begin(logs=logs_dict)
for i_epoch in range(initial_epoch, epochs):
# init steps
infinite_flag = False
if steps_per_epoch is None:
infinite_flag = True
step_iter = count()
else:
step_iter = range(steps_per_epoch)
# train\_loop
for i_step in step_iter:
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_train_batch_begin(i_step, logs=logs_dict)
try:
X_batch, y_batch = train_iter.next()
except StopIteration:
train_iter = iter(train_batchs)
if infinite_flag is True:
break
else:
X_batch, y_batch = train_iter.next()
train_logs_dict = train_step(x=X_batch,y=y_batch,loss_tracker=train_loss_tracker)
logs_dict.update(train_logs_dict)
print_status_bar(i_step, steps_per_epoch or i_step, [train_loss_tracker])
callbacks.on_train_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
if steps_per_epoch is None:
print()
steps_per_epoch = i_step
if val_batchs is not None:
# val\_loop
for i_step,(X_batch,y_batch) in enumerate(iter(val_batchs)):
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_test_batch_begin(i_step, logs=logs_dict)
val_logs_dict = val_step(x=X_batch,y=y_batch,loss_tracker=val_loss_tracker)
logs_dict.update(val_logs_dict)
callbacks.on_test_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
logs_dict.update(val_logs_dict)
print_status_bar(steps_per_epoch, steps_per_epoch, [train_loss_tracker, val_loss_tracker])
callbacks.on_epoch_end(i_epoch, logs=logs_dict)
for metric in [train_loss_tracker, val_loss_tracker]:
metric.reset_states()
callbacks.on_train_end(logs=logs_dict)
# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
if isinstance(cb, tf.keras.callbacks.History):
history_object = cb
return history_object
折叠
边栏推荐
猜你喜欢

MaxWell scraped data

(Crypto必备干货)详细分析目前NFT的几大交易市场

Lock wait timeout exceeded solution

JVM性能调优

Flink本地UI运行

Flink optimization

Excel使用Visual Basic Editor对宏进行修改

MongoDB starts an error Process: 29784 ExecStart=/usr/bin/mongod $OPTIONS (code=exited, status=14)

Distributed pre-course: MySQL implements distributed locks

PyQt5快速开发与实战 9.1 使用PyInstaller打包项目生成exe文件
随机推荐
5G-based Warehousing Informatization Solution 2022
视频切换播放的例子(视频切换范例)代码
学习 MySQL 需要知道的 28 个小技巧
Office Automation | Office Software and Edraw MindMaster Shortcuts
Excel使用Visual Basic Editor对宏进行修改
Understand the Chisel language. 29. Chisel advanced communication state machine (1) - communication state machine: take the flash as an example
元宇宙邮局AI航天主题系列数字藏品 将于7月30日10:00点上线“元邮数藏”
MySql error: SqlError(Unable to execute query", "Can't create/write to file OS errno 2 - No such file...
Application of time series database in the field of ship risk management
localhost with 127.0.0.1
Meta首份元宇宙白皮书9大看点,瞄准80万亿美元市场
有关收集箱的改进建议
瑞吉外卖项目实战Day02
Flink优化
Understand Chisel language. 28. Chisel advanced finite state machine (2) - Mealy state machine and comparison with Moore state machine
一文读懂网络效应对Web3的重要意义
吃透Chisel语言.29.Chisel进阶之通信状态机(一)——通信状态机:以闪光灯为例
我们公司用了 6 年的网关服务,动态路由、鉴权、限流等都有,稳的一批!
关于MySQL主从复制的数据同步延迟问题
Use of SLF4J