当前位置:网站首页>7 自定义损失函数
7 自定义损失函数
2022-06-26 15:30:00 【X1996_】
自定义损失函数
这个实验需要用到mnist.npz数据集
自定义训练和用自带的fit()函数训练好像差不多
自定义训练
头文件
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# 按需,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
载入数据集并处理
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
# 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
搭建网络
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
定义损失函数,一个是用类实现的,一个是用函数实现的,都能用
# #多分类的focal loss 损失函数,类的实现
# class FocalLoss(tf.keras.losses.Loss):
# def __init__(self,gamma=2.0,alpha=0.25):
# self.gamma = gamma
# self.alpha = alpha
# super(FocalLoss, self).__init__()
# def call(self,y_true,y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()#1e-7
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# 函数的方式实现
def FocalLoss(gamma=2.0,alpha=0.25):
def focal_loss_fixed(y_true, y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
return focal_loss_fixed
选择优化器损失函数。。。。。
model = MyModel()
# 自带的损失函数
# loss_object = tf.keras.losses.CategoricalCrossentropy()
# 自己定义的损失函数
loss_object = FocalLoss(gamma=2.0,alpha=0.25)
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
训练
epochs = 5
for epoch in range(epochs):
# 在下一个epoch开始时,重置评估指标
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result() * 100,
test_loss.result(),
test_accuracy.result() * 100))
fit()训练
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = np.int32(y_train)
y_test = np.int32(y_test)
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(100).batch(32)
# 定义模型
def MyModel():
inputs = tf.keras.Input(shape=(28,28,1), name='digits')
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(10,activation='softmax', name='predictions')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
# #多分类的focal loss 损失函数
class FocalLoss(tf.keras.losses.Loss):
def __init__(self,gamma=2.0,alpha=0.25):
self.gamma = gamma
self.alpha = alpha
super(FocalLoss, self).__init__()
def call(self,y_true,y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
# def FocalLoss(gamma=2.0,alpha=0.25):
# def focal_loss_fixed(y_true, y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# return focal_loss_fixed
# 优化器损失函数评估指标那些
# 损失函数可以用自己定义的
model = MyModel()
model.compile(optimizer = tf.keras.optimizers.Adam(0.001), #优化器
loss = FocalLoss(gamma=2.0,alpha=0.25), #损失函数
metrics = [tf.keras.metrics.CategoricalAccuracy()]
) #评估函数
# 训练
model.fit(train_ds, epochs=5,validation_data=test_ds)
边栏推荐
- Selenium chrome disable JS disable pictures
- 【leetcode】48. Rotate image
- 【微信小程序】事件绑定,你搞懂了吗?
- NFT 项目的开发、部署、上线的流程(2)
- PCIe Capabilities List
- Audio and video learning (I) -- PTZ control principle
- SVG大写字母A动画js特效
- 反射修改final
- AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
- 【leetcode】112. 路径总和 - 113. 路径总和 II
猜你喜欢

JS text scrolling scattered animation JS special effect

PCIe Capabilities List

js创意图标导航菜单切换背景色
![[tcapulusdb knowledge base] Introduction to tcapulusdb data structure](/img/64/4d7ec393d8469cdadc89078a8cf4b1.png)
[tcapulusdb knowledge base] Introduction to tcapulusdb data structure

Analyse panoramique de la chaîne industrielle en amont, en aval et en aval de la NFT « Dry goods»

人人都当科学家之免Gas体验mint爱死机

查词翻译类应用使用数据接口api总结

5000 word analysis: the way of container security attack and defense in actual combat scenarios

Svg savage animation code

Tweenmax+svg switch color animation scene
随机推荐
A blog to thoroughly master the theory and practice of particle filter (PF) (matlab version)
[wechat applet] event binding, do you understand?
【leetcode】48. Rotate image
9 Tensorboard的使用
svg环绕地球动画js特效
HW safety response
CNN优化trick
音视频学习(三)——sip协议
NFT合约基础知识讲解
Svg capital letter a animation JS effect
10 tf.data
[tcapulusdb knowledge base] Introduction to tcapulusdb system management
JS simple deepcopy (Introduction recursion)
High frequency interview 𞓜 Flink Shuangliu join
[tcapulusdb knowledge base] tcapulusdb doc acceptance - Introduction to creating game area
Is it safe to buy stocks and open accounts through the QR code of the securities manager? Want to open an account for stock trading
Particle filter PF -- Application in maneuvering target tracking (particle filter vs extended Kalman filter)
[tcapulusdb knowledge base] tcapulusdb doc acceptance - transaction execution introduction
学习内存屏障
Auto Sharding Policy will apply Data Sharding policy as it failed to apply file Sharding Policy