当前位置:网站首页>3 keras版本模型训练
3 keras版本模型训练
2022-06-26 15:30:00 【X1996_】
顺序模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))#第一层
model.add(layers.Dense(64, activation='relu'))#第二层
model.add(layers.Dense(10))#第三层
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
顺序模型 2
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),#第一层
layers.Dense(64, activation='relu'),#第二层
layers.Dense(10)#第三层
])
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
函数式
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
inputs = tf.keras.Input(shape=(32,))
# inputs = tf.keras.Input(shape=(32,))
x = layers.Dense(64, activation='relu')(inputs) #第一层
x = layers.Dense(64, activation='relu')(x) #第二层
predictions = layers.Dense(10)(x) #第三层
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
子类化模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# 定义自己需要的层
self.dense_1 = layers.Dense(32, activation='relu') #
self.dense_2 = layers.Dense(num_classes)
def call(self, inputs):
#定义前向传播
# 使用在 (in `__init__`)定义的层
x = self.dense_1(inputs)
x = self.dense_2(x)
return x
model = MyModel(num_classes=10)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
画图
tf.keras.utils.plot_model(model, 'multi_input_and_output_model.png', show_shapes=True,dpi=500)

模型训练:model.fit()
模型验证:model.evaluate()
模型预测: model.predict()
# Evaluate the model on the test data using `evaluate`
print('\n# Evaluate on test data')
results = model.evaluate(x_test, y_test, batch_size=128)
print('test loss, test acc:', results)
# Generate predictions (probabilities -- the output of the last layer)
# on new data using `predict`
print('\n# Generate predictions for 3 samples')
predictions = model.predict(x_test[:3])
print('predictions shape:', predictions.shape)
边栏推荐
- Nanopi duo2 connection WiFi
- NFT交易原理分析(1)
- SVG大写字母A动画js特效
- 8 自定义评估函数
- 北京房山区专精特新小巨人企业认定条件,补贴50万
- Seurat to h5ad summary
- [tcapulusdb knowledge base] tcapulusdb OMS business personnel permission introduction
- JS handwritten bind, apply, call
- el-dialog拖拽,边界问题完全修正,网上版本的bug修复
- El dialog drag and drop, the boundary problem is completely corrected, and the bug of the online version is fixed
猜你喜欢

全面解析Discord安全问题

js创意图标导航菜单切换背景色

NFT 项目的开发、部署、上线的流程(2)

Audio and video learning (III) -- SIP protocol

Have you ever had a Kindle with a keyboard?

「幹貨」NFT 上中下遊產業鏈全景分析

Utilisation d'abortcontroller

在重新格式化时不要删除自定义换行符(Don‘t remove custom line breaks on reformat)

A blog to thoroughly master the theory and practice of particle filter (PF) (matlab version)

评价——TOPSIS
随机推荐
Is it safe to open an account for mobile stock registration? Is there any risk?
NFT transaction principle analysis (1)
零知识 QAP 问题的转化
【微信小程序】事件绑定,你搞懂了吗?
买股票通过券商经理的开户二维码开户资金是否安全?想开户炒股
在重新格式化时不要删除自定义换行符(Don‘t remove custom line breaks on reformat)
【leetcode】112. Path sum - 113 Path sum II
AbortController的使用
【思考】在买NFT的时候你在买什么?
[tcapulusdb knowledge base] Introduction to tcapulusdb system management
CNN优化trick
IntelliJ idea -- Method for formatting SQL files
NFT 项目的开发、部署、上线的流程(2)
反射修改final
音视频学习(二)——帧率、码流和分辨率
9 Tensorboard的使用
还存在过有键盘的kindle?
Mr. Du said that the website was updated with illustrations
el-dialog拖拽,边界问题完全修正,网上版本的bug修复
JS之手写 bind、apply、call