当前位置:网站首页>TensorFlow自定义训练函数
TensorFlow自定义训练函数
2022-07-30 14:41:00 【u012804784】
优质资源分享
| 学习路线指引(点击解锁) | 知识定位 | 人群定位 |
|---|---|---|
| 🧡 Python实战微信订餐小程序 🧡 | 进阶级 | 本课程是python flask+微信小程序的完美结合,从项目搭建到腾讯云部署上线,打造一个全栈订餐系统。 |
| Python量化交易实战 | 入门级 | 手把手带你打造一个易扩展、更安全、效率更高的量化交易系统 |
本文记录了在TensorFlow框架中自定义训练函数的模板并简述了使用自定义训练函数的优势与劣势。
首先需要说明的是,本文中所记录的训练函数模板参考自https://stackoverflow.com/questions/59438904/applying-callbacks-in-a-custom-training-loop-in-tensorflow-2-0中的回答以及Hands-On Machine Learning with Scikit-Learn, Keras, and Tensorflow一书中第12.3.9节的内容,如有错漏,欢迎指正。
为什么和什么时候需要自定义训练函数
除非你真的需要额外的灵活性,否则应该更倾向使用fit()方法,为不是实现你自己的循环,尤其是在团队合作中。
如果你还在困惑为什么需要自定义训练函数的时候,那说明你还不需要自定义训练函数。通常只有在搭建一些结构奇特的模型时,我们才会发现model.fit()无法完全满足需求,接下来首先该尝试的方法是去看TensorFlow相关部分的源码,看看有没有认识之外的参数或方法,其次才是考虑使用自定义训练函数。毫无疑问,自定义训练函数会让代码更长、更难维护、更难懂。
但是,自定义训练函数的灵活性是fit()方法无法比拟的。比如,在自定义函数中你可以实现使用多个不同优化器的训练循环或是在多个数据集上计算验证循环。
自定义训练函数模板
模板设计的目的在于让我们通过对代码块的复用以及对关键部位的填空快速完成自定义训练函数,以使我们更专注于训练函数结构本身而非一些细枝末节的部分(如未知长度训练集的处理)并实现一些fit()方法支持的功能(如Callback类的使用)。
def train(model:keras.Model,train\_batchs,epochs=1,initial\_epoch=0,callbacks=None,steps\_per\_epoch=None,val\_batchs=None):
callbacks = tf.keras.callbacks.CallbackList(
callbacks, add_history=True, model=model)
logs_dict = {}
# init optimizer, loss function and metrics
optimizer = keras.optimizers.Nadam(learning_rate=0.0005)
loss_fn = keras.losses.MeanSquaredError
train_loss_tracker = keras.metrics.Mean(name="train\_loss")
val_loss_tracker = keras.metrics.Mean(name="val\_loss")
# train\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="train\_acc")
# val\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="val\_acc")
def count(): # infinite iter
x = 0
while True:yield x;x+=1
def print\_status\_bar(iteration, total, metrics=None):
metrics = " - ".join(["{}:{:.4f}".format(m.name,m.result()) for m in (metrics or [])])
end = "" if iteration < total or float('inf') else "\n"
print("\r{}/{} - ".format(iteration,total) + metrics, end=end)
def train\_step(x,y,loss\_tracker:keras.metrics.Metric):
with tf.GradientTape() as tape:
outputs = model(x)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients,model.trainable_variables))
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
def val\_step(x,y,loss\_tracker:keras.metrics.Metric):
outputs = model.predict(x,verbose=0)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
# init train\_batchs
train_iter = iter(train_batchs)
callbacks.on_train_begin(logs=logs_dict)
for i_epoch in range(initial_epoch, epochs):
# init steps
infinite_flag = False
if steps_per_epoch is None:
infinite_flag = True
step_iter = count()
else:
step_iter = range(steps_per_epoch)
# train\_loop
for i_step in step_iter:
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_train_batch_begin(i_step, logs=logs_dict)
try:
X_batch, y_batch = train_iter.next()
except StopIteration:
train_iter = iter(train_batchs)
if infinite_flag is True:
break
else:
X_batch, y_batch = train_iter.next()
train_logs_dict = train_step(x=X_batch,y=y_batch,loss_tracker=train_loss_tracker)
logs_dict.update(train_logs_dict)
print_status_bar(i_step, steps_per_epoch or i_step, [train_loss_tracker])
callbacks.on_train_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
if steps_per_epoch is None:
print()
steps_per_epoch = i_step
if val_batchs is not None:
# val\_loop
for i_step,(X_batch,y_batch) in enumerate(iter(val_batchs)):
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_test_batch_begin(i_step, logs=logs_dict)
val_logs_dict = val_step(x=X_batch,y=y_batch,loss_tracker=val_loss_tracker)
logs_dict.update(val_logs_dict)
callbacks.on_test_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
logs_dict.update(val_logs_dict)
print_status_bar(steps_per_epoch, steps_per_epoch, [train_loss_tracker, val_loss_tracker])
callbacks.on_epoch_end(i_epoch, logs=logs_dict)
for metric in [train_loss_tracker, val_loss_tracker]:
metric.reset_states()
callbacks.on_train_end(logs=logs_dict)
# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
if isinstance(cb, tf.keras.callbacks.History):
history_object = cb
return history_object
折叠
边栏推荐
- 阿里CTO程立:阿里巴巴的开源历程、理念和实践
- What is Ts?
- GeoServer + openlayers
- 新时代背景下智慧城市的建设与5G技术有何关联
- [机缘参悟-53]:《素书》-3-修身养志[求人之志章第三]
- Flink实时仓库-DWS层(状态编程,windowall的使用,数据保存到clickhouse)模板代码
- 一文读懂网络效应对Web3的重要意义
- Flink real-time data warehouse completed
- 我们公司用了 6 年的网关服务,动态路由、鉴权、限流等都有,稳的一批!
- What is the relationship between the construction of smart cities and 5G technology in the new era
猜你喜欢

