当前位置:网站首页>TensorFlow custom training function
TensorFlow custom training function
2022-07-30 15:32:00 【u012804784】
优质资源分享
| 学习路线指引(点击解锁) | 知识定位 | 人群定位 |
|---|---|---|
| 🧡 Python实战微信订餐小程序 🧡 | 进阶级 | 本课程是python flask+微信小程序的完美结合,从项目搭建到腾讯云部署上线,打造一个全栈订餐系统. |
| Python量化交易实战 | 入门级 | 手把手带你打造一个易扩展、更安全、效率更高的量化交易系统 |
本文记录了在TensorFlow框架中自定义训练函数的模板并简述了使用自定义训练函数的优势与劣势.
首先需要说明的是,本文中所记录的训练函数模板参考自https://stackoverflow.com/questions/59438904/applying-callbacks-in-a-custom-training-loop-in-tensorflow-2-0中的回答以及Hands-On Machine Learning with Scikit-Learn, Keras, and Tensorflow一书中第12.3.9节的内容,如有错漏,欢迎指正.
为什么和什么时候需要自定义训练函数
除非你真的需要额外的灵活性,否则应该更倾向使用fit()方法,为不是实现你自己的循环,尤其是在团队合作中.
如果你还在困惑为什么需要自定义训练函数的时候,那说明你还不需要自定义训练函数.通常只有在搭建一些结构奇特的模型时,我们才会发现model.fit()无法完全满足需求,接下来首先该尝试的方法是去看TensorFlow相关部分的源码,看看有没有认识之外的参数或方法,其次才是考虑使用自定义训练函数.毫无疑问,自定义训练函数会让代码更长、更难维护、更难懂.
但是,自定义训练函数的灵活性是fit()方法无法比拟的.比如,在自定义函数中你可以实现使用多个不同优化器的训练循环或是在多个数据集上计算验证循环.
自定义训练函数模板
模板设计的目的在于让我们通过对代码块的复用以及对关键部位的填空快速完成自定义训练函数,以使我们更专注于训练函数结构本身而非一些细枝末节的部分(如未知长度训练集的处理)并实现一些fit()方法支持的功能(如Callback类的使用).
def train(model:keras.Model,train\_batchs,epochs=1,initial\_epoch=0,callbacks=None,steps\_per\_epoch=None,val\_batchs=None):
callbacks = tf.keras.callbacks.CallbackList(
callbacks, add_history=True, model=model)
logs_dict = {}
# init optimizer, loss function and metrics
optimizer = keras.optimizers.Nadam(learning_rate=0.0005)
loss_fn = keras.losses.MeanSquaredError
train_loss_tracker = keras.metrics.Mean(name="train\_loss")
val_loss_tracker = keras.metrics.Mean(name="val\_loss")
# train\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="train\_acc")
# val\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="val\_acc")
def count(): # infinite iter
x = 0
while True:yield x;x+=1
def print\_status\_bar(iteration, total, metrics=None):
metrics = " - ".join(["{}:{:.4f}".format(m.name,m.result()) for m in (metrics or [])])
end = "" if iteration < total or float('inf') else "\n"
print("\r{}/{} - ".format(iteration,total) + metrics, end=end)
def train\_step(x,y,loss\_tracker:keras.metrics.Metric):
with tf.GradientTape() as tape:
outputs = model(x)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients,model.trainable_variables))
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
def val\_step(x,y,loss\_tracker:keras.metrics.Metric):
outputs = model.predict(x,verbose=0)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
# init train\_batchs
train_iter = iter(train_batchs)
callbacks.on_train_begin(logs=logs_dict)
for i_epoch in range(initial_epoch, epochs):
# init steps
infinite_flag = False
if steps_per_epoch is None:
infinite_flag = True
step_iter = count()
else:
step_iter = range(steps_per_epoch)
# train\_loop
for i_step in step_iter:
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_train_batch_begin(i_step, logs=logs_dict)
try:
X_batch, y_batch = train_iter.next()
except StopIteration:
train_iter = iter(train_batchs)
if infinite_flag is True:
break
else:
X_batch, y_batch = train_iter.next()
train_logs_dict = train_step(x=X_batch,y=y_batch,loss_tracker=train_loss_tracker)
logs_dict.update(train_logs_dict)
print_status_bar(i_step, steps_per_epoch or i_step, [train_loss_tracker])
callbacks.on_train_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
if steps_per_epoch is None:
print()
steps_per_epoch = i_step
if val_batchs is not None:
# val\_loop
for i_step,(X_batch,y_batch) in enumerate(iter(val_batchs)):
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_test_batch_begin(i_step, logs=logs_dict)
val_logs_dict = val_step(x=X_batch,y=y_batch,loss_tracker=val_loss_tracker)
logs_dict.update(val_logs_dict)
callbacks.on_test_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
logs_dict.update(val_logs_dict)
print_status_bar(steps_per_epoch, steps_per_epoch, [train_loss_tracker, val_loss_tracker])
callbacks.on_epoch_end(i_epoch, logs=logs_dict)
for metric in [train_loss_tracker, val_loss_tracker]:
metric.reset_states()
callbacks.on_train_end(logs=logs_dict)
# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
if isinstance(cb, tf.keras.callbacks.History):
history_object = cb
return history_object
折叠
边栏推荐
- 瑞吉外卖项目实战Day02
- [深入研究4G/5G/6G专题-46]: 5G Link Adaption链路自适应-2-常见缩略语
- Extremely Knowing v2 Analysis
- 关于mariadb/mysql的user表:密码正确但登录失败,可能与mysql的空用户有关
- 5G-based Warehousing Informatization Solution 2022
- MaxWell scraped data
- 我们公司用了 6 年的网关服务,动态路由、鉴权、限流等都有,稳的一批!
- Delayed message queue
- 存储器映射、位带操作
- CMake库搜索函数居然不搜索LD_LIBRARY_PATH
猜你喜欢

