当前位置:网站首页>8 自定义评估函数
8 自定义评估函数
2022-06-26 15:30:00 【X1996_】
自定义评估函数跟自定义损失函数差不多,本文自定义一个评估函数,返回正确的个数
自定义训练
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# 按需,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
# 数据处理
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# 搭建模型
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# 自定义评估函数
# 返回的是一个正确的个数
class CatgoricalTruePositives(tf.keras.metrics.Metric):
def __init__(self, name='categorical_true_positives', **kwargs):
super(CatgoricalTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred,axis=-1)
values = tf.equal(tf.cast(y_true, 'int32'), tf.cast(y_pred, 'int32'))
values = tf.cast(values, 'float32')
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, 'float32')
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_states(self):
self.true_positives.assign(0.)
model = MyModel()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy() #损失函数
optimizer = tf.keras.optimizers.Adam() #优化器
#评估函数
train_loss = tf.keras.metrics.Mean(name='train_loss') #loss
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') #准确率
train_tp = CatgoricalTruePositives(name="train_tp") #返回正确的个数
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
test_tp = CatgoricalTruePositives(name='test_tp')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
#评估函数的结果
train_loss(loss)
train_accuracy(labels, predictions)
train_tp(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
test_tp(labels, predictions)
EPOCHS = 5
for epoch in range(EPOCHS):
# 在下一个epoch开始时,重置评估指标
train_loss.reset_states()
train_accuracy.reset_states()
train_tp.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
test_tp.reset_states()
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, TP: {},Test Loss: {}, Test Accuracy: {}, Test TP:{}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result() * 100,
train_tp.result(),
test_loss.result(),
test_accuracy.result() * 100,
test_tp.result()))
fit()训练
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# 按需,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(100).batch(32)
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# 自定义
#返回的是一个正确的个数
#y_true
#y_pred
class CatgoricalTruePositives(tf.keras.metrics.Metric):
def __init__(self, name='categorical_true_positives', **kwargs):
super(CatgoricalTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred,axis=-1)
y_true = tf.argmax(y_true,axis=-1)
values = tf.equal(tf.cast(y_true, 'int32'), tf.cast(y_pred, 'int32'))
values = tf.cast(values, 'float32')
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, 'float32')
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_states(self):
self.true_positives.assign(0.)
model = MyModel()
model.compile(optimizer = tf.keras.optimizers.Adam(0.001), #优化器
loss = tf.keras.losses.CategoricalCrossentropy(), #损失函数
metrics = [tf.keras.metrics.CategoricalAccuracy(),
CatgoricalTruePositives(),
]
) #评估函数
model.fit(train_ds, epochs=5,validation_data=test_ds)
边栏推荐
- 粒子滤波 PF——在机动目标跟踪中的应用(粒子滤波VS扩展卡尔曼滤波)
- 2022 Beijing Shijingshan District specializes in the application process for special new small and medium-sized enterprises, with a subsidy of 100000-200000 yuan
- [thinking] what were you buying when you bought NFT?
- JS events
- Comparative analysis of restcloud ETL and kettle
- [CEPH] cephfs internal implementation (II): example -- undigested
- 【leetcode】48.旋转图像
- golang 临时对象池优化
- [CEPH] MKDIR | mksnap process source code analysis | lock state switching example
- Audio and video learning (II) -- frame rate, code stream and resolution
猜你喜欢

IDEA本地代理后,无法下载插件

【问题解决】新版webots纹理等资源文件加载/下载时间过长

「干货」NFT 上中下游产业链全景分析

PCIe Capabilities List
![[tcapulusdb knowledge base] Introduction to tcapulusdb system management](/img/5a/28aaf8b115cbf4798cf0b201e4c068.png)
[tcapulusdb knowledge base] Introduction to tcapulusdb system management

svg环绕地球动画js特效
![[C language practice - printing hollow upper triangle and its deformation]](/img/56/6a88b3d8de32a3215399f915bba33e.png)
[C language practice - printing hollow upper triangle and its deformation]

IntelliJ idea -- Method for formatting SQL files

【小程序实战系列】小程序框架 页面注册 生命周期 介绍

Restcloud ETL resolves shell script parameterization
随机推荐
[CEPH] cephfs internal implementation (IV): how is MDS started-- Undigested
svg野人动画代码
Using restcloud ETL shell component to schedule dataX offline tasks
Mr. Du said that the website was updated with illustrations
Summary of data interface API used in word search and translation applications
Solana扩容机制分析(1):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
2022北京石景山区专精特新中小企业申报流程,补贴10-20万
【leetcode】331. Verifying the preorder serialization of a binary tree
Summary of students' learning career (2022)
AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
IntelliJ idea -- Method for formatting SQL files
[problem solving] the loading / downloading time of the new version of webots texture and other resource files is too long
评价——模糊综合评价
【leetcode】701. 二叉搜索树中的插入操作
【问题解决】新版webots纹理等资源文件加载/下载时间过长
Summer camp is coming!!! Chongchongchong
Auto Sharding Policy will apply Data Sharding policy as it failed to apply file Sharding Policy
在重新格式化时不要删除自定义换行符(Don‘t remove custom line breaks on reformat)
Notes on brushing questions (19) -- binary tree: modification and construction of binary search tree
夏令营来啦!!!冲冲冲