当前位置:网站首页>3 keras版本模型训练
3 keras版本模型训练
2022-06-26 15:30:00 【X1996_】
顺序模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))#第一层
model.add(layers.Dense(64, activation='relu'))#第二层
model.add(layers.Dense(10))#第三层
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
顺序模型 2
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),#第一层
layers.Dense(64, activation='relu'),#第二层
layers.Dense(10)#第三层
])
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
函数式
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
inputs = tf.keras.Input(shape=(32,))
# inputs = tf.keras.Input(shape=(32,))
x = layers.Dense(64, activation='relu')(inputs) #第一层
x = layers.Dense(64, activation='relu')(x) #第二层
predictions = layers.Dense(10)(x) #第三层
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
子类化模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# 定义自己需要的层
self.dense_1 = layers.Dense(32, activation='relu') #
self.dense_2 = layers.Dense(num_classes)
def call(self, inputs):
#定义前向传播
# 使用在 (in `__init__`)定义的层
x = self.dense_1(inputs)
x = self.dense_2(x)
return x
model = MyModel(num_classes=10)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
画图
tf.keras.utils.plot_model(model, 'multi_input_and_output_model.png', show_shapes=True,dpi=500)

模型训练:model.fit()
模型验证:model.evaluate()
模型预测: model.predict()
# Evaluate the model on the test data using `evaluate`
print('\n# Evaluate on test data')
results = model.evaluate(x_test, y_test, batch_size=128)
print('test loss, test acc:', results)
# Generate predictions (probabilities -- the output of the last layer)
# on new data using `predict`
print('\n# Generate predictions for 3 samples')
predictions = model.predict(x_test[:3])
print('predictions shape:', predictions.shape)
边栏推荐
- 音视频学习(一)——PTZ控制原理
- 5000 word analysis: the way of container security attack and defense in actual combat scenarios
- NFT合约基础知识讲解
- NFT交易原理分析(2)
- Development, deployment and online process of NFT project (1)
- [wechat applet] event binding, do you understand?
- 评价——模糊综合评价
- [thinking] what were you buying when you bought NFT?
- 9 Tensorboard的使用
- Reflection modification final
猜你喜欢

粒子滤波 PF——在机动目标跟踪中的应用(粒子滤波VS扩展卡尔曼滤波)

Solana扩容机制分析(1):牺牲可用性换取高效率的极端尝试 | CatcherVC Research

svg野人动画代码

Using restcloud ETL shell component to schedule dataX offline tasks

JVM notes

El dialog drag and drop, the boundary problem is completely corrected, and the bug of the online version is fixed

sqlite加载csv文件,并做数据分析

Utilisation d'abortcontroller

Vsomeip3 dual computer communication file configuration

A blog to thoroughly master the theory and practice of particle filter (PF) (matlab version)
随机推荐
Transaction input data of Ethereum
Solana扩容机制分析(1):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
js创意图标导航菜单切换背景色
【leetcode】701. Insert operation in binary search tree
[CEPH] Lock Notes of cephfs
Selenium saves elements as pictures
Auto Sharding Policy will apply Data Sharding policy as it failed to apply file Sharding Policy
【leetcode】701. 二叉搜索树中的插入操作
js文本滚动分散动画js特效
Evaluation - TOPSIS
CNN optimized trick
在重新格式化时不要删除自定义换行符(Don‘t remove custom line breaks on reformat)
selenium将元素保存为图片
HW安全响应
有Cmake的工程交叉编译到链接时报错找不到.so动态库文件
Is it safe to open a new bond registration account? Is there any risk?
AbortController的使用
评价——模糊综合评价
[tcapulusdb knowledge base] tcapulusdb doc acceptance - Introduction to creating game area
Reflection modification final