当前位置:网站首页>7 user defined loss function
7 user defined loss function
2022-06-26 15:58:00 【X1996_】
Custom loss function
This experiment requires mnist.npz Data sets
Customize your workout and use your own fit() Function training seems to be similar
Custom training
The header file
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# On demand ,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
Load datasets and process
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
# normalization
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
Build network
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
Define the loss function , One is implemented with classes , One is implemented with a function , Can be used
# # Multi category focal loss Loss function , The realization of the class
# class FocalLoss(tf.keras.losses.Loss):
# def __init__(self,gamma=2.0,alpha=0.25):
# self.gamma = gamma
# self.alpha = alpha
# super(FocalLoss, self).__init__()
# def call(self,y_true,y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()#1e-7
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# Functions are implemented in the same way
def FocalLoss(gamma=2.0,alpha=0.25):
def focal_loss_fixed(y_true, y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
return focal_loss_fixed
Select optimizer loss function .....
model = MyModel()
# Its own loss function
# loss_object = tf.keras.losses.CategoricalCrossentropy()
# Self defined loss function
loss_object = FocalLoss(gamma=2.0,alpha=0.25)
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
Training
epochs = 5
for epoch in range(epochs):
# The next epoch At the beginning of the , Reset evaluation indicator
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result() * 100,
test_loss.result(),
test_accuracy.result() * 100))
fit() Training
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = np.int32(y_train)
y_test = np.int32(y_test)
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(100).batch(32)
# Defining models
def MyModel():
inputs = tf.keras.Input(shape=(28,28,1), name='digits')
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(10,activation='softmax', name='predictions')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
# # Multi category focal loss Loss function
class FocalLoss(tf.keras.losses.Loss):
def __init__(self,gamma=2.0,alpha=0.25):
self.gamma = gamma
self.alpha = alpha
super(FocalLoss, self).__init__()
def call(self,y_true,y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
# def FocalLoss(gamma=2.0,alpha=0.25):
# def focal_loss_fixed(y_true, y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# return focal_loss_fixed
# The optimizer loss function evaluates those metrics
# The loss function can be defined by itself
model = MyModel()
model.compile(optimizer = tf.keras.optimizers.Adam(0.001), # Optimizer
loss = FocalLoss(gamma=2.0,alpha=0.25), # Loss function
metrics = [tf.keras.metrics.CategoricalAccuracy()]
) # Evaluation function
# Training
model.fit(train_ds, epochs=5,validation_data=test_ds)
边栏推荐
- 在哪个平台买股票开户安全?求指导
- Audio and video learning (II) -- frame rate, code stream and resolution
- Solana capacity expansion mechanism analysis (1): an extreme attempt to sacrifice availability for efficiency | catchervc research
- 3 keras版本模型训练
- Stepn novice introduction and advanced
- How to configure and use the new single line lidar
- Common properties of XOR and addition
- Binding method of multiple sub control signal slots under QT
- Selenium chrome disable JS disable pictures
- Selenium saves elements as pictures
猜你喜欢

Svg rising Color Bubble animation

Audio and video learning (III) -- SIP protocol

查词翻译类应用使用数据接口api总结

NFT transaction principle analysis (1)

nanoPi Duo2连接wifi

How to handle 2gcsv files that cannot be opened? Use byzer

Stepn novice introduction and advanced

Summary of data interface API used in word search and translation applications

el-dialog拖拽,边界问题完全修正,网上版本的bug修复

OpenSea上如何创建自己的NFT(Polygon)
随机推荐
On which platform is it safe to buy shares and open an account? Ask for guidance
HW safety response
Solana capacity expansion mechanism analysis (1): an extreme attempt to sacrifice availability for efficiency | catchervc research
学习内存屏障
AbortController的使用
NFT交易原理分析(2)
Analyse panoramique de la chaîne industrielle en amont, en aval et en aval de la NFT « Dry goods»
Beijing Fangshan District specialized special new small giant enterprise recognition conditions, with a subsidy of 500000 yuan
NFT contract basic knowledge explanation
[problem solving] the loading / downloading time of the new version of webots texture and other resource files is too long
NFT 项目的开发、部署、上线的流程(2)
AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
NFT交易原理分析(1)
【leetcode】701. 二叉搜索树中的插入操作
【leetcode】112. 路径总和 - 113. 路径总和 II
[thinking] what were you buying when you bought NFT?
简单科普Ethereum的Transaction Input Data
[graduation season · advanced technology Er] what is a wechat applet, which will help you open the door of the applet
When a project with cmake is cross compiled to a link, an error cannot be found So dynamic library file
/etc/profile、/etc/bashrc、~/. Bashrc differences