当前位置:网站首页>Tensorflow模型训练和评估的内置方法
Tensorflow模型训练和评估的内置方法
2022-07-27 08:45:00 【qq_27390023】
Tensorflow有许多内置的优化器、损失和度量都可用。一般来说,你不必从头开始创建自己的损失、指标或优化器,因为你需要的东西很可能已经是Keras API的一部分。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are NumPy arrays)
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve 10,000 samples for validation
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# 自定义模型函数
def get_uncompiled_model():
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def get_compiled_model():
model = get_uncompiled_model()
model.compile(
optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"],
)
return model
### 1.自定义模型损失函数
def custom_mean_squared_error(y_true, y_pred):
return tf.math.reduce_mean(tf.square(y_true - y_pred))
model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)
# We need to one-hot encode the labels to use MSE
y_train_one_hot = tf.one_hot(y_train, depth=10)
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
### 2. 自定义损失函数类
class CustomMSE(keras.losses.Loss):
def __init__(self, regularization_factor=0.1, name="custom_mse"):
super().__init__(name=name)
self.regularization_factor = regularization_factor
def call(self, y_true, y_pred):
mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
return mse + reg * self.regularization_factor
model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
### 3. 自定义模型评价类
class CategoricalTruePositives(keras.metrics.Metric):
def __init__(self, name="categorical_true_positives", **kwargs):
super(CategoricalTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name="ctp", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
values = tf.cast(values, "float32")
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, "float32")
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_state(self):
# The state of the metric will be reset at the start of each epoch.
self.true_positives.assign(0.0)
model = get_uncompiled_model()
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[CategoricalTruePositives()],
)
model.fit(x_train, y_train, batch_size=64, epochs=3)
### 4.模型训练时使用validation数据集
model = get_compiled_model()
# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
# Prepare the validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)
history = model.fit(train_dataset,
epochs=3,
# Only run validation using the first 10 batches of the dataset
# using the `validation_steps` argument
validation_data=val_dataset,
validation_steps=10)
print(history.history)
### 5.多输入多输出模型
image_input = keras.Input(shape=(32, 32, 3), name="img_input")
timeseries_input = keras.Input(shape=(None, 10), name="ts_input")
x1 = layers.Conv2D(3, 3)(image_input)
x1 = layers.GlobalMaxPooling2D()(x1)
x2 = layers.Conv1D(3, 3)(timeseries_input)
x2 = layers.GlobalMaxPooling1D()(x2)
x = layers.concatenate([x1, x2])
score_output = layers.Dense(1, name="score_output")(x)
class_output = layers.Dense(5, name="class_output")(x)
model = keras.Model(
inputs=[image_input, timeseries_input], outputs=[score_output, class_output]
)
model.compile(
optimizer=keras.optimizers.RMSprop(1e-3),
loss={
"score_output": keras.losses.MeanSquaredError(),
"class_output": keras.losses.CategoricalCrossentropy(),
},
metrics={
"score_output": [
keras.metrics.MeanAbsolutePercentageError(),
keras.metrics.MeanAbsoluteError(),
],
"class_output": [keras.metrics.CategoricalAccuracy()],
},
loss_weights={"score_output": 2.0, "class_output": 1.0},
)
# Generate dummy NumPy data
import numpy as np
img_data = np.random.random_sample(size=(100, 32, 32, 3))
ts_data = np.random.random_sample(size=(100, 20, 10))
score_targets = np.random.random_sample(size=(100, 1))
class_targets = np.random.random_sample(size=(100, 5))
# Fit on lists
model.fit([img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1)
# Alternatively, fit on dicts
model.fit(
{"img_input": img_data, "ts_input": ts_data},
{"score_output": score_targets, "class_output": class_targets},
batch_size=32,
epochs=1,
)
### 6.使用回调
# Keras中的回调是在训练过程中的不同点(在一个历时的开始,在一个批次的结束,在一个历时的结束,等等)被调用的对象。
# 它们可以被用来实现某些行为,比如:
# 在训练过程中的不同点进行验证(超出内置的每周期验证)。
# 每隔一段时间或当模型超过某个精度阈值时对其进行检查。
# 当训练似乎趋于平稳时,改变模型的学习率
# 当训练似乎趋于平稳时,对顶层进行微调
# 当训练结束或超过某个性能阈值时,发送电子邮件或即时信息通知等等。
model = get_compiled_model()
callbacks = [
keras.callbacks.EarlyStopping(
# Stop training when `val_loss` is no longer improving
monitor="val_loss",
# "no longer improving" being defined as "no better than 1e-2 less"
min_delta=1e-2,
# "no longer improving" being further defined as "for at least 2 epochs"
patience=2,
verbose=1,
)
]
model.fit(
x_train,
y_train,
epochs=20,
batch_size=64,
callbacks=callbacks,
validation_split=0.2,
)参考:
https://tensorflow.google.cn/guide/keras/train_and_evaluate?hl=en
边栏推荐
- ROS2安装时出现Connection failed [IP: 91.189.91.39 80]
- [penetration test tool sharing] [dnslog server building guidance]
- Flink1.15 source code reading Flink clients client execution process (reading is boring)
- Oppo self-developed large-scale knowledge map and its application in digital intelligence engineering
- Using ecological power, opengauss breaks through the performance bottleneck
- 如何在B站上快乐的学习?
- 4274. 后缀表达式
- 接口测试工具-Postman使用详解
- What are the differences or similarities between "demand fulfillment to settlement" and "purchase to payment"?
- Aruba学习笔记10-安全认证-Portal认证(web页面配置)
猜你喜欢

Day3 -- flag state holding, exception handling and request hook

Background image related applications - full, adaptive

Network IO summary

First experience of tryme in opengauss

MCDF顶层验证方案

Flink1.15源码阅读flink-clients客户端执行流程(阅读较枯燥)

UVM Introduction Experiment 1

How to merge multiple columns in an excel table into one column

Massive data Xiao Feng: jointly build and govern opengauss root community and share a thriving new ecosystem

微信安装包从0.5M暴涨到260M,为什么我们的程序越来越大?
随机推荐
4274. Suffix expression
Unity3D 2021软件安装包下载及安装教程
Matlab画图技巧与实例:堆叠图stackedplot
It's better to be full than delicious; It's better to be drunk than drunk
NIO this.selector.select()
4275. Dijkstra sequence
Day6 --- Sqlalchemy advanced
JS basic knowledge - daily learning summary ①
The shelf life you filled in has been less than 10 days until now, and it is not allowed to publish. If the actual shelf life is more than 10 days, please truthfully fill in the production date and pu
Implementation of registration function
User management - restrictions
4275. Dijkstra序列
Flask login implementation
数智革新
Flask request data acquisition and response
Realize SKU management in the background
NiO Summary - read and understand the whole NiO process
The wechat installation package has soared from 0.5m to 260m. Why are our programs getting bigger and bigger?
Background coupon management
The following license SolidWorks Standard cannot be obtained, and the use license file cannot be found. (-1,359,2)。