当前位置:网站首页>Tensorflow模型训练和评估的内置方法
Tensorflow模型训练和评估的内置方法
2022-07-27 08:45:00 【qq_27390023】
Tensorflow有许多内置的优化器、损失和度量都可用。一般来说,你不必从头开始创建自己的损失、指标或优化器,因为你需要的东西很可能已经是Keras API的一部分。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are NumPy arrays)
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve 10,000 samples for validation
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# 自定义模型函数
def get_uncompiled_model():
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def get_compiled_model():
model = get_uncompiled_model()
model.compile(
optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"],
)
return model
### 1.自定义模型损失函数
def custom_mean_squared_error(y_true, y_pred):
return tf.math.reduce_mean(tf.square(y_true - y_pred))
model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)
# We need to one-hot encode the labels to use MSE
y_train_one_hot = tf.one_hot(y_train, depth=10)
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
### 2. 自定义损失函数类
class CustomMSE(keras.losses.Loss):
def __init__(self, regularization_factor=0.1, name="custom_mse"):
super().__init__(name=name)
self.regularization_factor = regularization_factor
def call(self, y_true, y_pred):
mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
return mse + reg * self.regularization_factor
model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
### 3. 自定义模型评价类
class CategoricalTruePositives(keras.metrics.Metric):
def __init__(self, name="categorical_true_positives", **kwargs):
super(CategoricalTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name="ctp", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
values = tf.cast(values, "float32")
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, "float32")
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_state(self):
# The state of the metric will be reset at the start of each epoch.
self.true_positives.assign(0.0)
model = get_uncompiled_model()
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[CategoricalTruePositives()],
)
model.fit(x_train, y_train, batch_size=64, epochs=3)
### 4.模型训练时使用validation数据集
model = get_compiled_model()
# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
# Prepare the validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)
history = model.fit(train_dataset,
epochs=3,
# Only run validation using the first 10 batches of the dataset
# using the `validation_steps` argument
validation_data=val_dataset,
validation_steps=10)
print(history.history)
### 5.多输入多输出模型
image_input = keras.Input(shape=(32, 32, 3), name="img_input")
timeseries_input = keras.Input(shape=(None, 10), name="ts_input")
x1 = layers.Conv2D(3, 3)(image_input)
x1 = layers.GlobalMaxPooling2D()(x1)
x2 = layers.Conv1D(3, 3)(timeseries_input)
x2 = layers.GlobalMaxPooling1D()(x2)
x = layers.concatenate([x1, x2])
score_output = layers.Dense(1, name="score_output")(x)
class_output = layers.Dense(5, name="class_output")(x)
model = keras.Model(
inputs=[image_input, timeseries_input], outputs=[score_output, class_output]
)
model.compile(
optimizer=keras.optimizers.RMSprop(1e-3),
loss={
"score_output": keras.losses.MeanSquaredError(),
"class_output": keras.losses.CategoricalCrossentropy(),
},
metrics={
"score_output": [
keras.metrics.MeanAbsolutePercentageError(),
keras.metrics.MeanAbsoluteError(),
],
"class_output": [keras.metrics.CategoricalAccuracy()],
},
loss_weights={"score_output": 2.0, "class_output": 1.0},
)
# Generate dummy NumPy data
import numpy as np
img_data = np.random.random_sample(size=(100, 32, 32, 3))
ts_data = np.random.random_sample(size=(100, 20, 10))
score_targets = np.random.random_sample(size=(100, 1))
class_targets = np.random.random_sample(size=(100, 5))
# Fit on lists
model.fit([img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1)
# Alternatively, fit on dicts
model.fit(
{"img_input": img_data, "ts_input": ts_data},
{"score_output": score_targets, "class_output": class_targets},
batch_size=32,
epochs=1,
)
### 6.使用回调
# Keras中的回调是在训练过程中的不同点(在一个历时的开始,在一个批次的结束,在一个历时的结束,等等)被调用的对象。
# 它们可以被用来实现某些行为,比如:
# 在训练过程中的不同点进行验证(超出内置的每周期验证)。
# 每隔一段时间或当模型超过某个精度阈值时对其进行检查。
# 当训练似乎趋于平稳时,改变模型的学习率
# 当训练似乎趋于平稳时,对顶层进行微调
# 当训练结束或超过某个性能阈值时,发送电子邮件或即时信息通知等等。
model = get_compiled_model()
callbacks = [
keras.callbacks.EarlyStopping(
# Stop training when `val_loss` is no longer improving
monitor="val_loss",
# "no longer improving" being defined as "no better than 1e-2 less"
min_delta=1e-2,
# "no longer improving" being further defined as "for at least 2 epochs"
patience=2,
verbose=1,
)
]
model.fit(
x_train,
y_train,
epochs=20,
batch_size=64,
callbacks=callbacks,
validation_split=0.2,
)参考:
https://tensorflow.google.cn/guide/keras/train_and_evaluate?hl=en
边栏推荐
- 低成本、低门槛、易部署,4800+万户中小企业数字化转型新选择
- 4276. 擅长C
- It's better to be full than delicious; It's better to be drunk than drunk
- 2034: [Blue Bridge Cup 2022 preliminary] pruning shrubs
- [uni app advanced practice] take you hand-in-hand to learn the development of a purely practical complex project 1/100
- Use of flask
- How to uninstall -- Qianxin secure terminal management system
- MCDF顶层验证方案
- JS检测客户端软件是否安装
- 4278. 峰会
猜你喜欢

View 的滑动冲突

Block, there is a gap between the block elements in the row

Aruba学习笔记10-安全认证-Portal认证(web页面配置)

Chapter 2 foreground data display

First experience of tryme in opengauss

NiO Summary - read and understand the whole NiO process

Hundreds of people participated. What are these people talking about in the opengauss open source community?

Have a good laugh

Node installation and debugging

NIO总结文——一篇读懂NIO整个流程
随机推荐
General view, DRF view review
NIO this.selector.select()
JS detects whether the client software is installed
4277. 区块反转
E. Split into two sets
Iterators and generators
Cenos7 update MariaDB
How to permanently set source
Sliding conflict of view
MCDF top level verification scheme
User management - restrictions
User management - restrictions
Flask one to many database creation, basic addition, deletion, modification and query
Using ecological power, opengauss breaks through the performance bottleneck
02 linear structure 3 reversing linked list
Use of flask
Day6 --- Sqlalchemy advanced
4278. 峰会
众昂矿业:新能源行业快速发展,氟化工产品势头强劲
海关总署:这类产品暂停进口