当前位置:网站首页>TensorFlow自定义训练函数
TensorFlow自定义训练函数
2022-07-30 14:41:00 【u012804784】
优质资源分享
| 学习路线指引(点击解锁) | 知识定位 | 人群定位 |
|---|---|---|
| 🧡 Python实战微信订餐小程序 🧡 | 进阶级 | 本课程是python flask+微信小程序的完美结合,从项目搭建到腾讯云部署上线,打造一个全栈订餐系统。 |
| Python量化交易实战 | 入门级 | 手把手带你打造一个易扩展、更安全、效率更高的量化交易系统 |
本文记录了在TensorFlow框架中自定义训练函数的模板并简述了使用自定义训练函数的优势与劣势。
首先需要说明的是,本文中所记录的训练函数模板参考自https://stackoverflow.com/questions/59438904/applying-callbacks-in-a-custom-training-loop-in-tensorflow-2-0中的回答以及Hands-On Machine Learning with Scikit-Learn, Keras, and Tensorflow一书中第12.3.9节的内容,如有错漏,欢迎指正。
为什么和什么时候需要自定义训练函数
除非你真的需要额外的灵活性,否则应该更倾向使用fit()方法,为不是实现你自己的循环,尤其是在团队合作中。
如果你还在困惑为什么需要自定义训练函数的时候,那说明你还不需要自定义训练函数。通常只有在搭建一些结构奇特的模型时,我们才会发现model.fit()无法完全满足需求,接下来首先该尝试的方法是去看TensorFlow相关部分的源码,看看有没有认识之外的参数或方法,其次才是考虑使用自定义训练函数。毫无疑问,自定义训练函数会让代码更长、更难维护、更难懂。
但是,自定义训练函数的灵活性是fit()方法无法比拟的。比如,在自定义函数中你可以实现使用多个不同优化器的训练循环或是在多个数据集上计算验证循环。
自定义训练函数模板
模板设计的目的在于让我们通过对代码块的复用以及对关键部位的填空快速完成自定义训练函数,以使我们更专注于训练函数结构本身而非一些细枝末节的部分(如未知长度训练集的处理)并实现一些fit()方法支持的功能(如Callback类的使用)。
def train(model:keras.Model,train\_batchs,epochs=1,initial\_epoch=0,callbacks=None,steps\_per\_epoch=None,val\_batchs=None):
callbacks = tf.keras.callbacks.CallbackList(
callbacks, add_history=True, model=model)
logs_dict = {}
# init optimizer, loss function and metrics
optimizer = keras.optimizers.Nadam(learning_rate=0.0005)
loss_fn = keras.losses.MeanSquaredError
train_loss_tracker = keras.metrics.Mean(name="train\_loss")
val_loss_tracker = keras.metrics.Mean(name="val\_loss")
# train\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="train\_acc")
# val\_acc\_metric = tf.keras.metrics.BinaryAccuracy(name="val\_acc")
def count(): # infinite iter
x = 0
while True:yield x;x+=1
def print\_status\_bar(iteration, total, metrics=None):
metrics = " - ".join(["{}:{:.4f}".format(m.name,m.result()) for m in (metrics or [])])
end = "" if iteration < total or float('inf') else "\n"
print("\r{}/{} - ".format(iteration,total) + metrics, end=end)
def train\_step(x,y,loss\_tracker:keras.metrics.Metric):
with tf.GradientTape() as tape:
outputs = model(x)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients,model.trainable_variables))
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
def val\_step(x,y,loss\_tracker:keras.metrics.Metric):
outputs = model.predict(x,verbose=0)
main_loss = tf.reduce_mean(loss_fn(y,outputs))
loss = tf.add_n([main_loss] + model.losses)
loss_tracker.update_state(loss)
return {loss_tracker.name:loss_tracker.result()}
# init train\_batchs
train_iter = iter(train_batchs)
callbacks.on_train_begin(logs=logs_dict)
for i_epoch in range(initial_epoch, epochs):
# init steps
infinite_flag = False
if steps_per_epoch is None:
infinite_flag = True
step_iter = count()
else:
step_iter = range(steps_per_epoch)
# train\_loop
for i_step in step_iter:
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_train_batch_begin(i_step, logs=logs_dict)
try:
X_batch, y_batch = train_iter.next()
except StopIteration:
train_iter = iter(train_batchs)
if infinite_flag is True:
break
else:
X_batch, y_batch = train_iter.next()
train_logs_dict = train_step(x=X_batch,y=y_batch,loss_tracker=train_loss_tracker)
logs_dict.update(train_logs_dict)
print_status_bar(i_step, steps_per_epoch or i_step, [train_loss_tracker])
callbacks.on_train_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
if steps_per_epoch is None:
print()
steps_per_epoch = i_step
if val_batchs is not None:
# val\_loop
for i_step,(X_batch,y_batch) in enumerate(iter(val_batchs)):
callbacks.on_batch_begin(i_step, logs=logs_dict)
callbacks.on_test_batch_begin(i_step, logs=logs_dict)
val_logs_dict = val_step(x=X_batch,y=y_batch,loss_tracker=val_loss_tracker)
logs_dict.update(val_logs_dict)
callbacks.on_test_batch_end(i_step, logs=logs_dict)
callbacks.on_batch_end(i_step, logs=logs_dict)
logs_dict.update(val_logs_dict)
print_status_bar(steps_per_epoch, steps_per_epoch, [train_loss_tracker, val_loss_tracker])
callbacks.on_epoch_end(i_epoch, logs=logs_dict)
for metric in [train_loss_tracker, val_loss_tracker]:
metric.reset_states()
callbacks.on_train_end(logs=logs_dict)
# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
if isinstance(cb, tf.keras.callbacks.History):
history_object = cb
return history_object
折叠
边栏推荐
- Could not acquire management access for administration
- [In-depth study of 4G/5G/6G topic-46]: 5G Link Adaption Link Adaption-2-Common Abbreviations
- 【元胞自动机】基于元胞自动机模拟生命演化、病毒感染等实例附matlab代码
- 1222. 可以攻击国王的皇后-力扣双百代码
- MongoDB启动报错 Process: 29784 ExecStart=/usr/bin/mongod $OPTIONS (code=exited, status=14)
- JUC common thread pool source learning 02 ( ThreadPoolExecutor thread pool )
- (科普文)什么是碎片化NFT(Fractional NFT)
- Flink优化
- Chapter6 : Has Artificial Intelligence Impacted Drug Discovery?
- [Enlightenment by Opportunity-53]: "Sushu"-3- Self-cultivation and Self-cultivation
猜你喜欢

