当前位置:网站首页>3 keras版本模型训练
3 keras版本模型训练
2022-06-26 15:30:00 【X1996_】
顺序模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))#第一层
model.add(layers.Dense(64, activation='relu'))#第二层
model.add(layers.Dense(10))#第三层
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
顺序模型 2
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),#第一层
layers.Dense(64, activation='relu'),#第二层
layers.Dense(10)#第三层
])
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
函数式
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
inputs = tf.keras.Input(shape=(32,))
# inputs = tf.keras.Input(shape=(32,))
x = layers.Dense(64, activation='relu')(inputs) #第一层
x = layers.Dense(64, activation='relu')(x) #第二层
predictions = layers.Dense(10)(x) #第三层
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
子类化模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# 定义自己需要的层
self.dense_1 = layers.Dense(32, activation='relu') #
self.dense_2 = layers.Dense(num_classes)
def call(self, inputs):
#定义前向传播
# 使用在 (in `__init__`)定义的层
x = self.dense_1(inputs)
x = self.dense_2(x)
return x
model = MyModel(num_classes=10)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
画图
tf.keras.utils.plot_model(model, 'multi_input_and_output_model.png', show_shapes=True,dpi=500)

模型训练:model.fit()
模型验证:model.evaluate()
模型预测: model.predict()
# Evaluate the model on the test data using `evaluate`
print('\n# Evaluate on test data')
results = model.evaluate(x_test, y_test, batch_size=128)
print('test loss, test acc:', results)
# Generate predictions (probabilities -- the output of the last layer)
# on new data using `predict`
print('\n# Generate predictions for 3 samples')
predictions = model.predict(x_test[:3])
print('predictions shape:', predictions.shape)
边栏推荐
- 北京房山区专精特新小巨人企业认定条件,补贴50万
- Summer camp is coming!!! Chongchongchong
- Solana扩容机制分析(2):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
- /etc/profile、/etc/bashrc、~/. Bashrc differences
- 一篇博客彻底掌握:粒子滤波 particle filter (PF) 的理论及实践(matlab版)
- 评价——模糊综合评价
- JS handwritten bind, apply, call
- PCIe Capabilities List
- 在重新格式化时不要删除自定义换行符(Don‘t remove custom line breaks on reformat)
- I want to know how to open an account through online stock? Is online account opening safe?
猜你喜欢

9 Tensorboard的使用

简单科普Ethereum的Transaction Input Data

SQLite loads CSV files and performs data analysis

10 tf.data

Panoramic analysis of upstream, middle and downstream industrial chain of "dry goods" NFT

AbortController的使用

在重新格式化时不要删除自定义换行符(Don‘t remove custom line breaks on reformat)

Svg animation around the earth JS special effects
![[tcapulusdb knowledge base] tcapulusdb doc acceptance - table creation approval introduction](/img/66/f3ab0514d691967ad88535ae1448c1.png)
[tcapulusdb knowledge base] tcapulusdb doc acceptance - table creation approval introduction

Utilisation d'abortcontroller
随机推荐
为什么图像分割任务中经常用到编码器和解码器结构?
CNN优化trick
一篇博客彻底掌握:粒子滤波 particle filter (PF) 的理论及实践(matlab版)
Utilisation d'abortcontroller
2Gcsv文件打不开怎么处理,使用byzer工具
Golang temporary object pool optimization
# 粒子滤波 PF——三维匀速运动CV目标跟踪(粒子滤波VS扩展卡尔曼滤波)
NFT 平台安全指南(2)
Development, deployment and online process of NFT project (1)
【leetcode】701. 二叉搜索树中的插入操作
NFT 平台安全指南(1)
High frequency interview 𞓜 Flink Shuangliu join
Ansible自动化的运用
[tcapulusdb knowledge base] tcapulusdb doc acceptance - Introduction to creating game area
Solana扩容机制分析(2):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
el-dialog拖拽,边界问题完全修正,网上版本的bug修复
TweenMax+SVG切换颜色动画场景
Solana capacity expansion mechanism analysis (1): an extreme attempt to sacrifice availability for efficiency | catchervc research
粒子滤波 PF——在机动目标跟踪中的应用(粒子滤波VS扩展卡尔曼滤波)
Panoramic analysis of upstream, middle and downstream industrial chain of "dry goods" NFT