MaxWell scraped data

Office Automation | Office Software and Edraw MindMaster Shortcuts

学习 MySQL 需要知道的 28 个小技巧

惊艳!京东T8纯手码的Redis核心原理手册,基础与源码齐下

Redis cache penetration, breakdown, avalanche and consistency issues

Smart Contract Security - Private Data Access

5G-based Warehousing Informatization Solution 2022

Mysql数据库查询好慢,除了索引,还能因为什么?

Alluxio为Presto赋能跨云的自助服务能力

JUC常见的线程池源码学习 02 ( ThreadPoolExecutor 线程池 )
随机推荐
视频切换播放的例子(视频切换范例)代码
What is the relationship between the construction of smart cities and 5G technology in the new era
(科普文)什么是碎片化NFT(Fractional NFT)
一文读懂网络效应对Web3的重要意义
Understand the Chisel language. 29. Chisel advanced communication state machine (1) - communication state machine: take the flash as an example
去腾讯面试,直接让人出门左拐 :幂等性都不知道!
Container sorting case
v-model组件化编程应用
The website adds a live 2d kanban girl that can dress up and interact
MongoDB启动报错 Process: 29784 ExecStart=/usr/bin/mongod $OPTIONS (code=exited, status=14)
[机缘参悟-53]:《素书》-3-修身养志[求人之志章第三]
Flink本地UI运行
B+树索引页大小是如何确定的?
Smart Contract Security - Private Data Access
Flink实时仓库-DWS层(状态编程,windowall的使用,数据保存到clickhouse)模板代码
Huawei issues another summoning order for "Genius Boys"!He, who had given up an annual salary of 3.6 million, also made his debut
闭包和装饰器
DocuWare 文件管理与工作流程自动化案例研究——DocuWare 工作流程功能使在家工作的员工能够保持沟通和高效工作,支持混合环境
MongoDB starts an error Process: 29784 ExecStart=/usr/bin/mongod $OPTIONS (code=exited, status=14)
Installing and Uninstalling MySQL on Mac