去腾讯面试,直接让人出门左拐 :幂等性都不知道!

学习 MySQL 需要知道的 28 个小技巧

Go to Tencent for an interview and let people turn left directly: I don't know idempotency!

1222. 可以攻击国王的皇后-力扣双百代码

Metaverse Post Office AI space-themed series of digital collections will be launched at 10:00 on July 30th "Yuanyou Digital Collection"

CS内网横向移动 模拟渗透实操 超详细

存储器映射、位带操作

新时代背景下智慧城市的建设与5G技术有何关联

嵌入式开发:嵌入式基础知识——正确启动固件项目的 10 条建议

Flink实时仓库-DWS层(关键词搜索分析-自定义函数,窗口操作,FlinkSql设置水位线,保存数据到Clickhouse)模板代码
随机推荐
关于mariadb/mysql的user表:密码正确但登录失败,可能与mysql的空用户有关
华为无线设备Mesh配置命令
数字量输入模块io
CMake库搜索函数居然不搜索LD_LIBRARY_PATH
ECCV 2022 | Towards Data Efficient Transformer Object Detectors
那些破釜沉舟入局Web3.0的互联网精英都怎么样了?
这个编辑器居然号称快如闪电!
1700. 无法吃午餐的学生数量
localhost与127.0.0.1
Use of SLF4J
【云原生 • DevOps】influxDB、cAdvisor、Grafana 工具使用详解
Understand Chisel language. 28. Chisel advanced finite state machine (2) - Mealy state machine and comparison with Moore state machine
智能合约安全——私有数据访问
浅析显卡市场的未来走向:现在可以抄底了吗?
Flink本地UI运行
Remember an experience of interviewing an outsourcing company, should you go?
B+树索引页大小是如何确定的?
The website adds a live 2d kanban girl that can dress up and interact
The highest level of wiring in the computer room, the beauty is suffocating
(科普文)什么是碎片化NFT(Fractional NFT)