当前位置:网站首页>7 自定义损失函数
7 自定义损失函数
2022-06-26 15:30:00 【X1996_】
自定义损失函数
这个实验需要用到mnist.npz数据集
自定义训练和用自带的fit()函数训练好像差不多
自定义训练
头文件
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# 按需,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
载入数据集并处理
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
# 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
搭建网络
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
定义损失函数,一个是用类实现的,一个是用函数实现的,都能用
# #多分类的focal loss 损失函数,类的实现
# class FocalLoss(tf.keras.losses.Loss):
# def __init__(self,gamma=2.0,alpha=0.25):
# self.gamma = gamma
# self.alpha = alpha
# super(FocalLoss, self).__init__()
# def call(self,y_true,y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()#1e-7
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# 函数的方式实现
def FocalLoss(gamma=2.0,alpha=0.25):
def focal_loss_fixed(y_true, y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
return focal_loss_fixed
选择优化器损失函数。。。。。
model = MyModel()
# 自带的损失函数
# loss_object = tf.keras.losses.CategoricalCrossentropy()
# 自己定义的损失函数
loss_object = FocalLoss(gamma=2.0,alpha=0.25)
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
训练
epochs = 5
for epoch in range(epochs):
# 在下一个epoch开始时,重置评估指标
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result() * 100,
test_loss.result(),
test_accuracy.result() * 100))
fit()训练
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = np.int32(y_train)
y_test = np.int32(y_test)
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(100).batch(32)
# 定义模型
def MyModel():
inputs = tf.keras.Input(shape=(28,28,1), name='digits')
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(10,activation='softmax', name='predictions')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
# #多分类的focal loss 损失函数
class FocalLoss(tf.keras.losses.Loss):
def __init__(self,gamma=2.0,alpha=0.25):
self.gamma = gamma
self.alpha = alpha
super(FocalLoss, self).__init__()
def call(self,y_true,y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
# def FocalLoss(gamma=2.0,alpha=0.25):
# def focal_loss_fixed(y_true, y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# return focal_loss_fixed
# 优化器损失函数评估指标那些
# 损失函数可以用自己定义的
model = MyModel()
model.compile(optimizer = tf.keras.optimizers.Adam(0.001), #优化器
loss = FocalLoss(gamma=2.0,alpha=0.25), #损失函数
metrics = [tf.keras.metrics.CategoricalAccuracy()]
) #评估函数
# 训练
model.fit(train_ds, epochs=5,validation_data=test_ds)
边栏推荐
- Golang temporary object pool optimization
- NFT 项目的开发、部署、上线的流程(1)
- CNN优化trick
- 10 tf.data
- svg canvas画布拖拽
- Solana capacity expansion mechanism analysis (1): an extreme attempt to sacrifice availability for efficiency | catchervc research
- AbortController的使用
- Restcloud ETL resolves shell script parameterization
- 手机上怎么开户?在线开户安全么?
- 【问题解决】新版webots纹理等资源文件加载/下载时间过长
猜你喜欢

Summary of students' learning career (2022)

一篇博客彻底掌握:粒子滤波 particle filter (PF) 的理论及实践(matlab版)

NFT 项目的开发、部署、上线的流程(2)

Solana扩容机制分析(2):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
![[tcapulusdb knowledge base] Introduction to tcapulusdb system management](/img/5a/28aaf8b115cbf4798cf0b201e4c068.png)
[tcapulusdb knowledge base] Introduction to tcapulusdb system management

9 Tensorboard的使用

PCIe Capabilities List

Svg savage animation code

Particle filter PF - 3D CV target tracking with uniform motion (particle filter vs extended Kalman filter)

svg野人动画代码
随机推荐
[applet practice series] Introduction to the registration life cycle of the applet framework page
AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
CNN优化trick
Audio and video learning (II) -- frame rate, code stream and resolution
JS handwritten bind, apply, call
Interview pit summary I
CNN optimized trick
Analyse panoramique de la chaîne industrielle en amont, en aval et en aval de la NFT « Dry goods»
10 tf.data
TweenMax+SVG切换颜色动画场景
【leetcode】331. 验证二叉树的前序序列化
Beijing Fangshan District specialized special new small giant enterprise recognition conditions, with a subsidy of 500000 yuan
SQLite loads CSV files and performs data analysis
Evaluate:huggingface评价指标模块入门详细介绍
为什么图像分割任务中经常用到编码器和解码器结构?
Unable to download Plug-in after idea local agent
Why are encoder and decoder structures often used in image segmentation tasks?
el-dialog拖拽,边界问题完全修正,网上版本的bug修复
Summary of data interface API used in word search and translation applications
Golang 1.18 go work usage