当前位置:网站首页>Built in method of tensorflow model training and evaluation
Built in method of tensorflow model training and evaluation
2022-07-27 08:51:00 【qq_ twenty-seven million three hundred and ninety thousand and 】
Tensorflow There are many built-in optimizers 、 Loss and measurement are available . Generally speaking , You don't have to create your own losses from scratch 、 Metrics or optimizers , Because what you need is probably already Keras API Part of .
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are NumPy arrays)
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve 10,000 samples for validation
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Custom model functions
def get_uncompiled_model():
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def get_compiled_model():
model = get_uncompiled_model()
model.compile(
optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"],
)
return model
### 1. Custom model loss function
def custom_mean_squared_error(y_true, y_pred):
return tf.math.reduce_mean(tf.square(y_true - y_pred))
model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)
# We need to one-hot encode the labels to use MSE
y_train_one_hot = tf.one_hot(y_train, depth=10)
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
### 2. Custom loss function class
class CustomMSE(keras.losses.Loss):
def __init__(self, regularization_factor=0.1, name="custom_mse"):
super().__init__(name=name)
self.regularization_factor = regularization_factor
def call(self, y_true, y_pred):
mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
return mse + reg * self.regularization_factor
model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
### 3. User defined model evaluation class
class CategoricalTruePositives(keras.metrics.Metric):
def __init__(self, name="categorical_true_positives", **kwargs):
super(CategoricalTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name="ctp", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
values = tf.cast(values, "float32")
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, "float32")
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_state(self):
# The state of the metric will be reset at the start of each epoch.
self.true_positives.assign(0.0)
model = get_uncompiled_model()
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[CategoricalTruePositives()],
)
model.fit(x_train, y_train, batch_size=64, epochs=3)
### 4. Used in model training validation Data sets
model = get_compiled_model()
# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
# Prepare the validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)
history = model.fit(train_dataset,
epochs=3,
# Only run validation using the first 10 batches of the dataset
# using the `validation_steps` argument
validation_data=val_dataset,
validation_steps=10)
print(history.history)
### 5. Multiple input multiple output model
image_input = keras.Input(shape=(32, 32, 3), name="img_input")
timeseries_input = keras.Input(shape=(None, 10), name="ts_input")
x1 = layers.Conv2D(3, 3)(image_input)
x1 = layers.GlobalMaxPooling2D()(x1)
x2 = layers.Conv1D(3, 3)(timeseries_input)
x2 = layers.GlobalMaxPooling1D()(x2)
x = layers.concatenate([x1, x2])
score_output = layers.Dense(1, name="score_output")(x)
class_output = layers.Dense(5, name="class_output")(x)
model = keras.Model(
inputs=[image_input, timeseries_input], outputs=[score_output, class_output]
)
model.compile(
optimizer=keras.optimizers.RMSprop(1e-3),
loss={
"score_output": keras.losses.MeanSquaredError(),
"class_output": keras.losses.CategoricalCrossentropy(),
},
metrics={
"score_output": [
keras.metrics.MeanAbsolutePercentageError(),
keras.metrics.MeanAbsoluteError(),
],
"class_output": [keras.metrics.CategoricalAccuracy()],
},
loss_weights={"score_output": 2.0, "class_output": 1.0},
)
# Generate dummy NumPy data
import numpy as np
img_data = np.random.random_sample(size=(100, 32, 32, 3))
ts_data = np.random.random_sample(size=(100, 20, 10))
score_targets = np.random.random_sample(size=(100, 1))
class_targets = np.random.random_sample(size=(100, 5))
# Fit on lists
model.fit([img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1)
# Alternatively, fit on dicts
model.fit(
{"img_input": img_data, "ts_input": ts_data},
{"score_output": score_targets, "class_output": class_targets},
batch_size=32,
epochs=1,
)
### 6. Use callback
# Keras The callback in is the difference in the training process ( At the beginning of a diachronic , At the end of a batch , At the end of a diachronic , wait ) Called object .
# They can be used to achieve certain behaviors , such as :
# Verify the differences in the training process ( Beyond the built-in verification per cycle ).
# Check the model every once in a while or when it exceeds a certain accuracy threshold .
# When training seems to level off , Change the learning rate of the model
# When training seems to level off , Fine tune the top layer
# When the training ends or exceeds a certain personality threshold , Send email or instant message notification, etc .
model = get_compiled_model()
callbacks = [
keras.callbacks.EarlyStopping(
# Stop training when `val_loss` is no longer improving
monitor="val_loss",
# "no longer improving" being defined as "no better than 1e-2 less"
min_delta=1e-2,
# "no longer improving" being further defined as "for at least 2 epochs"
patience=2,
verbose=1,
)
]
model.fit(
x_train,
y_train,
epochs=20,
batch_size=64,
callbacks=callbacks,
validation_split=0.2,
)Reference resources :
https://tensorflow.google.cn/guide/keras/train_and_evaluate?hl=en
边栏推荐
猜你喜欢

“鼓浪屿元宇宙”,能否成为中国文旅产业的“升级样本”

The following license SolidWorks Standard cannot be obtained, and the use license file cannot be found. (-1,359,2)。
![[I2C reading mpu6050 of Renesas ra6m4 development board]](/img/1b/c991dd0d798edbb7410a1e16f3a323.png)
[I2C reading mpu6050 of Renesas ra6m4 development board]

Flink1.15源码阅读flink-clients客户端执行流程(阅读较枯燥)
![Connection failed during installation of ros2 [ip: 91.189.91.39 80]](/img/7f/92b7d44cddc03c58364d8d3f19198a.png)
Connection failed during installation of ros2 [ip: 91.189.91.39 80]

Flask project configuration

4276. Good at C

Day5 - Flame restful request response and Sqlalchemy Foundation

NIO this.selector.select()

4279. Cartesian tree
随机推荐
Include error in vs Code (new header file)
List delete collection elements
四个开源的人脸识别项目分享
缓存一致性与内存屏障
Openresty + keepalived 实现负载均衡 + IPV6 验证
Arm undefined instruction exception assembly
Digital intelligence innovation
Make a game by yourself with pyGame 01
693. 行程排序
View 的滑动冲突
redis 网络IO
User management - restrictions
Full Permutation (depth first, permutation tree)
JS检测客户端软件是否安装
Matlab drawing skills and examples: stackedplot
4277. Block reversal
Unity3d 2021 software installation package download and installation tutorial
JS detects whether the client software is installed
微信安装包从0.5M暴涨到260M,为什么我们的程序越来越大?
Vertical align cannot align the picture and text vertically