5G-based Warehousing Informatization Solution 2022

四大首搭加持,美学、安全、操控、效率优势明显,比亚迪海豹售价20.98万元起售!

MongoDB启动报错 Process: 29784 ExecStart=/usr/bin/mongod $OPTIONS (code=exited, status=14)

Alluxio for Presto fu can across the cloud self-service ability

B+树索引页大小是如何确定的?

Smart Contract Security - Private Data Access
4位资深专家多年大厂经验分享出Flink技术内幕架构设计与实现原理

This editor actually claims to be as fast as lightning!

How do luxury giants such as GUCCI and LV deploy the metaverse, should other brands keep up?

Installing and Uninstalling MySQL on Mac
随机推荐
canal scrape data
GeoServer + openlayers
Flink实时仓库-DWS层(状态编程,windowall的使用,数据保存到clickhouse)模板代码
CMake库搜索函数居然不搜索LD_LIBRARY_PATH
71页全域旅游综合整体解决方案2021 ppt
JVM performance tuning
Normal and escaped strings for postgresql
Metaverse Post Office AI space-themed series of digital collections will be launched at 10:00 on July 30th "Yuanyou Digital Collection"
[机缘参悟-53]:《素书》-3-修身养志[求人之志章第三]
被捧上天的Scrum敏捷管理为何不受大厂欢迎了?
golang modules初始化项目
2022最新 | 室外单目深度估计研究综述
JUC common thread pool source learning 02 ( ThreadPoolExecutor thread pool )
极验深知v2分析
ISELED---氛围灯方案的新选择
Distributed pre-course: MySQL implements distributed locks
JUC常见的线程池源码学习 02 ( ThreadPoolExecutor 线程池 )
[Enlightenment by Opportunity-53]: "Sushu"-3- Self-cultivation and Self-cultivation
数据库日期类型字段设计,应该如何选择?
This editor actually claims to be as fast as lightning!